@@ -218,15 +218,6 @@ def visualize_dataset(
218218
219219 repo_id = dataset .repo_id
220220
221- logging .info ("Loading dataloader" )
222- episode_sampler = EpisodeSampler (dataset , episode_index )
223- dataloader = torch .utils .data .DataLoader (
224- dataset ,
225- num_workers = num_workers ,
226- batch_size = batch_size ,
227- sampler = episode_sampler ,
228- )
229-
230221 logging .info ("Starting Rerun" )
231222
232223 if mode not in ["local" , "distant" ]:
@@ -299,52 +290,133 @@ def visualize_dataset(
299290 "Failed to log AssetVideo for %s (%s). Falling back to frame logging." , key , video_path
300291 )
301292
293+ # Fast path: when every camera stream is logged as AssetVideo, avoid dataset.__getitem__,
294+ # which would decode video frames for each sample.
295+ can_skip_decode = len (dataset .meta .camera_keys ) == len (video_asset_keys )
296+
297+ logging .info ("Loading iteration source" )
298+ row_indices : list [int ] | None = None
299+ no_transform_ds = None
300+ dataloader = None
301+ has_action = False
302+ has_observation_state = False
303+ has_next_done = False
304+ has_next_reward = False
305+ has_next_success = False
306+ if can_skip_decode :
307+ logging .info ("Using metadata-only iteration path (no frame decoding)." )
308+ epi_idx = dataset .epi2idx [episode_index ]
309+ from_idx = int (dataset .episode_data_index ["from" ][epi_idx ].item ())
310+ to_idx = int (dataset .episode_data_index ["to" ][epi_idx ].item ())
311+ row_indices = list (range (from_idx , to_idx ))
312+ no_transform_ds = dataset .hf_dataset .with_transform (None ).with_format ("numpy" )
313+ no_transform_columns = set (no_transform_ds .column_names )
314+ has_action = "action" in no_transform_columns
315+ has_observation_state = "observation.state" in no_transform_columns
316+ has_next_done = "next.done" in no_transform_columns
317+ has_next_reward = "next.reward" in no_transform_columns
318+ has_next_success = "next.success" in no_transform_columns
319+ else :
320+ logging .info ("Loading dataloader" )
321+ episode_sampler = EpisodeSampler (dataset , episode_index )
322+ dataloader = torch .utils .data .DataLoader (
323+ dataset ,
324+ num_workers = num_workers ,
325+ batch_size = batch_size ,
326+ sampler = episode_sampler ,
327+ )
328+
302329 logging .info ("Logging to Rerun" )
303330 episode_start_ts : float | None = None
304331
305- for batch in tqdm .tqdm (dataloader , total = len (dataloader )):
306- # iterate over the batch
307- for i in range (len (batch ["index" ])):
308- frame_index = batch ["frame_index" ][i ].item ()
309- timestamp_s = batch ["timestamp" ][i ].item ()
310- _rr_set_sequence ("frame_index" , frame_index )
311- _rr_set_seconds ("timestamp" , timestamp_s )
312- if episode_start_ts is None :
313- episode_start_ts = timestamp_s
314- episode_video_t = max (0.0 , timestamp_s - episode_start_ts )
315-
316- # display each camera image
317- for key in dataset .meta .camera_keys :
318- if key in video_asset_keys :
319- rr .log (key , rr .VideoFrameReference (seconds = episode_video_t , video_reference = key ))
320- else :
321- # TODO(rcadene): add `.compress()`? is it lossless?
322- rr .log (key , rr .Image (to_hwc_uint8_numpy (batch [key ][i ])))
323-
324- # display each dimension of action space (e.g. actuators command)
325- if "action" in batch :
326- for dim_idx , val in enumerate (batch ["action" ][i ]):
327- rr .log (f"action/{ dim_idx } " , _rr_scalar (val .item ()))
328-
329- # display each dimension of observed state space (e.g. agent position in joint space)
330- if "observation.state" in batch :
331- states = batch ["observation.state" ][i ]
332- for dim_idx , val in enumerate (states ):
333- jnt_name = joint_names [dim_idx ] if dim_idx < len (joint_names ) else str (dim_idx )
334- rr .log (f"state/{ jnt_name } " , _rr_scalar (val .item ()))
335- if jnt_name in urdf_joints :
336- joint = urdf_joints [jnt_name ]
337- transform = joint .compute_transform (float (val ))
338- rr .log ("URDF" , transform )
339-
340- if "next.done" in batch :
341- rr .log ("next.done" , _rr_scalar (batch ["next.done" ][i ].item ()))
342-
343- if "next.reward" in batch :
344- rr .log ("next.reward" , _rr_scalar (batch ["next.reward" ][i ].item ()))
345-
346- if "next.success" in batch :
347- rr .log ("next.success" , _rr_scalar (batch ["next.success" ][i ].item ()))
332+ if can_skip_decode :
333+ assert row_indices is not None
334+ assert no_transform_ds is not None
335+ total_batches = max (1 , (len (row_indices ) + batch_size - 1 ) // batch_size )
336+ for start in tqdm .tqdm (range (0 , len (row_indices ), batch_size ), total = total_batches ):
337+ batch_indices = row_indices [start : start + batch_size ]
338+ batch = no_transform_ds .select (batch_indices )
339+
340+ for i in range (len (batch ["index" ])):
341+ frame_index = int (np .asarray (batch ["frame_index" ][i ]).item ())
342+ timestamp_s = float (np .asarray (batch ["timestamp" ][i ]).item ())
343+ _rr_set_sequence ("frame_index" , frame_index )
344+ _rr_set_seconds ("timestamp" , timestamp_s )
345+ if episode_start_ts is None :
346+ episode_start_ts = timestamp_s
347+ episode_video_t = max (0.0 , timestamp_s - episode_start_ts )
348+
349+ for key in dataset .meta .camera_keys :
350+ if key in video_asset_keys :
351+ rr .log (key , rr .VideoFrameReference (seconds = episode_video_t , video_reference = key ))
352+
353+ if has_action :
354+ for dim_idx , val in enumerate (np .asarray (batch ["action" ][i ]).reshape (- 1 )):
355+ rr .log (f"action/{ dim_idx } " , _rr_scalar (float (val )))
356+
357+ if has_observation_state :
358+ states = np .asarray (batch ["observation.state" ][i ]).reshape (- 1 )
359+ for dim_idx , val in enumerate (states ):
360+ jnt_name = joint_names [dim_idx ] if dim_idx < len (joint_names ) else str (dim_idx )
361+ rr .log (f"state/{ jnt_name } " , _rr_scalar (float (val )))
362+ if jnt_name in urdf_joints :
363+ joint = urdf_joints [jnt_name ]
364+ transform = joint .compute_transform (float (val ))
365+ rr .log ("URDF" , transform )
366+
367+ if has_next_done :
368+ rr .log ("next.done" , _rr_scalar (float (np .asarray (batch ["next.done" ][i ]).item ())))
369+
370+ if has_next_reward :
371+ rr .log ("next.reward" , _rr_scalar (float (np .asarray (batch ["next.reward" ][i ]).item ())))
372+
373+ if has_next_success :
374+ rr .log ("next.success" , _rr_scalar (float (np .asarray (batch ["next.success" ][i ]).item ())))
375+ else :
376+ assert dataloader is not None
377+ for batch in tqdm .tqdm (dataloader , total = len (dataloader )):
378+ # iterate over the batch
379+ for i in range (len (batch ["index" ])):
380+ frame_index = batch ["frame_index" ][i ].item ()
381+ timestamp_s = batch ["timestamp" ][i ].item ()
382+ _rr_set_sequence ("frame_index" , frame_index )
383+ _rr_set_seconds ("timestamp" , timestamp_s )
384+ if episode_start_ts is None :
385+ episode_start_ts = timestamp_s
386+ episode_video_t = max (0.0 , timestamp_s - episode_start_ts )
387+
388+ # display each camera image
389+ for key in dataset .meta .camera_keys :
390+ if key in video_asset_keys :
391+ rr .log (key , rr .VideoFrameReference (seconds = episode_video_t , video_reference = key ))
392+ else :
393+ # TODO(rcadene): add `.compress()`? is it lossless?
394+ rr .log (key , rr .Image (to_hwc_uint8_numpy (batch [key ][i ])))
395+
396+ # display each dimension of action space (e.g. actuators command)
397+ if "action" in batch :
398+ for dim_idx , val in enumerate (batch ["action" ][i ]):
399+ rr .log (f"action/{ dim_idx } " , _rr_scalar (val .item ()))
400+
401+ # display each dimension of observed state space (e.g. agent position in joint space)
402+ if "observation.state" in batch :
403+ states = batch ["observation.state" ][i ]
404+ for dim_idx , val in enumerate (states ):
405+ jnt_name = joint_names [dim_idx ] if dim_idx < len (joint_names ) else str (dim_idx )
406+ rr .log (f"state/{ jnt_name } " , _rr_scalar (val .item ()))
407+ if jnt_name in urdf_joints :
408+ joint = urdf_joints [jnt_name ]
409+ transform = joint .compute_transform (float (val ))
410+ rr .log ("URDF" , transform )
411+
412+ if "next.done" in batch :
413+ rr .log ("next.done" , _rr_scalar (batch ["next.done" ][i ].item ()))
414+
415+ if "next.reward" in batch :
416+ rr .log ("next.reward" , _rr_scalar (batch ["next.reward" ][i ].item ()))
417+
418+ if "next.success" in batch :
419+ rr .log ("next.success" , _rr_scalar (batch ["next.success" ][i ].item ()))
348420
349421 if mode == "local" and save :
350422 # save .rrd locally
@@ -445,12 +517,14 @@ def parse_args() -> dict:
445517 )
446518 parser .add_argument (
447519 "--tolerance-s" ,
520+ "--tolerance" ,
521+ dest = "tolerance_s" ,
448522 type = float ,
449523 default = 1e-4 ,
450524 help = (
451525 "Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
452526 "This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
453- "If not given, defaults to 1e-4."
527+ "If not given, defaults to 1e-4. `--tolerance` is kept as an alias. "
454528 ),
455529 )
456530 parser .add_argument (
@@ -499,13 +573,43 @@ def main():
499573 kwargs ["urdf" ] = None
500574
501575 logging .info ("Loading dataset" )
502- dataset = LeRobotDataset (
503- create_mock_train_config (),
504- repo_id ,
505- root = root ,
506- tolerance_s = tolerance_s ,
507- standardize = False ,
508- )
576+ tolerance_schedule = [tolerance_s ]
577+ for candidate in [1e-3 , 3e-3 , 1e-2 ]:
578+ if candidate > tolerance_schedule [- 1 ]:
579+ tolerance_schedule .append (candidate )
580+
581+ dataset = None
582+ last_timestamp_error = None
583+ for tol in tolerance_schedule :
584+ try :
585+ dataset = LeRobotDataset (
586+ create_mock_train_config (),
587+ repo_id ,
588+ root = root ,
589+ tolerance_s = tol ,
590+ standardize = False ,
591+ )
592+ if tol != tolerance_s :
593+ logging .warning (
594+ "Dataset timestamp check required relaxed tolerance. "
595+ "Requested=%s, using=%s for visualization." ,
596+ tolerance_s ,
597+ tol ,
598+ )
599+ break
600+ except ValueError as e :
601+ # Visualization should be resilient to small timestamp quantization jitter.
602+ if "timestamps unexpectedly violate the tolerance" not in str (e ):
603+ raise
604+ last_timestamp_error = e
605+ logging .warning (
606+ "Timestamp sync check failed with tolerance_s=%s. Retrying with a looser tolerance." ,
607+ tol ,
608+ )
609+
610+ if dataset is None :
611+ assert last_timestamp_error is not None
612+ raise last_timestamp_error
509613
510614 visualize_dataset (dataset , ** kwargs )
511615
0 commit comments