Skip to content

Commit 6933b5d

Browse files
committed
fix: properly segment videos and images
1 parent b05378c commit 6933b5d

File tree

2 files changed

+258
-6
lines changed

2 files changed

+258
-6
lines changed

src/opentau/scripts/segment_lerobot_dataset.py

Lines changed: 151 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import argparse
3232
import math
3333
import shutil
34+
import subprocess
3435
from copy import deepcopy
3536
from pathlib import Path
3637
from typing import Any, cast
@@ -42,8 +43,8 @@
4243
from opentau.datasets.compute_stats import compute_episode_stats
4344
from opentau.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
4445
from opentau.datasets.utils import (
46+
DEFAULT_IMAGE_PATH,
4547
EPISODES_PATH,
46-
EPISODES_STATS_PATH,
4748
TASKS_PATH,
4849
append_jsonlines,
4950
write_episode_stats,
@@ -125,6 +126,127 @@ def _to_numpy_for_stats(column: pa.ChunkedArray) -> np.ndarray:
125126
return np.asarray(column.to_pylist())
126127

127128

129+
def _trim_video_segment(src_video_path: Path, dst_video_path: Path, start_frame: int, end_frame: int) -> None:
130+
"""Trim a source video to the requested frame interval.
131+
132+
Args:
133+
src_video_path: Source episode video path.
134+
dst_video_path: Output path for the trimmed segment video.
135+
start_frame: Inclusive start frame index.
136+
end_frame: Exclusive end frame index.
137+
138+
Raises:
139+
RuntimeError: If ffmpeg is unavailable or the trim command fails.
140+
"""
141+
if shutil.which("ffmpeg") is None:
142+
raise RuntimeError("ffmpeg is required to trim segmented videos but was not found in PATH.")
143+
144+
# Trim by exact frame indices and reset timeline to start at zero.
145+
vf = f"trim=start_frame={start_frame}:end_frame={end_frame},setpts=PTS-STARTPTS"
146+
cmd = [
147+
"ffmpeg",
148+
"-hide_banner",
149+
"-loglevel",
150+
"error",
151+
"-y",
152+
"-i",
153+
str(src_video_path),
154+
"-vf",
155+
vf,
156+
"-an",
157+
str(dst_video_path),
158+
]
159+
result = subprocess.run(cmd, capture_output=True, text=True)
160+
if result.returncode != 0:
161+
raise RuntimeError(
162+
f"Failed to trim video segment {start_frame}:{end_frame} from '{src_video_path}'. "
163+
f"ffmpeg stderr: {result.stderr.strip()}"
164+
)
165+
166+
167+
def _copy_segment_images_and_rewrite_column(
168+
image_cells: list[Any],
169+
input_root: Path,
170+
output_root: Path,
171+
image_key: str,
172+
output_episode_index: int,
173+
source_episode_index: int,
174+
source_segment_start: int,
175+
) -> list[Any]:
176+
"""Copy image files for a segment and rewrite per-row image references.
177+
178+
Args:
179+
image_cells: Image column values from the sliced source table.
180+
input_root: Source dataset root path.
181+
output_root: Output dataset root path.
182+
image_key: Feature key for this image stream.
183+
output_episode_index: Output episode index receiving this segment.
184+
source_episode_index: Source episode index for image path fallback.
185+
source_segment_start: Start frame index of this segment in source episode.
186+
187+
Returns:
188+
New image column values with updated file paths for copied images.
189+
190+
Raises:
191+
FileNotFoundError: If a referenced source image file does not exist.
192+
"""
193+
rewritten_cells: list[Any] = []
194+
for frame_index, cell in enumerate(image_cells):
195+
rel_dst = DEFAULT_IMAGE_PATH.format(
196+
image_key=image_key,
197+
episode_index=output_episode_index,
198+
frame_index=frame_index,
199+
)
200+
dst_path = output_root / rel_dst
201+
dst_path.parent.mkdir(parents=True, exist_ok=True)
202+
203+
if isinstance(cell, dict):
204+
image_bytes = cell.get("bytes")
205+
if isinstance(image_bytes, (bytes, bytearray)) and len(image_bytes) > 0:
206+
dst_path.write_bytes(bytes(image_bytes))
207+
new_cell = dict(cell)
208+
new_cell["path"] = str(dst_path)
209+
rewritten_cells.append(new_cell)
210+
continue
211+
212+
src_path: Path | None = None
213+
if isinstance(cell, str):
214+
src_path = Path(cell)
215+
elif isinstance(cell, dict):
216+
path_val = cell.get("path")
217+
if isinstance(path_val, str) and path_val:
218+
src_path = Path(path_val)
219+
220+
# Embedded-image rows may not require copying when path is empty.
221+
if src_path is None:
222+
rewritten_cells.append(cell)
223+
continue
224+
225+
if not src_path.is_absolute():
226+
src_path = input_root / src_path
227+
if not src_path.is_file():
228+
# Fallback to canonical image location under input root.
229+
source_frame_index = source_segment_start + frame_index
230+
src_path = input_root / DEFAULT_IMAGE_PATH.format(
231+
image_key=image_key,
232+
episode_index=source_episode_index,
233+
frame_index=source_frame_index,
234+
)
235+
if not src_path.is_file():
236+
raise FileNotFoundError(f"Missing source image for key '{image_key}': {src_path}")
237+
238+
shutil.copy2(src_path, dst_path)
239+
240+
if isinstance(cell, str):
241+
rewritten_cells.append(str(dst_path))
242+
else:
243+
new_cell = dict(cell)
244+
new_cell["path"] = str(dst_path)
245+
rewritten_cells.append(new_cell)
246+
247+
return rewritten_cells
248+
249+
128250
def segment_dataset(
129251
input_root: Path,
130252
output_root: Path,
@@ -139,6 +261,12 @@ def segment_dataset(
139261
episode_id: Source episode index to slice.
140262
segments: List of ``(start, end)`` frame ranges in ``[start, end)`` form.
141263
264+
Notes:
265+
For visual features (``dtype`` in ``{"image", "video"}``), per-episode
266+
statistics (``min``, ``max``, ``mean``, ``std``) are inherited from the
267+
source episode statistics and only the ``count`` is updated to the segment
268+
length. They are not recomputed from the segmented visual data.
269+
142270
Raises:
143271
ValueError: If inputs are invalid, source files are missing, or segment
144272
ranges are out of bounds.
@@ -207,6 +335,26 @@ def segment_dataset(
207335
if col_idx >= 0:
208336
seg_table = seg_table.set_column(col_idx, key, arr)
209337

338+
# For image-based datasets, copy only the segment frames and rewrite image references.
339+
image_keys = [k for k, ft in source_meta.features.items() if ft["dtype"] == "image"]
340+
for image_key in image_keys:
341+
if image_key not in seg_table.column_names:
342+
continue
343+
col_idx = seg_table.schema.get_field_index(image_key)
344+
image_cells = seg_table.column(image_key).to_pylist()
345+
rewritten = _copy_segment_images_and_rewrite_column(
346+
image_cells=image_cells,
347+
input_root=input_root,
348+
output_root=output_root,
349+
image_key=image_key,
350+
output_episode_index=output_episode_index,
351+
source_episode_index=episode_id,
352+
source_segment_start=start,
353+
)
354+
seg_table = seg_table.set_column(
355+
col_idx, image_key, pa.array(rewritten, type=seg_table.schema.field(image_key).type)
356+
)
357+
210358
episode_chunk = output_episode_index // chunks_size
211359
output_parquet_path = output_root / source_meta.data_path.format(
212360
episode_chunk=episode_chunk,
@@ -267,15 +415,15 @@ def segment_dataset(
267415
src_video_path = input_root / source_meta.get_video_file_path(episode_id, video_key)
268416
if not src_video_path.is_file():
269417
raise ValueError(f"Missing source video for key '{video_key}': {src_video_path}")
270-
for output_episode_index in range(len(segments)):
418+
for output_episode_index, (start, end) in enumerate(segments):
271419
episode_chunk = output_episode_index // chunks_size
272420
dst_video_path = output_root / video_path_template_str.format(
273421
episode_chunk=episode_chunk,
274422
video_key=video_key,
275423
episode_index=output_episode_index,
276424
)
277425
dst_video_path.parent.mkdir(parents=True, exist_ok=True)
278-
shutil.copy2(src_video_path, dst_video_path)
426+
_trim_video_segment(src_video_path, dst_video_path, start, end)
279427

280428
for episode in output_episodes:
281429
append_jsonlines(episode, output_root / EPISODES_PATH)
@@ -290,9 +438,6 @@ def segment_dataset(
290438
info["splits"] = {"train": f"0:{total_episodes}"}
291439
write_json(info, output_root / "meta" / "info.json")
292440

293-
# Ensure expected meta files exist and are explicit outputs.
294-
_ = output_root / EPISODES_STATS_PATH
295-
296441

297442
def main() -> None:
298443
"""CLI entry point."""

tests/datasets/test_segment_lerobot_dataset.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from pathlib import Path
1818
from typing import Any
19+
from unittest.mock import patch
1920

2021
import numpy as np
2122
import pyarrow.parquet as pq
@@ -25,6 +26,24 @@
2526
from opentau.scripts.segment_lerobot_dataset import segment_dataset
2627

2728

29+
def _extract_image_path(cell: Any) -> str | None:
30+
"""Extract image path from a parquet image cell.
31+
32+
Args:
33+
cell: A parquet image value (string path or dict with `path`/`bytes`).
34+
35+
Returns:
36+
Image path string if present, otherwise None.
37+
"""
38+
if isinstance(cell, str):
39+
return cell
40+
if isinstance(cell, dict):
41+
path = cell.get("path")
42+
if isinstance(path, str) and path:
43+
return path
44+
return None
45+
46+
2847
def test_segment_lerobot_v21_dataset(tmp_path: Path, empty_lerobot_dataset_factory: Any) -> None:
2948
"""Validate baseline segmentation behavior for v2.1 input.
3049
@@ -207,3 +226,91 @@ def test_segment_lerobot_non_consecutive_and_overlapping_ranges(
207226
assert [float(x) for x in ep0["state"]] == [float(i) for i in range(0, 10)]
208227
assert [float(x) for x in ep1["state"]] == [float(i) for i in range(18, 23)]
209228
assert [float(x) for x in ep2["state"]] == [float(i) for i in range(5, 15)]
229+
230+
231+
def test_segment_lerobot_copies_image_files_for_segments(
232+
tmp_path: Path, empty_lerobot_dataset_factory: Any
233+
) -> None:
234+
"""Ensure segmented datasets copy and rewrite image file references.
235+
236+
Args:
237+
tmp_path: Temporary directory fixture provided by pytest.
238+
empty_lerobot_dataset_factory: Fixture that creates a writable dataset.
239+
"""
240+
input_root = tmp_path / "source_image_dataset"
241+
output_root = tmp_path / "segmented_image_dataset"
242+
image_key = "observation.images.camera"
243+
244+
features = {
245+
"state": {"dtype": "float32", "shape": (1,), "names": None},
246+
"actions": {"dtype": "float32", "shape": (1,), "names": None},
247+
image_key: {"dtype": "image", "shape": (3, 8, 8), "names": ["channel", "height", "width"]},
248+
}
249+
dataset = empty_lerobot_dataset_factory(root=input_root, features=features, use_videos=False)
250+
for i in range(8):
251+
dataset.add_frame(
252+
{
253+
"state": np.array([float(i)], dtype=np.float32),
254+
"actions": np.array([float(i) + 1.0], dtype=np.float32),
255+
"observation.images.camera": np.full((8, 8, 3), i / 8.0, dtype=np.float32),
256+
"task": "image task",
257+
}
258+
)
259+
dataset.save_episode()
260+
261+
segment_dataset(
262+
input_root=input_root,
263+
output_root=output_root,
264+
episode_id=0,
265+
segments=[(1, 4), (4, 8)],
266+
)
267+
268+
out_meta = LeRobotDatasetMetadata(repo_id=output_root.name, root=output_root)
269+
ep0 = pq.read_table(output_root / out_meta.get_data_file_path(0)).to_pydict()
270+
ep1 = pq.read_table(output_root / out_meta.get_data_file_path(1)).to_pydict()
271+
272+
ep0_paths = [_extract_image_path(cell) for cell in ep0[image_key]]
273+
ep1_paths = [_extract_image_path(cell) for cell in ep1[image_key]]
274+
assert all(path is not None for path in ep0_paths)
275+
assert all(path is not None for path in ep1_paths)
276+
277+
for frame_idx, path in enumerate(ep0_paths):
278+
assert path is not None
279+
expected = output_root / f"images/{image_key}/episode_000000/frame_{frame_idx:06d}.png"
280+
assert Path(path) == expected
281+
assert expected.is_file()
282+
283+
for frame_idx, path in enumerate(ep1_paths):
284+
assert path is not None
285+
expected = output_root / f"images/{image_key}/episode_000001/frame_{frame_idx:06d}.png"
286+
assert Path(path) == expected
287+
assert expected.is_file()
288+
289+
290+
def test_trim_video_segment_uses_frame_range_filter(tmp_path: Path) -> None:
291+
"""Ensure ffmpeg trim command uses frame-range segmentation.
292+
293+
Args:
294+
tmp_path: Temporary directory fixture provided by pytest.
295+
"""
296+
src = tmp_path / "src.mp4"
297+
dst = tmp_path / "dst.mp4"
298+
src.write_bytes(b"fake")
299+
300+
with (
301+
patch("opentau.scripts.segment_lerobot_dataset.shutil.which", return_value="/usr/bin/ffmpeg"),
302+
patch("opentau.scripts.segment_lerobot_dataset.subprocess.run") as run_mock,
303+
):
304+
run_mock.return_value.returncode = 0
305+
run_mock.return_value.stderr = ""
306+
307+
from opentau.scripts.segment_lerobot_dataset import _trim_video_segment
308+
309+
_trim_video_segment(src, dst, 5, 15)
310+
311+
assert run_mock.call_count == 1
312+
cmd = run_mock.call_args.args[0]
313+
assert "ffmpeg" in cmd[0]
314+
assert "-vf" in cmd
315+
vf_expr = cmd[cmd.index("-vf") + 1]
316+
assert "trim=start_frame=5:end_frame=15" in vf_expr

0 commit comments

Comments
 (0)