Skip to content

Commit dc9dce4

Browse files
authored
refactor: refactor loss function (#1920)
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent 7684dc2 commit dc9dce4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+924
-793
lines changed

docs/design-docs/loss-functions.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@ $$
1717

1818
which is, in general, not equivalent to the full-batch loss. To fix this, we need each microbatch to have information about how many tokens are in the other microbatches in the global batch.
1919

20-
In NeMo RL, this information is passed to the loss function directly. Each loss function is expected to fall into one of two categories, token-level or sequence-level, which is an attribute of the loss function itself (see [loss_functions.py](../../nemo_rl/algorithms/loss_functions.py) for some examples). The policy then uses this information to compute the global normalization factor using the full batch (for token-level losses, this is the total number of tokens in the batch. For sequence-level losses, this is the number of valid sequences in the batch). The normalization factor is then passed to the loss function, which uses it to normalize the microbatch loss. To get the loss for the global batch, the policy simply sums across all microbatch losses.
20+
In NeMo RL, this information is passed to the loss function directly. Each loss function is expected to fall into one of two categories, token-level or sequence-level, which is an attribute of the loss function itself (see [loss_functions.py](../../nemo_rl/algorithms/loss/loss_functions.py) for some examples). The policy then uses this information to compute the global normalization factor using the full batch (for token-level losses, this is the total number of tokens in the batch. For sequence-level losses, this is the number of valid sequences in the batch). The normalization factor is then passed to the loss function, which uses it to normalize the microbatch loss. To get the loss for the global batch, the policy simply sums across all microbatch losses.
2121

2222
For our simple example above, this would look like:
2323

2424
```{testcode}
2525
import torch
26-
from nemo_rl.algorithms.interfaces import LossFunction
27-
from nemo_rl.algorithms.loss_functions import LossType
26+
from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType
2827
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
2928
3029

docs/guides/grpo.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ The function, [grpo_train](../../nemo_rl/algorithms/grpo.py), contains the core
343343
RL generations typically produce highly variable sequence lengths, which result in a significant amount of padding if approached naively. We address this with Sequence Packing and Dynamic Batching, which are techniques to reduce the amount of padding required. You can read more about these in the [design doc](../design-docs/sequence-packing-and-dynamic-batching.md).
344344

345345
## Loss
346-
We use the [ClippedPGLossFn](../../nemo_rl/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally,
346+
We use the [ClippedPGLossFn](../../nemo_rl/algorithms/loss/loss_functions.py) to calculate the loss for GRPO. Formally,
347347

348348
$$
349349
L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta D_{\text{KL}} (\pi_\theta \| \pi_\text{ref})

docs/guides/prorlv2.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ loss_fn:
106106

107107
This keeps PPO/GRPO-style clipping behavior but allows a larger expansion region than the contraction region, which can help exploration and reduce early collapse.
108108

109-
- **Implementation**: `ClippedPGLossFn` documents decoupled clipping in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py).
109+
- **Implementation**: `ClippedPGLossFn` documents decoupled clipping in [`nemo_rl/algorithms/loss/loss_functions.py`](../../nemo_rl/algorithms/loss/loss_functions.py).
110110

111111
## Loss: Token-level Policy Gradient
112112

@@ -153,7 +153,7 @@ loss_fn:
153153
- `"icepop"`: set weights outside \([min, max]\) to zero (filter outliers)
154154
- `"seq-mask-tis"`: sequence-level geometric-mean mask + non-truncated token-level IS correction (see below)
155155

156-
- **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py).
156+
- **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss/loss_functions.py`](../../nemo_rl/algorithms/loss/loss_functions.py).
157157

158158
### Seq-mask-tis: Sequence-level Geometric-Mean Mask
159159

nemo_rl/algorithms/distillation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
2525

