|
8 | 8 | import torch.nn.functional as F |
9 | 9 |
|
10 | 10 |
|
11 | | -def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor: |
| 11 | +def compute_logprobs( |
| 12 | + logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0, align: bool = True |
| 13 | +) -> torch.Tensor: |
12 | 14 | """ |
13 | | - A memory-efficient implementation of the common `log_softmax -> gather` operation. |
| 15 | + Computes the log probabilities of the input tokens given the model logits and temperature. |
| 16 | + Always converts inputs to fp32 for numerical stability. |
14 | 17 |
|
15 | | - This function is equivalent to the following naive implementation: |
16 | | - ```python |
17 | | - logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) |
18 | | - ``` |
| 18 | + This function handles two common usage patterns: |
19 | 19 |
|
20 | | - Args: |
21 | | - logits (`torch.Tensor`): |
22 | | - Logits tensor of shape `(..., num_classes)`. |
23 | | - index (`torch.Tensor`): |
24 | | - Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. |
| 20 | + **Pattern 1: Pre-aligned logits (align=False)** |
| 21 | + Use when logits are already aligned with input_ids, typically when you: |
| 22 | + - Pass input_ids to the model: model(input_ids) -> logits |
| 23 | + - The model outputs logits[i] that predict target_ids[i] |
| 24 | + - logits.shape[1] == input_ids.shape[1] |
25 | 25 |
|
26 | | - Returns: |
27 | | - `torch.Tensor`: |
28 | | - Gathered log probabilities with the same shape as `index`. |
29 | | - """ |
30 | | - if logits.dtype in [torch.float32, torch.float64]: |
31 | | - selected_logits = torch.gather( |
32 | | - logits, dim=-1, index=index.unsqueeze(-1) |
33 | | - ).squeeze(-1) |
34 | | - # loop to reduce peak mem consumption |
35 | | - logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) |
36 | | - per_token_logps = ( |
37 | | - selected_logits - logsumexp_values |
38 | | - ) # log_softmax(x_i) = x_i - logsumexp(x) |
39 | | - else: |
40 | | - # logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach |
41 | | - per_token_logps = [] |
42 | | - for row_logits, row_labels in zip( |
43 | | - logits, index |
44 | | - ): # loop to reduce peak mem consumption |
45 | | - row_logps = F.log_softmax(row_logits, dim=-1) |
46 | | - row_per_token_logps = row_logps.gather( |
47 | | - dim=-1, index=row_labels.unsqueeze(-1) |
48 | | - ).squeeze(-1) |
49 | | - per_token_logps.append(row_per_token_logps) |
50 | | - per_token_logps = torch.stack(per_token_logps) |
51 | | - return per_token_logps |
| 26 | + Example: |
| 27 | + >>> input_ids = torch.tensor([[1, 2, 3, 4]]) # Model input |
| 28 | + >>> target_ids = torch.tensor([[2, 3, 4, 5]]) # Shifted by 1 (next-token prediction) |
| 29 | + >>> logits = model(input_ids) # Shape: [1, 4, vocab_size] |
| 30 | + >>> # logits already aligned: logits[:, i] predicts target_ids[:, i] |
| 31 | + >>> logprobs = compute_logprobs(logits, target_ids, align=False) |
52 | 32 |
|
| 33 | + **Pattern 2: Full-sequence logits needing alignment (align=True, default)** |
| 34 | + Use when you have logits for the full sequence but only want log probs for a subset |
| 35 | + (e.g., just the response tokens, not the prompt). The function will: |
| 36 | + - Slice logits to match the length of input_ids |
| 37 | + - Take logits[:, -len(input_ids)-1:-1] to get positions that predict input_ids |
53 | 38 |
|
54 | | -def compute_logprobs( |
55 | | - logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 |
56 | | -) -> torch.Tensor: |
57 | | - """ |
58 | | - Computes the log probabilities of the input tokens given the model logits and temperature. |
59 | | - Always converts inputs to fp32 for numerical stability |
| 39 | + Example: |
| 40 | + >>> # Full sequence passed to model: [prompt + response] |
| 41 | + >>> full_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]]) # Prompt + response |
| 42 | + >>> logits = model(full_input_ids) # Shape: [1, 6, vocab_size] |
| 43 | + >>> # Only want log probs for response tokens |
| 44 | + >>> response_tokens = torch.tensor([[4, 5, 6]]) # Just the response |
| 45 | + >>> logprobs = compute_logprobs(logits, response_tokens, align=True) |
| 46 | + >>> # Function slices logits[:, -4:-1] to get logits that predict tokens [4, 5, 6] |
| 47 | +
|
| 48 | + The alignment logic ensures that when you have a full sequence but only want log |
| 49 | + probabilities for the response portion, you don't need to re-run the model. This |
| 50 | + is a key optimization in RL training where the prompt remains constant. |
60 | 51 |
|
61 | 52 | Args: |
62 | 53 | logits (`torch.Tensor`): |
63 | 54 | The model output logits of shape `(batch_size, sequence_length, vocab_size)`. |
64 | 55 | input_ids (`torch.Tensor`): |
65 | | - The input token ids of shape `(batch_size, target_sequence_length)`. |
| 56 | + The target token ids of shape `(batch_size, target_sequence_length)`. |
| 57 | + These are the tokens for which you want to compute log probabilities. |
66 | 58 | temperature (`float`, *optional*, defaults to 1.0): |
67 | 59 | The temperature value for scaling logits before computing log probabilities. |
| 60 | + Higher values make the distribution more uniform, lower values more peaked. |
| 61 | + align (`bool`, *optional*, defaults to True): |
| 62 | + If True (default), align logits with input_ids by slicing to extract the |
| 63 | + relevant positions from a longer sequence (Pattern 2). |
| 64 | + If False, assume logits are already aligned with input_ids (Pattern 1). |
68 | 65 |
|
69 | 66 | Returns: |
70 | | - logprobs: [batch, seq_len] log probabilities for each token |
| 67 | + torch.Tensor: Log probabilities of shape `(batch_size, target_sequence_length)`. |
| 68 | + Each element [b, i] is the log probability of input_ids[b, i] given the |
| 69 | + corresponding logits. |
| 70 | +
|
| 71 | + Note: |
| 72 | + This function uses cross_entropy instead of log_softmax + gather for better |
| 73 | + numerical stability, especially important for fp16/bf16 training. |
71 | 74 | """ |
72 | | - # Ignore the last token from logits because it predicts the next token (-1) |
73 | | - # And align logits with the input tokens length. |
74 | | - logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device) |
| 75 | + # Align logits with input_ids if requested |
| 76 | + if align: |
| 77 | + # Ignore the last token from logits because it predicts the next token (-1) |
| 78 | + # And align logits with the input tokens length. |
| 79 | + logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device) |
| 80 | + |
75 | 81 | scaled_logits = logits / temperature |
76 | 82 |
|
77 | 83 | # Cast up to fp32 for numerical stability |
|
0 commit comments