Skip to content

Commit 0eddfa0

Browse files
committed
fix unit test
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent aa9b1f4 commit 0eddfa0

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

nemo_rl/algorithms/loss/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
"NLLLossFn",
4646
"PreferenceLossDataDict",
4747
"PreferenceLossFn",
48-
"SequencePackingLossWrapper",
4948
"prepare_loss_input",
49+
"SequencePackingLossWrapper",
5050
"wrap_loss_fn_with_input_preparation",
5151
]

tests/unit/models/automodel/test_automodel_train.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
except ImportError:
2525
pytest.skip("nemo_automodel not available", allow_module_level=True)
2626

27+
from nemo_rl.algorithms.loss.interfaces import LossInputType
2728
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
2829
from nemo_rl.models.automodel.data import (
2930
ProcessedInputs,
@@ -63,6 +64,7 @@ def mock_model():
6364
def mock_loss_fn():
6465
loss_fn = MagicMock()
6566
loss_fn.return_value = (torch.tensor(0.5), {"loss": 0.5})
67+
loss_fn.input_type = LossInputType.LOGIT
6668
return loss_fn
6769

6870

@@ -310,10 +312,10 @@ def test_basic_loss_computation(
310312

311313
# Verify loss function was called
312314
mock_loss_fn.assert_called_once()
313-
call_args = mock_loss_fn.call_args[0]
314-
assert torch.is_tensor(call_args[0]) # logits
315-
assert call_args[2] == global_valid_seqs # global_valid_seqs
316-
assert call_args[3] == global_valid_toks # global_valid_toks
315+
call_kwargs = mock_loss_fn.call_args[1]
316+
assert torch.is_tensor(call_kwargs["logits"])
317+
assert call_kwargs["global_valid_seqs"] == global_valid_seqs
318+
assert call_kwargs["global_valid_toks"] == global_valid_toks
317319

318320
@patch("nemo_rl.models.automodel.train.SequencePackingLossWrapper")
319321
def test_loss_with_sequence_packing(
@@ -1896,10 +1898,12 @@ def forward(self, input_ids, **kwargs):
18961898
)
18971899

18981900
# Create loss function that returns requires_grad tensor
1899-
def loss_fn(logits, mb, global_valid_seqs, global_valid_toks):
1901+
def loss_fn(logits, data, global_valid_seqs, global_valid_toks):
19001902
loss = logits.mean()
19011903
return loss, {"loss": loss.item()}
19021904

1905+
loss_fn.input_type = LossInputType.LOGIT
1906+
19031907
# Create loss post-processor
19041908
loss_post_processor = LossPostProcessor(
19051909
loss_fn=loss_fn,

tests/unit/models/generation/test_vllm_generation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -894,13 +894,13 @@ async def run_hf_train_process(
894894
@pytest.mark.parametrize(
895895
("async_engine", "cpu_offload", "vllm_precision", "enable_lora"),
896896
[
897-
# (True, False, "bfloat16", False),
898-
# (False, True, "bfloat16", False),
899-
# (True, False, "fp8", False),
900-
# (False, True, "fp8", False),
897+
(True, False, "bfloat16", False),
898+
(False, True, "bfloat16", False),
899+
(True, False, "fp8", False),
900+
(False, True, "fp8", False),
901901
# LoRA tests
902902
(False, False, "bfloat16", True),
903-
# (True, False, "bfloat16", True),
903+
(True, False, "bfloat16", True),
904904
],
905905
)
906906
async def test_vllm_generation_with_hf_training_colocated(

0 commit comments

Comments
 (0)