Skip to content

Commit 12f2f35

Browse files
- Introduce _current_file_start_frame for better tracking of the number of frames in each parquet file (#2280)
- Added testing for that section in `test_datasets.py`
1 parent a024d33 commit 12f2f35

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

src/lerobot/datasets/lerobot_dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ def __init__(
686686
self.episode_buffer = None
687687
self.writer = None
688688
self.latest_episode = None
689+
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
689690

690691
self.root.mkdir(exist_ok=True, parents=True)
691692

@@ -1232,6 +1233,7 @@ def _save_episode_data(self, episode_buffer: dict) -> dict:
12321233
# Initialize indices and frame count for a new dataset made of the first episode data
12331234
chunk_idx, file_idx = 0, 0
12341235
global_frame_index = 0
1236+
self._current_file_start_frame = 0
12351237
# However, if the episodes already exists
12361238
# It means we are resuming recording, so we need to load the latest episode
12371239
# Update the indices to avoid overwriting the latest episode
@@ -1243,6 +1245,7 @@ def _save_episode_data(self, episode_buffer: dict) -> dict:
12431245

12441246
# When resuming, move to the next file
12451247
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
1248+
self._current_file_start_frame = global_frame_index
12461249
else:
12471250
# Retrieve information from the latest parquet file
12481251
latest_ep = self.latest_episode
@@ -1253,7 +1256,7 @@ def _save_episode_data(self, episode_buffer: dict) -> dict:
12531256
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
12541257
latest_size_in_mb = get_file_size_in_mb(latest_path)
12551258

1256-
frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"]
1259+
frames_in_current_file = global_frame_index - self._current_file_start_frame
12571260
av_size_per_frame = (
12581261
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
12591262
)
@@ -1267,6 +1270,7 @@ def _save_episode_data(self, episode_buffer: dict) -> dict:
12671270
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
12681271
self._close_writer()
12691272
self._writer_closed_for_reading = False
1273+
self._current_file_start_frame = global_frame_index
12701274

12711275
ep_dict["data/chunk_index"] = chunk_idx
12721276
ep_dict["data/file_index"] = file_idx
@@ -1473,6 +1477,7 @@ def create(
14731477
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
14741478
obj.writer = None
14751479
obj.latest_episode = None
1480+
obj._current_file_start_frame = None
14761481
# Initialize tracking for incremental recording
14771482
obj._lazy_loading = False
14781483
obj._recorded_frames = 0

tests/datasets/test_datasets.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,3 +1199,96 @@ def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory):
11991199
expected_to = (ep_idx + 1) * frames_per_episode
12001200
assert ep_metadata["dataset_from_index"] == expected_from
12011201
assert ep_metadata["dataset_to_index"] == expected_to
1202+
1203+
1204+
def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_factory):
1205+
"""Regression test for bug where frames_in_current_file only counted frames from last episode instead of all frames in current file."""
1206+
features = {
1207+
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
1208+
"action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]},
1209+
}
1210+
1211+
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
1212+
dataset.meta.update_chunk_settings(data_files_size_in_mb=100)
1213+
1214+
assert dataset._current_file_start_frame is None
1215+
1216+
frames_per_episode = 10
1217+
for _ in range(frames_per_episode):
1218+
dataset.add_frame(
1219+
{
1220+
"observation.state": torch.randn(2),
1221+
"action": torch.randn(2),
1222+
"task": "task_0",
1223+
}
1224+
)
1225+
dataset.save_episode()
1226+
1227+
assert dataset._current_file_start_frame == 0
1228+
assert dataset.meta.total_episodes == 1
1229+
assert dataset.meta.total_frames == frames_per_episode
1230+
1231+
for _ in range(frames_per_episode):
1232+
dataset.add_frame(
1233+
{
1234+
"observation.state": torch.randn(2),
1235+
"action": torch.randn(2),
1236+
"task": "task_1",
1237+
}
1238+
)
1239+
dataset.save_episode()
1240+
1241+
assert dataset._current_file_start_frame == 0
1242+
assert dataset.meta.total_episodes == 2
1243+
assert dataset.meta.total_frames == 2 * frames_per_episode
1244+
1245+
ep1_chunk = dataset.latest_episode["data/chunk_index"]
1246+
ep1_file = dataset.latest_episode["data/file_index"]
1247+
assert ep1_chunk == 0
1248+
assert ep1_file == 0
1249+
1250+
for _ in range(frames_per_episode):
1251+
dataset.add_frame(
1252+
{
1253+
"observation.state": torch.randn(2),
1254+
"action": torch.randn(2),
1255+
"task": "task_2",
1256+
}
1257+
)
1258+
dataset.save_episode()
1259+
1260+
assert dataset._current_file_start_frame == 0
1261+
assert dataset.meta.total_episodes == 3
1262+
assert dataset.meta.total_frames == 3 * frames_per_episode
1263+
1264+
ep2_chunk = dataset.latest_episode["data/chunk_index"]
1265+
ep2_file = dataset.latest_episode["data/file_index"]
1266+
assert ep2_chunk == 0
1267+
assert ep2_file == 0
1268+
1269+
dataset.finalize()
1270+
1271+
from lerobot.datasets.utils import load_episodes
1272+
1273+
dataset.meta.episodes = load_episodes(dataset.root)
1274+
assert dataset.meta.episodes is not None
1275+
1276+
for ep_idx in range(3):
1277+
ep_metadata = dataset.meta.episodes[ep_idx]
1278+
assert ep_metadata["data/chunk_index"] == 0
1279+
assert ep_metadata["data/file_index"] == 0
1280+
1281+
expected_from = ep_idx * frames_per_episode
1282+
expected_to = (ep_idx + 1) * frames_per_episode
1283+
assert ep_metadata["dataset_from_index"] == expected_from
1284+
assert ep_metadata["dataset_to_index"] == expected_to
1285+
1286+
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
1287+
assert len(loaded_dataset) == 3 * frames_per_episode
1288+
assert loaded_dataset.meta.total_episodes == 3
1289+
assert loaded_dataset.meta.total_frames == 3 * frames_per_episode
1290+
1291+
for idx in range(len(loaded_dataset)):
1292+
frame = loaded_dataset[idx]
1293+
expected_ep = idx // frames_per_episode
1294+
assert frame["episode_index"].item() == expected_ep

0 commit comments

Comments
 (0)