Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 1 addition & 22 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.losses.grpo_loss import SimpleGRPOLoss
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from omegaconf import DictConfig
from torch import nn
from torchstore.state_dict_utils import DELIM
from torchtitan.config.job_config import Model as TitanJobModelConfig
from transformers import AutoModelForCausalLM
Expand All @@ -49,27 +49,6 @@ def compute_logprobs(
return logprobs


class SimpleGRPOLoss(nn.Module):
"""Simplified GRPO Loss for simplified single step updates
Inspired by the Hugging Face TRL implementation:
https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624.
"""

def __init__(self, beta: float = 0.1):
super().__init__()
self.beta = beta

def forward(self, logprobs, ref_logprobs, advantages, padding_mask):
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
per_token_loss = -(per_token_policy_loss - self.beta * kl)
loss = (
((per_token_loss * padding_mask).sum(dim=1))
/ (padding_mask.sum(dim=1).clamp(min=1.0))
).mean()
return loss


@dataclass
class Episode:
# TODO: add adtional layer for multi-turn
Expand Down
5 changes: 5 additions & 0 deletions src/forge/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
29 changes: 29 additions & 0 deletions src/forge/losses/grpo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn


class SimpleGRPOLoss(nn.Module):
"""Simplified GRPO Loss for simplified single step updates
Inspired by the Hugging Face TRL implementation:
https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624.
"""

def __init__(self, beta: float = 0.1):
super().__init__()
self.beta = beta

def forward(self, logprobs, ref_logprobs, advantages, padding_mask):
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings

I noticed that logprobs - logprobs.detach() will always be zero, since logprobs.detach() is just logprobs with no gradient. That means torch.exp(0) is always 1, so this term simplifies to just advantages.

Is there a specific reason for writing it this way? Or is it a leftover from a more general case (like multi-step or importance sampling)? Just wanted to check in case I’m missing some context!

Copy link
Member

@joecummings joecummings Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is just a direct translation of the code from TRL for ease of correctness testing: https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1664

They keep this term in for importance sampling (swapping out the second term for old logprobs).

I defer to you on whether or not to keep this expression for now :)

per_token_loss = -(per_token_policy_loss - self.beta * kl)
loss = (
((per_token_loss * padding_mask).sum(dim=1))
/ (padding_mask.sum(dim=1).clamp(min=1.0))
).mean()
return loss
1 change: 0 additions & 1 deletion tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ Ensure you have all development dependencies installed:

```bash
pip install -r dev-requirements.txt
pip install -r requirements.txt
```

### Running Integration Tests
Expand Down
1 change: 1 addition & 0 deletions tests/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest==7.3.2
240 changes: 240 additions & 0 deletions tests/unit_tests/losses/test_grpo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from forge.losses.grpo_loss import SimpleGRPOLoss


class TestSimpleGRPOLoss:
@pytest.fixture
def loss_fn(self):
"""Create a GRPO loss instance with default beta."""
return SimpleGRPOLoss(beta=0.1)

@pytest.fixture
def sample_data(self):
"""Create sample input data for testing."""
batch_size, seq_len = 2, 4

# Create log probabilities (should be negative)
logprobs = torch.log(torch.rand(batch_size, seq_len) * 0.9 + 0.1)
ref_logprobs = torch.log(torch.rand(batch_size, seq_len) * 0.9 + 0.1)

# Create advantages (can be positive or negative)
advantages = torch.randn(batch_size, seq_len)

# Create padding mask (1 for valid tokens, 0 for padding)
padding_mask = torch.ones(batch_size, seq_len)
padding_mask[0, -1] = 0 # Add some padding
padding_mask[1, -2:] = 0 # Add more padding

return logprobs, ref_logprobs, advantages, padding_mask

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_forward_basic(self, loss_fn, sample_data):
"""Test basic forward pass."""
logprobs, ref_logprobs, advantages, padding_mask = sample_data

loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask)

# Loss should be a scalar
assert loss.dim() == 0
assert torch.isfinite(loss)
assert not torch.isnan(loss)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_output_shape(self, loss_fn):
"""Test output shape for different input sizes."""
for batch_size in [1, 3, 8]:
for seq_len in [1, 10, 32]:
logprobs = torch.randn(batch_size, seq_len)
ref_logprobs = torch.randn(batch_size, seq_len)
advantages = torch.randn(batch_size, seq_len)
padding_mask = torch.ones(batch_size, seq_len)

loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask)
assert loss.shape == torch.Size([])

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_gradient_flow(self, loss_fn, sample_data):
"""Test that gradients flow through logprobs."""
logprobs, ref_logprobs, advantages, padding_mask = sample_data
logprobs.requires_grad_(True)

loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask)
loss.backward()

assert logprobs.grad is not None
assert not torch.isnan(logprobs.grad).any()
assert torch.isfinite(logprobs.grad).all()

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_no_gradient_to_ref_logprobs(self, loss_fn, sample_data):
"""Test that gradients don't flow to reference logprobs."""
logprobs, ref_logprobs, advantages, padding_mask = sample_data
ref_logprobs.requires_grad_(True)

loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask)
loss.backward()

