Skip to content

Commit 360268a

Browse files
committed
enhance(viz): skip frame decoding when using mp4 asset as is
1 parent aa66e11 commit 360268a

File tree

1 file changed

+164
-60
lines changed

1 file changed

+164
-60
lines changed

src/opentau/scripts/visualize_dataset.py

Lines changed: 164 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,6 @@ def visualize_dataset(
218218

219219
repo_id = dataset.repo_id
220220

221-
logging.info("Loading dataloader")
222-
episode_sampler = EpisodeSampler(dataset, episode_index)
223-
dataloader = torch.utils.data.DataLoader(
224-
dataset,
225-
num_workers=num_workers,
226-
batch_size=batch_size,
227-
sampler=episode_sampler,
228-
)
229-
230221
logging.info("Starting Rerun")
231222

232223
if mode not in ["local", "distant"]:
@@ -299,52 +290,133 @@ def visualize_dataset(
299290
"Failed to log AssetVideo for %s (%s). Falling back to frame logging.", key, video_path
300291
)
301292

293+
# Fast path: when every camera stream is logged as AssetVideo, avoid dataset.__getitem__,
294+
# which would decode video frames for each sample.
295+
can_skip_decode = len(dataset.meta.camera_keys) == len(video_asset_keys)
296+
297+
logging.info("Loading iteration source")
298+
row_indices: list[int] | None = None
299+
no_transform_ds = None
300+
dataloader = None
301+
has_action = False
302+
has_observation_state = False
303+
has_next_done = False
304+
has_next_reward = False
305+
has_next_success = False
306+
if can_skip_decode:
307+
logging.info("Using metadata-only iteration path (no frame decoding).")
308+
epi_idx = dataset.epi2idx[episode_index]
309+
from_idx = int(dataset.episode_data_index["from"][epi_idx].item())
310+
to_idx = int(dataset.episode_data_index["to"][epi_idx].item())
311+
row_indices = list(range(from_idx, to_idx))
312+
no_transform_ds = dataset.hf_dataset.with_transform(None).with_format("numpy")
313+
no_transform_columns = set(no_transform_ds.column_names)
314+
has_action = "action" in no_transform_columns
315+
has_observation_state = "observation.state" in no_transform_columns
316+
has_next_done = "next.done" in no_transform_columns
317+
has_next_reward = "next.reward" in no_transform_columns
318+
has_next_success = "next.success" in no_transform_columns
319+
else:
320+
logging.info("Loading dataloader")
321+
episode_sampler = EpisodeSampler(dataset, episode_index)
322+
dataloader = torch.utils.data.DataLoader(
323+
dataset,
324+
num_workers=num_workers,
325+
batch_size=batch_size,
326+
sampler=episode_sampler,
327+
)
328+
302329
logging.info("Logging to Rerun")
303330
episode_start_ts: float | None = None
304331

305-
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
306-
# iterate over the batch
307-
for i in range(len(batch["index"])):
308-
frame_index = batch["frame_index"][i].item()
309-
timestamp_s = batch["timestamp"][i].item()
310-
_rr_set_sequence("frame_index", frame_index)
311-
_rr_set_seconds("timestamp", timestamp_s)
312-
if episode_start_ts is None:
313-
episode_start_ts = timestamp_s
314-
episode_video_t = max(0.0, timestamp_s - episode_start_ts)
315-
316-
# display each camera image
317-
for key in dataset.meta.camera_keys:
318-
if key in video_asset_keys:
319-
rr.log(key, rr.VideoFrameReference(seconds=episode_video_t, video_reference=key))
320-
else:
321-
# TODO(rcadene): add `.compress()`? is it lossless?
322-
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
323-
324-
# display each dimension of action space (e.g. actuators command)
325-
if "action" in batch:
326-
for dim_idx, val in enumerate(batch["action"][i]):
327-
rr.log(f"action/{dim_idx}", _rr_scalar(val.item()))
328-
329-
# display each dimension of observed state space (e.g. agent position in joint space)
330-
if "observation.state" in batch:
331-
states = batch["observation.state"][i]
332-
for dim_idx, val in enumerate(states):
333-
jnt_name = joint_names[dim_idx] if dim_idx < len(joint_names) else str(dim_idx)
334-
rr.log(f"state/{jnt_name}", _rr_scalar(val.item()))
335-
if jnt_name in urdf_joints:
336-
joint = urdf_joints[jnt_name]
337-
transform = joint.compute_transform(float(val))
338-
rr.log("URDF", transform)
339-
340-
if "next.done" in batch:
341-
rr.log("next.done", _rr_scalar(batch["next.done"][i].item()))
342-
343-
if "next.reward" in batch:
344-
rr.log("next.reward", _rr_scalar(batch["next.reward"][i].item()))
345-
346-
if "next.success" in batch:
347-
rr.log("next.success", _rr_scalar(batch["next.success"][i].item()))
332+
if can_skip_decode:
333+
assert row_indices is not None
334+
assert no_transform_ds is not None
335+
total_batches = max(1, (len(row_indices) + batch_size - 1) // batch_size)
336+
for start in tqdm.tqdm(range(0, len(row_indices), batch_size), total=total_batches):
337+
batch_indices = row_indices[start : start + batch_size]
338+
batch = no_transform_ds.select(batch_indices)
339+
340+
for i in range(len(batch["index"])):
341+
frame_index = int(np.asarray(batch["frame_index"][i]).item())
342+
timestamp_s = float(np.asarray(batch["timestamp"][i]).item())
343+
_rr_set_sequence("frame_index", frame_index)
344+
_rr_set_seconds("timestamp", timestamp_s)
345+
if episode_start_ts is None:
346+
episode_start_ts = timestamp_s
347+
episode_video_t = max(0.0, timestamp_s - episode_start_ts)
348+
349+
for key in dataset.meta.camera_keys:
350+
if key in video_asset_keys:
351+
rr.log(key, rr.VideoFrameReference(seconds=episode_video_t, video_reference=key))
352+
353+
if has_action:
354+
for dim_idx, val in enumerate(np.asarray(batch["action"][i]).reshape(-1)):
355+
rr.log(f"action/{dim_idx}", _rr_scalar(float(val)))
356+
357+
if has_observation_state:
358+
states = np.asarray(batch["observation.state"][i]).reshape(-1)
359+
for dim_idx, val in enumerate(states):
360+
jnt_name = joint_names[dim_idx] if dim_idx < len(joint_names) else str(dim_idx)
361+
rr.log(f"state/{jnt_name}", _rr_scalar(float(val)))
362+
if jnt_name in urdf_joints:
363+
joint = urdf_joints[jnt_name]
364+
transform = joint.compute_transform(float(val))
365+
rr.log("URDF", transform)
366+
367+
if has_next_done:
368+
rr.log("next.done", _rr_scalar(float(np.asarray(batch["next.done"][i]).item())))
369+
370+
if has_next_reward:
371+
rr.log("next.reward", _rr_scalar(float(np.asarray(batch["next.reward"][i]).item())))
372+
373+
if has_next_success:
374+
rr.log("next.success", _rr_scalar(float(np.asarray(batch["next.success"][i]).item())))
375+
else:
376+
assert dataloader is not None
377+
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
378+
# iterate over the batch
379+
for i in range(len(batch["index"])):
380+
frame_index = batch["frame_index"][i].item()
381+
timestamp_s = batch["timestamp"][i].item()
382+
_rr_set_sequence("frame_index", frame_index)
383+
_rr_set_seconds("timestamp", timestamp_s)
384+
if episode_start_ts is None:
385+
episode_start_ts = timestamp_s
386+
episode_video_t = max(0.0, timestamp_s - episode_start_ts)
387+
388+
# display each camera image
389+
for key in dataset.meta.camera_keys:
390+
if key in video_asset_keys:
391+
rr.log(key, rr.VideoFrameReference(seconds=episode_video_t, video_reference=key))
392+
else:
393+
# TODO(rcadene): add `.compress()`? is it lossless?
394+
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
395+
396+
# display each dimension of action space (e.g. actuators command)
397+
if "action" in batch:
398+
for dim_idx, val in enumerate(batch["action"][i]):
399+
rr.log(f"action/{dim_idx}", _rr_scalar(val.item()))
400+
401+
# display each dimension of observed state space (e.g. agent position in joint space)
402+
if "observation.state" in batch:
403+
states = batch["observation.state"][i]
404+
for dim_idx, val in enumerate(states):
405+
jnt_name = joint_names[dim_idx] if dim_idx < len(joint_names) else str(dim_idx)
406+
rr.log(f"state/{jnt_name}", _rr_scalar(val.item()))
407+
if jnt_name in urdf_joints:
408+
joint = urdf_joints[jnt_name]
409+
transform = joint.compute_transform(float(val))
410+
rr.log("URDF", transform)
411+
412+
if "next.done" in batch:
413+
rr.log("next.done", _rr_scalar(batch["next.done"][i].item()))
414+
415+
if "next.reward" in batch:
416+
rr.log("next.reward", _rr_scalar(batch["next.reward"][i].item()))
417+
418+
if "next.success" in batch:
419+
rr.log("next.success", _rr_scalar(batch["next.success"][i].item()))
348420

349421
if mode == "local" and save:
350422
# save .rrd locally
@@ -445,12 +517,14 @@ def parse_args() -> dict:
445517
)
446518
parser.add_argument(
447519
"--tolerance-s",
520+
"--tolerance",
521+
dest="tolerance_s",
448522
type=float,
449523
default=1e-4,
450524
help=(
451525
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
452526
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
453-
"If not given, defaults to 1e-4."
527+
"If not given, defaults to 1e-4. `--tolerance` is kept as an alias."
454528
),
455529
)
456530
parser.add_argument(
@@ -499,13 +573,43 @@ def main():
499573
kwargs["urdf"] = None
500574

501575
logging.info("Loading dataset")
502-
dataset = LeRobotDataset(
503-
create_mock_train_config(),
504-
repo_id,
505-
root=root,
506-
tolerance_s=tolerance_s,
507-
standardize=False,
508-
)
576+
tolerance_schedule = [tolerance_s]
577+
for candidate in [1e-3, 3e-3, 1e-2]:
578+
if candidate > tolerance_schedule[-1]:
579+
tolerance_schedule.append(candidate)
580+
581+
dataset = None
582+
last_timestamp_error = None
583+
for tol in tolerance_schedule:
584+
try:
585+
dataset = LeRobotDataset(
586+
create_mock_train_config(),
587+
repo_id,
588+
root=root,
589+
tolerance_s=tol,
590+
standardize=False,
591+
)
592+
if tol != tolerance_s:
593+
logging.warning(
594+
"Dataset timestamp check required relaxed tolerance. "
595+
"Requested=%s, using=%s for visualization.",
596+
tolerance_s,
597+
tol,
598+
)
599+
break
600+
except ValueError as e:
601+
# Visualization should be resilient to small timestamp quantization jitter.
602+
if "timestamps unexpectedly violate the tolerance" not in str(e):
603+
raise
604+
last_timestamp_error = e
605+
logging.warning(
606+
"Timestamp sync check failed with tolerance_s=%s. Retrying with a looser tolerance.",
607+
tol,
608+
)
609+
610+
if dataset is None:
611+
assert last_timestamp_error is not None
612+
raise last_timestamp_error
509613

510614
visualize_dataset(dataset, **kwargs)
511615

0 commit comments

Comments
 (0)