You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/design-docs/loss-functions.md
+2-3Lines changed: 2 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,14 +17,13 @@ $$
17
17
18
18
which is, in general, not equivalent to the full-batch loss. To fix this, we need each microbatch to have information about how many tokens are in the other microbatches in the global batch.
19
19
20
-
In NeMo RL, this information is passed to the loss function directly. Each loss function is expected to fall into one of two categories, token-level or sequence-level, which is an attribute of the loss function itself (see [loss_functions.py](../../nemo_rl/algorithms/loss_functions.py) for some examples). The policy then uses this information to compute the global normalization factor using the full batch (for token-level losses, this is the total number of tokens in the batch. For sequence-level losses, this is the number of valid sequences in the batch). The normalization factor is then passed to the loss function, which uses it to normalize the microbatch loss. To get the loss for the global batch, the policy simply sums across all microbatch losses.
20
+
In NeMo RL, this information is passed to the loss function directly. Each loss function is expected to fall into one of two categories, token-level or sequence-level, which is an attribute of the loss function itself (see [loss_functions.py](../../nemo_rl/algorithms/loss/loss_functions.py) for some examples). The policy then uses this information to compute the global normalization factor using the full batch (for token-level losses, this is the total number of tokens in the batch. For sequence-level losses, this is the number of valid sequences in the batch). The normalization factor is then passed to the loss function, which uses it to normalize the microbatch loss. To get the loss for the global batch, the policy simply sums across all microbatch losses.
21
21
22
22
For our simple example above, this would look like:
23
23
24
24
```{testcode}
25
25
import torch
26
-
from nemo_rl.algorithms.interfaces import LossFunction
27
-
from nemo_rl.algorithms.loss_functions import LossType
26
+
from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType
28
27
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
Copy file name to clipboardExpand all lines: docs/guides/grpo.md
+1-1Lines changed: 1 addition & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -343,7 +343,7 @@ The function, [grpo_train](../../nemo_rl/algorithms/grpo.py), contains the core
343
343
RL generations typically produce highly variable sequence lengths, which result in a significant amount of padding if approached naively. We address this with Sequence Packing and Dynamic Batching, which are techniques to reduce the amount of padding required. You can read more about these in the [design doc](../design-docs/sequence-packing-and-dynamic-batching.md).
344
344
345
345
## Loss
346
-
We use the [ClippedPGLossFn](../../nemo_rl/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally,
346
+
We use the [ClippedPGLossFn](../../nemo_rl/algorithms/loss/loss_functions.py) to calculate the loss for GRPO. Formally,
Copy file name to clipboardExpand all lines: docs/guides/prorlv2.md
+2-2Lines changed: 2 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -106,7 +106,7 @@ loss_fn:
106
106
107
107
This keeps PPO/GRPO-style clipping behavior but allows a larger expansion region than the contraction region, which can help exploration and reduce early collapse.
108
108
109
-
- **Implementation**: `ClippedPGLossFn` documents decoupled clipping in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py).
109
+
- **Implementation**: `ClippedPGLossFn` documents decoupled clipping in [`nemo_rl/algorithms/loss/loss_functions.py`](../../nemo_rl/algorithms/loss/loss_functions.py).
110
110
111
111
## Loss: Token-level Policy Gradient
112
112
@@ -153,7 +153,7 @@ loss_fn:
153
153
- `"icepop"`: set weights outside \([min, max]\) to zero (filter outliers)
154
154
- `"seq-mask-tis"`: sequence-level geometric-mean mask + non-truncated token-level IS correction (see below)
155
155
156
-
- **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py).
156
+
- **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss/loss_functions.py`](../../nemo_rl/algorithms/loss/loss_functions.py).
0 commit comments