3131import argparse
3232import math
3333import shutil
34+ import subprocess
3435from copy import deepcopy
3536from pathlib import Path
3637from typing import Any , cast
4243from opentau .datasets .compute_stats import compute_episode_stats
4344from opentau .datasets .lerobot_dataset import CODEBASE_VERSION , LeRobotDatasetMetadata
4445from opentau .datasets .utils import (
46+ DEFAULT_IMAGE_PATH ,
4547 EPISODES_PATH ,
46- EPISODES_STATS_PATH ,
4748 TASKS_PATH ,
4849 append_jsonlines ,
4950 write_episode_stats ,
@@ -125,6 +126,127 @@ def _to_numpy_for_stats(column: pa.ChunkedArray) -> np.ndarray:
125126 return np .asarray (column .to_pylist ())
126127
127128
129+ def _trim_video_segment (src_video_path : Path , dst_video_path : Path , start_frame : int , end_frame : int ) -> None :
130+ """Trim a source video to the requested frame interval.
131+
132+ Args:
133+ src_video_path: Source episode video path.
134+ dst_video_path: Output path for the trimmed segment video.
135+ start_frame: Inclusive start frame index.
136+ end_frame: Exclusive end frame index.
137+
138+ Raises:
139+ RuntimeError: If ffmpeg is unavailable or the trim command fails.
140+ """
141+ if shutil .which ("ffmpeg" ) is None :
142+ raise RuntimeError ("ffmpeg is required to trim segmented videos but was not found in PATH." )
143+
144+ # Trim by exact frame indices and reset timeline to start at zero.
145+ vf = f"trim=start_frame={ start_frame } :end_frame={ end_frame } ,setpts=PTS-STARTPTS"
146+ cmd = [
147+ "ffmpeg" ,
148+ "-hide_banner" ,
149+ "-loglevel" ,
150+ "error" ,
151+ "-y" ,
152+ "-i" ,
153+ str (src_video_path ),
154+ "-vf" ,
155+ vf ,
156+ "-an" ,
157+ str (dst_video_path ),
158+ ]
159+ result = subprocess .run (cmd , capture_output = True , text = True )
160+ if result .returncode != 0 :
161+ raise RuntimeError (
162+ f"Failed to trim video segment { start_frame } :{ end_frame } from '{ src_video_path } '. "
163+ f"ffmpeg stderr: { result .stderr .strip ()} "
164+ )
165+
166+
167+ def _copy_segment_images_and_rewrite_column (
168+ image_cells : list [Any ],
169+ input_root : Path ,
170+ output_root : Path ,
171+ image_key : str ,
172+ output_episode_index : int ,
173+ source_episode_index : int ,
174+ source_segment_start : int ,
175+ ) -> list [Any ]:
176+ """Copy image files for a segment and rewrite per-row image references.
177+
178+ Args:
179+ image_cells: Image column values from the sliced source table.
180+ input_root: Source dataset root path.
181+ output_root: Output dataset root path.
182+ image_key: Feature key for this image stream.
183+ output_episode_index: Output episode index receiving this segment.
184+ source_episode_index: Source episode index for image path fallback.
185+ source_segment_start: Start frame index of this segment in source episode.
186+
187+ Returns:
188+ New image column values with updated file paths for copied images.
189+
190+ Raises:
191+ FileNotFoundError: If a referenced source image file does not exist.
192+ """
193+ rewritten_cells : list [Any ] = []
194+ for frame_index , cell in enumerate (image_cells ):
195+ rel_dst = DEFAULT_IMAGE_PATH .format (
196+ image_key = image_key ,
197+ episode_index = output_episode_index ,
198+ frame_index = frame_index ,
199+ )
200+ dst_path = output_root / rel_dst
201+ dst_path .parent .mkdir (parents = True , exist_ok = True )
202+
203+ if isinstance (cell , dict ):
204+ image_bytes = cell .get ("bytes" )
205+ if isinstance (image_bytes , (bytes , bytearray )) and len (image_bytes ) > 0 :
206+ dst_path .write_bytes (bytes (image_bytes ))
207+ new_cell = dict (cell )
208+ new_cell ["path" ] = str (dst_path )
209+ rewritten_cells .append (new_cell )
210+ continue
211+
212+ src_path : Path | None = None
213+ if isinstance (cell , str ):
214+ src_path = Path (cell )
215+ elif isinstance (cell , dict ):
216+ path_val = cell .get ("path" )
217+ if isinstance (path_val , str ) and path_val :
218+ src_path = Path (path_val )
219+
220+ # Embedded-image rows may not require copying when path is empty.
221+ if src_path is None :
222+ rewritten_cells .append (cell )
223+ continue
224+
225+ if not src_path .is_absolute ():
226+ src_path = input_root / src_path
227+ if not src_path .is_file ():
228+ # Fallback to canonical image location under input root.
229+ source_frame_index = source_segment_start + frame_index
230+ src_path = input_root / DEFAULT_IMAGE_PATH .format (
231+ image_key = image_key ,
232+ episode_index = source_episode_index ,
233+ frame_index = source_frame_index ,
234+ )
235+ if not src_path .is_file ():
236+ raise FileNotFoundError (f"Missing source image for key '{ image_key } ': { src_path } " )
237+
238+ shutil .copy2 (src_path , dst_path )
239+
240+ if isinstance (cell , str ):
241+ rewritten_cells .append (str (dst_path ))
242+ else :
243+ new_cell = dict (cell )
244+ new_cell ["path" ] = str (dst_path )
245+ rewritten_cells .append (new_cell )
246+
247+ return rewritten_cells
248+
249+
128250def segment_dataset (
129251 input_root : Path ,
130252 output_root : Path ,
@@ -139,6 +261,12 @@ def segment_dataset(
139261 episode_id: Source episode index to slice.
140262 segments: List of ``(start, end)`` frame ranges in ``[start, end)`` form.
141263
264+ Notes:
265+ For visual features (``dtype`` in ``{"image", "video"}``), per-episode
266+ statistics (``min``, ``max``, ``mean``, ``std``) are inherited from the
267+ source episode statistics and only the ``count`` is updated to the segment
268+ length. They are not recomputed from the segmented visual data.
269+
142270 Raises:
143271 ValueError: If inputs are invalid, source files are missing, or segment
144272 ranges are out of bounds.
@@ -207,6 +335,26 @@ def segment_dataset(
207335 if col_idx >= 0 :
208336 seg_table = seg_table .set_column (col_idx , key , arr )
209337
338+ # For image-based datasets, copy only the segment frames and rewrite image references.
339+ image_keys = [k for k , ft in source_meta .features .items () if ft ["dtype" ] == "image" ]
340+ for image_key in image_keys :
341+ if image_key not in seg_table .column_names :
342+ continue
343+ col_idx = seg_table .schema .get_field_index (image_key )
344+ image_cells = seg_table .column (image_key ).to_pylist ()
345+ rewritten = _copy_segment_images_and_rewrite_column (
346+ image_cells = image_cells ,
347+ input_root = input_root ,
348+ output_root = output_root ,
349+ image_key = image_key ,
350+ output_episode_index = output_episode_index ,
351+ source_episode_index = episode_id ,
352+ source_segment_start = start ,
353+ )
354+ seg_table = seg_table .set_column (
355+ col_idx , image_key , pa .array (rewritten , type = seg_table .schema .field (image_key ).type )
356+ )
357+
210358 episode_chunk = output_episode_index // chunks_size
211359 output_parquet_path = output_root / source_meta .data_path .format (
212360 episode_chunk = episode_chunk ,
@@ -267,15 +415,15 @@ def segment_dataset(
267415 src_video_path = input_root / source_meta .get_video_file_path (episode_id , video_key )
268416 if not src_video_path .is_file ():
269417 raise ValueError (f"Missing source video for key '{ video_key } ': { src_video_path } " )
270- for output_episode_index in range ( len ( segments ) ):
418+ for output_episode_index , ( start , end ) in enumerate ( segments ):
271419 episode_chunk = output_episode_index // chunks_size
272420 dst_video_path = output_root / video_path_template_str .format (
273421 episode_chunk = episode_chunk ,
274422 video_key = video_key ,
275423 episode_index = output_episode_index ,
276424 )
277425 dst_video_path .parent .mkdir (parents = True , exist_ok = True )
278- shutil . copy2 (src_video_path , dst_video_path )
426+ _trim_video_segment (src_video_path , dst_video_path , start , end )
279427
280428 for episode in output_episodes :
281429 append_jsonlines (episode , output_root / EPISODES_PATH )
@@ -290,9 +438,6 @@ def segment_dataset(
290438 info ["splits" ] = {"train" : f"0:{ total_episodes } " }
291439 write_json (info , output_root / "meta" / "info.json" )
292440
293- # Ensure expected meta files exist and are explicit outputs.
294- _ = output_root / EPISODES_STATS_PATH
295-
296441
297442def main () -> None :
298443 """CLI entry point."""
0 commit comments