Skip to content

Commit 8d84099

Browse files
authored
Add ComputeLobprob Tests (#244)
* Basic test + Move compute_logprobs to util ops * docstring * Rebase + add empty response test * Update math
1 parent 97a33e4 commit 8d84099

File tree

4 files changed

+142
-13
lines changed

4 files changed

+142
-13
lines changed

apps/grpo/main.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from forge.observability.metric_actors import get_or_create_metric_logger
3333
from forge.observability.metrics import record_metric, Reduce
3434
from forge.observability.perf_tracker import Tracer
35-
from forge.util.ops import selective_log_softmax
35+
from forge.util.ops import compute_logprobs
3636
from monarch.actor import endpoint
3737
from omegaconf import DictConfig
3838
from vllm.transformers_utils.tokenizer import get_tokenizer
@@ -137,16 +137,6 @@ def collate(batches: list[list[Episode]]):
137137
return inputs, targets
138138

139139

140-
def compute_logprobs(
141-
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
142-
) -> torch.Tensor:
143-
context_length = logits.shape[1] - input_ids.shape[1]
144-
logits = logits[:, context_length - 1 : -1].to(input_ids.device)
145-
scaled_logits = logits / temperature
146-
logprobs = selective_log_softmax(scaled_logits, input_ids)
147-
return logprobs
148-
149-
150140
def simple_grpo_loss(
151141
logits: torch.Tensor,
152142
response: torch.Tensor,
@@ -155,7 +145,12 @@ def simple_grpo_loss(
155145
padding_mask: torch.Tensor,
156146
beta: float = 0.1,
157147
) -> torch.Tensor:
158-
logprobs = compute_logprobs(logits, response)
148+
"""
149+
Example GRPO Loss Function for RLTrainer
150+
"""
151+
logprobs: torch.Tensor = compute_logprobs(logits, response)
152+
153+
# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
159154
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
160155
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
161156
per_token_loss = -(per_token_policy_loss - beta * kl)

src/forge/util/ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,26 @@ def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Te
4949
per_token_logps.append(row_per_token_logps)
5050
per_token_logps = torch.stack(per_token_logps)
5151
return per_token_logps
52+
53+
54+
def compute_logprobs(
55+
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
56+
) -> torch.Tensor:
57+
"""
58+
Computes the log probabilities of the input tokens given the model logits and temperature.
59+
60+
Args:
61+
logits (`torch.Tensor`):
62+
The model output logits of shape `(batch_size, sequence_length, vocab_size)`.
63+
input_ids (`torch.Tensor`):
64+
The input token ids of shape `(batch_size, target_sequence_length)`.
65+
temperature (`float`, *optional*, defaults to 1.0):
66+
The temperature value for scaling logits before computing log probabilities.
67+
68+
"""
69+
# Ignore the last token from logits because it predicts the next token (-1)
70+
# And align logits with the input tokens length.
71+
logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)
72+
scaled_logits = logits / temperature
73+
logprobs = selective_log_softmax(scaled_logits, input_ids)
74+
return logprobs
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
import torch.nn.functional as F
10+
from forge.util.ops import compute_logprobs
11+
12+
13+
def _textbook_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor):
14+
# Helper: Textbook Log Softmax
15+
log_probs = F.log_softmax(logits, dim=-1)
16+
return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
17+
18+
19+
class TestComputeLogprobs:
20+
def test_single_batch_item(self):
21+
"""Test with single batch item."""
22+
# Shape: (1, 2, 3)
23+
logits = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
24+
# Shape: (1, 1)
25+
input_ids = torch.tensor([[1]])
26+
result = compute_logprobs(logits, input_ids)
27+
28+
# Manual calculation
29+
expected_logits = torch.tensor([[[1.0, 2.0, 3.0]]])
30+
expected = _textbook_log_softmax(expected_logits, input_ids)
31+
32+
assert torch.allclose(result, expected, atol=1e-5)
33+
assert result.shape == (1, 1)
34+
35+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
36+
37+
# Shape: (1, 3, 3)
38+
logits = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]])
39+
# Shape: (1, 2)
40+
input_ids = torch.tensor([[2, 0]])
41+
result = compute_logprobs(logits, input_ids)
42+
43+
# Manual calculation
44+
expected_logits = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]])
45+
expected = _textbook_log_softmax(expected_logits, input_ids)
46+
47+
assert torch.allclose(result, expected, atol=1e-5)
48+
assert result.shape == (1, 2)
49+
50+
@pytest.mark.timeout(10)
51+
def test_multi_batch(self):
52+
"""Test with multiple batch items."""
53+
# Shape: (2, 2, 3)
54+
logits = torch.tensor(
55+
[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[0.5, 1.5, 2.5], [3.5, 4.5, 5.5]]]
56+
)
57+
# Shape: (2, 1)
58+
input_ids = torch.tensor([[1], [2]])
59+
result = compute_logprobs(logits, input_ids)
60+
61+
# Manual calculation
62+
expected_logits = torch.tensor([[[1.0, 2.0, 3.0]], [[0.5, 1.5, 2.5]]])
63+
expected = _textbook_log_softmax(expected_logits, input_ids)
64+
65+
assert torch.allclose(result, expected, atol=1e-5)
66+
assert result.shape == (2, 1)
67+
68+
@pytest.mark.timeout(10)
69+
def test_temperature(self):
70+
"""Test with different temperature values."""
71+
batch_size, seq_len, vocab_size = 2, 4, 6
72+
logits = torch.randn(batch_size, seq_len, vocab_size)
73+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len - 1))
74+
75+
# Manual calculation with temperature scaling
76+
def _manual(temperature: float):
77+
expected_logits = logits[:, 0:-1] / temperature
78+
return _textbook_log_softmax(expected_logits, input_ids)
79+
80+
temperatures = [1.0, 2.0, 4.5]
81+
for temperature in temperatures:
82+
result = compute_logprobs(logits, input_ids, temperature=temperature)
83+
expected = _manual(temperature)
84+
assert torch.allclose(result, expected, atol=1e-5)
85+
assert result.shape == input_ids.shape
86+
87+
@pytest.mark.timeout(10)
88+
def test_edge_cases(self):
89+
"""Test edge cases."""
90+
# Test with very large values (numerical stability)
91+
logits = torch.tensor([[[1000.0, 2000.0], [1500.0, 2500.0]]])
92+
input_ids = torch.tensor([[0]])
93+
result = compute_logprobs(logits, input_ids)
94+
# Should not be NaN or inf
95+
assert torch.isfinite(result).all()
96+
97+
# Test with very small values
98+
logits = torch.tensor([[[-1000.0, -2000.0], [-1500.0, -2500.0]]])
99+
input_ids = torch.tensor([[1]])
100+
result = compute_logprobs(logits, input_ids)
101+
# Should not be NaN or inf
102+
assert torch.isfinite(result).all()
103+
104+
def test_compute_logprobs_empty_response(self):
105+
"""Test logprobs computation with empty response."""
106+
batch_size, seq_len, vocab_size = 1, 5, 1000
107+
logits = torch.randn(batch_size, seq_len, vocab_size)
108+
input_ids = torch.tensor([[]])
109+
110+
result = compute_logprobs(logits, input_ids)
111+
assert result.shape == (batch_size, 0)

tests/unit_tests/util/test_ops.py renamed to tests/unit_tests/util/test_selective_log_softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from forge.util.ops import selective_log_softmax
1111

1212

13-
class TestOps:
13+
class TestSelectiveLogSoftmax:
1414
@pytest.mark.timeout(10)
1515
def test_basic_2d(self):
1616
"""Test basic 2D case."""

0 commit comments

Comments
 (0)