Skip to content

Commit 4fa9ea0

Browse files
committed
fix: rename samples to seq_lens and fix typos
- Rename variable 'samples' to 'seq_lens' for clarity since it holds sequence lengths - Remove unused assignment (dead code) - Fix typo 'mirobatches' -> 'microbatches' in comments
1 parent 5be0387 commit 4fa9ea0

File tree

1 file changed

+7
-9
lines changed
  • slime/backends/megatron_utils

1 file changed

+7
-9
lines changed

slime/backends/megatron_utils/data.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,14 @@ def _generate_data_iterator(rollout_data, micro_batch_size, micro_batch_indices=
308308
data_iterator = _generate_data_iterator(rollout_data, args.micro_batch_size)
309309
else:
310310
assert args.max_tokens_per_gpu is not None
311-
# calculate the number of mirobatches for each step
312-
samples = rollout_data["total_lengths"]
313-
assert len(samples) == num_local_samples
311+
# calculate the number of microbatches for each step
312+
seq_lens = rollout_data["total_lengths"]
313+
assert len(seq_lens) == num_local_samples
314314
num_microbatches = []
315315
for i in range(num_steps_per_rollout):
316316
start, end = i * num_local_gbs, (i + 1) * num_local_gbs
317317
num_microbatches.append(
318-
get_minimum_num_micro_batch_size(samples[start:end], args.max_tokens_per_gpu * cp_size)
318+
get_minimum_num_micro_batch_size(seq_lens[start:end], args.max_tokens_per_gpu * cp_size)
319319
)
320320

321321
num_microbatches = torch.tensor(num_microbatches, dtype=torch.int, device=torch.cuda.current_device())
@@ -330,14 +330,12 @@ def _generate_data_iterator(rollout_data, micro_batch_size, micro_batch_indices=
330330

331331
num_microbatches = num_microbatches.tolist()
332332

333-
# balance the each micro batch
334-
samples = rollout_data["total_lengths"]
335-
# balance the number of mirobatches across steps
333+
# balance the number of microbatches across steps
336334
micro_batch_indices = []
337335
for i, num_mbs in enumerate(num_microbatches):
338336
start, end = i * num_local_gbs, (i + 1) * num_local_gbs
339-
samples = rollout_data["total_lengths"][start:end]
340-
partitions = get_seqlen_balanced_partitions(samples, num_mbs, equal_size=False)
337+
seq_lens = rollout_data["total_lengths"][start:end]
338+
partitions = get_seqlen_balanced_partitions(seq_lens, num_mbs, equal_size=False)
341339
for j in range(num_mbs):
342340
for k in range(len(partitions[j])):
343341
partitions[j][k] += start

0 commit comments

Comments
 (0)