From da21e1df5fcf19018e0e99e446215a831c580c9d Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 21 Aug 2025 16:10:55 -0700 Subject: [PATCH 1/4] Add reward interface, math reward, unit tests --- src/forge/data/rewards/math.py | 56 ++++++ src/forge/interfaces.py | 13 +- tests/unit_tests/data/test_math_reward.py | 202 ++++++++++++++++++++++ 3 files changed, 269 insertions(+), 2 deletions(-) create mode 100644 src/forge/data/rewards/math.py create mode 100644 tests/unit_tests/data/test_math_reward.py diff --git a/src/forge/data/rewards/math.py b/src/forge/data/rewards/math.py new file mode 100644 index 000000000..06dd6cbc6 --- /dev/null +++ b/src/forge/data/rewards/math.py @@ -0,0 +1,56 @@ +import re +from typing import Optional + +from forge.interfaces import Reward + + +class MathReward(Reward): + """Reward class for evaluating math correctness.""" + + def __init__(self, tolerance: float = 1e-6, partial_credit: float = 0.1): + self.tolerance = tolerance + self.partial_credit = partial_credit + + def _to_float(self, text) -> Optional[float]: + """Safely parse a string into a float, or return None if invalid.""" + if text is None: + return None + try: + return float(str(text).strip()) + except (ValueError, TypeError): + return None + + def _extract_number(self, text: str) -> Optional[float]: + """Try to extract a numeric answer from text.""" + number_pattern = r"([+-]?\d+(?:\.\d+)?(?:e[+-]?\d+)?)" + patterns = [ + r"####\s*" + number_pattern, + r"(?:the\s+)?answer\s+is\s*" + number_pattern, + r"(?:answer:|result:)\s*" + number_pattern, + r"\$" + number_pattern, # currency + number_pattern, # fallback + r"=\s*" + number_pattern + r"\s*(?:\.|$)", + r"\b" + number_pattern + r"\s*(?:\.|$)", + ] + text = text.lower().strip() + for pattern in patterns: + matches = re.findall(pattern, text) + if matches: + return self._to_float(matches[-1]) + return None + + def __call__(self, prompt: str, response: str, target: str) -> float: + """Compute math correctness reward.""" + # Parse expected + expected_answer = self._to_float(target) + + # Parse response + model_answer = self._extract_number(response) + + # Scoring + if expected_answer is None or model_answer is None: + return self.partial_credit # Partial credit for attempting + + if abs(expected_answer - model_answer) < self.tolerance: + return 1.0 # Correct answer + return 0.0 # Incorrect answer diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index f19f379cb..b485fc791 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -7,10 +7,10 @@ from abc import ABC, abstractmethod from typing import Any -from monarch.actor import Actor, endpoint - from forge.types import Action, Message, Observation, State +from monarch.actor import Actor, endpoint + class Transform(ABC): """Abstract base class for observation transforms. @@ -150,3 +150,12 @@ def tokenize_messages( tuple[list[int], list[bool]]: The list of token ids and the list of masks. """ pass + + +class Reward(ABC): + """Abstract base class for reward models.""" + + @abstractmethod + def __call__(self, observation: Observation) -> float: + """Compute a reward for an observation.""" + pass diff --git a/tests/unit_tests/data/test_math_reward.py b/tests/unit_tests/data/test_math_reward.py new file mode 100644 index 000000000..a109492dd --- /dev/null +++ b/tests/unit_tests/data/test_math_reward.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest import mock + +from forge.data.rewards.math import MathReward + + +class TestMathReward(unittest.TestCase): + def setUp(self): + """Set up test fixtures before each test method.""" + self.reward = MathReward() + self.custom_reward = MathReward(tolerance=1e-3, partial_credit=0.2) + + def test_init_default_values(self): + """Test MathReward initialization with default values.""" + reward = MathReward() + self.assertEqual(reward.tolerance, 1e-6) + self.assertEqual(reward.partial_credit, 0.1) + + def test_init_custom_values(self): + """Test MathReward initialization with custom values.""" + reward = MathReward(tolerance=1e-3, partial_credit=0.2) + self.assertEqual(reward.tolerance, 1e-3) + self.assertEqual(reward.partial_credit, 0.2) + + def test_to_float_valid_numbers(self): + """Test _to_float with valid numeric strings.""" + self.assertEqual(self.reward._to_float("42"), 42.0) + self.assertEqual(self.reward._to_float("3.14"), 3.14) + self.assertEqual(self.reward._to_float("-5.5"), -5.5) + self.assertEqual(self.reward._to_float("0"), 0.0) + self.assertEqual(self.reward._to_float(" 123.45 "), 123.45) + + def test_to_float_invalid_inputs(self): + """Test _to_float with invalid inputs.""" + self.assertIsNone(self.reward._to_float("abc")) + self.assertIsNone(self.reward._to_float("")) + self.assertIsNone(self.reward._to_float("12.34.56")) + self.assertIsNone(self.reward._to_float("not a number")) + self.assertIsNone(self.reward._to_float(None)) + + def test_to_float_edge_cases(self): + """Test _to_float with edge cases.""" + self.assertEqual(self.reward._to_float("1e6"), 1000000.0) + self.assertEqual(self.reward._to_float("-1.5e-3"), -0.0015) + self.assertEqual(self.reward._to_float("inf"), float("inf")) + self.assertEqual(self.reward._to_float("-inf"), float("-inf")) + + def test_extract_number_gsm8k_format(self): + """Test _extract_number with GSM8K style format.""" + self.assertEqual(self.reward._extract_number("#### 42"), 42.0) + self.assertEqual(self.reward._extract_number("#### -3.14"), -3.14) + self.assertEqual(self.reward._extract_number("Some text #### 123.45"), 123.45) + + def test_extract_number_answer_patterns(self): + """Test _extract_number with various answer patterns.""" + self.assertEqual(self.reward._extract_number("The answer is 42"), 42.0) + self.assertEqual(self.reward._extract_number("answer is 3.14"), 3.14) + self.assertEqual(self.reward._extract_number("Answer: 123"), 123.0) + self.assertEqual(self.reward._extract_number("Result: -5.5"), -5.5) + + def test_extract_number_equals_pattern(self): + """Test _extract_number with equals sign patterns.""" + self.assertEqual(self.reward._extract_number("x = 42."), 42.0) + self.assertEqual(self.reward._extract_number("The result = 3.14"), 3.14) + self.assertEqual(self.reward._extract_number("calculation = -7.5."), -7.5) + + def test_extract_number_end_of_text(self): + """Test _extract_number with numbers at end of text.""" + self.assertEqual(self.reward._extract_number("The final result is 42."), 42.0) + self.assertEqual(self.reward._extract_number("We get 3.14"), 3.14) + self.assertEqual(self.reward._extract_number("Answer: -5.5."), -5.5) + + def test_extract_number_fallback_pattern(self): + """Test _extract_number with fallback pattern (any number).""" + self.assertEqual(self.reward._extract_number("There are 42 items"), 42.0) + self.assertEqual(self.reward._extract_number("Cost is $3.14 per item"), 3.14) + self.assertEqual(self.reward._extract_number("Temperature: -5.5 degrees"), -5.5) + + def test_extract_number_multiple_matches(self): + """Test _extract_number returns the last match when multiple numbers exist.""" + # Should return the last match from the pattern + self.assertEqual( + self.reward._extract_number("First 10, then 20, finally 30"), 30.0 + ) + self.assertEqual( + self.reward._extract_number("#### 5 but actually #### 10"), 10.0 + ) + + def test_extract_number_no_match(self): + """Test _extract_number when no numbers are found.""" + self.assertIsNone(self.reward._extract_number("No numbers here")) + self.assertIsNone(self.reward._extract_number("")) + self.assertIsNone(self.reward._extract_number("Just text")) + + def test_extract_number_case_insensitive(self): + """Test _extract_number is case insensitive.""" + self.assertEqual(self.reward._extract_number("THE ANSWER IS 42"), 42.0) + self.assertEqual(self.reward._extract_number("Answer: 3.14"), 3.14) + self.assertEqual(self.reward._extract_number("RESULT: 123"), 123.0) + + def test_call_correct_answer(self): + """Test __call__ with correct answers.""" + self.assertEqual(self.reward("prompt", "The answer is 42", "42"), 1.0) + self.assertEqual(self.reward("prompt", "#### 3.14", "3.14"), 1.0) + self.assertEqual(self.reward("prompt", "Result: -5.5", "-5.5"), 1.0) + + def test_call_within_tolerance(self): + """Test __call__ with answers within tolerance.""" + # Default tolerance is 1e-6 + self.assertEqual(self.reward("prompt", "42.0000001", "42"), 1.0) + self.assertEqual(self.reward("prompt", "3.1400001", "3.14"), 1.0) + + # Custom tolerance + self.assertEqual(self.custom_reward("prompt", "42.0001", "42"), 1.0) + self.assertEqual(self.custom_reward("prompt", "3.141", "3.14"), 1.0) + + def test_call_outside_tolerance(self): + """Test __call__ with answers outside tolerance.""" + self.assertEqual(self.reward("prompt", "42.1", "42"), 0.0) + self.assertEqual(self.reward("prompt", "3.15", "3.14"), 0.0) + self.assertEqual(self.custom_reward("prompt", "42.01", "42"), 0.0) + + def test_call_invalid_target(self): + """Test __call__ with invalid target values.""" + self.assertEqual( + self.reward("prompt", "42", "invalid"), self.reward.partial_credit + ) + self.assertEqual(self.reward("prompt", "42", ""), self.reward.partial_credit) + self.assertEqual( + self.reward("prompt", "42", "not a number"), self.reward.partial_credit + ) + + def test_call_invalid_response(self): + """Test __call__ with invalid response values.""" + self.assertEqual( + self.reward("prompt", "no number", "42"), self.reward.partial_credit + ) + self.assertEqual(self.reward("prompt", "", "42"), self.reward.partial_credit) + self.assertEqual( + self.reward("prompt", "just text", "42"), self.reward.partial_credit + ) + + def test_call_both_invalid(self): + """Test __call__ with both invalid target and response.""" + self.assertEqual( + self.reward("prompt", "no number", "invalid"), self.reward.partial_credit + ) + self.assertEqual(self.reward("prompt", "", ""), self.reward.partial_credit) + + def test_call_custom_partial_credit(self): + """Test __call__ uses custom partial credit value.""" + self.assertEqual(self.custom_reward("prompt", "no number", "42"), 0.2) + self.assertEqual(self.custom_reward("prompt", "42", "invalid"), 0.2) + + def test_call_zero_values(self): + """Test __call__ with zero values.""" + self.assertEqual(self.reward("prompt", "0", "0"), 1.0) + self.assertEqual(self.reward("prompt", "The answer is 0", "0.0"), 1.0) + + def test_call_negative_values(self): + """Test __call__ with negative values.""" + self.assertEqual(self.reward("prompt", "-42", "-42"), 1.0) + self.assertEqual(self.reward("prompt", "#### -3.14", "-3.14"), 1.0) + self.assertEqual(self.reward("prompt", "-5", "-4.9"), 0.0) + + def test_call_large_numbers(self): + """Test __call__ with large numbers.""" + self.assertEqual(self.reward("prompt", "1000000", "1000000"), 1.0) + self.assertEqual(self.reward("prompt", "1e6", "1000000"), 1.0) + self.assertEqual(self.reward("prompt", "1000001", "1000000"), 0.0) + + def test_call_small_numbers(self): + """Test __call__ with very small numbers.""" + self.assertEqual(self.reward("prompt", "0.000001", "0.000001"), 1.0) + self.assertEqual(self.reward("prompt", "1e-6", "0.000001"), 1.0) + + def test_call_complex_response_text(self): + """Test __call__ with complex response text containing multiple elements.""" + response = """ + Let me solve this step by step: + First, I calculate 2 + 3 = 5 + Then, I multiply by 4: 5 * 4 = 20 + Finally, I subtract 8: 20 - 8 = 12 + #### 12 + """ + self.assertEqual(self.reward("prompt", response, "12"), 1.0) + + def test_call_with_units_and_formatting(self): + """Test __call__ with responses containing units and formatting.""" + self.assertEqual(self.reward("prompt", "The cost is $42.50", "42.5"), 1.0) + self.assertEqual(self.reward("prompt", "Distance: 3.14 meters", "3.14"), 1.0) + self.assertEqual(self.reward("prompt", "Temperature is -5.5°C", "-5.5"), 1.0) + + +if __name__ == "__main__": + unittest.main() From fee1ad5853c4930964bca09700ea7348c4c39e56 Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 28 Aug 2025 10:20:51 -0700 Subject: [PATCH 2/4] log trainer metrics --- apps/grpo/main.py | 55 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 7fd10736f..5c34b2035 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -112,11 +112,24 @@ def __init__( self.model.parameters(), lr=self.learning_rate ) + # Initialize metrics storage + self.log_dict = {} + self.logger.info(f"Model initialized on {self.device}") + @endpoint + async def get_metrics(self): + """Return metrics dict for external logger to log.""" + return self.log_dict.copy() + @endpoint async def train_step(self, batch: list[Episode]): total_loss = 0.0 + total_kl_loss = 0.0 + total_pg_loss = 0.0 + total_ratio_mean = 0.0 + total_ratio_std = 0.0 + total_response_len = 0.0 num_groups_processed = 0 for episode in batch: @@ -170,6 +183,17 @@ async def train_step(self, batch: list[Episode]): # Total GRPO loss loss = pg_loss + kl_penalty total_loss += loss.item() + total_kl_loss += kl_penalty.item() + total_pg_loss += pg_loss.item() + total_ratio_mean += ratio.detach().float().cpu().numpy().mean() + total_ratio_std += ratio.detach().float().cpu().numpy().std() + + # Calculate average response length for this episode + episode_response_len = sum( + len(response) for response in response_texts + ) / len(response_texts) + total_response_len += episode_response_len + num_groups_processed += len(groups) self.optimizer.zero_grad() @@ -180,7 +204,28 @@ async def train_step(self, batch: list[Episode]): self.optimizer.step() - avg_loss = total_loss / len(batch) if batch else 0.0 + # Compute averaged metrics across the batch + if batch: + avg_loss = total_loss / len(batch) + avg_kl_loss = total_kl_loss / len(batch) + avg_pg_loss = total_pg_loss / len(batch) + avg_ratio_mean = total_ratio_mean / len(batch) + avg_ratio_std = total_ratio_std / len(batch) + avg_response_len = total_response_len / len(batch) + else: + avg_loss = avg_kl_loss = avg_pg_loss = avg_ratio_mean = avg_ratio_std = ( + avg_response_len + ) = 0.0 + + # Store averaged metrics for external logging + self.log_dict = { + "loss/total": avg_loss, + "loss/kl": avg_kl_loss, + "loss/policy": avg_pg_loss, + "metrics/ratio_mean": avg_ratio_mean, + "metrics/ratio_std": avg_ratio_std, + "metrics/response_len": avg_response_len, + } return {"loss": avg_loss, "groups_processed": num_groups_processed} @@ -460,7 +505,7 @@ async def continuous_rollouts(): print( f"Generated {rollout_count} rollouts w/ average reward {avg_reward}" ) - logger.log("reward/rollout", avg_reward, rollout_count) + logger.log("metrics/reward_per_rollout", avg_reward, rollout_count) async def continuous_training(): training_step = 0 @@ -476,7 +521,11 @@ async def continuous_training(): if training_result: loss_value = training_result.get("loss", 0.0) print(f"Latest loss: {loss_value}") - logger.log("loss/training_step", loss_value, training_step) + + # Get and log detailed metrics + metrics = await trainer.get_metrics.choose() + logger.log_dict(metrics, training_step) + # await trainer.update_weights(policy) print("Starting GRPO training loops...") From 68a796c48452d19f372877c1140bfcbfe35f2fae Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 28 Aug 2025 10:21:55 -0700 Subject: [PATCH 3/4] remove file --- tests/unit_tests/data/test_math_reward.py | 202 ---------------------- 1 file changed, 202 deletions(-) delete mode 100644 tests/unit_tests/data/test_math_reward.py diff --git a/tests/unit_tests/data/test_math_reward.py b/tests/unit_tests/data/test_math_reward.py deleted file mode 100644 index a109492dd..000000000 --- a/tests/unit_tests/data/test_math_reward.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -from unittest import mock - -from forge.data.rewards.math import MathReward - - -class TestMathReward(unittest.TestCase): - def setUp(self): - """Set up test fixtures before each test method.""" - self.reward = MathReward() - self.custom_reward = MathReward(tolerance=1e-3, partial_credit=0.2) - - def test_init_default_values(self): - """Test MathReward initialization with default values.""" - reward = MathReward() - self.assertEqual(reward.tolerance, 1e-6) - self.assertEqual(reward.partial_credit, 0.1) - - def test_init_custom_values(self): - """Test MathReward initialization with custom values.""" - reward = MathReward(tolerance=1e-3, partial_credit=0.2) - self.assertEqual(reward.tolerance, 1e-3) - self.assertEqual(reward.partial_credit, 0.2) - - def test_to_float_valid_numbers(self): - """Test _to_float with valid numeric strings.""" - self.assertEqual(self.reward._to_float("42"), 42.0) - self.assertEqual(self.reward._to_float("3.14"), 3.14) - self.assertEqual(self.reward._to_float("-5.5"), -5.5) - self.assertEqual(self.reward._to_float("0"), 0.0) - self.assertEqual(self.reward._to_float(" 123.45 "), 123.45) - - def test_to_float_invalid_inputs(self): - """Test _to_float with invalid inputs.""" - self.assertIsNone(self.reward._to_float("abc")) - self.assertIsNone(self.reward._to_float("")) - self.assertIsNone(self.reward._to_float("12.34.56")) - self.assertIsNone(self.reward._to_float("not a number")) - self.assertIsNone(self.reward._to_float(None)) - - def test_to_float_edge_cases(self): - """Test _to_float with edge cases.""" - self.assertEqual(self.reward._to_float("1e6"), 1000000.0) - self.assertEqual(self.reward._to_float("-1.5e-3"), -0.0015) - self.assertEqual(self.reward._to_float("inf"), float("inf")) - self.assertEqual(self.reward._to_float("-inf"), float("-inf")) - - def test_extract_number_gsm8k_format(self): - """Test _extract_number with GSM8K style format.""" - self.assertEqual(self.reward._extract_number("#### 42"), 42.0) - self.assertEqual(self.reward._extract_number("#### -3.14"), -3.14) - self.assertEqual(self.reward._extract_number("Some text #### 123.45"), 123.45) - - def test_extract_number_answer_patterns(self): - """Test _extract_number with various answer patterns.""" - self.assertEqual(self.reward._extract_number("The answer is 42"), 42.0) - self.assertEqual(self.reward._extract_number("answer is 3.14"), 3.14) - self.assertEqual(self.reward._extract_number("Answer: 123"), 123.0) - self.assertEqual(self.reward._extract_number("Result: -5.5"), -5.5) - - def test_extract_number_equals_pattern(self): - """Test _extract_number with equals sign patterns.""" - self.assertEqual(self.reward._extract_number("x = 42."), 42.0) - self.assertEqual(self.reward._extract_number("The result = 3.14"), 3.14) - self.assertEqual(self.reward._extract_number("calculation = -7.5."), -7.5) - - def test_extract_number_end_of_text(self): - """Test _extract_number with numbers at end of text.""" - self.assertEqual(self.reward._extract_number("The final result is 42."), 42.0) - self.assertEqual(self.reward._extract_number("We get 3.14"), 3.14) - self.assertEqual(self.reward._extract_number("Answer: -5.5."), -5.5) - - def test_extract_number_fallback_pattern(self): - """Test _extract_number with fallback pattern (any number).""" - self.assertEqual(self.reward._extract_number("There are 42 items"), 42.0) - self.assertEqual(self.reward._extract_number("Cost is $3.14 per item"), 3.14) - self.assertEqual(self.reward._extract_number("Temperature: -5.5 degrees"), -5.5) - - def test_extract_number_multiple_matches(self): - """Test _extract_number returns the last match when multiple numbers exist.""" - # Should return the last match from the pattern - self.assertEqual( - self.reward._extract_number("First 10, then 20, finally 30"), 30.0 - ) - self.assertEqual( - self.reward._extract_number("#### 5 but actually #### 10"), 10.0 - ) - - def test_extract_number_no_match(self): - """Test _extract_number when no numbers are found.""" - self.assertIsNone(self.reward._extract_number("No numbers here")) - self.assertIsNone(self.reward._extract_number("")) - self.assertIsNone(self.reward._extract_number("Just text")) - - def test_extract_number_case_insensitive(self): - """Test _extract_number is case insensitive.""" - self.assertEqual(self.reward._extract_number("THE ANSWER IS 42"), 42.0) - self.assertEqual(self.reward._extract_number("Answer: 3.14"), 3.14) - self.assertEqual(self.reward._extract_number("RESULT: 123"), 123.0) - - def test_call_correct_answer(self): - """Test __call__ with correct answers.""" - self.assertEqual(self.reward("prompt", "The answer is 42", "42"), 1.0) - self.assertEqual(self.reward("prompt", "#### 3.14", "3.14"), 1.0) - self.assertEqual(self.reward("prompt", "Result: -5.5", "-5.5"), 1.0) - - def test_call_within_tolerance(self): - """Test __call__ with answers within tolerance.""" - # Default tolerance is 1e-6 - self.assertEqual(self.reward("prompt", "42.0000001", "42"), 1.0) - self.assertEqual(self.reward("prompt", "3.1400001", "3.14"), 1.0) - - # Custom tolerance - self.assertEqual(self.custom_reward("prompt", "42.0001", "42"), 1.0) - self.assertEqual(self.custom_reward("prompt", "3.141", "3.14"), 1.0) - - def test_call_outside_tolerance(self): - """Test __call__ with answers outside tolerance.""" - self.assertEqual(self.reward("prompt", "42.1", "42"), 0.0) - self.assertEqual(self.reward("prompt", "3.15", "3.14"), 0.0) - self.assertEqual(self.custom_reward("prompt", "42.01", "42"), 0.0) - - def test_call_invalid_target(self): - """Test __call__ with invalid target values.""" - self.assertEqual( - self.reward("prompt", "42", "invalid"), self.reward.partial_credit - ) - self.assertEqual(self.reward("prompt", "42", ""), self.reward.partial_credit) - self.assertEqual( - self.reward("prompt", "42", "not a number"), self.reward.partial_credit - ) - - def test_call_invalid_response(self): - """Test __call__ with invalid response values.""" - self.assertEqual( - self.reward("prompt", "no number", "42"), self.reward.partial_credit - ) - self.assertEqual(self.reward("prompt", "", "42"), self.reward.partial_credit) - self.assertEqual( - self.reward("prompt", "just text", "42"), self.reward.partial_credit - ) - - def test_call_both_invalid(self): - """Test __call__ with both invalid target and response.""" - self.assertEqual( - self.reward("prompt", "no number", "invalid"), self.reward.partial_credit - ) - self.assertEqual(self.reward("prompt", "", ""), self.reward.partial_credit) - - def test_call_custom_partial_credit(self): - """Test __call__ uses custom partial credit value.""" - self.assertEqual(self.custom_reward("prompt", "no number", "42"), 0.2) - self.assertEqual(self.custom_reward("prompt", "42", "invalid"), 0.2) - - def test_call_zero_values(self): - """Test __call__ with zero values.""" - self.assertEqual(self.reward("prompt", "0", "0"), 1.0) - self.assertEqual(self.reward("prompt", "The answer is 0", "0.0"), 1.0) - - def test_call_negative_values(self): - """Test __call__ with negative values.""" - self.assertEqual(self.reward("prompt", "-42", "-42"), 1.0) - self.assertEqual(self.reward("prompt", "#### -3.14", "-3.14"), 1.0) - self.assertEqual(self.reward("prompt", "-5", "-4.9"), 0.0) - - def test_call_large_numbers(self): - """Test __call__ with large numbers.""" - self.assertEqual(self.reward("prompt", "1000000", "1000000"), 1.0) - self.assertEqual(self.reward("prompt", "1e6", "1000000"), 1.0) - self.assertEqual(self.reward("prompt", "1000001", "1000000"), 0.0) - - def test_call_small_numbers(self): - """Test __call__ with very small numbers.""" - self.assertEqual(self.reward("prompt", "0.000001", "0.000001"), 1.0) - self.assertEqual(self.reward("prompt", "1e-6", "0.000001"), 1.0) - - def test_call_complex_response_text(self): - """Test __call__ with complex response text containing multiple elements.""" - response = """ - Let me solve this step by step: - First, I calculate 2 + 3 = 5 - Then, I multiply by 4: 5 * 4 = 20 - Finally, I subtract 8: 20 - 8 = 12 - #### 12 - """ - self.assertEqual(self.reward("prompt", response, "12"), 1.0) - - def test_call_with_units_and_formatting(self): - """Test __call__ with responses containing units and formatting.""" - self.assertEqual(self.reward("prompt", "The cost is $42.50", "42.5"), 1.0) - self.assertEqual(self.reward("prompt", "Distance: 3.14 meters", "3.14"), 1.0) - self.assertEqual(self.reward("prompt", "Temperature is -5.5°C", "-5.5"), 1.0) - - -if __name__ == "__main__": - unittest.main() From 830d2dc593dc39fb8ea56aed7dc9bfd847d2fe36 Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 28 Aug 2025 10:31:50 -0700 Subject: [PATCH 4/4] log advantages --- apps/grpo/main.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 5c34b2035..c5df455c4 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -130,6 +130,7 @@ async def train_step(self, batch: list[Episode]): total_ratio_mean = 0.0 total_ratio_std = 0.0 total_response_len = 0.0 + total_advantages = 0.0 num_groups_processed = 0 for episode in batch: @@ -194,6 +195,10 @@ async def train_step(self, batch: list[Episode]): ) / len(response_texts) total_response_len += episode_response_len + # Calculate average advantages for this episode + episode_advantages = advantages_tensor.detach().float().cpu().numpy().mean() + total_advantages += episode_advantages + num_groups_processed += len(groups) self.optimizer.zero_grad() @@ -212,10 +217,11 @@ async def train_step(self, batch: list[Episode]): avg_ratio_mean = total_ratio_mean / len(batch) avg_ratio_std = total_ratio_std / len(batch) avg_response_len = total_response_len / len(batch) + avg_advantages = total_advantages / len(batch) else: avg_loss = avg_kl_loss = avg_pg_loss = avg_ratio_mean = avg_ratio_std = ( avg_response_len - ) = 0.0 + ) = avg_advantages = 0.0 # Store averaged metrics for external logging self.log_dict = { @@ -225,6 +231,7 @@ async def train_step(self, batch: list[Episode]): "metrics/ratio_mean": avg_ratio_mean, "metrics/ratio_std": avg_ratio_std, "metrics/response_len": avg_response_len, + "metrics/advantages": avg_advantages, } return {"loss": avg_loss, "groups_processed": num_groups_processed} @@ -505,7 +512,7 @@ async def continuous_rollouts(): print( f"Generated {rollout_count} rollouts w/ average reward {avg_reward}" ) - logger.log("metrics/reward_per_rollout", avg_reward, rollout_count) + logger.log("metrics/reward_per_ten_rollout", avg_reward, rollout_count) async def continuous_training(): training_step = 0