Skip to content

Commit 658d87b

Browse files
author
Your Name
committed
commit before working on other stuff
1 parent 05077b7 commit 658d87b

File tree

2 files changed

+98
-56
lines changed

2 files changed

+98
-56
lines changed

examples/lerobot/robodm_training_pipeline.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def _convert_trajectory(self, trajectory: Dict[str, Any], episode_idx: int) -> L
9595
else:
9696
return [] # No valid data found
9797

98-
# DiffusionPolicy expects sequences, so we need horizon=16 for actions
99-
horizon = 16
98+
# DiffusionPolicy expects sequences with full prediction horizon
99+
horizon = 16 # This should match DiffusionPolicy's horizon (not n_action_steps)
100100
timesteps = []
101101

102102
# Create training samples with action sequences
@@ -138,19 +138,22 @@ def _convert_trajectory(self, trajectory: Dict[str, Any], episode_idx: int) -> L
138138
action_is_pad_sequence.append(False)
139139
else:
140140
# Pad with zeros
141-
action_sequence.append(torch.zeros(2, dtype=torch.float32)) # Assuming 2D actions
141+
action_dim = action_data.shape[0] if hasattr(action_data, 'shape') else 2
142+
action_sequence.append(torch.zeros(action_dim, dtype=torch.float32))
142143
action_is_pad_sequence.append(True)
143144
else:
144145
# Pad with zeros when we run out of actions
145-
action_sequence.append(torch.zeros(2, dtype=torch.float32)) # Assuming 2D actions
146+
action_dim = action_sequence[0].shape[0] if action_sequence else 2
147+
action_sequence.append(torch.zeros(action_dim, dtype=torch.float32))
146148
action_is_pad_sequence.append(True)
147149

148150
# Stack into sequence tensors
149151
timestep['action'] = torch.stack(action_sequence) # Shape: [horizon, action_dim]
150152
timestep['action_is_pad'] = torch.tensor(action_is_pad_sequence, dtype=torch.bool) # Shape: [horizon]
151153
else:
152-
# No action data at all
153-
timestep['action'] = torch.zeros(horizon, 2, dtype=torch.float32) # Shape: [horizon, action_dim]
154+
# No action data at all - use default action dimension
155+
default_action_dim = 2 # You should adjust this to match your robot's action space
156+
timestep['action'] = torch.zeros(horizon, default_action_dim, dtype=torch.float32) # Shape: [horizon, action_dim]
154157
timestep['action_is_pad'] = torch.ones(horizon, dtype=torch.bool) # All padded
155158

