55# LICENSE file in the root directory of this source tree.
66
77import unittest
8- from unittest import mock
98
109from forge .data .rewards import MathReward
1110
@@ -36,6 +35,13 @@ def test_to_float_valid_numbers(self):
3635 self .assertEqual (self .reward ._to_float ("0" ), 0.0 )
3736 self .assertEqual (self .reward ._to_float (" 123.45 " ), 123.45 )
3837
38+ def test_to_float_with_currency_and_formatting (self ):
39+ """Test _to_float with currency symbols and commas."""
40+ self .assertEqual (self .reward ._to_float ("$42" ), 42.0 )
41+ self .assertEqual (self .reward ._to_float ("$1,000" ), 1000.0 )
42+ self .assertEqual (self .reward ._to_float ("1,234.56" ), 1234.56 )
43+ self .assertEqual (self .reward ._to_float ("$ 42.50 " ), 42.5 )
44+
3945 def test_to_float_invalid_inputs (self ):
4046 """Test _to_float with invalid inputs."""
4147 self .assertIsNone (self .reward ._to_float ("abc" ))
@@ -48,154 +54,140 @@ def test_to_float_edge_cases(self):
4854 """Test _to_float with edge cases."""
4955 self .assertEqual (self .reward ._to_float ("1e6" ), 1000000.0 )
5056 self .assertEqual (self .reward ._to_float ("-1.5e-3" ), - 0.0015 )
51- self .assertEqual (self .reward ._to_float ("inf" ), float ("inf" ))
52- self .assertEqual (self .reward ._to_float ("-inf" ), float ("-inf" ))
53-
54- def test_extract_number_gsm8k_format (self ):
55- """Test _extract_number with GSM8K style format."""
56- self .assertEqual (self .reward ._extract_number ("#### 42" ), 42.0 )
57- self .assertEqual (self .reward ._extract_number ("#### -3.14" ), - 3.14 )
58- self .assertEqual (self .reward ._extract_number ("Some text #### 123.45" ), 123.45 )
59-
60- def test_extract_number_answer_patterns (self ):
61- """Test _extract_number with various answer patterns."""
62- self .assertEqual (self .reward ._extract_number ("The answer is 42" ), 42.0 )
63- self .assertEqual (self .reward ._extract_number ("answer is 3.14" ), 3.14 )
64- self .assertEqual (self .reward ._extract_number ("Answer: 123" ), 123.0 )
65- self .assertEqual (self .reward ._extract_number ("Result: -5.5" ), - 5.5 )
66-
67- def test_extract_number_equals_pattern (self ):
68- """Test _extract_number with equals sign patterns."""
69- self .assertEqual (self .reward ._extract_number ("x = 42." ), 42.0 )
70- self .assertEqual (self .reward ._extract_number ("The result = 3.14" ), 3.14 )
71- self .assertEqual (self .reward ._extract_number ("calculation = -7.5." ), - 7.5 )
72-
73- def test_extract_number_end_of_text (self ):
74- """Test _extract_number with numbers at end of text."""
75- self .assertEqual (self .reward ._extract_number ("The final result is 42." ), 42.0 )
76- self .assertEqual (self .reward ._extract_number ("We get 3.14" ), 3.14 )
77- self .assertEqual (self .reward ._extract_number ("Answer: -5.5." ), - 5.5 )
78-
79- def test_extract_number_fallback_pattern (self ):
80- """Test _extract_number with fallback pattern (any number)."""
81- self .assertEqual (self .reward ._extract_number ("There are 42 items" ), 42.0 )
82- self .assertEqual (self .reward ._extract_number ("Cost is $3.14 per item" ), 3.14 )
83- self .assertEqual (self .reward ._extract_number ("Temperature: -5.5 degrees" ), - 5.5 )
84-
85- def test_extract_number_multiple_matches (self ):
86- """Test _extract_number returns the last match when multiple numbers exist."""
87- # Should return the last match from the pattern
88- self .assertEqual (
89- self .reward ._extract_number ("First 10, then 20, finally 30" ), 30.0
90- )
91- self .assertEqual (
92- self .reward ._extract_number ("#### 5 but actually #### 10" ), 10.0
93- )
9457
95- def test_extract_number_no_match (self ):
96- """Test _extract_number when no numbers are found ."""
97- self .assertIsNone (self .reward . _extract_number ( "No numbers here" ) )
98- self .assertIsNone (self .reward . _extract_number ( "" ) )
99- self .assertIsNone (self .reward . _extract_number ( "Just text" ) )
58+ def test_call_correct_answer_in_tags (self ):
59+ """Test __call__ with correct answers in <answer></answer> tags ."""
60+ self .assertEqual (self .reward ( "prompt" , "<answer>42</answer>" , "42" ), 1.0 )
61+ self .assertEqual (self .reward ( "prompt" , "<answer>3.14</answer>" , "3.14" ), 1.0 )
62+ self .assertEqual (self .reward ( "prompt" , "<answer>-5.5</answer>" , "-5.5" ), 1.0 )
10063
101- def test_extract_number_case_insensitive (self ):
102- """Test _extract_number is case insensitive."""
103- self .assertEqual (self .reward ._extract_number ("THE ANSWER IS 42" ), 42.0 )
104- self .assertEqual (self .reward ._extract_number ("Answer: 3.14" ), 3.14 )
105- self .assertEqual (self .reward ._extract_number ("RESULT: 123" ), 123.0 )
64+ def test_call_answer_tags_with_whitespace (self ):
65+ """Test __call__ with answer tags containing whitespace."""
66+ self .assertEqual (self .reward ("prompt" , "<answer> 42 </answer>" , "42" ), 1.0 )
67+ self .assertEqual (
68+ self .reward ("prompt" , "<answer>\n 3.14\n </answer>" , "3.14" ), 1.0
69+ )
10670
107- def test_call_correct_answer (self ):
108- """Test __call__ with correct answers."""
109- self .assertEqual (self .reward ("prompt" , "The answer is 42" , "42" ), 1.0 )
110- self .assertEqual (self .reward ("prompt" , "#### 3.14" , "3.14" ), 1.0 )
111- self .assertEqual (self .reward ("prompt" , "Result: -5.5" , "-5.5" ), 1.0 )
71+ def test_call_answer_tags_with_complex_content (self ):
72+ """Test __call__ with complex content in answer tags."""
73+ response = """
74+ Let me solve this step by step:
75+ First, I calculate 2 + 3 = 5
76+ Then, I multiply by 4: 5 * 4 = 20
77+ Finally, I subtract 8: 20 - 8 = 12
78+ <answer>12</answer>
79+ """
80+ self .assertEqual (self .reward ("prompt" , response , "12" ), 1.0 )
11281
11382 def test_call_within_tolerance (self ):
11483 """Test __call__ with answers within tolerance."""
11584 # Default tolerance is 1e-6
116- self .assertEqual (self .reward ("prompt" , "42.0000001" , "42" ), 1.0 )
117- self .assertEqual (self .reward ("prompt" , "3.1400001" , "3.14" ), 1.0 )
118-
119- # Custom tolerance
120- self .assertEqual (self .custom_reward ("prompt" , "42.0001" , "42" ), 1.0 )
121- self .assertEqual (self .custom_reward ("prompt" , "3.141" , "3.14" ), 1.0 )
122-
123- def test_call_outside_tolerance (self ):
124- """Test __call__ with answers outside tolerance."""
125- self .assertEqual (self .reward ("prompt" , "42.1" , "42" ), 0.0 )
126- self .assertEqual (self .reward ("prompt" , "3.15" , "3.14" ), 0.0 )
127- self .assertEqual (self .custom_reward ("prompt" , "42.01" , "42" ), 0.0 )
128-
129- def test_call_invalid_target (self ):
130- """Test __call__ with invalid target values."""
13185 self .assertEqual (
132- self .reward ("prompt" , "42 " , "invalid " ), self . reward . partial_credit
86+ self .reward ("prompt" , "<answer>42.0000001</answer> " , "42 " ), 1.0
13387 )
134- self .assertEqual (self .reward ("prompt" , "42" , "" ), self .reward .partial_credit )
13588 self .assertEqual (
136- self .reward ("prompt" , "42 " , "not a number " ), self . reward . partial_credit
89+ self .reward ("prompt" , "<answer>3.1400001</answer> " , "3.14 " ), 1.0
13790 )
13891
139- def test_call_invalid_response (self ):
140- """Test __call__ with invalid response values."""
92+ # Custom tolerance
14193 self .assertEqual (
142- self .reward ("prompt" , "no number " , "42" ), self . reward . partial_credit
94+ self .custom_reward ("prompt" , "<answer>42.0001</answer> " , "42" ), 1.0
14395 )
144- self .assertEqual (self .reward ("prompt" , "" , "42" ), self .reward .partial_credit )
14596 self .assertEqual (
146- self .reward ("prompt" , "just text" , "42" ), self .reward .partial_credit
97+ self .custom_reward ("prompt" , "<answer>3.141</answer>" , "3.14" ), 1.0
98+ )
99+
100+ def test_call_outside_tolerance (self ):
101+ """Test __call__ with answers outside tolerance."""
102+ self .assertEqual (self .reward ("prompt" , "<answer>42.1</answer>" , "42" ), 0.0 )
103+ self .assertEqual (self .reward ("prompt" , "<answer>3.15</answer>" , "3.14" ), 0.0 )
104+ self .assertEqual (
105+ self .custom_reward ("prompt" , "<answer>42.01</answer>" , "42" ), 0.0
147106 )
148107
149- def test_call_both_invalid (self ):
150- """Test __call__ with both invalid target and response."""
108+ def test_call_partial_credit_target_in_response (self ):
109+ """Test __call__ with partial credit when target appears in response."""
110+ response = "The calculation shows 42 but I put <answer>43</answer>"
111+ self .assertEqual (self .reward ("prompt" , response , "42" ), 0.1 )
112+
113+ response = "Let me work through this: 42 + 1 = 43. <answer>43</answer>"
114+ self .assertEqual (self .reward ("prompt" , response , "42" ), 0.1 )
115+
116+ def test_call_partial_credit_custom_value (self ):
117+ """Test __call__ with custom partial credit value."""
118+ response = "The calculation shows 42 but I put <answer>43</answer>"
119+ self .assertEqual (self .custom_reward ("prompt" , response , "42" ), 0.2 )
120+
121+ def test_call_no_partial_credit_with_answer_tags (self ):
122+ """Test __call__ doesn't give partial credit if target is only in answer tags."""
123+ response = "Let me solve this. <answer>42</answer>"
124+ # Target 100 is not elsewhere in response, so no partial credit
125+ self .assertEqual (self .reward ("prompt" , response , "100" ), 0.0 )
126+
127+ def test_call_integer_target_formatting (self ):
128+ """Test __call__ with integer targets formatted correctly."""
129+ # Integer targets should be formatted without decimal point
130+ response = "I calculated and got 117 as the answer. <answer>118</answer>"
131+ self .assertEqual (self .reward ("prompt" , response , "117" ), 0.1 )
132+
133+ # Should work with 117.0 in target too
134+ self .assertEqual (self .reward ("prompt" , response , "117.0" ), 0.1 )
135+
136+ def test_call_float_target_formatting (self ):
137+ """Test __call__ with float targets."""
138+ response = "I calculated and got 3.14 as the answer. <answer>3.15</answer>"
139+ self .assertEqual (self .reward ("prompt" , response , "3.14" ), 0.1 )
140+
141+ def test_call_invalid_target (self ):
142+ """Test __call__ with invalid target values."""
143+ self .assertEqual (self .reward ("prompt" , "<answer>42</answer>" , "invalid" ), 0.0 )
144+ self .assertEqual (self .reward ("prompt" , "<answer>42</answer>" , "" ), 0.0 )
151145 self .assertEqual (
152- self .reward ("prompt" , "no number " , "invalid " ), self . reward . partial_credit
146+ self .reward ("prompt" , "<answer>42</answer> " , "not a number " ), 0.0
153147 )
154- self .assertEqual (self .reward ("prompt" , "" , "" ), self .reward .partial_credit )
155148
156- def test_call_custom_partial_credit (self ):
157- """Test __call__ uses custom partial credit value."""
158- self .assertEqual (self .custom_reward ("prompt" , "no number" , "42" ), 0.2 )
159- self .assertEqual (self .custom_reward ("prompt" , "42" , "invalid" ), 0.2 )
149+ def test_call_no_answer_tags (self ):
150+ """Test __call__ with response that has no answer tags."""
151+ # Should still check for partial credit
152+ self .assertEqual (self .reward ("prompt" , "The answer is 42" , "42" ), 0.1 )
153+ self .assertEqual (self .reward ("prompt" , "No matching number" , "42" ), 0.0 )
154+
155+ def test_call_invalid_answer_in_tags (self ):
156+ """Test __call__ with invalid answer in tags."""
157+ response = "<answer>not a number</answer> but 42 is correct"
158+ self .assertEqual (self .reward ("prompt" , response , "42" ), 0.1 )
160159
161160 def test_call_zero_values (self ):
162161 """Test __call__ with zero values."""
163- self .assertEqual (self .reward ("prompt" , "0 " , "0" ), 1.0 )
164- self .assertEqual (self .reward ("prompt" , "The answer is 0 " , "0. 0" ), 1.0 )
162+ self .assertEqual (self .reward ("prompt" , "<answer>0</answer> " , "0" ), 1.0 )
163+ self .assertEqual (self .reward ("prompt" , "< answer>0.0</answer> " , "0" ), 1.0 )
165164
166165 def test_call_negative_values (self ):
167166 """Test __call__ with negative values."""
168- self .assertEqual (self .reward ("prompt" , "-42" , "-42" ), 1.0 )
169- self .assertEqual (self .reward ("prompt" , "#### -3.14" , "-3.14" ), 1.0 )
170- self .assertEqual (self .reward ("prompt" , "-5" , "-4.9" ), 0.0 )
167+ self .assertEqual (self .reward ("prompt" , "<answer>-42</answer>" , "-42" ), 1.0 )
168+ self .assertEqual (self .reward ("prompt" , "<answer>-3.14</answer>" , "-3.14" ), 1.0 )
171169
172170 def test_call_large_numbers (self ):
173171 """Test __call__ with large numbers."""
174- self .assertEqual (self .reward ("prompt" , "1000000" , "1000000" ), 1.0 )
175- self .assertEqual (self .reward ("prompt" , "1e6" , "1000000" ), 1.0 )
176- self .assertEqual (self .reward ("prompt" , "1000001" , "1000000" ), 0.0 )
172+ self .assertEqual (
173+ self .reward ("prompt" , "<answer>1000000</answer>" , "1000000" ), 1.0
174+ )
175+ self .assertEqual (self .reward ("prompt" , "<answer>1e6</answer>" , "1000000" ), 1.0 )
177176
178177 def test_call_small_numbers (self ):
179178 """Test __call__ with very small numbers."""
180- self .assertEqual (self .reward ("prompt" , "0.000001" , "0.000001" ), 1.0 )
181- self .assertEqual (self .reward ("prompt" , "1e-6" , "0.000001" ), 1.0 )
182-
183- def test_call_complex_response_text (self ):
184- """Test __call__ with complex response text containing multiple elements."""
185- response = """
186- Let me solve this step by step:
187- First, I calculate 2 + 3 = 5
188- Then, I multiply by 4: 5 * 4 = 20
189- Finally, I subtract 8: 20 - 8 = 12
190- #### 12
191- """
192- self .assertEqual (self .reward ("prompt" , response , "12" ), 1.0 )
179+ self .assertEqual (
180+ self .reward ("prompt" , "<answer>0.000001</answer>" , "0.000001" ), 1.0
181+ )
182+ self .assertEqual (
183+ self .reward ("prompt" , "<answer>1e-6</answer>" , "0.000001" ), 1.0
184+ )
193185
194- def test_call_with_units_and_formatting (self ):
195- """Test __call__ with responses containing units and formatting ."""
196- self . assertEqual ( self . reward ( "prompt" , "The cost is $42.50" , "42.5" ), 1.0 )
197- self .assertEqual (self .reward ("prompt" , "Distance: 3.14 meters" , "3.14 " ), 1.0 )
198- self .assertEqual (self .reward ("prompt" , "Temperature is -5.5°C" , "-5.5 " ), 1.0 )
186+ def test_call_multiple_answer_tags (self ):
187+ """Test __call__ with multiple answer tags (should use first one) ."""
188+ response = "First answer: <answer>42</answer> Second: <answer>43</answer>"
189+ self .assertEqual (self .reward ("prompt" , response , "42 " ), 1.0 )
190+ self .assertEqual (self .reward ("prompt" , response , "43 " ), 0.1 )
199191
200192
201193if __name__ == "__main__" :
0 commit comments