Skip to content

Commit b74a47c

Browse files
committed
Update rewards
1 parent 55c32be commit b74a47c

File tree

4 files changed

+242
-185
lines changed

4 files changed

+242
-185
lines changed

apps/grpo/qwen3_1_7b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# Global configuration
44
group_size: 8
5-
batch_size: 16
5+
batch_size: 8
66
max_req_tokens: 512
77
max_res_tokens: 512
88
model: "Qwen/Qwen3-1.7B"

src/forge/data/rewards.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import re
8-
from typing import Optional
98

109
from forge.interfaces import Reward
1110

@@ -71,11 +70,13 @@ def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0):
7170
self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE)
7271

7372
def __call__(self, prompt: str, response: str, target: str | None = None) -> float:
74-
matches = self._THINK_BLOCK_RE.findall(response or "")
73+
"""Compute thinking reward."""
74+
if not response:
75+
return 0.0
76+
77+
matches = self._THINK_BLOCK_RE.findall(response)
7578
has_well_formed = any(len(re.sub(r"\s+", "", m)) >= 1 for m in matches)
76-
has_attempt = bool(self._THINK_TAG_ATTEMPT_RE.search(response or "")) or bool(
77-
matches
78-
)
79+
has_attempt = bool(self._THINK_TAG_ATTEMPT_RE.search(response)) or bool(matches)
7980
if has_well_formed:
8081
return self.full_reward
8182
elif has_attempt:

tests/unit_tests/rl/test_math_reward.py

Lines changed: 108 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import unittest
8-
from unittest import mock
98

109
from forge.data.rewards import MathReward
1110

@@ -36,6 +35,13 @@ def test_to_float_valid_numbers(self):
3635
self.assertEqual(self.reward._to_float("0"), 0.0)
3736
self.assertEqual(self.reward._to_float(" 123.45 "), 123.45)
3837

38+
def test_to_float_with_currency_and_formatting(self):
39+
"""Test _to_float with currency symbols and commas."""
40+
self.assertEqual(self.reward._to_float("$42"), 42.0)
41+
self.assertEqual(self.reward._to_float("$1,000"), 1000.0)
42+
self.assertEqual(self.reward._to_float("1,234.56"), 1234.56)
43+
self.assertEqual(self.reward._to_float("$ 42.50 "), 42.5)
44+
3945
def test_to_float_invalid_inputs(self):
4046
"""Test _to_float with invalid inputs."""
4147
self.assertIsNone(self.reward._to_float("abc"))
@@ -48,154 +54,140 @@ def test_to_float_edge_cases(self):
4854
"""Test _to_float with edge cases."""
4955
self.assertEqual(self.reward._to_float("1e6"), 1000000.0)
5056
self.assertEqual(self.reward._to_float("-1.5e-3"), -0.0015)
51-
self.assertEqual(self.reward._to_float("inf"), float("inf"))
52-
self.assertEqual(self.reward._to_float("-inf"), float("-inf"))
53-
54-
def test_extract_number_gsm8k_format(self):
55-
"""Test _extract_number with GSM8K style format."""
56-
self.assertEqual(self.reward._extract_number("#### 42"), 42.0)
57-
self.assertEqual(self.reward._extract_number("#### -3.14"), -3.14)
58-
self.assertEqual(self.reward._extract_number("Some text #### 123.45"), 123.45)
59-
60-
def test_extract_number_answer_patterns(self):
61-
"""Test _extract_number with various answer patterns."""
62-
self.assertEqual(self.reward._extract_number("The answer is 42"), 42.0)
63-
self.assertEqual(self.reward._extract_number("answer is 3.14"), 3.14)
64-
self.assertEqual(self.reward._extract_number("Answer: 123"), 123.0)
65-
self.assertEqual(self.reward._extract_number("Result: -5.5"), -5.5)
66-
67-
def test_extract_number_equals_pattern(self):
68-
"""Test _extract_number with equals sign patterns."""
69-
self.assertEqual(self.reward._extract_number("x = 42."), 42.0)
70-
self.assertEqual(self.reward._extract_number("The result = 3.14"), 3.14)
71-
self.assertEqual(self.reward._extract_number("calculation = -7.5."), -7.5)
72-
73-
def test_extract_number_end_of_text(self):
74-
"""Test _extract_number with numbers at end of text."""
75-
self.assertEqual(self.reward._extract_number("The final result is 42."), 42.0)
76-
self.assertEqual(self.reward._extract_number("We get 3.14"), 3.14)
77-
self.assertEqual(self.reward._extract_number("Answer: -5.5."), -5.5)
78-
79-
def test_extract_number_fallback_pattern(self):
80-
"""Test _extract_number with fallback pattern (any number)."""
81-
self.assertEqual(self.reward._extract_number("There are 42 items"), 42.0)
82-
self.assertEqual(self.reward._extract_number("Cost is $3.14 per item"), 3.14)
83-
self.assertEqual(self.reward._extract_number("Temperature: -5.5 degrees"), -5.5)
84-
85-
def test_extract_number_multiple_matches(self):
86-
"""Test _extract_number returns the last match when multiple numbers exist."""
87-
# Should return the last match from the pattern
88-
self.assertEqual(
89-
self.reward._extract_number("First 10, then 20, finally 30"), 30.0
90-
)
91-
self.assertEqual(
92-
self.reward._extract_number("#### 5 but actually #### 10"), 10.0
93-
)
9457

