@@ -192,40 +192,61 @@ def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[st
192192 }
193193
194194
195- def compute_episode_stats (episode_data : dict [str , list [str ] | np .ndarray ], features : dict ) -> dict :
195+ def compute_episode_stats (
196+ episode_data : dict [str , list [str ] | np .ndarray ],
197+ features : dict ,
198+ skip_video_stats : bool = False ,
199+ ) -> dict :
196200 """Compute statistics for a single episode.
197201
198- For image/video features, samples and downsamples images before computing stats.
202+ For image/video features, samples and downsamples images before computing stats
203+ (unless skip_video_stats is True, in which case placeholder stats are used).
199204 For other features, computes stats directly on the array data.
200205
201206 Args:
202207 episode_data: Dictionary mapping feature names to their data (arrays or image paths).
203208 features: Dictionary of feature specifications with 'dtype' keys.
209+ skip_video_stats: If True, do not compute real stats for image/video features;
210+ instead use placeholder stats (min=0, max=1, mean=0.5, std=0.5, count from data)
211+ so the output format remains valid.
204212
205213 Returns:
206214 Dictionary mapping feature names to their statistics (min, max, mean, std, count).
207- Image statistics are normalized to [0, 1] range.
215+ Image statistics are normalized to [0, 1] range (or placeholders when skip_video_stats) .
208216 """
209217 ep_stats = {}
210218 for key , data in episode_data .items ():
211219 if features [key ]["dtype" ] == "string" :
212220 continue # HACK: we should receive np.arrays of strings
213221 elif features [key ]["dtype" ] in ["image" , "video" ]:
214- ep_ft_array = sample_images (data ) # data is a list of image paths
215- axes_to_reduce = (0 , 2 , 3 ) # keep channel dim
216- keepdims = True
222+ if skip_video_stats :
223+ # Placeholder stats: shape (3, 1, 1) for min/max/mean/std, count from length
224+ n_frames = len (data ) if isinstance (data , list ) else data .shape [0 ]
225+ shape = features [key ]["shape" ]
226+ # Expected shape for video is (C, H, W) e.g. (3, H, W)
227+ c = shape [0 ] if len (shape ) >= 3 else 3
228+ ep_stats [key ] = {
229+ "min" : np .zeros ((c , 1 , 1 ), dtype = np .float64 ),
230+ "max" : np .ones ((c , 1 , 1 ), dtype = np .float64 ),
231+ "mean" : np .full ((c , 1 , 1 ), 0.5 , dtype = np .float64 ),
232+ "std" : np .full ((c , 1 , 1 ), 0.5 , dtype = np .float64 ),
233+ "count" : np .array ([n_frames ]),
234+ }
235+ else :
236+ image_paths = data .tolist () if isinstance (data , np .ndarray ) else data
237+ ep_ft_array = sample_images (image_paths ) # image_paths is list[str]
238+ axes_to_reduce = (0 , 2 , 3 ) # keep channel dim
239+ keepdims = True
240+ ep_stats [key ] = get_feature_stats (ep_ft_array , axis = axes_to_reduce , keepdims = keepdims )
241+ # normalize and remove batch dim for images
242+ ep_stats [key ] = {
243+ k : v if k == "count" else np .squeeze (v / 255.0 , axis = 0 ) for k , v in ep_stats [key ].items ()
244+ }
217245 else :
218- ep_ft_array = data # data is already a np.ndarray
219- axes_to_reduce = 0 # compute stats over the first axis
220- keepdims = data .ndim == 1 # keep as np.array
221-
222- ep_stats [key ] = get_feature_stats (ep_ft_array , axis = axes_to_reduce , keepdims = keepdims )
223-
224- # finally, we normalize and remove batch dim for images
225- if features [key ]["dtype" ] in ["image" , "video" ]:
226- ep_stats [key ] = {
227- k : v if k == "count" else np .squeeze (v / 255.0 , axis = 0 ) for k , v in ep_stats [key ].items ()
228- }
246+ ep_ft_array = data if isinstance (data , np .ndarray ) else np .asarray (data )
247+ axes_to_reduce = (0 ,) # compute stats over the first axis
248+ keepdims = ep_ft_array .ndim == 1 # keep as np.array
249+ ep_stats [key ] = get_feature_stats (ep_ft_array , axis = axes_to_reduce , keepdims = keepdims )
229250
230251 return ep_stats
231252
0 commit comments