| 
8 | 8 | import torch.nn.functional as F  | 
9 | 9 | 
 
  | 
10 | 10 | 
 
  | 
11 |  | -def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor:  | 
 | 11 | +def compute_logprobs(  | 
 | 12 | +    logits: torch.Tensor,  | 
 | 13 | +    input_ids: torch.Tensor,  | 
 | 14 | +    temperature: float = 1.0,  | 
 | 15 | +    align: bool = True,  | 
 | 16 | +) -> torch.Tensor:  | 
12 | 17 |     """  | 
13 |  | -    A memory-efficient implementation of the common `log_softmax -> gather` operation.  | 
 | 18 | +    Computes the log probabilities of the input tokens given the model logits and temperature.  | 
 | 19 | +    Always converts inputs to fp32 for numerical stability.  | 
14 | 20 | 
  | 
15 |  | -    This function is equivalent to the following naive implementation:  | 
16 |  | -    ```python  | 
17 |  | -    logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)  | 
18 |  | -    ```  | 
 | 21 | +    This function handles two common usage patterns:  | 
19 | 22 | 
  | 
20 |  | -    Args:  | 
21 |  | -        logits (`torch.Tensor`):  | 
22 |  | -            Logits tensor of shape `(..., num_classes)`.  | 
23 |  | -        index (`torch.Tensor`):  | 
24 |  | -            Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.  | 
 | 23 | +    **Pattern 1: Pre-aligned logits (align=False)**  | 
 | 24 | +    Use when logits are already aligned with input_ids, typically when you:  | 
 | 25 | +    - Pass input_ids to the model: model(input_ids) -> logits  | 
 | 26 | +    - The model outputs logits[i] that predict target_ids[i]  | 
 | 27 | +    - logits.shape[1] == input_ids.shape[1]  | 
25 | 28 | 
  | 
26 |  | -    Returns:  | 
27 |  | -        `torch.Tensor`:  | 
28 |  | -            Gathered log probabilities with the same shape as `index`.  | 
29 |  | -    """  | 
30 |  | -    if logits.dtype in [torch.float32, torch.float64]:  | 
31 |  | -        selected_logits = torch.gather(  | 
32 |  | -            logits, dim=-1, index=index.unsqueeze(-1)  | 
33 |  | -        ).squeeze(-1)  | 
34 |  | -        # loop to reduce peak mem consumption  | 
35 |  | -        logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])  | 
36 |  | -        per_token_logps = (  | 
37 |  | -            selected_logits - logsumexp_values  | 
38 |  | -        )  # log_softmax(x_i) = x_i - logsumexp(x)  | 
39 |  | -    else:  | 
40 |  | -        # logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach  | 
41 |  | -        per_token_logps = []  | 
42 |  | -        for row_logits, row_labels in zip(  | 
43 |  | -            logits, index  | 
44 |  | -        ):  # loop to reduce peak mem consumption  | 
45 |  | -            row_logps = F.log_softmax(row_logits, dim=-1)  | 
46 |  | -            row_per_token_logps = row_logps.gather(  | 
47 |  | -                dim=-1, index=row_labels.unsqueeze(-1)  | 
48 |  | -            ).squeeze(-1)  | 
49 |  | -            per_token_logps.append(row_per_token_logps)  | 
50 |  | -        per_token_logps = torch.stack(per_token_logps)  | 
51 |  | -    return per_token_logps  | 
 | 29 | +    Example:  | 
 | 30 | +        >>> input_ids = torch.tensor([[1, 2, 3, 4]])  # Model input  | 
 | 31 | +        >>> target_ids = torch.tensor([[2, 3, 4, 5]]) # Shifted by 1 (next-token prediction)  | 
 | 32 | +        >>> logits = model(input_ids)  # Shape: [1, 4, vocab_size]  | 
 | 33 | +        >>> # logits already aligned: logits[:, i] predicts target_ids[:, i]  | 
 | 34 | +        >>> logprobs = compute_logprobs(logits, target_ids, align=False)  | 
