Skip to content

Commit e363ebc

Browse files
committed
address comments
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent 3784566 commit e363ebc

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

nemo_rl/algorithms/loss/interfaces.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,30 +43,29 @@ class LossFunction(Protocol):
4343

4444
def __call__(
4545
self,
46-
next_token_logits: torch.Tensor,
4746
data: BatchedDataDict,
4847
global_valid_seqs: torch.Tensor,
4948
global_valid_toks: torch.Tensor,
49+
**kwargs: Any,
5050
) -> tuple[torch.Tensor, dict[str, Any]]:
5151
"""Compute loss and metrics from logprobs and other data.
5252
5353
Args:
54-
next_token_logits: Logits from the model, typically with shape [batch_size, seq_len, vocab_size].
55-
For each position (b, i), contains the logit distribution over the entire vocabulary
56-
for predicting the next token (at position i+1). For example, if processing "The cat sat on",
57-
then next_token_logits[b, 3] would contain the logits for predicting the word
58-
that follows "on".
5954
data: Dictionary containing all relevant data for loss computation
6055
such as rewards, values, actions, advantages, masks, and other
6156
algorithm-specific information needed for the particular loss calculation.
6257
global_valid_seqs: torch.Tensor
63-
this tensor should contain the number of valid sequences in the microbatch.
58+
This tensor should contain the number of valid sequences in the microbatch.
6459
It's used for global normalization for losses/metrics that are computed at the sequence level
6560
and needs to be aggregated across all microbatches.
6661
global_valid_toks: torch.Tensor
6762
This tensor should contain the number of valid tokens in the microbatch.
6863
It's used for global normalization for losses/metrics that are computed at the token level
6964
and needs to be aggregated across all microbatches.
65+
**kwargs: Loss function input, which varies by input_type:
66+
- For LossInputType.LOGPROB: next_token_logprobs (torch.Tensor)
67+
- For LossInputType.LOGIT: logits (torch.Tensor)
68+
- For LossInputType.DISTILLATION: student_topk_logprobs, teacher_topk_logprobs, H_all (torch.Tensor)
7069
7170
Returns:
7271
tuple: (loss, metrics)

nemo_rl/algorithms/loss/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def prepare_loss_input(
3838
logits: Logits from the model.
3939
data: Microbatch data.
4040
loss_fn: Loss function.
41+
vocab_parallel_rank: Vocab parallel rank.
42+
vocab_parallel_group: Vocab parallel group.
43+
context_parallel_group: Context parallel group.
44+
45+
vocab_parallel_rank, vocab_parallel_group, context_parallel_group are only used for megatron policy worker.
4146
4247
Returns:
4348
Loss input.

nemo_rl/algorithms/loss/wrapper.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ def __init__(
3535
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
3636
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
3737
):
38+
"""Wrap a loss function to handle sequence packing.
39+
40+
Args:
41+
loss_fn: Loss function.
42+
prepare_fn: Prepare function.
43+
cu_seqlens_q: Unpadded cu seqlens q.
44+
cu_seqlens_q_padded: Padded cu seqlens q.
45+
vocab_parallel_rank: Vocab parallel rank.
46+
vocab_parallel_group: Vocab parallel group.
47+
context_parallel_group: Context parallel group.
48+
49+
vocab_parallel_rank, vocab_parallel_group, context_parallel_group are only used for megatron policy worker.
50+
51+
Returns:
52+
Sequence packing loss wrapper.
53+
"""
3854
self.loss_fn = loss_fn
3955
self.prepare_fn = prepare_fn
4056
self.cu_seqlens_q = cu_seqlens_q

0 commit comments

Comments
 (0)