Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/forge/losses/reinforce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch import nn

from forge.util.ops import selective_log_softmax
from forge.util.ops import compute_logprobs


class ReinforceLoss(nn.Module):
Expand All @@ -29,7 +29,7 @@ def __init__(self):
def forward(
self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs
):
trainer_log_probs = selective_log_softmax(trainer_logits, target_ids)
trainer_log_probs = compute_logprobs(trainer_logits, target_ids, align=False)
target_mask = target_mask.detach()
target_weights = target_weights
target_mask_sum = target_mask.sum()
Expand Down
102 changes: 54 additions & 48 deletions src/forge/util/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,70 +8,76 @@
import torch.nn.functional as F


def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0, align: bool = True
) -> torch.Tensor:
"""
A memory-efficient implementation of the common `log_softmax -> gather` operation.
Computes the log probabilities of the input tokens given the model logits and temperature.
Always converts inputs to fp32 for numerical stability.

This function is equivalent to the following naive implementation:
```python
logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
```
This function handles two common usage patterns:

Args:
logits (`torch.Tensor`):
Logits tensor of shape `(..., num_classes)`.
index (`torch.Tensor`):
Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
**Pattern 1: Pre-aligned logits (align=False)**
Use when logits are already aligned with input_ids, typically when you:
- Pass input_ids to the model: model(input_ids) -> logits
- The model outputs logits[i] that predict target_ids[i]
- logits.shape[1] == input_ids.shape[1]

Returns:
`torch.Tensor`:
Gathered log probabilities with the same shape as `index`.
"""
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(
logits, dim=-1, index=index.unsqueeze(-1)
).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = (
selected_logits - logsumexp_values
) # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach
per_token_logps = []
for row_logits, row_labels in zip(
logits, index
): # loop to reduce peak mem consumption
row_logps = F.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(
dim=-1, index=row_labels.unsqueeze(-1)
).squeeze(-1)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps
Example:
>>> input_ids = torch.tensor([[1, 2, 3, 4]]) # Model input
>>> target_ids = torch.tensor([[2, 3, 4, 5]]) # Shifted by 1 (next-token prediction)
>>> logits = model(input_ids) # Shape: [1, 4, vocab_size]
>>> # logits already aligned: logits[:, i] predicts target_ids[:, i]
>>> logprobs = compute_logprobs(logits, target_ids, align=False)

**Pattern 2: Full-sequence logits needing alignment (align=True, default)**
Use when you have logits for the full sequence but only want log probs for a subset
(e.g., just the response tokens, not the prompt). The function will:
- Slice logits to match the length of input_ids
- Take logits[:, -len(input_ids)-1:-1] to get positions that predict input_ids

def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
) -> torch.Tensor:
"""
Computes the log probabilities of the input tokens given the model logits and temperature.
Always converts inputs to fp32 for numerical stability
Example:
>>> # Full sequence passed to model: [prompt + response]
>>> full_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]]) # Prompt + response
>>> logits = model(full_input_ids) # Shape: [1, 6, vocab_size]
>>> # Only want log probs for response tokens
>>> response_tokens = torch.tensor([[4, 5, 6]]) # Just the response
>>> logprobs = compute_logprobs(logits, response_tokens, align=True)
>>> # Function slices logits[:, -4:-1] to get logits that predict tokens [4, 5, 6]

The alignment logic ensures that when you have a full sequence but only want log
probabilities for the response portion, you don't need to re-run the model. This
is a key optimization in RL training where the prompt remains constant.

Args:
logits (`torch.Tensor`):
The model output logits of shape `(batch_size, sequence_length, vocab_size)`.
input_ids (`torch.Tensor`):
The input token ids of shape `(batch_size, target_sequence_length)`.
The target token ids of shape `(batch_size, target_sequence_length)`.
These are the tokens for which you want to compute log probabilities.
temperature (`float`, *optional*, defaults to 1.0):
The temperature value for scaling logits before computing log probabilities.
Higher values make the distribution more uniform, lower values more peaked.
align (`bool`, *optional*, defaults to True):
If True (default), align logits with input_ids by slicing to extract the
relevant positions from a longer sequence (Pattern 2).
If False, assume logits are already aligned with input_ids (Pattern 1).

Returns:
logprobs: [batch, seq_len] log probabilities for each token
torch.Tensor: Log probabilities of shape `(batch_size, target_sequence_length)`.
Each element [b, i] is the log probability of input_ids[b, i] given the
corresponding logits.

Note:
This function uses cross_entropy instead of log_softmax + gather for better
numerical stability, especially important for fp16/bf16 training.
"""
# Ignore the last token from logits because it predicts the next token (-1)
# And align logits with the input tokens length.
logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)
# Align logits with input_ids if requested
if align:
# Ignore the last token from logits because it predicts the next token (-1)
# And align logits with the input tokens length.
logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)

scaled_logits = logits / temperature

# Cast up to fp32 for numerical stability
Expand Down
7 changes: 4 additions & 3 deletions tests/sandbox/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from forge.observability.metrics import record_metric, Reduce
from forge.util.config import parse
from forge.util.ops import selective_log_softmax
from forge.util.ops import compute_logprobs
from monarch.actor import endpoint
from omegaconf import DictConfig

Expand Down Expand Up @@ -241,7 +241,8 @@ async def forward(self, episode: Episode) -> torch.Tensor:
with torch.inference_mode():
logits = self.model(input_ids=input_ids, attention_mask=mask).logits

return selective_log_softmax(logits, target_ids).squeeze(0)
log_probs = compute_logprobs(logits, target_ids, align=False)
return log_probs.squeeze(0)


@dataclass
Expand Down Expand Up @@ -325,7 +326,7 @@ def train_step(self, episodes: list[Episode]) -> float:
# Forward pass
logits = self.model(input_ids=input_ids, attention_mask=attention_mask).logits

trainer_log_probs = selective_log_softmax(logits, target_ids)
trainer_log_probs = compute_logprobs(logits, target_ids, align=False)
# Compute loss only on response tokens
# loss = self.loss(logits, target_ids, loss_masks, weights, sampling_log_probs)
loss = self.loss(trainer_log_probs, ref_logprobs, weights, loss_masks)
Expand Down
92 changes: 0 additions & 92 deletions tests/unit_tests/util/test_selective_log_softmax.py

This file was deleted.

Loading