# ref_logprobs should receive gradients (it's used in KL computation)
assert ref_logprobs.grad is not None

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_padding_mask_effect(self, loss_fn):
"""Test that padding mask correctly ignores padded tokens."""
batch_size, seq_len = 2, 4

logprobs = torch.randn(batch_size, seq_len)
ref_logprobs = torch.randn(batch_size, seq_len)
advantages = torch.randn(batch_size, seq_len)

# Test with full mask
full_mask = torch.ones(batch_size, seq_len)
loss_full = loss_fn(logprobs, ref_logprobs, advantages, full_mask)

# Test with partial mask
partial_mask = torch.ones(batch_size, seq_len)
partial_mask[:, -1] = 0 # Mask last token
loss_partial = loss_fn(logprobs, ref_logprobs, advantages, partial_mask)

# Losses should be different
assert not torch.allclose(loss_full, loss_partial)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_beta_parameter_effect(self, sample_data):
"""Test that different beta values produce different losses."""
logprobs, ref_logprobs, advantages, padding_mask = sample_data

loss_fn_1 = SimpleGRPOLoss(beta=0.1)
loss_fn_2 = SimpleGRPOLoss(beta=0.5)

loss_1 = loss_fn_1(logprobs, ref_logprobs, advantages, padding_mask)
loss_2 = loss_fn_2(logprobs, ref_logprobs, advantages, padding_mask)

# Different beta should produce different losses (unless KL is zero)
# This test might be flaky if KL happens to be very small
if not torch.allclose(ref_logprobs, logprobs, atol=1e-6):
assert not torch.allclose(loss_1, loss_2, atol=1e-6)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_zero_advantages(self, loss_fn):
"""Test behavior with zero advantages."""
batch_size, seq_len = 2, 4

logprobs = torch.randn(batch_size, seq_len)
ref_logprobs = torch.randn(batch_size, seq_len)
advantages = torch.zeros(batch_size, seq_len)
padding_mask = torch.ones(batch_size, seq_len)

loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask)

# With zero advantages, loss should only depend on KL term
assert torch.isfinite(loss)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_identical_policies(self, loss_fn):
"""Test behavior when current and reference policies are identical."""
batch_size, seq_len = 2, 4

logprobs = torch.randn(batch_size, seq_len)
ref_logprobs = logprobs.clone() # Identical policies
advantages = torch.randn(batch_size, seq_len)
padding_mask = torch.ones(batch_size, seq_len)

loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask)

# KL should be approximately zero for identical policies
assert torch.isfinite(loss)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_extreme_values(self, loss_fn):
"""Test with extreme but valid values."""
batch_size, seq_len = 2, 3

# Large negative log probabilities (very low probabilities)
logprobs = torch.full((batch_size, seq_len), -10.0)
ref_logprobs = torch.full((batch_size, seq_len), -5.0)

# Large advantages
advantages = torch.full((batch_size, seq_len), 10.0)
padding_mask = torch.ones(batch_size, seq_len)

loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask)

assert torch.isfinite(loss)
assert not torch.isnan(loss)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_numerical_stability(self, loss_fn):
"""Test numerical stability with edge cases."""
batch_size, seq_len = 1, 2

# Test with very similar log probabilities
logprobs = torch.tensor([[0.0, -1e-8]])
ref_logprobs = torch.tensor([[1e-8, 0.0]])
advantages = torch.tensor([[1.0, -1.0]])
padding_mask = torch.ones(batch_size, seq_len)

loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask)

assert torch.isfinite(loss)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_all_masked_sequence(self, loss_fn):
"""Test behavior when entire sequence is masked."""
batch_size, seq_len = 1, 3

logprobs = torch.randn(batch_size, seq_len)
ref_logprobs = torch.randn(batch_size, seq_len)
advantages = torch.randn(batch_size, seq_len)
padding_mask = torch.zeros(batch_size, seq_len) # All masked

loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask)

# Should handle division by zero gracefully due to clamp(min=1.0)
assert torch.isfinite(loss)

@pytest.mark.timeout(10)
@pytest.mark.asyncio
def test_mathematical_correctness(self, loss_fn):
"""Test mathematical correctness with simpler verification."""
# Test with known simple case
logprobs = torch.tensor([[0.0]]) # log(1.0) = 0
ref_logprobs = torch.tensor([[0.0]]) # Same as current
advantages = torch.tensor([[1.0]])
padding_mask = torch.ones(1, 1)

loss = loss_fn(logprobs, ref_logprobs, advantages, padding_mask)

# When logprobs == ref_logprobs, KL should be 0
# Loss should be -(1.0 * 1.0 - beta * 0) = -1.0
expected_loss = torch.tensor(-1.0)
assert torch.allclose(loss, expected_loss, atol=1e-6)

# Test symmetry: swapping positive and negative advantages
advantages_pos = torch.tensor([[2.0]])
advantages_neg = torch.tensor([[-2.0]])

loss_pos = loss_fn(logprobs, ref_logprobs, advantages_pos, padding_mask)
loss_neg = loss_fn(logprobs, ref_logprobs, advantages_neg, padding_mask)

# Should be symmetric around some center point
assert torch.isfinite(loss_pos)
assert torch.isfinite(loss_neg)
assert loss_pos != loss_neg # Should be different
Loading