95-
def test_extract_number_no_match(self):
96-
"""Test _extract_number when no numbers are found."""
97-
self.assertIsNone(self.reward._extract_number("No numbers here"))
98-
self.assertIsNone(self.reward._extract_number(""))
99-
self.assertIsNone(self.reward._extract_number("Just text"))
58+
def test_call_correct_answer_in_tags(self):
59+
"""Test __call__ with correct answers in <answer></answer> tags."""
60+
self.assertEqual(self.reward("prompt", "<answer>42</answer>", "42"), 1.0)
61+
self.assertEqual(self.reward("prompt", "<answer>3.14</answer>", "3.14"), 1.0)
62+
self.assertEqual(self.reward("prompt", "<answer>-5.5</answer>", "-5.5"), 1.0)
10063

101-
def test_extract_number_case_insensitive(self):
102-
"""Test _extract_number is case insensitive."""
103-
self.assertEqual(self.reward._extract_number("THE ANSWER IS 42"), 42.0)
104-
self.assertEqual(self.reward._extract_number("Answer: 3.14"), 3.14)
105-
self.assertEqual(self.reward._extract_number("RESULT: 123"), 123.0)
64+
def test_call_answer_tags_with_whitespace(self):
65+
"""Test __call__ with answer tags containing whitespace."""
66+
self.assertEqual(self.reward("prompt", "<answer> 42 </answer>", "42"), 1.0)
67+
self.assertEqual(
68+
self.reward("prompt", "<answer>\n3.14\n</answer>", "3.14"), 1.0
69+
)
10670

107-
def test_call_correct_answer(self):
108-
"""Test __call__ with correct answers."""
109-
self.assertEqual(self.reward("prompt", "The answer is 42", "42"), 1.0)
110-
self.assertEqual(self.reward("prompt", "#### 3.14", "3.14"), 1.0)
111-
self.assertEqual(self.reward("prompt", "Result: -5.5", "-5.5"), 1.0)
71+
def test_call_answer_tags_with_complex_content(self):
72+
"""Test __call__ with complex content in answer tags."""
73+
response = """
74+
Let me solve this step by step:
75+
First, I calculate 2 + 3 = 5
76+
Then, I multiply by 4: 5 * 4 = 20
77+
Finally, I subtract 8: 20 - 8 = 12
78+
<answer>12</answer>
79+
"""
80+
self.assertEqual(self.reward("prompt", response, "12"), 1.0)
11281

11382
def test_call_within_tolerance(self):
11483
"""Test __call__ with answers within tolerance."""
11584
# Default tolerance is 1e-6
116-
self.assertEqual(self.reward("prompt", "42.0000001", "42"), 1.0)
117-
self.assertEqual(self.reward("prompt", "3.1400001", "3.14"), 1.0)
118-
119-
# Custom tolerance
120-
self.assertEqual(self.custom_reward("prompt", "42.0001", "42"), 1.0)
121-
self.assertEqual(self.custom_reward("prompt", "3.141", "3.14"), 1.0)
122-
123-
def test_call_outside_tolerance(self):
124-
"""Test __call__ with answers outside tolerance."""
125-
self.assertEqual(self.reward("prompt", "42.1", "42"), 0.0)
126-
self.assertEqual(self.reward("prompt", "3.15", "3.14"), 0.0)
127-
self.assertEqual(self.custom_reward("prompt", "42.01", "42"), 0.0)
128-
129-
def test_call_invalid_target(self):
130-
"""Test __call__ with invalid target values."""
13185
self.assertEqual(
132-
self.reward("prompt", "42", "invalid"), self.reward.partial_credit
86+
self.reward("prompt", "<answer>42.0000001</answer>", "42"), 1.0
13387
)
134-
self.assertEqual(self.reward("prompt", "42", ""), self.reward.partial_credit)
13588
self.assertEqual(
136-
self.reward("prompt", "42", "not a number"), self.reward.partial_credit
89+
self.reward("prompt", "<answer>3.1400001</answer>", "3.14"), 1.0
13790
)
13891

