-
Notifications
You must be signed in to change notification settings - Fork 26
GRPO Loss basic unit tests #168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from torch import nn | ||
|
||
|
||
class SimpleGRPOLoss(nn.Module): | ||
"""Simplified GRPO Loss for simplified single step updates | ||
Inspired by the Hugging Face TRL implementation: | ||
https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624. | ||
""" | ||
|
||
def __init__(self, beta: float = 0.1): | ||
super().__init__() | ||
self.beta = beta | ||
|
||
def forward(self, logprobs, ref_logprobs, advantages, padding_mask): | ||
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 | ||
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages | ||
per_token_loss = -(per_token_policy_loss - self.beta * kl) | ||
loss = ( | ||
((per_token_loss * padding_mask).sum(dim=1)) | ||
/ (padding_mask.sum(dim=1).clamp(min=1.0)) | ||
).mean() | ||
return loss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pytest==7.3.2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
from forge.losses.grpo_loss import SimpleGRPOLoss | ||
|
||
|
||
class TestSimpleGRPOLoss: | ||
@pytest.fixture | ||
def loss_fn(self): | ||
"""Create a GRPO loss instance with default beta.""" | ||
return SimpleGRPOLoss(beta=0.1) | ||
|
||
@pytest.fixture | ||
def sample_data(self): | ||
"""Create sample input data for testing.""" | ||
batch_size, seq_len = 2, 4 | ||
|
||
# Create log probabilities (should be negative) | ||
logprobs = torch.log(torch.rand(batch_size, seq_len) * 0.9 + 0.1) | ||
ref_logprobs = torch.log(torch.rand(batch_size, seq_len) * 0.9 + 0.1) | ||
|
||
# Create advantages (can be positive or negative) | ||
advantages = torch.randn(batch_size, seq_len) | ||
|
||
# Create padding mask (1 for valid tokens, 0 for padding) | ||
padding_mask = torch.ones(batch_size, seq_len) | ||
padding_mask[0, -1] = 0 # Add some padding | ||
padding_mask[1, -2:] = 0 # Add more padding | ||
|
||
return logprobs, ref_logprobs, advantages, padding_mask | ||
|
||
@pytest.mark.timeout(10) | ||
@pytest.mark.asyncio | ||
def test_forward_basic(self, loss_fn, sample_data): | ||
"""Test basic forward pass.""" | ||
logprobs, ref_logprobs, advantages, padding_mask = sample_data | ||
|
||
loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask) | ||
|
||
# Loss should be a scalar | ||
assert loss.dim() == 0 | ||
assert torch.isfinite(loss) | ||
assert not torch.isnan(loss) | ||
|
||
@pytest.mark.timeout(10) | ||
@pytest.mark.asyncio | ||
def test_output_shape(self, loss_fn): | ||
"""Test output shape for different input sizes.""" | ||
for batch_size in [1, 3, 8]: | ||
for seq_len in [1, 10, 32]: | ||
logprobs = torch.randn(batch_size, seq_len) | ||
ref_logprobs = torch.randn(batch_size, seq_len) | ||
advantages = torch.randn(batch_size, seq_len) | ||
padding_mask = torch.ones(batch_size, seq_len) | ||
|
||
loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask) | ||
assert loss.shape == torch.Size([]) | ||
|
||
@pytest.mark.timeout(10) | ||
@pytest.mark.asyncio | ||
def test_gradient_flow(self, loss_fn, sample_data): | ||
"""Test that gradients flow through logprobs.""" | ||
logprobs, ref_logprobs, advantages, padding_mask = sample_data | ||
logprobs.requires_grad_(True) | ||
|
||
loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask) | ||
loss.backward() | ||
|
||
assert logprobs.grad is not None | ||
assert not torch.isnan(logprobs.grad).any() | ||
assert torch.isfinite(logprobs.grad).all() | ||
|
||
@pytest.mark.timeout(10) | ||
@pytest.mark.asyncio | ||
def test_no_gradient_to_ref_logprobs(self, loss_fn, sample_data): | ||
"""Test that gradients don't flow to reference logprobs.""" | ||
logprobs, ref_logprobs, advantages, padding_mask = sample_data | ||
ref_logprobs.requires_grad_(True) | ||
|
||
loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask) | ||
loss.backward() | ||
|
||
# ref_logprobs should receive gradients (it's used in KL computation) | ||
assert ref_logprobs.grad is not None | ||
|
||
@pytest.mark.timeout(10) | ||
@pytest.mark.asyncio | ||
def test_padding_mask_effect(self, loss_fn): | ||
"""Test that padding mask correctly ignores padded tokens.""" | ||
batch_size, seq_len = 2, 4 | ||
|
||
logprobs = torch.randn(batch_size, seq_len) | ||
ref_logprobs = torch.randn(batch_size, seq_len) | ||
advantages = torch.randn(batch_size, seq_len) | ||
|
||
# Test with full mask | ||
full_mask = torch.ones(batch_size, seq_len) | ||
loss_full = loss_fn(logprobs, ref_logprobs, advantages, full_mask) | ||
|
||
# Test with partial mask | ||
partial_mask = torch.ones(batch_size, seq_len) | ||
partial_mask[:, -1] = 0 # Mask last token | ||
loss_partial = loss_fn(logprobs, ref_logprobs, advantages, partial_mask) | ||
|
||
# Losses should be different | ||
assert not torch.allclose(loss_full, loss_partial) | ||
|
||
@pytest.mark.timeout(10) | ||
@pytest.mark.asyncio | ||
def test_beta_parameter_effect(self, sample_data): | ||
"""Test that different beta values produce different losses.""" | ||
logprobs, ref_logprobs, advantages, padding_mask = sample_data | ||
|
||
loss_fn_1 = SimpleGRPOLoss(beta=0.1) | ||
loss_fn_2 = SimpleGRPOLoss(beta=0.5) | ||
|
||
loss_1 = loss_fn_1(logprobs, ref_logprobs, advantages, padding_mask) | ||
loss_2 = loss_fn_2(logprobs, ref_logprobs, advantages, padding_mask) | ||
|
||
# Different beta should produce different losses (unless KL is zero) | ||
# This test might be flaky if KL happens to be very small | ||
if not torch.allclose(ref_logprobs, logprobs, atol=1e-6): | ||
assert not torch.allclose(loss_1, loss_2, atol=1e-6) | ||
|
||
@pytest.mark.timeout(10) | ||
@pytest.mark.asyncio | ||
def test_zero_advantages(self, loss_fn): | ||
"""Test behavior with zero advantages.""" | ||
batch_size, seq_len = 2, 4 | ||
|
||
logprobs = torch.randn(batch_size, seq_len) | ||
ref_logprobs = torch.randn(batch_size, seq_len) | ||
advantages = torch.zeros(batch_size, seq_len) | ||
padding_mask = torch.ones(batch_size, seq_len) | ||
|
||
loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask) | ||
|
||
# With zero advantages, loss should only depend on KL term | ||
assert torch.isfinite(loss) | ||
|
||
@pytest.mark.timeout(10) | ||
@pytest.mark.asyncio | ||
def test_identical_policies(self, loss_fn): | ||
"""Test behavior when current and reference policies are identical.""" | ||
batch_size, seq_len = 2, 4 | ||
|
||
logprobs = torch.randn(batch_size, seq_len) | ||
ref_logprobs = logprobs.clone() # Identical policies | ||
advantages = torch.randn(batch_size, seq_len) | ||
padding_mask = torch.ones(batch_size, seq_len) | ||
|
||
loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask) | ||
|
||
# KL should be approximately zero for identical policies | ||
assert torch.isfinite(loss) | ||
|
||
@pytest.mark.timeout(10) | ||
@pytest.mark.asyncio | ||
def test_extreme_values(self, loss_fn): | ||
"""Test with extreme but valid values.""" | ||
batch_size, seq_len = 2, 3 | ||
|
||
# Large negative log probabilities (very low probabilities) | ||
logprobs = torch.full((batch_size, seq_len), -10.0) | ||
ref_logprobs = torch.full((batch_size, seq_len), -5.0) | ||
|
||
# Large advantages | ||
advantages = torch.full((batch_size, seq_len), 10.0) | ||
padding_mask = torch.ones(batch_size, seq_len) | ||
|
||
loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask) | ||
|
||
assert torch.isfinite(loss) | ||
assert not torch.isnan(loss) | ||
|
||
@pytest.mark.timeout(10) | ||
@pytest.mark.asyncio | ||
def test_numerical_stability(self, loss_fn): | ||
"""Test numerical stability with edge cases.""" | ||
batch_size, seq_len = 1, 2 | ||
|
||
# Test with very similar log probabilities | ||
logprobs = torch.tensor([[0.0, -1e-8]]) | ||
ref_logprobs = torch.tensor([[1e-8, 0.0]]) | ||
advantages = torch.tensor([[1.0, -1.0]]) | ||
padding_mask = torch.ones(batch_size, seq_len) | ||
|
||
loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask) | ||
|
||
assert torch.isfinite(loss) | ||
|
||
@pytest.mark.timeout(10) | ||
@pytest.mark.asyncio | ||
def test_all_masked_sequence(self, loss_fn): | ||
"""Test behavior when entire sequence is masked.""" | ||
batch_size, seq_len = 1, 3 | ||
|
||
logprobs = torch.randn(batch_size, seq_len) | ||
ref_logprobs = torch.randn(batch_size, seq_len) | ||
advantages = torch.randn(batch_size, seq_len) | ||
padding_mask = torch.zeros(batch_size, seq_len) # All masked | ||
|
||
loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask) | ||
|
||
# Should handle division by zero gracefully due to clamp(min=1.0) | ||
assert torch.isfinite(loss) | ||
|
||
@pytest.mark.timeout(10) | ||
@pytest.mark.asyncio | ||
def test_mathematical_correctness(self, loss_fn): | ||
"""Test mathematical correctness with simpler verification.""" | ||
# Test with known simple case | ||
logprobs = torch.tensor([[0.0]]) # log(1.0) = 0 | ||
ref_logprobs = torch.tensor([[0.0]]) # Same as current | ||
advantages = torch.tensor([[1.0]]) | ||
padding_mask = torch.ones(1, 1) | ||
|
||
loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask) | ||
|
||
# When logprobs == ref_logprobs, KL should be 0 | ||
# Loss should be -(1.0 * 1.0 - beta * 0) = -1.0 | ||
expected_loss = torch.tensor(-1.0) | ||
assert torch.allclose(loss, expected_loss, atol=1e-6) | ||
|
||
# Test symmetry: swapping positive and negative advantages | ||
advantages_pos = torch.tensor([[2.0]]) | ||
advantages_neg = torch.tensor([[-2.0]]) | ||
|
||
loss_pos = loss_fn(logprobs, ref_logprobs, advantages_pos, padding_mask) | ||
loss_neg = loss_fn(logprobs, ref_logprobs, advantages_neg, padding_mask) | ||
|
||
# Should be symmetric around some center point | ||
assert torch.isfinite(loss_pos) | ||
assert torch.isfinite(loss_neg) | ||
assert loss_pos != loss_neg # Should be different |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joecummings
I noticed that
logprobs - logprobs.detach()
will always be zero, sincelogprobs.detach()
is just logprobs with no gradient. That means torch.exp(0) is always 1, so this term simplifies to just advantages.Is there a specific reason for writing it this way? Or is it a leftover from a more general case (like multi-step or importance sampling)? Just wanted to check in case I’m missing some context!
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, this is just a direct translation of the code from TRL for ease of correctness testing: https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1664
They keep this term in for importance sampling (swapping out the second term for old logprobs).
I defer to you on whether or not to keep this expression for now :)