Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 52 additions & 12 deletions src/lerobot/datasets/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def update_meta_data(
dst_meta,
meta_idx,
data_idx,
data_mapping,
videos_idx,
):
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
Expand All @@ -119,6 +120,8 @@ def update_meta_data(
dst_meta: Destination dataset metadata.
meta_idx: Dictionary containing current metadata chunk and file indices.
data_idx: Dictionary containing current data chunk and file indices.
data_mapping: Mapping of source `(chunk_index, file_index)` pairs to
their destination equivalents.
videos_idx: Dictionary containing current video indices and timestamps.

Returns:
Expand All @@ -127,8 +130,29 @@ def update_meta_data(

df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
if data_mapping:
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
df["_orig_data_file"] = df["data/file_index"].copy()

new_chunk_indices: list[int] = []
new_file_indices: list[int] = []
for orig_chunk, orig_file in zip(
df["_orig_data_chunk"], df["_orig_data_file"], strict=False
):
mapped = data_mapping.get((int(orig_chunk), int(orig_file)))
if mapped is None:
new_chunk_indices.append(int(orig_chunk) + data_idx["chunk"])
new_file_indices.append(int(orig_file) + data_idx["file"])
else:
new_chunk_indices.append(mapped[0])
new_file_indices.append(mapped[1])

df["data/chunk_index"] = new_chunk_indices
df["data/file_index"] = new_file_indices
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
else:
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
# Store original video file indices before updating
orig_chunk_col = f"videos/{key}/chunk_index"
Expand Down Expand Up @@ -237,9 +261,13 @@ def aggregate_datasets(

for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
data_idx, data_mapping = aggregate_data(
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size
)

meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
meta_idx = aggregate_metadata(
src_meta, dst_meta, meta_idx, data_idx, data_mapping, videos_idx
)

dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames
Expand Down Expand Up @@ -355,8 +383,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
data_idx: Dictionary tracking data chunk and file indices.

Returns:
dict: Updated data_idx with current chunk and file indices.
tuple[dict, dict]: Updated ``data_idx`` and a mapping from source
``(chunk_index, file_index)`` pairs to their destination counterparts.
"""
data_mapping: dict[tuple[int, int], tuple[int, int]] = {}
unique_chunk_file_ids = {
(c, f)
for c, f in zip(
Expand All @@ -373,7 +403,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)

data_idx = append_or_create_parquet_file(
data_idx, dst_chunk_idx, dst_file_idx = append_or_create_parquet_file(
df,
src_path,
data_idx,
Expand All @@ -383,11 +413,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
contains_images=len(dst_meta.image_keys) > 0,
aggr_root=dst_meta.root,
)
data_mapping[(src_chunk_idx, src_file_idx)] = (dst_chunk_idx, dst_file_idx)

return data_idx
return data_idx, data_mapping


def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_mapping, videos_idx):
"""Aggregates metadata from a source dataset into the destination dataset.

Reads source metadata files, updates all indices and timestamps,
Expand All @@ -398,6 +429,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
dst_meta: Destination dataset metadata.
meta_idx: Dictionary tracking metadata chunk and file indices.
data_idx: Dictionary tracking data chunk and file indices.
data_mapping: Mapping from source data files to destination data files.
videos_idx: Dictionary tracking video indices and timestamps.

Returns:
Expand All @@ -421,10 +453,11 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
dst_meta,
meta_idx,
data_idx,
data_mapping,
videos_idx,
)

meta_idx = append_or_create_parquet_file(
meta_idx, _, _ = append_or_create_parquet_file(
df,
src_path,
meta_idx,
Expand Down Expand Up @@ -468,17 +501,20 @@ def append_or_create_parquet_file(
aggr_root: Root path for the aggregated dataset.

Returns:
dict: Updated index dictionary with current chunk and file indices.
tuple[dict, int, int]: Updated index dictionary along with the chunk and
file indices where ``df`` was written.
"""
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
target_chunk = idx["chunk"]
target_file = idx["file"]

if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images:
to_parquet_with_hf_images(df, dst_path)
else:
df.to_parquet(dst_path)
return idx
return idx, target_chunk, target_file

src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
Expand All @@ -489,17 +525,21 @@ def append_or_create_parquet_file(
new_path.parent.mkdir(parents=True, exist_ok=True)
final_df = df
target_path = new_path
target_chunk = idx["chunk"]
target_file = idx["file"]
else:
existing_df = pd.read_parquet(dst_path)
final_df = pd.concat([existing_df, df], ignore_index=True)
target_path = dst_path
target_chunk = idx["chunk"]
target_file = idx["file"]

if contains_images:
to_parquet_with_hf_images(final_df, target_path)
else:
final_df.to_parquet(target_path)

return idx
return idx, target_chunk, target_file


def finalize_aggregation(aggr_meta, all_metadata):
Expand Down