139-
def test_call_invalid_response(self):
140-
"""Test __call__ with invalid response values."""
92+
# Custom tolerance
14193
self.assertEqual(
142-
self.reward("prompt", "no number", "42"), self.reward.partial_credit
94+
self.custom_reward("prompt", "<answer>42.0001</answer>", "42"), 1.0
14395
)
144-
self.assertEqual(self.reward("prompt", "", "42"), self.reward.partial_credit)
14596
self.assertEqual(
146-
self.reward("prompt", "just text", "42"), self.reward.partial_credit
97+
self.custom_reward("prompt", "<answer>3.141</answer>", "3.14"), 1.0
98+
)
99+
100+
def test_call_outside_tolerance(self):
101+
"""Test __call__ with answers outside tolerance."""
102+
self.assertEqual(self.reward("prompt", "<answer>42.1</answer>", "42"), 0.0)
103+
self.assertEqual(self.reward("prompt", "<answer>3.15</answer>", "3.14"), 0.0)
104+
self.assertEqual(
105+
self.custom_reward("prompt", "<answer>42.01</answer>", "42"), 0.0
147106
)
148107

149-
def test_call_both_invalid(self):
150-
"""Test __call__ with both invalid target and response."""
108+
def test_call_partial_credit_target_in_response(self):
109+
"""Test __call__ with partial credit when target appears in response."""
110+
response = "The calculation shows 42 but I put <answer>43</answer>"
111+
self.assertEqual(self.reward("prompt", response, "42"), 0.1)
112+
113+
response = "Let me work through this: 42 + 1 = 43. <answer>43</answer>"
114+
self.assertEqual(self.reward("prompt", response, "42"), 0.1)
115+
116+
def test_call_partial_credit_custom_value(self):
117+
"""Test __call__ with custom partial credit value."""
118+
response = "The calculation shows 42 but I put <answer>43</answer>"
119+
self.assertEqual(self.custom_reward("prompt", response, "42"), 0.2)
120+
121+
def test_call_no_partial_credit_with_answer_tags(self):
122+
"""Test __call__ doesn't give partial credit if target is only in answer tags."""
123+
response = "Let me solve this. <answer>42</answer>"
124+
# Target 100 is not elsewhere in response, so no partial credit
125+
self.assertEqual(self.reward("prompt", response, "100"), 0.0)
126+
127+
def test_call_integer_target_formatting(self):
128+
"""Test __call__ with integer targets formatted correctly."""
129+
# Integer targets should be formatted without decimal point
130+
response = "I calculated and got 117 as the answer. <answer>118</answer>"
131+
self.assertEqual(self.reward("prompt", response, "117"), 0.1)
132+
133+
# Should work with 117.0 in target too
134+
self.assertEqual(self.reward("prompt", response, "117.0"), 0.1)
135+
136+
def test_call_float_target_formatting(self):
137+
"""Test __call__ with float targets."""
138+
response = "I calculated and got 3.14 as the answer. <answer>3.15</answer>"
139+
self.assertEqual(self.reward("prompt", response, "3.14"), 0.1)
140+
141+
def test_call_invalid_target(self):
142+
"""Test __call__ with invalid target values."""
143+
self.assertEqual(self.reward("prompt", "<answer>42</answer>", "invalid"), 0.0)
144+
self.assertEqual(self.reward("prompt", "<answer>42</answer>", ""), 0.0)
151145
self.assertEqual(
152-
self.reward("prompt", "no number", "invalid"), self.reward.partial_credit
146+
self.reward("prompt", "<answer>42</answer>", "not a number"), 0.0
153147
)
154-
self.assertEqual(self.reward("prompt", "", ""), self.reward.partial_credit)
155148

