Skip to content

Commit 76a425c

Browse files
Fix: check_cached_episodes doesn't check if the requested episode video were downloaded (#2296)
* In `check_cached_episodes_sufficient` check whether all the requested video files are downloaded * optimize loop over the video paths * revert example num_workers * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> Signed-off-by: Michel Aractingi <[email protected]> * set num_workers to zero in example * style nit * reintroduce copilot optim --------- Signed-off-by: Michel Aractingi <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent df71f3c commit 76a425c

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

examples/dataset/load_lerobot_dataset.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,15 @@
132132
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
133133
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
134134

135-
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
136-
# PyTorch datasets.
137-
dataloader = torch.utils.data.DataLoader(
138-
dataset,
139-
num_workers=4,
140-
batch_size=32,
141-
shuffle=True,
142-
)
143-
144-
for batch in dataloader:
145-
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
146-
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
147-
print(f"{batch['action'].shape=}") # (32, 64, c)
148-
break
135+
if __name__ == "__main__":
136+
dataloader = torch.utils.data.DataLoader(
137+
dataset,
138+
num_workers=4,
139+
batch_size=32,
140+
shuffle=True,
141+
)
142+
for batch in dataloader:
143+
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
144+
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
145+
print(f"{batch['action'].shape=}") # (32, 64, c)
146+
break

src/lerobot/datasets/lerobot_dataset.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ def load_hf_dataset(self) -> datasets.Dataset:
837837
return hf_dataset
838838

839839
def _check_cached_episodes_sufficient(self) -> bool:
840-
"""Check if the cached dataset contains all requested episodes."""
840+
"""Check if the cached dataset contains all requested episodes and their video files."""
841841
if self.hf_dataset is None or len(self.hf_dataset) == 0:
842842
return False
843843

@@ -856,7 +856,18 @@ def _check_cached_episodes_sufficient(self) -> bool:
856856
requested_episodes = set(self.episodes)
857857

858858
# Check if all requested episodes are available in cached data
859-
return requested_episodes.issubset(available_episodes)
859+
if not requested_episodes.issubset(available_episodes):
860+
return False
861+
862+
# Check if all required video files exist
863+
if len(self.meta.video_keys) > 0:
864+
for ep_idx in requested_episodes:
865+
for vid_key in self.meta.video_keys:
866+
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
867+
if not video_path.exists():
868+
return False
869+
870+
return True
860871

861872
def create_hf_dataset(self) -> datasets.Dataset:
862873
features = get_hf_features_from_features(self.features)

0 commit comments

Comments
 (0)