156159
timesteps.append(timestep)
@@ -176,10 +179,30 @@ def _add_image_observation_sequences(self, timestep: Dict[str, torch.Tensor], tr
176179
# Make a copy to ensure the array is writable
177180
image_data = image_data.copy()
178181
# Convert to tensor, ensure it's in CHW format
179-
if len(image_data.shape) == 3 and image_data.shape[2] == 3: # HWC format
180-
image_tensor = torch.from_numpy(image_data).permute(2, 0, 1).float() / 255.0
181-
else: # Already in CHW format
182-
image_tensor = torch.from_numpy(image_data).float() / 255.0
182+
if len(image_data.shape) == 3:
183+
# Check if it's HWC format (height, width, channels)
184+
if image_data.shape[2] == 3: # HWC format
185+
image_tensor = torch.from_numpy(image_data).permute(2, 0, 1).float() / 255.0
186+
elif image_data.shape[0] == 3: # Already CHW format
187+
image_tensor = torch.from_numpy(image_data).float() / 255.0
188+
else:
189+
# Unknown format, assume HWC and convert
190+
image_tensor = torch.from_numpy(image_data).permute(2, 0, 1).float() / 255.0
191+
else:
192+
# Handle 2D images by adding channel dimension
193+
if len(image_data.shape) == 2:
194+
image_tensor = torch.from_numpy(image_data).unsqueeze(0).float() / 255.0
195+
else:
196+
# Fallback: try to reshape to CHW format
197+
image_tensor = torch.from_numpy(image_data).float() / 255.0
198+
if image_tensor.dim() == 1:
199+
# Try to reshape to square image
200+
size = int(np.sqrt(image_tensor.shape[0] / 3))
201+
if size * size * 3 == image_tensor.shape[0]:
202+
image_tensor = image_tensor.view(3, size, size)
203+
else:
204+
# Create placeholder if can't reshape
205+
image_tensor = torch.zeros(3, 96, 96, dtype=torch.float32)
183206
image_sequence.append(image_tensor)
184207
else:
185208
# Create a placeholder image if no image data
@@ -211,17 +234,37 @@ def _add_image_observations(self, timestep: Dict[str, torch.Tensor], trajectory:
211234
# Make a copy to ensure the array is writable
212235
image_data = image_data.copy()
213236
# Convert to tensor, ensure it's in CHW format
214-
if len(image_data.shape) == 3 and image_data.shape[2] == 3: # HWC format
215-
image_tensor = torch.from_numpy(image_data).permute(2, 0, 1).float() / 255.0
216-
else: # Already in CHW format
217-
image_tensor = torch.from_numpy(image_data).float() / 255.0
237+
if len(image_data.shape) == 3:
238+
# Check if it's HWC format (height, width, channels)
239+
if image_data.shape[2] == 3: # HWC format
240+
image_tensor = torch.from_numpy(image_data).permute(2, 0, 1).float() / 255.0
241+
elif image_data.shape[0] == 3: # Already CHW format
242+
image_tensor = torch.from_numpy(image_data).float() / 255.0
243+
else:
244+
# Unknown format, assume HWC and convert
245+
image_tensor = torch.from_numpy(image_data).permute(2, 0, 1).float() / 255.0
246+
else:
247+
# Handle 2D images by adding channel dimension
248+
if len(image_data.shape) == 2:
249+
image_tensor = torch.from_numpy(image_data).unsqueeze(0).float() / 255.0
250+
else:
251+
# Fallback: try to reshape to CHW format
252+
image_tensor = torch.from_numpy(image_data).float() / 255.0
253+
if image_tensor.dim() == 1:
254+
# Try to reshape to square image
255+
size = int(np.sqrt(image_tensor.shape[0] / 3))
256+
if size * size * 3 == image_tensor.shape[0]:
257+
image_tensor = image_tensor.view(3, size, size)
258+
else:
259+
# Create placeholder if can't reshape
260+
image_tensor = torch.zeros(3, 96, 96, dtype=torch.float32)
218261
timestep['observation.image'] = image_tensor
219262
else:
220263
# Create a placeholder image if no image data
221-
timestep['observation.image'] = torch.zeros(3, 64, 64, dtype=torch.float32)
264+
timestep['observation.image'] = torch.zeros(3, 96, 96, dtype=torch.float32)
222265
else:
223266
# Create a placeholder image if frame is out of range
224-
timestep['observation.image'] = torch.zeros(3, 64, 64, dtype=torch.float32)
267+
timestep['observation.image'] = torch.zeros(3, 96, 96, dtype=torch.float32)
225268

226269
def get_torch_dataset(self) -> torch_data.Dataset:
227270
"""Get PyTorch dataset."""
@@ -296,6 +339,10 @@ def get_dataset_stats(self) -> Dict[str, Dict[str, torch.Tensor]]:
296339
if all_actions:
297340
try:
298341
actions = torch.stack(all_actions)
342+
# Transpose actions from [samples, horizon, action_dim] to [samples, action_dim, horizon]
343+
# to match the expected format for DiffusionPolicy
344+
if len(actions.shape) == 3:
345+
actions = actions.transpose(1, 2) # [samples, action_dim, horizon]
299346
stats['action'] = {
300347
'mean': actions.mean(dim=0),
301348
'std': actions.std(dim=0),

examples/lerobot/run_pipeline.py

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -149,64 +149,58 @@ def run_complete_pipeline(dataset_name: str, num_episodes: int = None,
149149
policy_features = pipeline.get_policy_features()
150150
dataset_stats = pipeline.get_dataset_stats()
151151

152+
152153
# Create policy configuration
153154
cfg = DiffusionConfig(
154155
input_features=policy_features['input_features'],
155156
output_features=policy_features['output_features'],
156-
crop_shape=None # Disable cropping since our images are 96x96
157+
crop_shape=None, # Disable cropping since our images are 96x96
158+
horizon=16 # Match the horizon used in RoboDM data generation
157159
)
158160

161+
159162
# Create and setup policy
160163
policy = DiffusionPolicy(cfg, dataset_stats=dataset_stats)
161164
policy.train()
162165
policy.to(device)
163166

164-
# Setup training with custom collate function for DiffusionPolicy
167+
# Use observation sequence collate function for DiffusionPolicy
168+
from torch.utils.data import default_collate
169+
165170
def collate_fn(batch):
166-
"""Custom collate function that creates observation sequences for DiffusionPolicy."""
167-
result = {}
171+
"""Collate function for DiffusionPolicy training with RoboDM data."""
172+
if not batch:
173+
return {}
174+
175+
# Use default collate for everything
176+
from torch.utils.data import default_collate
177+
collated = default_collate(batch)
178+
168179
batch_size = len(batch)
169180
n_obs_steps = 2 # DiffusionPolicy default
170181

171-
# Stack all non-sequence keys normally
172-
for key in batch[0].keys():
173-
if key not in ['observation.image', 'observation.state']:
174-
values = [item[key] for item in batch if item[key] is not None]
175-
if values and all(isinstance(v, torch.Tensor) for v in values):
176-
try:
177-
result[key] = torch.stack(values)
178-
except RuntimeError:
179-
result[key] = values[0].unsqueeze(0).repeat(len(batch), *([1] * (values[0].dim())))
180-
elif values:
181-
result[key] = values[0] if len(values) == 1 else values
182+
# Create observation sequences for DiffusionPolicy
183+
if 'observation.image' in collated:
184+
# Images: [B, C, H, W] -> [B, T, C, H, W]
185+
images = collated['observation.image']
186+
# Create temporal sequence by repeating current observation
187+
image_seq = images.unsqueeze(1).repeat(1, n_obs_steps, 1, 1, 1)
188+
collated['observation.image'] = image_seq
182189

183-
# Handle observation sequences specially
184-
if 'observation.image' in batch[0]:
185-
# Create observation.images with proper sequence format
186-
images = []
187-
for i in range(batch_size):
188-
# Get current observation
189-
current_obs = batch[i]['observation.image']
190-
# For simplicity, repeat current observation for n_obs_steps
191-
# In a proper implementation, you'd track actual historical observations
192-
obs_sequence = current_obs.unsqueeze(0).repeat(n_obs_steps, 1, 1, 1) # [n_obs_steps, C, H, W]
193-
obs_sequence = obs_sequence.unsqueeze(1) # Add camera dim: [n_obs_steps, 1, C, H, W]
194-
images.append(obs_sequence)
195-
result['observation.images'] = torch.stack(images) # [B, n_obs_steps, 1, C, H, W]
190+
if 'observation.state' in collated:
191+
# States: [B, state_dim] -> [B, T, state_dim]
192+
states = collated['observation.state']
193+
state_seq = states.unsqueeze(1).repeat(1, n_obs_steps, 1)
194+
collated['observation.state'] = state_seq
196195

197-
if 'observation.state' in batch[0]:
198-
# Create observation.state sequence
199-
states = []
200-
for i in range(batch_size):
201-
current_state = batch[i]['observation.state']
202-
# Repeat current state for n_obs_steps
203-
state_sequence = current_state.unsqueeze(0).repeat(n_obs_steps, 1) # [n_obs_steps, state_dim]
204-
states.append(state_sequence)
205-
result['observation.state'] = torch.stack(states) # [B, n_obs_steps, state_dim]
196+
if 'action' in collated:
197+
# Actions: [B, horizon, action_dim] -> [B, action_dim, horizon]
198+
if collated['action'].ndim == 3:
199+
collated['action'] = collated['action'].transpose(1, 2)
206200

207-
return result
201+
return collated
208202

209-
dataloader = pipeline.get_dataloader(batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
203+
dataloader = pipeline.get_dataloader(batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn)
210204
optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
211205

212206
# Training loop
@@ -220,6 +214,7 @@ def collate_fn(batch):
220214
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v)
221215
for k, v in batch.items()}
222216

217+
223218
loss, _ = policy.forward(batch)
224219
loss.backward()
225220
optimizer.step()
@@ -317,7 +312,7 @@ def main():
317312
# Dataset arguments
318313
parser.add_argument("--dataset", type=str, default="lerobot/pusht",
319314
help="LeRobot dataset name (e.g., lerobot/pusht)")
320-
parser.add_argument("--num_episodes", type=int, default=50,
315+
parser.add_argument("--num_episodes", type=int, default=5,
321316
help="Number of episodes to convert (default: 50)")
322317
parser.add_argument("--robodm_data_dir", type=str, default=None,
323318
help="Directory containing existing RoboDM data (skips ingestion)")

0 commit comments

Comments
 (0)