@@ -44,7 +44,10 @@ def test_sequence_packing_gradients(self):
4444 from nemo_rl .models .megatron .common import (
4545 forward_step_arbitrary_loss ,
4646 )
47- from nemo_rl .models .megatron .data import _pack_sequences_for_megatron
47+ from nemo_rl .models .megatron .data import (
48+ _pack_sequences_for_megatron ,
49+ make_processed_microbatch_iterator ,
50+ )
4851
4952 # Initialize process group
5053 torch .distributed .init_process_group (backend = "nccl" )
@@ -325,13 +328,27 @@ def __exit__(self, exc_type, exc_val, exc_tb):
325328 MockMcoreState (),
326329 global_valid_seqs ,
327330 global_valid_toks ,
328- data_iterator = iter ([packed_data_dict ]),
331+ data_iterator = make_processed_microbatch_iterator (
332+ iter ([packed_data_dict ]),
333+ cfg = {
334+ "sequence_packing" : {"enabled" : True },
335+ "dynamic_batching" : {"enabled" : False },
336+ "megatron_cfg" : {
337+ "tensor_model_parallel_size" : 1 ,
338+ "sequence_parallel" : False ,
339+ "pipeline_model_parallel_size" : 1 ,
340+ "context_parallel_size" : cp_size ,
341+ },
342+
343+
344+ },
345+ seq_length_key = "input_lengths" ,
346+ pad_individual_seqs_to_multiple_of = pad_to_multiple ,
347+ pad_full_seq_to = max_seq_len * batch_size if cp_size > 1 else None ,
348+ ),
329349 model = MockModel (),
330350 loss_fn = base_loss_fn ,
331351 pack_sequences = True ,
332- seq_length_key = "input_lengths" ,
333- pad_individual_seqs_to_multiple_of = pad_to_multiple ,
334- pad_full_seq_to = max_seq_len * batch_size if cp_size > 1 else None ,
335352 cp_normalize = True ,
336353 )
337354 loss , metrics = wrapped_loss_fn (output_tensor )
0 commit comments