52 | 35 | 
  | 
 | 36 | +    **Pattern 2: Full-sequence logits needing alignment (align=True, default)**  | 
 | 37 | +    Use when you have logits for the full sequence but only want log probs for a subset  | 
 | 38 | +    (e.g., just the response tokens, not the prompt). The function will:  | 
 | 39 | +    - Slice logits to match the length of input_ids  | 
 | 40 | +    - Take logits[:, -len(input_ids)-1:-1] to get positions that predict input_ids  | 
53 | 41 | 
  | 
54 |  | -def compute_logprobs(  | 
55 |  | -    logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0  | 
56 |  | -) -> torch.Tensor:  | 
57 |  | -    """  | 
58 |  | -    Computes the log probabilities of the input tokens given the model logits and temperature.  | 
59 |  | -    Always converts inputs to fp32 for numerical stability  | 
 | 42 | +    Example:  | 
 | 43 | +        >>> # Full sequence passed to model: [prompt + response]  | 
 | 44 | +        >>> full_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]])  # Prompt + response  | 
 | 45 | +        >>> logits = model(full_input_ids)  # Shape: [1, 6, vocab_size]  | 
 | 46 | +        >>> # Only want log probs for response tokens  | 
 | 47 | +        >>> response_tokens = torch.tensor([[4, 5, 6]])  # Just the response  | 
 | 48 | +        >>> logprobs = compute_logprobs(logits, response_tokens, align=True)  | 
 | 49 | +        >>> # Function slices logits[:, -4:-1] to get logits that predict tokens [4, 5, 6]  | 
 | 50 | +
  | 
 | 51 | +    The alignment logic ensures that when you have a full sequence but only want log  | 
 | 52 | +    probabilities for the response portion, you don't need to re-run the model. This  | 
 | 53 | +    is a key optimization in RL training where the prompt remains constant.  | 
60 | 54 | 
  | 
61 | 55 |     Args:  | 
62 | 56 |         logits (`torch.Tensor`):  | 
63 | 57 |             The model output logits of shape `(batch_size, sequence_length, vocab_size)`.  | 
64 | 58 |         input_ids (`torch.Tensor`):  | 
65 |  | -            The input token ids of shape `(batch_size, target_sequence_length)`.  | 
 | 59 | +            The target token ids of shape `(batch_size, target_sequence_length)`.  | 
 | 60 | +            These are the tokens for which you want to compute log probabilities.  | 
66 | 61 |         temperature (`float`, *optional*, defaults to 1.0):  | 
67 | 62 |             The temperature value for scaling logits before computing log probabilities.  | 
 | 63 | +            Higher values make the distribution more uniform, lower values more peaked.  | 
 | 64 | +        align (`bool`, *optional*, defaults to True):  | 
 | 65 | +            If True (default), align logits with input_ids by slicing to extract the  | 
 | 66 | +            relevant positions from a longer sequence (Pattern 2).  | 
 | 67 | +            If False, assume logits are already aligned with input_ids (Pattern 1).  | 
68 | 68 | 
  | 
69 | 69 |     Returns:  | 
70 |  | -        logprobs: [batch, seq_len] log probabilities for each token  | 
 | 70 | +        torch.Tensor: Log probabilities of shape `(batch_size, target_sequence_length)`.  | 
 | 71 | +            Each element [b, i] is the log probability of input_ids[b, i] given the  | 
 | 72 | +            corresponding logits.  | 
 | 73 | +
  | 
 | 74 | +    Note:  | 
 | 75 | +        This function uses cross_entropy instead of log_softmax + gather for better  | 
 | 76 | +        numerical stability, especially important for fp16/bf16 training.  | 
71 | 77 |     """  | 
72 |  | -    # Ignore the last token from logits because it predicts the next token (-1)  | 
73 |  | -    # And align logits with the input tokens length.  | 
74 |  | -    logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)  | 
 | 78 | +    # Align logits with input_ids if requested  | 
 | 79 | +    if align:  | 
 | 80 | +        # Ignore the last token from logits because it predicts the next token (-1)  | 
 | 81 | +        # And align logits with the input tokens length.  | 
 | 82 | +        logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)  | 
 | 83 | + | 
75 | 84 |     scaled_logits = logits / temperature  | 
76 | 85 | 
 
  | 
77 | 86 |     # Cast up to fp32 for numerical stability  | 
 | 
0 commit comments