Skip to content

Commit b8fe110

Browse files
committed
reformat file
1 parent c8fa61a commit b8fe110

File tree

1 file changed

+62
-39
lines changed

1 file changed

+62
-39
lines changed

examples/reinforcement_learning/diffusion_policy.py

Lines changed: 62 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,21 @@
88
from diffusers import DDPMScheduler, UNet1DModel
99

1010

11-
add_safe_globals([multiarray._reconstruct, np.ndarray, np.dtype, np.dtype(np.float32).type, np.dtype(np.float64).type, np.dtype(np.int32).type, np.dtype(np.int64).type, type(np.dtype(np.float32)), type(np.dtype(np.float64)), type(np.dtype(np.int32)), type(np.dtype(np.int64))])
11+
add_safe_globals(
12+
[
13+
multiarray._reconstruct,
14+
np.ndarray,
15+
np.dtype,
16+
np.dtype(np.float32).type,
17+
np.dtype(np.float64).type,
18+
np.dtype(np.int32).type,
19+
np.dtype(np.int64).type,
20+
type(np.dtype(np.float32)),
21+
type(np.dtype(np.float64)),
22+
type(np.dtype(np.int32)),
23+
type(np.dtype(np.int64)),
24+
]
25+
)
1226

1327
"""
1428
An example of using HuggingFace's diffusers library for diffusion policy,
@@ -19,6 +33,7 @@
1933
then outputs a sequence of 16 (x,y) positions for the robot arm to follow.
2034
"""
2135

36+
2237
class ObservationEncoder(nn.Module):
2338
"""
2439
Converts raw robot observations (positions/angles) into a more compact representation
@@ -32,13 +47,11 @@ class ObservationEncoder(nn.Module):
3247

3348
def __init__(self, state_dim):
3449
super().__init__()
35-
self.net = nn.Sequential(
36-
nn.Linear(state_dim, 512),
37-
nn.ReLU(),
38-
nn.Linear(512, 256)
39-
)
50+
self.net = nn.Sequential(nn.Linear(state_dim, 512), nn.ReLU(), nn.Linear(512, 256))
51+
52+
def forward(self, x):
53+
return self.net(x)
4054

41-
def forward(self, x): return self.net(x)
4255

4356
class ObservationProjection(nn.Module):
4457
"""
@@ -50,16 +63,18 @@ class ObservationProjection(nn.Module):
5063
- Output: 32 contextual information values for the diffusion model
5164
Shape: (batch_size, 32)
5265
"""
66+
5367
def __init__(self):
5468
super().__init__()
5569
self.weight = nn.Parameter(torch.randn(32, 512))
5670
self.bias = nn.Parameter(torch.zeros(32))
5771

58-
def forward(self, x): # pad 256-dim input to 512-dim with zeros
72+
def forward(self, x): # pad 256-dim input to 512-dim with zeros
5973
if x.size(-1) == 256:
6074
x = torch.cat([x, torch.zeros(*x.shape[:-1], 256, device=x.device)], dim=-1)
6175
return nn.functional.linear(x, self.weight, self.bias)
6276

77+
6378
class DiffusionPolicy:
6479
"""
6580
Implements diffusion policy for generating robot arm trajectories.
@@ -69,11 +84,15 @@ class DiffusionPolicy:
6984
The model expects observations in pixel coordinates (0-512 range) and block angle in radians.
7085
It generates trajectories as sequences of (x,y) coordinates also in the 0-512 range.
7186
"""
87+
7288
def __init__(self, state_dim=5, device="cpu"):
7389
self.device = device
7490

7591
# define valid ranges for inputs/outputs
76-
self.stats = {'obs': {'min': torch.zeros(5), 'max': torch.tensor([512, 512, 512, 512, 2*np.pi])}, 'action': {'min': torch.zeros(2), 'max': torch.full((2,), 512)}}
92+
self.stats = {
93+
"obs": {"min": torch.zeros(5), "max": torch.tensor([512, 512, 512, 512, 2 * np.pi])},
94+
"action": {"min": torch.zeros(2), "max": torch.full((2,), 512)},
95+
}
7796

7897
self.obs_encoder = ObservationEncoder(state_dim).to(device)
7998
self.obs_projection = ObservationProjection().to(device)
@@ -82,34 +101,36 @@ def __init__(self, state_dim=5, device="cpu"):
82101
# takes in concatenated action (2 channels) and context (32 channels) = 34 channels
83102
# outputs predicted action (2 channels for x,y coordinates)
84103
self.model = UNet1DModel(
85-
sample_size=16, # length of trajectory sequence
104+
sample_size=16, # length of trajectory sequence
86105
in_channels=34,
87106
out_channels=2,
88-
layers_per_block=2, # number of layers per each UNet block
89-
block_out_channels=(128,), # number of output neurons per layer in each block
90-
down_block_types=("DownBlock1D",), # reduce the resolution of data
91-
up_block_types=("UpBlock1D",) # increase the resolution of data
107+
layers_per_block=2, # number of layers per each UNet block
108+
block_out_channels=(128,), # number of output neurons per layer in each block
109+
down_block_types=("DownBlock1D",), # reduce the resolution of data
110+
up_block_types=("UpBlock1D",), # increase the resolution of data
92111
).to(device)
93112

94113
# noise scheduler that controls the denoising process
95114
self.noise_scheduler = DDPMScheduler(
96-
num_train_timesteps=100, # number of denoising steps
97-
beta_schedule="squaredcos_cap_v2" # type of noise schedule
115+
num_train_timesteps=100, # number of denoising steps
116+
beta_schedule="squaredcos_cap_v2", # type of noise schedule
98117
)
99118

100119
# load pre-trained weights from HuggingFace
101-
checkpoint = torch.load(hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt"), weights_only=True, map_location=device)
102-
self.model.load_state_dict(checkpoint['model_state_dict'])
103-
self.obs_encoder.load_state_dict(checkpoint['encoder_state_dict'])
104-
self.obs_projection.load_state_dict(checkpoint['projection_state_dict'])
120+
checkpoint = torch.load(
121+
hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt"), weights_only=True, map_location=device
122+
)
123+
self.model.load_state_dict(checkpoint["model_state_dict"])
124+
self.obs_encoder.load_state_dict(checkpoint["encoder_state_dict"])
125+
self.obs_projection.load_state_dict(checkpoint["projection_state_dict"])
105126

106127
# scales data to [-1, 1] range for neural network processing
107128
def normalize_data(self, data, stats):
108-
return ((data - stats['min']) / (stats['max'] - stats['min'])) * 2 - 1
129+
return ((data - stats["min"]) / (stats["max"] - stats["min"])) * 2 - 1
109130

110131
# converts normalized data back to original range
111132
def unnormalize_data(self, ndata, stats):
112-
return ((ndata + 1) / 2) * (stats['max'] - stats['min']) + stats['min']
133+
return ((ndata + 1) / 2) * (stats["max"] - stats["min"]) + stats["min"]
113134

114135
@torch.no_grad()
115136
def predict(self, observation):
@@ -130,7 +151,7 @@ def predict(self, observation):
130151
process that gradually denoises random trajectories into smooth, purposeful movements.
131152
"""
132153
observation = observation.to(self.device)
133-
normalized_obs = self.normalize_data(observation, self.stats['obs'])
154+
normalized_obs = self.normalize_data(observation, self.stats["obs"])
134155

