Skip to content

Conversation

Ritesh1905
Copy link
Contributor

@Ritesh1905 Ritesh1905 commented Sep 17, 2025

Adds basic unit tests for the GRPO loss.

(forge) [[email protected] /data/users/rithesh/forge/tests (rithesh/grpo_tests)]$ pytest unit_tests/losses/test_grpo_loss.py -v
=========================================================================== test session starts ============================================================================
platform linux -- Python 3.10.18, pytest-7.3.2, pluggy-1.6.0 -- /home/rithesh/.conda/envs/forge/bin/python3.10
cachedir: .pytest_cache
rootdir: /data/users/rithesh/forge
configfile: pyproject.toml
plugins: typeguard-4.4.4, anyio-4.10.0
collected 12 items                                                                                                                                                         

unit_tests/losses/test_grpo_loss.py::TestSimpleGRPOLoss::test_forward_basic PASSED                                                                                   [  8%]
unit_tests/losses/test_grpo_loss.py::TestSimpleGRPOLoss::test_output_shape PASSED                                                                                    [ 16%]
unit_tests/losses/test_grpo_loss.py::TestSimpleGRPOLoss::test_gradient_flow PASSED                                                                                   [ 25%]
unit_tests/losses/test_grpo_loss.py::TestSimpleGRPOLoss::test_no_gradient_to_ref_logprobs PASSED                                                                     [ 33%]
unit_tests/losses/test_grpo_loss.py::TestSimpleGRPOLoss::test_padding_mask_effect PASSED                                                                             [ 41%]
unit_tests/losses/test_grpo_loss.py::TestSimpleGRPOLoss::test_beta_parameter_effect PASSED                                                                           [ 50%]
unit_tests/losses/test_grpo_loss.py::TestSimpleGRPOLoss::test_zero_advantages PASSED                                                                                 [ 58%]
unit_tests/losses/test_grpo_loss.py::TestSimpleGRPOLoss::test_identical_policies PASSED                                                                              [ 66%]
unit_tests/losses/test_grpo_loss.py::TestSimpleGRPOLoss::test_extreme_values PASSED                                                                                  [ 75%]
unit_tests/losses/test_grpo_loss.py::TestSimpleGRPOLoss::test_numerical_stability PASSED                                                                             [ 83%]
unit_tests/losses/test_grpo_loss.py::TestSimpleGRPOLoss::test_all_masked_sequence PASSED                                                                             [ 91%]
unit_tests/losses/test_grpo_loss.py::TestSimpleGRPOLoss::test_mathematical_correctness PASSED                                                                        [100%]

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 17, 2025
@Ritesh1905 Ritesh1905 changed the title GRPO basic unit tests GRPO Loss basic unit tests Sep 17, 2025
@Ritesh1905 Ritesh1905 marked this pull request as ready for review September 17, 2025 01:52

def forward(self, logprobs, ref_logprobs, advantages, padding_mask):
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings

I noticed that logprobs - logprobs.detach() will always be zero, since logprobs.detach() is just logprobs with no gradient. That means torch.exp(0) is always 1, so this term simplifies to just advantages.

Is there a specific reason for writing it this way? Or is it a leftover from a more general case (like multi-step or importance sampling)? Just wanted to check in case I’m missing some context!

Copy link
Member

@joecummings joecummings Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is just a direct translation of the code from TRL for ease of correctness testing: https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1664

They keep this term in for importance sampling (swapping out the second term for old logprobs).

I defer to you on whether or not to keep this expression for now :)

Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome @Ritesh1905 , thank you!

@allenwang28 allenwang28 merged commit 636e758 into main Sep 17, 2025
5 checks passed
@Ritesh1905 Ritesh1905 deleted the rithesh/grpo_tests branch October 7, 2025 17:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants