@@ -43,30 +43,29 @@ class LossFunction(Protocol):
4343
4444 def __call__ (
4545 self ,
46- next_token_logits : torch .Tensor ,
4746 data : BatchedDataDict ,
4847 global_valid_seqs : torch .Tensor ,
4948 global_valid_toks : torch .Tensor ,
49+ ** kwargs : Any ,
5050 ) -> tuple [torch .Tensor , dict [str , Any ]]:
5151 """Compute loss and metrics from logprobs and other data.
5252
5353 Args:
54- next_token_logits: Logits from the model, typically with shape [batch_size, seq_len, vocab_size].
55- For each position (b, i), contains the logit distribution over the entire vocabulary
56- for predicting the next token (at position i+1). For example, if processing "The cat sat on",
57- then next_token_logits[b, 3] would contain the logits for predicting the word
58- that follows "on".
5954 data: Dictionary containing all relevant data for loss computation
6055 such as rewards, values, actions, advantages, masks, and other
6156 algorithm-specific information needed for the particular loss calculation.
6257 global_valid_seqs: torch.Tensor
63- this tensor should contain the number of valid sequences in the microbatch.
58+ This tensor should contain the number of valid sequences in the microbatch.
6459 It's used for global normalization for losses/metrics that are computed at the sequence level
6560 and needs to be aggregated across all microbatches.
6661 global_valid_toks: torch.Tensor
6762 This tensor should contain the number of valid tokens in the microbatch.
6863 It's used for global normalization for losses/metrics that are computed at the token level
6964 and needs to be aggregated across all microbatches.
65+ **kwargs: Loss function input, which varies by input_type:
66+ - For LossInputType.LOGPROB: next_token_logprobs (torch.Tensor)
67+ - For LossInputType.LOGIT: logits (torch.Tensor)
68+ - For LossInputType.DISTILLATION: student_topk_logprobs, teacher_topk_logprobs, H_all (torch.Tensor)
7069
7170 Returns:
7271 tuple: (loss, metrics)
0 commit comments