Skip to content
5 changes: 2 additions & 3 deletions docs/design-docs/loss-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ $$

which is, in general, not equivalent to the full-batch loss. To fix this, we need each microbatch to have information about how many tokens are in the other microbatches in the global batch.

In NeMo RL, this information is passed to the loss function directly. Each loss function is expected to fall into one of two categories, token-level or sequence-level, which is an attribute of the loss function itself (see [loss_functions.py](../../nemo_rl/algorithms/loss_functions.py) for some examples). The policy then uses this information to compute the global normalization factor using the full batch (for token-level losses, this is the total number of tokens in the batch. For sequence-level losses, this is the number of valid sequences in the batch). The normalization factor is then passed to the loss function, which uses it to normalize the microbatch loss. To get the loss for the global batch, the policy simply sums across all microbatch losses.
In NeMo RL, this information is passed to the loss function directly. Each loss function is expected to fall into one of two categories, token-level or sequence-level, which is an attribute of the loss function itself (see [loss_functions.py](../../nemo_rl/algorithms/loss/loss_functions.py) for some examples). The policy then uses this information to compute the global normalization factor using the full batch (for token-level losses, this is the total number of tokens in the batch. For sequence-level losses, this is the number of valid sequences in the batch). The normalization factor is then passed to the loss function, which uses it to normalize the microbatch loss. To get the loss for the global batch, the policy simply sums across all microbatch losses.

For our simple example above, this would look like:

```{testcode}
import torch
from nemo_rl.algorithms.interfaces import LossFunction
from nemo_rl.algorithms.loss_functions import LossType
from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType
from nemo_rl.distributed.batched_data_dict import BatchedDataDict


Expand Down
2 changes: 1 addition & 1 deletion docs/guides/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ The function, [grpo_train](../../nemo_rl/algorithms/grpo.py), contains the core
RL generations typically produce highly variable sequence lengths, which result in a significant amount of padding if approached naively. We address this with Sequence Packing and Dynamic Batching, which are techniques to reduce the amount of padding required. You can read more about these in the [design doc](../design-docs/sequence-packing-and-dynamic-batching.md).

## Loss
We use the [ClippedPGLossFn](../../nemo_rl/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally,
We use the [ClippedPGLossFn](../../nemo_rl/algorithms/loss/loss_functions.py) to calculate the loss for GRPO. Formally,

$$
L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta D_{\text{KL}} (\pi_\theta \| \pi_\text{ref})
Expand Down
4 changes: 2 additions & 2 deletions docs/guides/prorlv2.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ loss_fn:

This keeps PPO/GRPO-style clipping behavior but allows a larger expansion region than the contraction region, which can help exploration and reduce early collapse.

- **Implementation**: `ClippedPGLossFn` documents decoupled clipping in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py).
- **Implementation**: `ClippedPGLossFn` documents decoupled clipping in [`nemo_rl/algorithms/loss/loss_functions.py`](../../nemo_rl/algorithms/loss/loss_functions.py).

## Loss: Token-level Policy Gradient

Expand Down Expand Up @@ -153,7 +153,7 @@ loss_fn:
- `"icepop"`: set weights outside \([min, max]\) to zero (filter outliers)
- `"seq-mask-tis"`: sequence-level geometric-mean mask + non-truncated token-level IS correction (see below)

- **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py).
- **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss/loss_functions.py`](../../nemo_rl/algorithms/loss/loss_functions.py).

