@@ -95,8 +95,8 @@ def _convert_trajectory(self, trajectory: Dict[str, Any], episode_idx: int) -> L
9595 else :
9696 return [] # No valid data found
9797
98- # DiffusionPolicy expects sequences, so we need horizon=16 for actions
99- horizon = 16
98+ # DiffusionPolicy expects sequences with full prediction horizon
99+ horizon = 16 # This should match DiffusionPolicy's horizon (not n_action_steps)
100100 timesteps = []
101101
102102 # Create training samples with action sequences
@@ -138,19 +138,22 @@ def _convert_trajectory(self, trajectory: Dict[str, Any], episode_idx: int) -> L
138138 action_is_pad_sequence .append (False )
139139 else :
140140 # Pad with zeros
141- action_sequence .append (torch .zeros (2 , dtype = torch .float32 )) # Assuming 2D actions
141+ action_dim = action_data .shape [0 ] if hasattr (action_data , 'shape' ) else 2
142+ action_sequence .append (torch .zeros (action_dim , dtype = torch .float32 ))
142143 action_is_pad_sequence .append (True )
143144 else :
144145 # Pad with zeros when we run out of actions
145- action_sequence .append (torch .zeros (2 , dtype = torch .float32 )) # Assuming 2D actions
146+ action_dim = action_sequence [0 ].shape [0 ] if action_sequence else 2
147+ action_sequence .append (torch .zeros (action_dim , dtype = torch .float32 ))
146148 action_is_pad_sequence .append (True )
147149
148150 # Stack into sequence tensors
149151 timestep ['action' ] = torch .stack (action_sequence ) # Shape: [horizon, action_dim]
150152 timestep ['action_is_pad' ] = torch .tensor (action_is_pad_sequence , dtype = torch .bool ) # Shape: [horizon]
151153 else :
152- # No action data at all
153- timestep ['action' ] = torch .zeros (horizon , 2 , dtype = torch .float32 ) # Shape: [horizon, action_dim]
154+ # No action data at all - use default action dimension
155+ default_action_dim = 2 # You should adjust this to match your robot's action space
156+ timestep ['action' ] = torch .zeros (horizon , default_action_dim , dtype = torch .float32 ) # Shape: [horizon, action_dim]
154157 timestep ['action_is_pad' ] = torch .ones (horizon , dtype = torch .bool ) # All padded
155158
156159 timesteps .append (timestep )
@@ -176,10 +179,30 @@ def _add_image_observation_sequences(self, timestep: Dict[str, torch.Tensor], tr
176179 # Make a copy to ensure the array is writable
177180 image_data = image_data .copy ()
178181 # Convert to tensor, ensure it's in CHW format
179- if len (image_data .shape ) == 3 and image_data .shape [2 ] == 3 : # HWC format
180- image_tensor = torch .from_numpy (image_data ).permute (2 , 0 , 1 ).float () / 255.0
181- else : # Already in CHW format
182- image_tensor = torch .from_numpy (image_data ).float () / 255.0
182+ if len (image_data .shape ) == 3 :
183+ # Check if it's HWC format (height, width, channels)
184+ if image_data .shape [2 ] == 3 : # HWC format
185+ image_tensor = torch .from_numpy (image_data ).permute (2 , 0 , 1 ).float () / 255.0
186+ elif image_data .shape [0 ] == 3 : # Already CHW format
187+ image_tensor = torch .from_numpy (image_data ).float () / 255.0
188+ else :
189+ # Unknown format, assume HWC and convert
190+ image_tensor = torch .from_numpy (image_data ).permute (2 , 0 , 1 ).float () / 255.0
191+ else :
192+ # Handle 2D images by adding channel dimension
193+ if len (image_data .shape ) == 2 :
194+ image_tensor = torch .from_numpy (image_data ).unsqueeze (0 ).float () / 255.0
195+ else :
196+ # Fallback: try to reshape to CHW format
197+ image_tensor = torch .from_numpy (image_data ).float () / 255.0
198+ if image_tensor .dim () == 1 :
199+ # Try to reshape to square image
200+ size = int (np .sqrt (image_tensor .shape [0 ] / 3 ))
201+ if size * size * 3 == image_tensor .shape [0 ]:
202+ image_tensor = image_tensor .view (3 , size , size )
203+ else :
204+ # Create placeholder if can't reshape
205+ image_tensor = torch .zeros (3 , 96 , 96 , dtype = torch .float32 )
183206 image_sequence .append (image_tensor )
184207 else :
185208 # Create a placeholder image if no image data
@@ -211,17 +234,37 @@ def _add_image_observations(self, timestep: Dict[str, torch.Tensor], trajectory:
211234 # Make a copy to ensure the array is writable
212235 image_data = image_data .copy ()
213236 # Convert to tensor, ensure it's in CHW format
214- if len (image_data .shape ) == 3 and image_data .shape [2 ] == 3 : # HWC format
215- image_tensor = torch .from_numpy (image_data ).permute (2 , 0 , 1 ).float () / 255.0
216- else : # Already in CHW format
217- image_tensor = torch .from_numpy (image_data ).float () / 255.0
237+ if len (image_data .shape ) == 3 :
238+ # Check if it's HWC format (height, width, channels)
239+ if image_data .shape [2 ] == 3 : # HWC format
240+ image_tensor = torch .from_numpy (image_data ).permute (2 , 0 , 1 ).float () / 255.0
241+ elif image_data .shape [0 ] == 3 : # Already CHW format
242+ image_tensor = torch .from_numpy (image_data ).float () / 255.0
243+ else :
244+ # Unknown format, assume HWC and convert
245+ image_tensor = torch .from_numpy (image_data ).permute (2 , 0 , 1 ).float () / 255.0
246+ else :
247+ # Handle 2D images by adding channel dimension
248+ if len (image_data .shape ) == 2 :
249+ image_tensor = torch .from_numpy (image_data ).unsqueeze (0 ).float () / 255.0
250+ else :
251+ # Fallback: try to reshape to CHW format
252+ image_tensor = torch .from_numpy (image_data ).float () / 255.0
253+ if image_tensor .dim () == 1 :
254+ # Try to reshape to square image
255+ size = int (np .sqrt (image_tensor .shape [0 ] / 3 ))
256+ if size * size * 3 == image_tensor .shape [0 ]:
257+ image_tensor = image_tensor .view (3 , size , size )
258+ else :
259+ # Create placeholder if can't reshape
260+ image_tensor = torch .zeros (3 , 96 , 96 , dtype = torch .float32 )
218261 timestep ['observation.image' ] = image_tensor
219262 else :
220263 # Create a placeholder image if no image data
221- timestep ['observation.image' ] = torch .zeros (3 , 64 , 64 , dtype = torch .float32 )
264+ timestep ['observation.image' ] = torch .zeros (3 , 96 , 96 , dtype = torch .float32 )
222265 else :
223266 # Create a placeholder image if frame is out of range
224- timestep ['observation.image' ] = torch .zeros (3 , 64 , 64 , dtype = torch .float32 )
267+ timestep ['observation.image' ] = torch .zeros (3 , 96 , 96 , dtype = torch .float32 )
225268
226269 def get_torch_dataset (self ) -> torch_data .Dataset :
227270 """Get PyTorch dataset."""
@@ -296,6 +339,10 @@ def get_dataset_stats(self) -> Dict[str, Dict[str, torch.Tensor]]:
296339 if all_actions :
297340 try :
298341 actions = torch .stack (all_actions )
342+ # Transpose actions from [samples, horizon, action_dim] to [samples, action_dim, horizon]
343+ # to match the expected format for DiffusionPolicy
344+ if len (actions .shape ) == 3 :
345+ actions = actions .transpose (1 , 2 ) # [samples, action_dim, horizon]
299346 stats ['action' ] = {
300347 'mean' : actions .mean (dim = 0 ),
301348 'std' : actions .std (dim = 0 ),
0 commit comments