135156
# encode the observation into context values for the diffusion model
136157
cond = self.obs_projection(self.obs_encoder(normalized_obs))
@@ -141,38 +162,40 @@ def predict(self, observation):
141162
action = torch.randn((observation.shape[0], 2, 16), device=self.device)
142163

143164
# denoise
144-
# at each step `t`, the current noisy trajectory (`action`) & conditioning info (context) are
145-
# fed into the model to predict a denoised trajectory, then uses self.noise_scheduler.step to
146-
# apply this prediction & slightly reduce the noise in `action` more
165+
# at each step `t`, the current noisy trajectory (`action`) & conditioning info (context) are
166+
# fed into the model to predict a denoised trajectory, then uses self.noise_scheduler.step to
167+
# apply this prediction & slightly reduce the noise in `action` more
147168

148169
self.noise_scheduler.set_timesteps(100)
149170
for t in self.noise_scheduler.timesteps:
150171
model_output = self.model(torch.cat([action, cond], dim=1), t)
151-
action = self.noise_scheduler.step(
152-
model_output.sample, t, action
153-
).prev_sample
172+
action = self.noise_scheduler.step(model_output.sample, t, action).prev_sample
154173

155174
action = action.transpose(1, 2) # reshape to [batch, 16, 2]
156-
action = self.unnormalize_data(action, self.stats['action']) # scale back to coordinates
175+
action = self.unnormalize_data(action, self.stats["action"]) # scale back to coordinates
157176
return action
158177

178+
159179
if __name__ == "__main__":
160180
policy = DiffusionPolicy()
161181

162182
# sample of a single observation
163183
# robot arm starts in center, block is slightly left and up, rotated 90 degrees
164-
obs = torch.tensor([[
165-
256.0, # robot arm x position (middle of screen)
166-
256.0, # robot arm y position (middle of screen)
167-
200.0, # block x position
168-
300.0, # block y position
169-
np.pi/2 # block angle (90 degrees)
170-
]])
184+
obs = torch.tensor(
185+
[
186+
[
187+
256.0, # robot arm x position (middle of screen)
188+
256.0, # robot arm y position (middle of screen)
189+
200.0, # block x position
190+
300.0, # block y position
191+
np.pi / 2, # block angle (90 degrees)
192+
]
193+
]
194+
)
171195

172196
action = policy.predict(obs)
173197

174-
print("Action shape:", action.shape) # should be [1, 16, 2] - one trajectory of 16 x,y positions
198+
print("Action shape:", action.shape) # should be [1, 16, 2] - one trajectory of 16 x,y positions
175199
print("\nPredicted trajectory:")
176200
for i, (x, y) in enumerate(action[0]):
177201
print(f"Step {i:2d}: x={x:6.1f}, y={y:6.1f}")
178-

0 commit comments

Comments
 (0)