### Seq-mask-tis: Sequence-level Geometric-Mean Mask

Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from nemo_rl.algorithms.grpo import _should_use_async_rollouts, refit_policy_generation
from nemo_rl.algorithms.loss_functions import (
from nemo_rl.algorithms.loss import (
DistillationLossConfig,
DistillationLossDataDict,
DistillationLossFn,
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoTokenizer

from nemo_rl.algorithms.loss_functions import DPOLossFn
from nemo_rl.algorithms.loss import DPOLossFn
from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed
from nemo_rl.data import DataConfig
from nemo_rl.data.collate_fn import preference_collate_fn
Expand Down
4 changes: 2 additions & 2 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
GRPOAdvantageEstimator,
ReinforcePlusPlusAdvantageEstimator,
)
from nemo_rl.algorithms.interfaces import LossFunction
from nemo_rl.algorithms.loss_functions import (
from nemo_rl.algorithms.loss import (
ClippedPGLossConfig,
ClippedPGLossDataDict,
ClippedPGLossFn,
)
from nemo_rl.algorithms.loss.interfaces import LossFunction
from nemo_rl.algorithms.reward_functions import (
RewardShapingConfig,
apply_reward_shaping,
Expand Down
51 changes: 51 additions & 0 deletions nemo_rl/algorithms/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_rl.algorithms.loss.loss_functions import (
ClippedPGLossConfig,
ClippedPGLossDataDict,
ClippedPGLossFn,
DistillationLossConfig,
DistillationLossDataDict,
DistillationLossFn,
DPOLossConfig,
DPOLossDataDict,
DPOLossFn,
NLLLossFn,
PreferenceLossDataDict,
PreferenceLossFn,
)
from nemo_rl.algorithms.loss.utils import prepare_loss_input
from nemo_rl.algorithms.loss.wrapper import (
SequencePackingLossWrapper,
wrap_loss_fn_with_input_preparation,
)

__all__ = [
"ClippedPGLossConfig",
"ClippedPGLossDataDict",
"ClippedPGLossFn",
"DistillationLossConfig",
"DistillationLossDataDict",
"DistillationLossFn",
"DPOLossConfig",
"DPOLossDataDict",
"DPOLossFn",
"NLLLossFn",
"PreferenceLossDataDict",
"PreferenceLossFn",
"prepare_loss_input",
"SequencePackingLossWrapper",
"wrap_loss_fn_with_input_preparation",
]
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ class LossType(enum.Enum):
SEQUENCE_LEVEL = "sequence_level"


class LossInputType(enum.Enum):
LOGIT = "logit"
LOGPROB = "logprob"
DISTILLATION = "distillation"


class LossFunction(Protocol):
"""Signature for loss functions used in reinforcement learning algorithms.

Expand All @@ -33,33 +39,33 @@ class LossFunction(Protocol):
"""

loss_type: LossType
input_type: LossInputType

def __call__(
self,
next_token_logits: torch.Tensor,
data: BatchedDataDict,
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor,
**kwargs: Any,
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Compute loss and metrics from logprobs and other data.

Args:
next_token_logits: Logits from the model, typically with shape [batch_size, seq_len, vocab_size].
For each position (b, i), contains the logit distribution over the entire vocabulary
for predicting the next token (at position i+1). For example, if processing "The cat sat on",
then next_token_logits[b, 3] would contain the logits for predicting the word
that follows "on".
data: Dictionary containing all relevant data for loss computation
such as rewards, values, actions, advantages, masks, and other
algorithm-specific information needed for the particular loss calculation.
global_valid_seqs: torch.Tensor
this tensor should contain the number of valid sequences in the microbatch.
This tensor should contain the number of valid sequences in the microbatch.
It's used for global normalization for losses/metrics that are computed at the sequence level
and needs to be aggregated across all microbatches.
global_valid_toks: torch.Tensor
This tensor should contain the number of valid tokens in the microbatch.
It's used for global normalization for losses/metrics that are computed at the token level
and needs to be aggregated across all microbatches.
**kwargs: Loss function input, which varies by input_type:
- For LossInputType.LOGPROB: next_token_logprobs (torch.Tensor)
- For LossInputType.LOGIT: logits (torch.Tensor)
- For LossInputType.DISTILLATION: student_topk_logprobs, teacher_topk_logprobs, H_all (torch.Tensor)

Returns:
tuple: (loss, metrics)
Expand Down
Loading
Loading