Skip to content

Commit ff3290e

Browse files
authored
Remove selective_log_softmax and unify with compute_logprobs (#464)
1 parent e24b173 commit ff3290e

File tree

5 files changed

+116
-145
lines changed

5 files changed

+116
-145
lines changed

src/forge/losses/reinforce_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from torch import nn
99

10-
from forge.util.ops import selective_log_softmax
10+
from forge.util.ops import compute_logprobs
1111

1212

1313
class ReinforceLoss(nn.Module):
@@ -29,7 +29,7 @@ def __init__(self):
2929
def forward(
3030
self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs
3131
):
32-
trainer_log_probs = selective_log_softmax(trainer_logits, target_ids)
32+
trainer_log_probs = compute_logprobs(trainer_logits, target_ids, align=False)
3333
target_mask = target_mask.detach()
3434
target_weights = target_weights
3535
target_mask_sum = target_mask.sum()

src/forge/util/ops.py

Lines changed: 57 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,70 +8,79 @@
88
import torch.nn.functional as F
99

1010

11-
def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
11+
def compute_logprobs(
12+
logits: torch.Tensor,
13+
input_ids: torch.Tensor,
14+
temperature: float = 1.0,
15+
align: bool = True,
16+
) -> torch.Tensor:
1217
"""
13-
A memory-efficient implementation of the common `log_softmax -> gather` operation.
18+
Computes the log probabilities of the input tokens given the model logits and temperature.
19+
Always converts inputs to fp32 for numerical stability.
1420
15-
This function is equivalent to the following naive implementation:
16-
```python
17-
logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
18-
```
21+
This function handles two common usage patterns:
1922
20-
Args:
21-
logits (`torch.Tensor`):
22-
Logits tensor of shape `(..., num_classes)`.
23-
index (`torch.Tensor`):
24-
Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
23+
**Pattern 1: Pre-aligned logits (align=False)**
24+
Use when logits are already aligned with input_ids, typically when you:
25+
- Pass input_ids to the model: model(input_ids) -> logits
26+
- The model outputs logits[i] that predict target_ids[i]
27+
- logits.shape[1] == input_ids.shape[1]
2528
26-
Returns:
27-
`torch.Tensor`:
28-
Gathered log probabilities with the same shape as `index`.
29-
"""
30-
if logits.dtype in [torch.float32, torch.float64]:
31-
selected_logits = torch.gather(
32-
logits, dim=-1, index=index.unsqueeze(-1)
33-
).squeeze(-1)
34-
# loop to reduce peak mem consumption
35-
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
36-
per_token_logps = (
37-
selected_logits - logsumexp_values
38-
) # log_softmax(x_i) = x_i - logsumexp(x)
39-
else:
40-
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach
41-
per_token_logps = []
42-
for row_logits, row_labels in zip(
43-
logits, index
44-
): # loop to reduce peak mem consumption
45-
row_logps = F.log_softmax(row_logits, dim=-1)
46-
row_per_token_logps = row_logps.gather(
47-
dim=-1, index=row_labels.unsqueeze(-1)
48-
).squeeze(-1)
49-
per_token_logps.append(row_per_token_logps)
50-
per_token_logps = torch.stack(per_token_logps)
51-
return per_token_logps
29+
Example:
30+
>>> input_ids = torch.tensor([[1, 2, 3, 4]]) # Model input
31+
>>> target_ids = torch.tensor([[2, 3, 4, 5]]) # Shifted by 1 (next-token prediction)
32+
>>> logits = model(input_ids) # Shape: [1, 4, vocab_size]
33+
>>> # logits already aligned: logits[:, i] predicts target_ids[:, i]
34+
>>> logprobs = compute_logprobs(logits, target_ids, align=False)
5235
36+
**Pattern 2: Full-sequence logits needing alignment (align=True, default)**
37+
Use when you have logits for the full sequence but only want log probs for a subset
38+
(e.g., just the response tokens, not the prompt). The function will:
39+
- Slice logits to match the length of input_ids
40+
- Take logits[:, -len(input_ids)-1:-1] to get positions that predict input_ids
5341
54-
def compute_logprobs(
55-
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
56-
) -> torch.Tensor:
57-
"""
58-
Computes the log probabilities of the input tokens given the model logits and temperature.
59-
Always converts inputs to fp32 for numerical stability
42+
Example:
43+
>>> # Full sequence passed to model: [prompt + response]
44+
>>> full_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]]) # Prompt + response
45+
>>> logits = model(full_input_ids) # Shape: [1, 6, vocab_size]
46+
>>> # Only want log probs for response tokens
47+
>>> response_tokens = torch.tensor([[4, 5, 6]]) # Just the response
48+
>>> logprobs = compute_logprobs(logits, response_tokens, align=True)
49+
>>> # Function slices logits[:, -4:-1] to get logits that predict tokens [4, 5, 6]
50+
51+
The alignment logic ensures that when you have a full sequence but only want log
52+
probabilities for the response portion, you don't need to re-run the model. This
53+
is a key optimization in RL training where the prompt remains constant.
6054
6155
Args:
6256
logits (`torch.Tensor`):
6357
The model output logits of shape `(batch_size, sequence_length, vocab_size)`.
6458
input_ids (`torch.Tensor`):
65-
The input token ids of shape `(batch_size, target_sequence_length)`.
59+
The target token ids of shape `(batch_size, target_sequence_length)`.
60+
These are the tokens for which you want to compute log probabilities.
6661
temperature (`float`, *optional*, defaults to 1.0):
6762
The temperature value for scaling logits before computing log probabilities.
63+
Higher values make the distribution more uniform, lower values more peaked.
64+
align (`bool`, *optional*, defaults to True):
65+
If True (default), align logits with input_ids by slicing to extract the
66+
relevant positions from a longer sequence (Pattern 2).
67+
If False, assume logits are already aligned with input_ids (Pattern 1).
6868
6969
Returns:
70-
logprobs: [batch, seq_len] log probabilities for each token
70+
torch.Tensor: Log probabilities of shape `(batch_size, target_sequence_length)`.
71+
Each element [b, i] is the log probability of input_ids[b, i] given the
72+
corresponding logits.
73+
74+
Note:
75+
This function uses cross_entropy instead of log_softmax + gather for better
76+
numerical stability, especially important for fp16/bf16 training.
7177
"""
72-
# Ignore the last token from logits because it predicts the next token (-1)
73-
# And align logits with the input tokens length.
74-
logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)
78+
# Align logits with input_ids if requested
79+
if align:
80+
# Ignore the last token from logits because it predicts the next token (-1)
81+
# And align logits with the input tokens length.
82+
logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)
83+
7584
scaled_logits = logits / temperature
7685