2626
from nemo_rl.algorithms.grpo import _should_use_async_rollouts, refit_policy_generation
27-
from nemo_rl.algorithms.loss_functions import (
27+
from nemo_rl.algorithms.loss import (
2828
DistillationLossConfig,
2929
DistillationLossDataDict,
3030
DistillationLossFn,

nemo_rl/algorithms/dpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torchdata.stateful_dataloader import StatefulDataLoader
2424
from transformers import AutoTokenizer
2525

26-
from nemo_rl.algorithms.loss_functions import DPOLossFn
26+
from nemo_rl.algorithms.loss import DPOLossFn
2727
from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed
2828
from nemo_rl.data import DataConfig
2929
from nemo_rl.data.collate_fn import preference_collate_fn

nemo_rl/algorithms/grpo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@
3131
GRPOAdvantageEstimator,
3232
ReinforcePlusPlusAdvantageEstimator,
3333
)
34-
from nemo_rl.algorithms.interfaces import LossFunction
35-
from nemo_rl.algorithms.loss_functions import (
34+
from nemo_rl.algorithms.loss import (
3635
ClippedPGLossConfig,
3736
ClippedPGLossDataDict,
3837
ClippedPGLossFn,
3938
)
39+
from nemo_rl.algorithms.loss.interfaces import LossFunction
4040
from nemo_rl.algorithms.reward_functions import (
4141
RewardShapingConfig,
4242
apply_reward_shaping,
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from nemo_rl.algorithms.loss.loss_functions import (
16+
ClippedPGLossConfig,
17+
ClippedPGLossDataDict,
18+
ClippedPGLossFn,
19+
DistillationLossConfig,
20+
DistillationLossDataDict,
21+
DistillationLossFn,
22+
DPOLossConfig,
23+
DPOLossDataDict,
24+
DPOLossFn,
25+
NLLLossFn,
26+
PreferenceLossDataDict,
27+
PreferenceLossFn,
28+
)
29+
from nemo_rl.algorithms.loss.utils import prepare_loss_input
30+
from nemo_rl.algorithms.loss.wrapper import (
31+
SequencePackingLossWrapper,
32+
wrap_loss_fn_with_input_preparation,
33+
)
34+
35+
__all__ = [
36+
"ClippedPGLossConfig",
37+
"ClippedPGLossDataDict",
38+
"ClippedPGLossFn",
39+
"DistillationLossConfig",
40+
"DistillationLossDataDict",
41+
"DistillationLossFn",
42+
"DPOLossConfig",
43+
"DPOLossDataDict",
44+
"DPOLossFn",
45+
"NLLLossFn",
46+
"PreferenceLossDataDict",
47+
"PreferenceLossFn",
48+
"prepare_loss_input",
49+
"SequencePackingLossWrapper",
50+
"wrap_loss_fn_with_input_preparation",
51+
]
Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ class LossType(enum.Enum):
2525
SEQUENCE_LEVEL = "sequence_level"
2626

2727

28+
class LossInputType(enum.Enum):
29+
LOGIT = "logit"
30+
LOGPROB = "logprob"
31+
DISTILLATION = "distillation"
32+
33+
2834
class LossFunction(Protocol):
2935
"""Signature for loss functions used in reinforcement learning algorithms.
3036
@@ -33,33 +39,33 @@ class LossFunction(Protocol):
3339
"""
3440

3541
loss_type: LossType
42+
input_type: LossInputType
3643

3744
def __call__(
3845
self,
39-
next_token_logits: torch.Tensor,
4046
data: BatchedDataDict,
4147
global_valid_seqs: torch.Tensor,
4248
global_valid_toks: torch.Tensor,
49+
**kwargs: Any,
4350
) -> tuple[torch.Tensor, dict[str, Any]]:
4451
"""Compute loss and metrics from logprobs and other data.
4552
4653
Args:
47-
next_token_logits: Logits from the model, typically with shape [batch_size, seq_len, vocab_size].
48-
For each position (b, i), contains the logit distribution over the entire vocabulary
49-
for predicting the next token (at position i+1). For example, if processing "The cat sat on",
50-
then next_token_logits[b, 3] would contain the logits for predicting the word
51-
that follows "on".
5254
data: Dictionary containing all relevant data for loss computation
5355
such as rewards, values, actions, advantages, masks, and other
5456
algorithm-specific information needed for the particular loss calculation.
5557
global_valid_seqs: torch.Tensor
56-
this tensor should contain the number of valid sequences in the microbatch.
58+
This tensor should contain the number of valid sequences in the microbatch.
5759
It's used for global normalization for losses/metrics that are computed at the sequence level
5860
and needs to be aggregated across all microbatches.
5961
global_valid_toks: torch.Tensor
6062
This tensor should contain the number of valid tokens in the microbatch.
6163
It's used for global normalization for losses/metrics that are computed at the token level
6264
and needs to be aggregated across all microbatches.
65+
**kwargs: Loss function input, which varies by input_type:
66+
- For LossInputType.LOGPROB: next_token_logprobs (torch.Tensor)
67+
- For LossInputType.LOGIT: logits (torch.Tensor)
68+
- For LossInputType.DISTILLATION: student_topk_logprobs, teacher_topk_logprobs, H_all (torch.Tensor)
6369
6470
Returns:
6571
tuple: (loss, metrics)

0 commit comments

Comments
 (0)