diff --git a/src/forge/losses/reinforce_loss.py b/src/forge/losses/reinforce_loss.py index c2dedd530..62c7e579a 100644 --- a/src/forge/losses/reinforce_loss.py +++ b/src/forge/losses/reinforce_loss.py @@ -7,7 +7,7 @@ import torch from torch import nn -from forge.util.ops import selective_log_softmax +from forge.util.ops import compute_logprobs class ReinforceLoss(nn.Module): @@ -29,7 +29,7 @@ def __init__(self): def forward( self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs ): - trainer_log_probs = selective_log_softmax(trainer_logits, target_ids) + trainer_log_probs = compute_logprobs(trainer_logits, target_ids, align=False) target_mask = target_mask.detach() target_weights = target_weights target_mask_sum = target_mask.sum() diff --git a/src/forge/util/ops.py b/src/forge/util/ops.py index a65b86e96..f7152f065 100644 --- a/src/forge/util/ops.py +++ b/src/forge/util/ops.py @@ -8,70 +8,79 @@ import torch.nn.functional as F -def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor: +def compute_logprobs( + logits: torch.Tensor, + input_ids: torch.Tensor, + temperature: float = 1.0, + align: bool = True, +) -> torch.Tensor: """ - A memory-efficient implementation of the common `log_softmax -> gather` operation. + Computes the log probabilities of the input tokens given the model logits and temperature. + Always converts inputs to fp32 for numerical stability. - This function is equivalent to the following naive implementation: - ```python - logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) - ``` + This function handles two common usage patterns: - Args: - logits (`torch.Tensor`): - Logits tensor of shape `(..., num_classes)`. - index (`torch.Tensor`): - Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. + **Pattern 1: Pre-aligned logits (align=False)** + Use when logits are already aligned with input_ids, typically when you: + - Pass input_ids to the model: model(input_ids) -> logits + - The model outputs logits[i] that predict target_ids[i] + - logits.shape[1] == input_ids.shape[1] - Returns: - `torch.Tensor`: - Gathered log probabilities with the same shape as `index`. - """ - if logits.dtype in [torch.float32, torch.float64]: - selected_logits = torch.gather( - logits, dim=-1, index=index.unsqueeze(-1) - ).squeeze(-1) - # loop to reduce peak mem consumption - logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) - per_token_logps = ( - selected_logits - logsumexp_values - ) # log_softmax(x_i) = x_i - logsumexp(x) - else: - # logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach - per_token_logps = [] - for row_logits, row_labels in zip( - logits, index - ): # loop to reduce peak mem consumption - row_logps = F.log_softmax(row_logits, dim=-1) - row_per_token_logps = row_logps.gather( - dim=-1, index=row_labels.unsqueeze(-1) - ).squeeze(-1) - per_token_logps.append(row_per_token_logps) - per_token_logps = torch.stack(per_token_logps) - return per_token_logps + Example: + >>> input_ids = torch.tensor([[1, 2, 3, 4]]) # Model input + >>> target_ids = torch.tensor([[2, 3, 4, 5]]) # Shifted by 1 (next-token prediction) + >>> logits = model(input_ids) # Shape: [1, 4, vocab_size] + >>> # logits already aligned: logits[:, i] predicts target_ids[:, i] + >>> logprobs = compute_logprobs(logits, target_ids, align=False) + **Pattern 2: Full-sequence logits needing alignment (align=True, default)** + Use when you have logits for the full sequence but only want log probs for a subset + (e.g., just the response tokens, not the prompt). The function will: + - Slice logits to match the length of input_ids + - Take logits[:, -len(input_ids)-1:-1] to get positions that predict input_ids -def compute_logprobs( - logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 -) -> torch.Tensor: - """ - Computes the log probabilities of the input tokens given the model logits and temperature. - Always converts inputs to fp32 for numerical stability + Example: + >>> # Full sequence passed to model: [prompt + response] + >>> full_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]]) # Prompt + response + >>> logits = model(full_input_ids) # Shape: [1, 6, vocab_size] + >>> # Only want log probs for response tokens + >>> response_tokens = torch.tensor([[4, 5, 6]]) # Just the response + >>> logprobs = compute_logprobs(logits, response_tokens, align=True) + >>> # Function slices logits[:, -4:-1] to get logits that predict tokens [4, 5, 6] + + The alignment logic ensures that when you have a full sequence but only want log + probabilities for the response portion, you don't need to re-run the model. This + is a key optimization in RL training where the prompt remains constant. Args: logits (`torch.Tensor`): The model output logits of shape `(batch_size, sequence_length, vocab_size)`. input_ids (`torch.Tensor`): - The input token ids of shape `(batch_size, target_sequence_length)`. + The target token ids of shape `(batch_size, target_sequence_length)`. + These are the tokens for which you want to compute log probabilities. temperature (`float`, *optional*, defaults to 1.0): The temperature value for scaling logits before computing log probabilities. + Higher values make the distribution more uniform, lower values more peaked. + align (`bool`, *optional*, defaults to True): + If True (default), align logits with input_ids by slicing to extract the + relevant positions from a longer sequence (Pattern 2). + If False, assume logits are already aligned with input_ids (Pattern 1). Returns: - logprobs: [batch, seq_len] log probabilities for each token + torch.Tensor: Log probabilities of shape `(batch_size, target_sequence_length)`. + Each element [b, i] is the log probability of input_ids[b, i] given the + corresponding logits. + + Note: + This function uses cross_entropy instead of log_softmax + gather for better + numerical stability, especially important for fp16/bf16 training. """ - # Ignore the last token from logits because it predicts the next token (-1) - # And align logits with the input tokens length. - logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device) + # Align logits with input_ids if requested + if align: + # Ignore the last token from logits because it predicts the next token (-1) + # And align logits with the input tokens length. + logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device) + scaled_logits = logits / temperature # Cast up to fp32 for numerical stability diff --git a/tests/sandbox/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py index 01a0f3936..56b669ce4 100644 --- a/tests/sandbox/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -25,7 +25,7 @@ from forge.observability.metrics import record_metric, Reduce from forge.util.config import parse -from forge.util.ops import selective_log_softmax +from forge.util.ops import compute_logprobs from monarch.actor import endpoint from omegaconf import DictConfig @@ -241,7 +241,8 @@ async def forward(self, episode: Episode) -> torch.Tensor: with torch.inference_mode(): logits = self.model(input_ids=input_ids, attention_mask=mask).logits - return selective_log_softmax(logits, target_ids).squeeze(0) + log_probs = compute_logprobs(logits, target_ids, align=False) + return log_probs.squeeze(0) @dataclass @@ -325,7 +326,7 @@ def train_step(self, episodes: list[Episode]) -> float: # Forward pass logits = self.model(input_ids=input_ids, attention_mask=attention_mask).logits - trainer_log_probs = selective_log_softmax(logits, target_ids) + trainer_log_probs = compute_logprobs(logits, target_ids, align=False) # Compute loss only on response tokens # loss = self.loss(logits, target_ids, loss_masks, weights, sampling_log_probs) loss = self.loss(trainer_log_probs, ref_logprobs, weights, loss_masks) diff --git a/tests/unit_tests/util/test_compute_logprobs.py b/tests/unit_tests/util/test_ops.py similarity index 65% rename from tests/unit_tests/util/test_compute_logprobs.py rename to tests/unit_tests/util/test_ops.py index c4e3bffcb..2f224743a 100644 --- a/tests/unit_tests/util/test_compute_logprobs.py +++ b/tests/unit_tests/util/test_ops.py @@ -109,3 +109,56 @@ def test_compute_logprobs_empty_response(self): result = compute_logprobs(logits, input_ids) assert result.shape == (batch_size, 0) + + @pytest.mark.timeout(10) + def test_align_parameter_false(self): + """Test with align=False (pre-aligned logits).""" + # When align=False, logits are already aligned with input_ids + # logits[:, i] predicts input_ids[:, i] + batch_size, seq_len, vocab_size = 2, 3, 5 + logits = torch.randn(batch_size, seq_len, vocab_size) + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + + result = compute_logprobs(logits, input_ids, align=False) + + # Manual calculation without slicing + expected = _textbook_log_softmax(logits, input_ids) + + assert torch.allclose(result, expected, atol=1e-5) + assert result.shape == input_ids.shape + + @pytest.mark.timeout(10) + def test_align_parameter_true(self): + """Test with align=True (default, needs slicing).""" + # When align=True, logits need to be sliced to align with input_ids + batch_size, full_seq_len, vocab_size = 2, 6, 5 + logits = torch.randn(batch_size, full_seq_len, vocab_size) + + # We want log probs for just the last 3 tokens + target_len = 3 + input_ids = torch.randint(0, vocab_size, (batch_size, target_len)) + + result = compute_logprobs(logits, input_ids, align=True) + + # Manual calculation: align=True slices logits[:, -target_len-1:-1] + sliced_logits = logits[:, -target_len - 1 : -1, :] + expected = _textbook_log_softmax(sliced_logits, input_ids) + + assert torch.allclose(result, expected, atol=1e-5) + assert result.shape == input_ids.shape + + @pytest.mark.timeout(10) + def test_align_comparison(self): + """Test that align=True properly slices logits.""" + batch_size, seq_len, vocab_size = 1, 4, 10 + logits = torch.randn(batch_size, seq_len, vocab_size) + input_ids = torch.randint(0, vocab_size, (batch_size, 2)) + + result_aligned = compute_logprobs(logits, input_ids, align=True) + + # Manually slice the same way align=True does + sliced_logits = logits[:, -input_ids.size(1) - 1 : -1, :] + result_manual = compute_logprobs(sliced_logits, input_ids, align=False) + + # Both should give the same result + assert torch.allclose(result_aligned, result_manual, atol=1e-5) diff --git a/tests/unit_tests/util/test_selective_log_softmax.py b/tests/unit_tests/util/test_selective_log_softmax.py deleted file mode 100644 index 4ca94f2c3..000000000 --- a/tests/unit_tests/util/test_selective_log_softmax.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import pytest -import torch -import torch.nn.functional as F -from forge.util.ops import selective_log_softmax - - -class TestSelectiveLogSoftmax: - @pytest.mark.timeout(10) - def test_basic_2d(self): - """Test basic 2D case.""" - logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - index = torch.tensor([0, 2]) # Select positions 0 and 2 - result = selective_log_softmax(logits, index) - # Compare with torch's implementation - expected = torch.gather( - F.log_softmax(logits, dim=-1), dim=-1, index=index.unsqueeze(-1) - ).squeeze(-1) - assert torch.allclose(result, expected, atol=1e-5) - assert result.shape == (2,) # Same shape as index - - @pytest.mark.timeout(10) - def test_single_row(self): - """Test with single row.""" - logits = torch.tensor([[1.0, 2.0, 3.0]]) - index = torch.tensor([1]) # Select middle element - result = selective_log_softmax(logits, index) - # Manual calculation: log_softmax then select index 1 - log_probs = F.log_softmax(logits, dim=-1) - expected = log_probs[0, 1] - assert torch.allclose(result, expected) - assert result.shape == (1,) - - @pytest.mark.timeout(10) - def test_different_dtypes(self): - """Test with different data types.""" - logits_f32 = torch.randn(2, 4, dtype=torch.float32) - logits_bf16 = torch.randn(2, 4, dtype=torch.bfloat16) - index = torch.tensor([0, 3]) - result_f32 = selective_log_softmax(logits_f32, index) - result_bf16 = selective_log_softmax(logits_bf16, index) - # Check output dtypes match input dtypes - assert result_f32.dtype == torch.float32 - assert result_bf16.dtype == torch.bfloat16 - # Check shapes - assert result_f32.shape == (2,) - assert result_bf16.shape == (2,) - - @pytest.mark.timeout(10) - def test_3d_tensor(self): - """Test with 3D tensor.""" - batch, seq, vocab = 2, 3, 5 - logits = torch.randn(batch, seq, vocab) - index = torch.randint(0, vocab, (batch, seq)) - result = selective_log_softmax(logits, index) - # Should have same shape as index - assert result.shape == (batch, seq) - # All values should be negative (log probabilities) - assert (result <= 0).all() - - @pytest.mark.timeout(10) - def test_known_values(self): - """Test with known values for manual verification.""" - # Simple case where we can calculate by hand - logits = torch.tensor([[0.0, 0.0]]) # Equal logits - index = torch.tensor([0]) - result = selective_log_softmax(logits, index) - # log_softmax of [0, 0] gives [-log(2), -log(2)] - # Selecting index 0 should give -log(2) - expected = -torch.log(torch.tensor(2.0)) - assert torch.allclose(result, expected, atol=1e-6) - - @pytest.mark.timeout(10) - def test_edge_cases(self): - """Test edge cases.""" - # Test with single class - logits = torch.tensor([[5.0]]) - index = torch.tensor([0]) - result = selective_log_softmax(logits, index) - # log_softmax of single element is 0 - assert torch.allclose(result, torch.tensor([0.0])) - # Test with large values (numerical stability) - logits = torch.tensor([[100.0, 200.0]]) - index = torch.tensor([1]) - result = selective_log_softmax(logits, index) - # Should not be NaN or inf - assert torch.isfinite(result).all()