Skip to content

Commit a6800af

Browse files
ashors1ananthsub
authored andcommitted
fix unit test
Signed-off-by: ashors1 <ashors@nvidia.com>
1 parent e596d1b commit a6800af

File tree

1 file changed

+30
-13
lines changed

1 file changed

+30
-13
lines changed

tests/unit/algorithms/test_sequence_packing_gradients.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ def __init__(self, cp_size):
4141

4242
def test_sequence_packing_gradients(self):
4343
from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank
44-
from nemo_rl.models.megatron.common import (
45-
forward_step_arbitrary_loss,
44+
from nemo_rl.models.megatron.train import (
45+
forward_with_post_processing_fn,
46+
LossPostProcessor,
4647
)
4748
from nemo_rl.models.megatron.data import (
4849
_pack_sequences_for_megatron,
@@ -289,7 +290,7 @@ def make_packed_logits(logits):
289290
packed_grad, baseline_grad_store, atol=1e-5, rtol=1e-5
290291
)
291292

292-
# test 3: with forward_step_arbitrary_loss
293+
# test 3: with forward_with_post_processing_fn
293294
# reset grad
294295
baseline_logits.grad.zero_()
295296
packed_logits = make_packed_logits(baseline_logits)
@@ -307,15 +308,24 @@ def forward(
307308
):
308309
return self.logits
309310

310-
class MockMcoreState:
311-
def __init__(self):
312-
# context that does nothing, but supports both with straggler_timer and with straggler_timer(bdata=True)
313-
from contextlib import nullcontext
311+
cfg = {
312+
"sequence_packing": {"enabled": True},
313+
"dynamic_batching": {"enabled": False},
314+
"megatron_cfg": {
315+
"tensor_model_parallel_size": 1,
316+
"sequence_parallel": False,
317+
"pipeline_model_parallel_size": 1,
318+
"context_parallel_size": cp_size,
319+
},
320+
}
314321

315-
class DummyStragglerTimer:
316-
def __call__(self, *args, **kwargs):
317-
return nullcontext()
322+
post_processor = LossPostProcessor(
323+
loss_fn=base_loss_fn,
324+
cfg=cfg,
325+
cp_normalize=True,
326+
)
318327

328+
<<<<<<< HEAD
319329
def __enter__(self):
320330
return self
321331

@@ -342,16 +352,23 @@ def __exit__(self, exc_type, exc_val, exc_tb):
342352
"context_parallel_size": cp_size,
343353
},
344354
},
355+
=======
356+
output_tensor, wrapped_loss_fn = forward_with_post_processing_fn(
357+
data_iterator=make_processed_microbatch_iterator(
358+
iter([packed_data_dict]),
359+
cfg=cfg,
360+
>>>>>>> a11ae1b2e (fix unit test)
345361
seq_length_key="input_lengths",
346362
pad_individual_seqs_to_multiple_of=pad_to_multiple,
347363
pad_packed_seq_to_multiple_of=1,
348364
straggler_timer=mock_mcore_state.straggler_timer,
349365
pad_full_seq_to=max_seq_len * batch_size if cp_size > 1 else None,
350366
),
351367
model=MockModel(),
352-
loss_fn=base_loss_fn,
353-
pack_sequences=True,
354-
cp_normalize=True,
368+
cfg=cfg,
369+
post_processing_fn=post_processor,
370+
global_valid_seqs=global_valid_seqs,
371+
global_valid_toks=global_valid_toks,
355372
)
356373
loss, metrics = wrapped_loss_fn(output_tensor)
357374

0 commit comments

Comments
 (0)