Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 38 additions & 36 deletions nemo_rl/distributed/batched_data_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,9 @@ def _get_padded_seqlen(seqlen: int) -> int:

aggregated_shards = [SlicedDataDict() for _ in range(shards)]

# Group data by shard position across all chunks
# Group data by shard position across all chunks.
for shard_idx in range(shards):
shard_ranges: list[tuple[int, int]] = []
for chunk_idx in range(num_chunks):
# Calculate indices for this particular sub-shard within the chunk
chunk_start = chunk_idx * batch_size
Expand All @@ -549,41 +550,42 @@ def _get_padded_seqlen(seqlen: int) -> int:
# or if shard_end calculation goes beyond total_batch_size
shard_start = min(shard_start, total_batch_size)
shard_end = min(shard_end, total_batch_size)
indices = torch.arange(shard_start, shard_end)

for k in data:
if k not in aggregated_shards[shard_idx]:
# First time seeing this key for this shard, initialize it
if torch.is_tensor(data[k]):
aggregated_shards[shard_idx][k] = data[k][indices].clone()
elif isinstance(data[k], PackedTensor):
aggregated_shards[shard_idx][k] = data[k].slice(
indices.tolist()
)
else:
aggregated_shards[shard_idx][k] = [
data[k][i] for i in indices
]
else:
# Append to existing data - concatenate tensors or extend lists
if torch.is_tensor(data[k]):
aggregated_shards[shard_idx][k] = torch.cat(
[
aggregated_shards[shard_idx][k],
data[k][indices].clone(),
]
)
elif isinstance(data[k], PackedTensor):
aggregated_shards[shard_idx][k] = PackedTensor.concat(
[
aggregated_shards[shard_idx][k],
data[k].slice(indices.tolist()),
]
)
else:
aggregated_shards[shard_idx][k].extend(
[data[k][i] for i in indices]
)

if shard_start < shard_end:
shard_ranges.append((shard_start, shard_end))

for k, v in data.items():
if torch.is_tensor(v):
# Pre-allocate and copy each chunk exactly once.
rows = sum(end - start for start, end in shard_ranges)
shard_tensor = torch.empty(
(rows, *v.shape[1:]),
dtype=v.dtype,
device=v.device,
)
offset = 0
for start, end in shard_ranges:
span = end - start
shard_tensor[offset : offset + span].copy_(v[start:end])
offset += span
aggregated_shards[shard_idx][k] = shard_tensor
elif isinstance(v, PackedTensor):
# PackedTensor is collected per chunk then concatenated once.
packed_slices = [
v.slice(list(range(start, end))) for start, end in shard_ranges
]
if packed_slices:
aggregated_shards[shard_idx][k] = (
PackedTensor.concat(packed_slices)
if len(packed_slices) > 1
else packed_slices[0]
)
else:
shard_values = []
for start, end in shard_ranges:
shard_values.extend([v[i] for i in range(start, end)])

aggregated_shards[shard_idx][k] = shard_values

# map inputs to microbatches such that the total number tokens in
# a microbatch is as close to (including padding tokens) 'max_tokens_per_microbatch'
Expand Down
Loading