156-
def test_call_custom_partial_credit(self):
157-
"""Test __call__ uses custom partial credit value."""
158-
self.assertEqual(self.custom_reward("prompt", "no number", "42"), 0.2)
159-
self.assertEqual(self.custom_reward("prompt", "42", "invalid"), 0.2)
149+
def test_call_no_answer_tags(self):
150+
"""Test __call__ with response that has no answer tags."""
151+
# Should still check for partial credit
152+
self.assertEqual(self.reward("prompt", "The answer is 42", "42"), 0.1)
153+
self.assertEqual(self.reward("prompt", "No matching number", "42"), 0.0)
154+
155+
def test_call_invalid_answer_in_tags(self):
156+
"""Test __call__ with invalid answer in tags."""
157+
response = "<answer>not a number</answer> but 42 is correct"
158+
self.assertEqual(self.reward("prompt", response, "42"), 0.1)
160159

161160
def test_call_zero_values(self):
162161
"""Test __call__ with zero values."""
163-
self.assertEqual(self.reward("prompt", "0", "0"), 1.0)
164-
self.assertEqual(self.reward("prompt", "The answer is 0", "0.0"), 1.0)
162+
self.assertEqual(self.reward("prompt", "<answer>0</answer>", "0"), 1.0)
163+
self.assertEqual(self.reward("prompt", "<answer>0.0</answer>", "0"), 1.0)
165164

166165
def test_call_negative_values(self):
167166
"""Test __call__ with negative values."""
168-
self.assertEqual(self.reward("prompt", "-42", "-42"), 1.0)
169-
self.assertEqual(self.reward("prompt", "#### -3.14", "-3.14"), 1.0)
170-
self.assertEqual(self.reward("prompt", "-5", "-4.9"), 0.0)
167+
self.assertEqual(self.reward("prompt", "<answer>-42</answer>", "-42"), 1.0)
168+
self.assertEqual(self.reward("prompt", "<answer>-3.14</answer>", "-3.14"), 1.0)
171169

172170
def test_call_large_numbers(self):
173171
"""Test __call__ with large numbers."""
174-
self.assertEqual(self.reward("prompt", "1000000", "1000000"), 1.0)
175-
self.assertEqual(self.reward("prompt", "1e6", "1000000"), 1.0)
176-
self.assertEqual(self.reward("prompt", "1000001", "1000000"), 0.0)
172+
self.assertEqual(
173+
self.reward("prompt", "<answer>1000000</answer>", "1000000"), 1.0
174+
)
175+
self.assertEqual(self.reward("prompt", "<answer>1e6</answer>", "1000000"), 1.0)
177176

178177
def test_call_small_numbers(self):
179178
"""Test __call__ with very small numbers."""
180-
self.assertEqual(self.reward("prompt", "0.000001", "0.000001"), 1.0)
181-
self.assertEqual(self.reward("prompt", "1e-6", "0.000001"), 1.0)
182-
183-
def test_call_complex_response_text(self):
184-
"""Test __call__ with complex response text containing multiple elements."""
185-
response = """
186-
Let me solve this step by step:
187-
First, I calculate 2 + 3 = 5
188-
Then, I multiply by 4: 5 * 4 = 20
189-
Finally, I subtract 8: 20 - 8 = 12
190-
#### 12
191-
"""
192-
self.assertEqual(self.reward("prompt", response, "12"), 1.0)
179+
self.assertEqual(
180+
self.reward("prompt", "<answer>0.000001</answer>", "0.000001"), 1.0
181+
)
182+
self.assertEqual(
183+
self.reward("prompt", "<answer>1e-6</answer>", "0.000001"), 1.0
184+
)
193185

194-
def test_call_with_units_and_formatting(self):
195-
"""Test __call__ with responses containing units and formatting."""
196-
self.assertEqual(self.reward("prompt", "The cost is $42.50", "42.5"), 1.0)
197-
self.assertEqual(self.reward("prompt", "Distance: 3.14 meters", "3.14"), 1.0)
198-
self.assertEqual(self.reward("prompt", "Temperature is -5.5°C", "-5.5"), 1.0)
186+
def test_call_multiple_answer_tags(self):
187+
"""Test __call__ with multiple answer tags (should use first one)."""
188+
response = "First answer: <answer>42</answer> Second: <answer>43</answer>"
189+
self.assertEqual(self.reward("prompt", response, "42"), 1.0)
190+
self.assertEqual(self.reward("prompt", response, "43"), 0.1)
199191

200192

201193
if __name__ == "__main__":

0 commit comments

Comments
 (0)