|
24 | 24 | except ImportError: |
25 | 25 | pytest.skip("nemo_automodel not available", allow_module_level=True) |
26 | 26 |
|
| 27 | +from nemo_rl.algorithms.loss.interfaces import LossInputType |
27 | 28 | from nemo_rl.distributed.batched_data_dict import BatchedDataDict |
28 | 29 | from nemo_rl.models.automodel.data import ( |
29 | 30 | ProcessedInputs, |
@@ -63,6 +64,7 @@ def mock_model(): |
63 | 64 | def mock_loss_fn(): |
64 | 65 | loss_fn = MagicMock() |
65 | 66 | loss_fn.return_value = (torch.tensor(0.5), {"loss": 0.5}) |
| 67 | + loss_fn.input_type = LossInputType.LOGIT |
66 | 68 | return loss_fn |
67 | 69 |
|
68 | 70 |
|
@@ -310,10 +312,10 @@ def test_basic_loss_computation( |
310 | 312 |
|
311 | 313 | # Verify loss function was called |
312 | 314 | mock_loss_fn.assert_called_once() |
313 | | - call_args = mock_loss_fn.call_args[0] |
314 | | - assert torch.is_tensor(call_args[0]) # logits |
315 | | - assert call_args[2] == global_valid_seqs # global_valid_seqs |
316 | | - assert call_args[3] == global_valid_toks # global_valid_toks |
| 315 | + call_kwargs = mock_loss_fn.call_args[1] |
| 316 | + assert torch.is_tensor(call_kwargs["logits"]) |
| 317 | + assert call_kwargs["global_valid_seqs"] == global_valid_seqs |
| 318 | + assert call_kwargs["global_valid_toks"] == global_valid_toks |
317 | 319 |
|
318 | 320 | @patch("nemo_rl.models.automodel.train.SequencePackingLossWrapper") |
319 | 321 | def test_loss_with_sequence_packing( |
@@ -1896,10 +1898,12 @@ def forward(self, input_ids, **kwargs): |
1896 | 1898 | ) |
1897 | 1899 |
|
1898 | 1900 | # Create loss function that returns requires_grad tensor |
1899 | | - def loss_fn(logits, mb, global_valid_seqs, global_valid_toks): |
| 1901 | + def loss_fn(logits, data, global_valid_seqs, global_valid_toks): |
1900 | 1902 | loss = logits.mean() |
1901 | 1903 | return loss, {"loss": loss.item()} |
1902 | 1904 |
|
| 1905 | + loss_fn.input_type = LossInputType.LOGIT |
| 1906 | + |
1903 | 1907 | # Create loss post-processor |
1904 | 1908 | loss_post_processor = LossPostProcessor( |
1905 | 1909 | loss_fn=loss_fn, |
|
0 commit comments