@@ -308,14 +308,14 @@ def _generate_data_iterator(rollout_data, micro_batch_size, micro_batch_indices=
308308 data_iterator = _generate_data_iterator (rollout_data , args .micro_batch_size )
309309 else :
310310 assert args .max_tokens_per_gpu is not None
311- # calculate the number of mirobatches for each step
312- samples = rollout_data ["total_lengths" ]
313- assert len (samples ) == num_local_samples
311+ # calculate the number of microbatches for each step
312+ seq_lens = rollout_data ["total_lengths" ]
313+ assert len (seq_lens ) == num_local_samples
314314 num_microbatches = []
315315 for i in range (num_steps_per_rollout ):
316316 start , end = i * num_local_gbs , (i + 1 ) * num_local_gbs
317317 num_microbatches .append (
318- get_minimum_num_micro_batch_size (samples [start :end ], args .max_tokens_per_gpu * cp_size )
318+ get_minimum_num_micro_batch_size (seq_lens [start :end ], args .max_tokens_per_gpu * cp_size )
319319 )
320320
321321 num_microbatches = torch .tensor (num_microbatches , dtype = torch .int , device = torch .cuda .current_device ())
@@ -330,14 +330,12 @@ def _generate_data_iterator(rollout_data, micro_batch_size, micro_batch_indices=
330330
331331 num_microbatches = num_microbatches .tolist ()
332332
333- # balance the each micro batch
334- samples = rollout_data ["total_lengths" ]
335- # balance the number of mirobatches across steps
333+ # balance the number of microbatches across steps
336334 micro_batch_indices = []
337335 for i , num_mbs in enumerate (num_microbatches ):
338336 start , end = i * num_local_gbs , (i + 1 ) * num_local_gbs
339- samples = rollout_data ["total_lengths" ][start :end ]
340- partitions = get_seqlen_balanced_partitions (samples , num_mbs , equal_size = False )
337+ seq_lens = rollout_data ["total_lengths" ][start :end ]
338+ partitions = get_seqlen_balanced_partitions (seq_lens , num_mbs , equal_size = False )
341339 for j in range (num_mbs ):
342340 for k in range (len (partitions [j ])):
343341 partitions [j ][k ] += start
0 commit comments