Skip to content

Commit a6af867

Browse files
committed
Remove selective_log_softmax and unify with compute_logprobs
This PR removes the `selective_log_softmax` function and consolidates all log probability computation through the existing `compute_logprobs` function. Changes: - Removed `selective_log_softmax` from src/forge/util/ops.py - Added `align` parameter to `compute_logprobs` to handle both usage patterns: - align=False: for pre-aligned logits (when model was called with input_ids) - align=True: for extracting subset from full-sequence logits - Updated all references to use `compute_logprobs` with appropriate align flag - Added comprehensive docstring explaining both usage patterns with examples - Removed test_selective_log_softmax.py (functionality now covered by compute_logprobs) The consolidation improves code maintainability by having a single function for log probability computation, while the align parameter handles the two common usage patterns in RL training.
1 parent 2559be6 commit a6af867

File tree

4 files changed

+60
-145
lines changed

4 files changed

+60
-145
lines changed

src/forge/losses/reinforce_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from torch import nn
99

10-
from forge.util.ops import selective_log_softmax
10+
from forge.util.ops import compute_logprobs
1111

1212

1313
class ReinforceLoss(nn.Module):
@@ -29,7 +29,7 @@ def __init__(self):
2929
def forward(
3030
self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs
3131
):
32-
trainer_log_probs = selective_log_softmax(trainer_logits, target_ids)
32+
trainer_log_probs = compute_logprobs(trainer_logits, target_ids, align=False)
3333
target_mask = target_mask.detach()
3434
target_weights = target_weights
3535
target_mask_sum = target_mask.sum()

src/forge/util/ops.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,70 +8,76 @@
88
import torch.nn.functional as F
99

1010

