@@ -41,8 +41,9 @@ def __init__(self, cp_size):
4141
4242 def test_sequence_packing_gradients (self ):
4343 from nemo_rl .distributed .model_utils import _get_tokens_on_this_cp_rank
44- from nemo_rl .models .megatron .common import (
45- forward_step_arbitrary_loss ,
44+ from nemo_rl .models .megatron .train import (
45+ forward_with_post_processing_fn ,
46+ LossPostProcessor ,
4647 )
4748 from nemo_rl .models .megatron .data import (
4849 _pack_sequences_for_megatron ,
@@ -289,7 +290,7 @@ def make_packed_logits(logits):
289290 packed_grad , baseline_grad_store , atol = 1e-5 , rtol = 1e-5
290291 )
291292
292- # test 3: with forward_step_arbitrary_loss
293+ # test 3: with forward_with_post_processing_fn
293294 # reset grad
294295 baseline_logits .grad .zero_ ()
295296 packed_logits = make_packed_logits (baseline_logits )
@@ -307,15 +308,24 @@ def forward(
307308 ):
308309 return self .logits
309310
310- class MockMcoreState :
311- def __init__ (self ):
312- # context that does nothing, but supports both with straggler_timer and with straggler_timer(bdata=True)
313- from contextlib import nullcontext
311+ cfg = {
312+ "sequence_packing" : {"enabled" : True },
313+ "dynamic_batching" : {"enabled" : False },
314+ "megatron_cfg" : {
315+ "tensor_model_parallel_size" : 1 ,
316+ "sequence_parallel" : False ,
317+ "pipeline_model_parallel_size" : 1 ,
318+ "context_parallel_size" : cp_size ,
319+ },
320+ }
314321
315- class DummyStragglerTimer :
316- def __call__ (self , * args , ** kwargs ):
317- return nullcontext ()
322+ post_processor = LossPostProcessor (
323+ loss_fn = base_loss_fn ,
324+ cfg = cfg ,
325+ cp_normalize = True ,
326+ )
318327
328+ < << << << HEAD
319329 def __enter__ (self ):
320330 return self
321331
@@ -342,16 +352,23 @@ def __exit__(self, exc_type, exc_val, exc_tb):
342352 "context_parallel_size" : cp_size ,
343353 },
344354 },
355+ == == == =
356+ output_tensor , wrapped_loss_fn = forward_with_post_processing_fn (
357+ data_iterator = make_processed_microbatch_iterator (
358+ iter ([packed_data_dict ]),
359+ cfg = cfg ,
360+ >> >> >> > a11ae1b2e (fix unit test )
345361 seq_length_key = "input_lengths" ,
346362 pad_individual_seqs_to_multiple_of = pad_to_multiple ,
347363 pad_packed_seq_to_multiple_of = 1 ,
348364 straggler_timer = mock_mcore_state .straggler_timer ,
349365 pad_full_seq_to = max_seq_len * batch_size if cp_size > 1 else None ,
350366 ),
351367 model = MockModel (),
352- loss_fn = base_loss_fn ,
353- pack_sequences = True ,
354- cp_normalize = True ,
368+ cfg = cfg ,
369+ post_processing_fn = post_processor ,
370+ global_valid_seqs = global_valid_seqs ,
371+ global_valid_toks = global_valid_toks ,
355372 )
356373 loss , metrics = wrapped_loss_fn (output_tensor )
357374
0 commit comments