7786
# Cast up to fp32 for numerical stability

tests/sandbox/toy_rl/sumdigits.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from forge.observability.metrics import record_metric, Reduce
2727
from forge.util.config import parse
28-
from forge.util.ops import selective_log_softmax
28+
from forge.util.ops import compute_logprobs
2929
from monarch.actor import endpoint
3030
from omegaconf import DictConfig
3131

@@ -241,7 +241,8 @@ async def forward(self, episode: Episode) -> torch.Tensor:
241241
with torch.inference_mode():
242242
logits = self.model(input_ids=input_ids, attention_mask=mask).logits
243243

244-
return selective_log_softmax(logits, target_ids).squeeze(0)
244+
log_probs = compute_logprobs(logits, target_ids, align=False)
245+
return log_probs.squeeze(0)
245246

246247

247248
@dataclass
@@ -325,7 +326,7 @@ def train_step(self, episodes: list[Episode]) -> float:
325326
# Forward pass
326327
logits = self.model(input_ids=input_ids, attention_mask=attention_mask).logits
327328

328-
trainer_log_probs = selective_log_softmax(logits, target_ids)
329+
trainer_log_probs = compute_logprobs(logits, target_ids, align=False)
329330
# Compute loss only on response tokens
330331
# loss = self.loss(logits, target_ids, loss_masks, weights, sampling_log_probs)
331332
loss = self.loss(trainer_log_probs, ref_logprobs, weights, loss_masks)

tests/unit_tests/util/test_compute_logprobs.py renamed to tests/unit_tests/util/test_ops.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,56 @@ def test_compute_logprobs_empty_response(self):
109109

110110
result = compute_logprobs(logits, input_ids)
111111
assert result.shape == (batch_size, 0)
112+
113+
@pytest.mark.timeout(10)
114+
def test_align_parameter_false(self):
115+
"""Test with align=False (pre-aligned logits)."""
116+
# When align=False, logits are already aligned with input_ids
117+
# logits[:, i] predicts input_ids[:, i]
118+
batch_size, seq_len, vocab_size = 2, 3, 5
119+
logits = torch.randn(batch_size, seq_len, vocab_size)
120+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
121+
122+
result = compute_logprobs(logits, input_ids, align=False)
123+
124+
# Manual calculation without slicing
125+
expected = _textbook_log_softmax(logits, input_ids)
126+
127+
assert torch.allclose(result, expected, atol=1e-5)
128+
assert result.shape == input_ids.shape
129+
130+
@pytest.mark.timeout(10)
131+
def test_align_parameter_true(self):
132+
"""Test with align=True (default, needs slicing)."""
133+
# When align=True, logits need to be sliced to align with input_ids
134+
batch_size, full_seq_len, vocab_size = 2, 6, 5
135+
logits = torch.randn(batch_size, full_seq_len, vocab_size)
136+
137+
# We want log probs for just the last 3 tokens
138+
target_len = 3
139+
input_ids = torch.randint(0, vocab_size, (batch_size, target_len))
140+
141+
result = compute_logprobs(logits, input_ids, align=True)
142+
143+
# Manual calculation: align=True slices logits[:, -target_len-1:-1]
144+
sliced_logits = logits[:, -target_len - 1 : -1, :]
145+
expected = _textbook_log_softmax(sliced_logits, input_ids)
146+
147+
assert torch.allclose(result, expected, atol=1e-5)
148+
assert result.shape == input_ids.shape
149+
150+
@pytest.mark.timeout(10)
151+
def test_align_comparison(self):
152+
"""Test that align=True properly slices logits."""
153+
batch_size, seq_len, vocab_size = 1, 4, 10
154+
logits = torch.randn(batch_size, seq_len, vocab_size)
155+
input_ids = torch.randint(0, vocab_size, (batch_size, 2))
156+
157+
result_aligned = compute_logprobs(logits, input_ids, align=True)
158+
159+
# Manually slice the same way align=True does
160+
sliced_logits = logits[:, -input_ids.size(1) - 1 : -1, :]
161+
result_manual = compute_logprobs(sliced_logits, input_ids, align=False)
162+
163+
# Both should give the same result
164+
assert torch.allclose(result_aligned, result_manual, atol=1e-5)

tests/unit_tests/util/test_selective_log_softmax.py

Lines changed: 0 additions & 92 deletions
This file was deleted.

0 commit comments

Comments
 (0)