diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 454389d46f..9c37bc9ca7 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -94,6 +94,7 @@ def __init__( seed: int = 42, rng: np.random.Generator | None = None, shuffle: bool = True, + data_files: str | None = None, ): """Initialize a StreamingLeRobotDataset. @@ -112,6 +113,8 @@ def __init__( seed (int, optional): Reproducibility random seed. rng (np.random.Generator | None, optional): Random number generator. shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True. + data_files (str | None, optional): Pattern to match data files. If None, auto-detects parquet files. + Defaults to None. """ super().__init__() self.repo_id = repo_id @@ -128,6 +131,7 @@ def __init__( self.streaming = streaming self.buffer_size = buffer_size + self.data_files = data_files # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None @@ -153,7 +157,7 @@ def __init__( self.repo_id if not self.streaming_from_local else str(self.root), split="train", streaming=self.streaming, - data_files="data/*/*.parquet", + data_files=self.data_files if self.data_files is not None else None, revision=self.revision, )