11-
def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
11+
def compute_logprobs(
12+
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0, align: bool = True
13+
) -> torch.Tensor:
1214
"""
13-
A memory-efficient implementation of the common `log_softmax -> gather` operation.
15+
Computes the log probabilities of the input tokens given the model logits and temperature.
16+
Always converts inputs to fp32 for numerical stability.
1417
15-
This function is equivalent to the following naive implementation:
16-
```python
17-
logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
18-
```
18+
This function handles two common usage patterns:
1919
20-
Args:
21-
logits (`torch.Tensor`):
22-
Logits tensor of shape `(..., num_classes)`.
23-
index (`torch.Tensor`):
24-
Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
20+
**Pattern 1: Pre-aligned logits (align=False)**
21+
Use when logits are already aligned with input_ids, typically when you:
22+
- Pass input_ids to the model: model(input_ids) -> logits
23+
- The model outputs logits[i] that predict target_ids[i]
24+
- logits.shape[1] == input_ids.shape[1]
2525
26-
Returns:
27-
`torch.Tensor`:
28-
Gathered log probabilities with the same shape as `index`.
29-
"""
30-
if logits.dtype in [torch.float32, torch.float64]:
31-
selected_logits = torch.gather(
32-
logits, dim=-1, index=index.unsqueeze(-1)
33-
).squeeze(-1)
34-
# loop to reduce peak mem consumption
35-
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
36-
per_token_logps = (
37-
selected_logits - logsumexp_values
38-
) # log_softmax(x_i) = x_i - logsumexp(x)
39-
else:
40-
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach
41-
per_token_logps = []
42-
for row_logits, row_labels in zip(
43-
logits, index
44-
): # loop to reduce peak mem consumption
45-
row_logps = F.log_softmax(row_logits, dim=-1)
46-
row_per_token_logps = row_logps.gather(
47-
dim=-1, index=row_labels.unsqueeze(-1)
48-
).squeeze(-1)
49-
per_token_logps.append(row_per_token_logps)
50-
per_token_logps = torch.stack(per_token_logps)
51-
return per_token_logps
26+
Example:
27+
>>> input_ids = torch.tensor([[1, 2, 3, 4]]) # Model input
28+
>>> target_ids = torch.tensor([[2, 3, 4, 5]]) # Shifted by 1 (next-token prediction)
29+
>>> logits = model(input_ids) # Shape: [1, 4, vocab_size]
30+
>>> # logits already aligned: logits[:, i] predicts target_ids[:, i]
31+
>>> logprobs = compute_logprobs(logits, target_ids, align=False)
5232
33+
**Pattern 2: Full-sequence logits needing alignment (align=True, default)**
34+
Use when you have logits for the full sequence but only want log probs for a subset
35+
(e.g., just the response tokens, not the prompt). The function will:
36+
- Slice logits to match the length of input_ids
37+
- Take logits[:, -len(input_ids)-1:-1] to get positions that predict input_ids
5338
54-
def compute_logprobs(
55-
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
56-
) -> torch.Tensor:
57-
"""
58-
Computes the log probabilities of the input tokens given the model logits and temperature.
59-
Always converts inputs to fp32 for numerical stability
39+
Example:
40+
>>> # Full sequence passed to model: [prompt + response]
41+
>>> full_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]]) # Prompt + response
42+
>>> logits = model(full_input_ids) # Shape: [1, 6, vocab_size]
43+
>>> # Only want log probs for response tokens
44+
>>> response_tokens = torch.tensor([[4, 5, 6]]) # Just the response
45+
>>> logprobs = compute_logprobs(logits, response_tokens, align=True)
46+
>>> # Function slices logits[:, -4:-1] to get logits that predict tokens [4, 5, 6]
47+
48+
The alignment logic ensures that when you have a full sequence but only want log
49+
probabilities for the response portion, you don't need to re-run the model. This
50+
is a key optimization in RL training where the prompt remains constant.
6051
6152
Args:
6253
logits (`torch.Tensor`):
6354
The model output logits of shape `(batch_size, sequence_length, vocab_size)`.
6455
input_ids (`torch.Tensor`):
65-
The input token ids of shape `(batch_size, target_sequence_length)`.
56+
The target token ids of shape `(batch_size, target_sequence_length)`.
57+
These are the tokens for which you want to compute log probabilities.
6658
temperature (`float`, *optional*, defaults to 1.0):
6759
The temperature value for scaling logits before computing log probabilities.
60+
Higher values make the distribution more uniform, lower values more peaked.
61+
align (`bool`, *optional*, defaults to True):
62+
If True (default), align logits with input_ids by slicing to extract the
63+
relevant positions from a longer sequence (Pattern 2).
64+
If False, assume logits are already aligned with input_ids (Pattern 1).
6865
6966
Returns:
70-
logprobs: [batch, seq_len] log probabilities for each token
67+
torch.Tensor: Log probabilities of shape `(batch_size, target_sequence_length)`.
68+
Each element [b, i] is the log probability of input_ids[b, i] given the
69+
corresponding logits.
70+
71+
Note:
72+
This function uses cross_entropy instead of log_softmax + gather for better
73+
numerical stability, especially important for fp16/bf16 training.
7174
"""
72-
# Ignore the last token from logits because it predicts the next token (-1)
73-
# And align logits with the input tokens length.
74-
logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)
75+
# Align logits with input_ids if requested
76+
if align:
77+
# Ignore the last token from logits because it predicts the next token (-1)
78+
# And align logits with the input tokens length.
79+
logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)
80+
7581
scaled_logits = logits / temperature
7682

7783
# Cast up to fp32 for numerical stability

tests/sandbox/toy_rl/sumdigits.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from forge.observability.metrics import record_metric, Reduce
2727
from forge.util.config import parse
28-
from forge.util.ops import selective_log_softmax
28+
from forge.util.ops import compute_logprobs
2929
from monarch.actor import endpoint
3030
from omegaconf import DictConfig
3131

@@ -241,7 +241,8 @@ async def forward(self, episode: Episode) -> torch.Tensor:
241241
with torch.inference_mode():
242242
logits = self.model(input_ids=input_ids, attention_mask=mask).logits
243243

244-
return selective_log_softmax(logits, target_ids).squeeze(0)
244+
log_probs = compute_logprobs(logits, target_ids, align=False)
245+
return log_probs.squeeze(0)
245246

246247

247248
@dataclass
@@ -325,7 +326,7 @@ def train_step(self, episodes: list[Episode]) -> float:
325326
# Forward pass
326327
logits = self.model(input_ids=input_ids, attention_mask=attention_mask).logits
327328

328-
trainer_log_probs = selective_log_softmax(logits, target_ids)
329+
trainer_log_probs = compute_logprobs(logits, target_ids, align=False)
329330
# Compute loss only on response tokens
330331
# loss = self.loss(logits, target_ids, loss_masks, weights, sampling_log_probs)
331332
loss = self.loss(trainer_log_probs, ref_logprobs, weights, loss_masks)

tests/unit_tests/util/test_selective_log_softmax.py

Lines changed: 0 additions & 92 deletions
This file was deleted.

0 commit comments

Comments
 (0)