|
1 | | -# Copyright 2025 Google LLC |
2 | | -# |
3 | | -# Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | -# you may not use this file except in compliance with the License. |
5 | | -# You may obtain a copy of the License at |
6 | | -# |
7 | | -# http://www.apache.org/licenses/LICENSE-2.0 |
8 | | -# |
9 | | -# Unless required by applicable law or agreed to in writing, software |
10 | | -# distributed under the License is distributed on an "AS IS" BASIS, |
11 | | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | -# See the License for the specific language governing permissions and |
13 | | -# limitations under the License. |
14 | | - |
15 | | -import dataclasses |
16 | | -import inspect |
17 | | -from typing import Any, List |
| 1 | +import os |
| 2 | +import unittest |
18 | 3 | from unittest import mock |
| 4 | + |
19 | 5 | from absl import logging |
20 | | -from absl.testing import absltest |
21 | | -from absl.testing import parameterized |
22 | 6 | import numpy as np |
23 | | -import numpy.testing as npt |
24 | 7 | from tunix.rl import algorithm_config as algo_config_lib |
25 | 8 | from tunix.rl import reward_manager |
26 | 9 |
|
27 | 10 |
|
28 | | -# --- Test Reward Functions --- |
29 | | -def len_reward( |
30 | | - prompts: List[str], completions: List[str], **kwargs: Any |
31 | | -) -> List[float]: |
32 | | - del prompts, kwargs # Unused |
33 | | - res = [float(len(c)) for c in completions] |
34 | | - return res |
35 | | - |
36 | | - |
37 | | -len_reward.__name__ = "len_reward" |
38 | | - |
39 | | - |
40 | | -def prompt_len_reward( |
41 | | - prompts: List[str], |
42 | | - completions: List[str], |
43 | | - custom_param: float = 1.0, |
44 | | - **kwargs: Any, |
45 | | -) -> List[float]: |
46 | | - del completions, kwargs # Unused |
47 | | - res = [custom_param * len(p) for p in prompts] |
48 | | - return res |
49 | | - |
50 | | - |
51 | | -prompt_len_reward.__name__ = "prompt_len_reward" |
52 | | - |
53 | | - |
54 | | -def nan_reward( |
55 | | - prompts: List[str], completions: List[str], **kwargs: Any |
56 | | -) -> List[float]: |
57 | | - del completions, kwargs # Unused |
58 | | - return [np.nan] * len(prompts) |
| 11 | +def reward_fn1(prompts, completions, **kwargs): |
| 12 | + del kwargs |
| 13 | + return [1.0] * len(prompts) |
59 | 14 |
|
60 | 15 |
|
61 | | -nan_reward.__name__ = "nan_reward" |
| 16 | +def reward_fn2(prompts, completions, **kwargs): |
| 17 | + del kwargs |
| 18 | + return [2.0] * len(prompts) |
62 | 19 |
|
63 | 20 |
|
64 | | -@dataclasses.dataclass(slots=True, kw_only=True) |
65 | | -class TestAlgoConfig(algo_config_lib.AlgorithmConfig): |
66 | | - """Test Algorithm Config.""" |
| 21 | +class RewardManagerTest(unittest.TestCase): |
67 | 22 |
|
68 | | - reward_manager: str = "sequence-level" |
69 | | - custom_param: float = 2.0 |
70 | | - |
71 | | - |
72 | | -# --- Test Class --- |
73 | | -class SequenceRewardManagerTest(parameterized.TestCase): |
74 | | - |
75 | | - def setUp(self): |
76 | | - super().setUp() |
77 | | - self.test_algo_config = TestAlgoConfig() |
78 | | - self.prompts = ["p1", "p22"] |
79 | | - self.completions = ["c1_long", "c2"] |
80 | | - |
81 | | - def test_initialization(self): |
82 | | - manager = reward_manager.SequenceRewardManager( |
83 | | - reward_fns=len_reward, |
84 | | - algo_config=self.test_algo_config, |
85 | | - ) |
86 | | - self.assertEqual(manager.reward_fns, [len_reward]) |
87 | | - self.assertEqual(manager.algo_config, self.test_algo_config) |
88 | | - |
89 | | - def test_single_reward_fn(self): |
90 | | - manager = reward_manager.SequenceRewardManager( |
91 | | - reward_fns=[len_reward], |
92 | | - algo_config=self.test_algo_config, |
93 | | - ) |
94 | | - rewards_info = manager( |
95 | | - self.prompts, |
96 | | - self.completions, |
| 23 | + @mock.patch("tunix.rl.reward_manager.asdict") |
| 24 | + def test_prepare_log_metrics_and_log_one_example(self, mock_asdict): |
| 25 | + mock_asdict.return_value = {} |
| 26 | + reward_fns = [reward_fn1, reward_fn2] |
| 27 | + algo_config = mock.MagicMock(spec=algo_config_lib.AlgorithmConfig) |
| 28 | + reward_manager_instance = reward_manager.SequenceRewardManager( |
| 29 | + reward_fns=reward_fns, algo_config=algo_config |
97 | 30 | ) |
| 31 | + prompts = ["prompt1", "prompt2"] |
| 32 | + completions = ["completion1", "completion2"] |
98 | 33 |
|
99 | | - expected_rewards = np.array([float(len("c1_long")), float(len("c2"))]) |
100 | | - np.testing.assert_array_equal(rewards_info["rewards"], expected_rewards) |
101 | | - self.assertLen(rewards_info["log_metrics"], 7) |
102 | | - |
103 | | - def test_multiple_reward_fns(self): |
104 | | - manager = reward_manager.SequenceRewardManager( |
105 | | - reward_fns=[len_reward, prompt_len_reward], |
106 | | - algo_config=self.test_algo_config, |
107 | | - ) |
108 | | - rewards_info = manager( |
109 | | - self.prompts, |
110 | | - self.completions, |
111 | | - ) |
| 34 | + with ( |
| 35 | + mock.patch.dict(os.environ, {"TUNIX_DEBUG_REWARDS": "1"}), |
| 36 | + mock.patch("absl.logging.info") as mock_log_info, |
| 37 | + ): |
| 38 | + rewards_info = reward_manager_instance(prompts, completions) |
| 39 | + log_metrics = rewards_info["log_metrics"] |
112 | 40 |
|
113 | | - # custom_param is 2.0 from test_algo_config |
114 | | - r1 = np.array(len_reward(self.prompts, self.completions)) |
115 | | - r2 = np.array( |
116 | | - prompt_len_reward(self.prompts, self.completions, custom_param=2.0) |
117 | | - ) |
118 | | - expected_rewards = r1 + r2 |
119 | | - rewards_matrix = np.array([r1, r2]) |
120 | | - np.testing.assert_array_almost_equal( |
121 | | - rewards_info["rewards"], expected_rewards |
122 | | - ) |
123 | | - test_metrics = rewards_info["log_metrics"] |
124 | | - for metric_name, v in test_metrics.items(): |
125 | | - if metric_name.startswith("rewards/"): |
126 | | - self.assertLen(v[0], 2) |
127 | | - npt.assert_allclose( |
128 | | - test_metrics["rewards/sum"][0], |
129 | | - expected_rewards, |
130 | | - err_msg="rewards/sum mismatch", |
131 | | - ) |
132 | | - npt.assert_allclose( |
133 | | - test_metrics["rewards/len_reward"][0], |
134 | | - r1, |
135 | | - err_msg="rewards/len_reward mismatch", |
136 | | - ) |
137 | | - npt.assert_allclose( |
138 | | - test_metrics["rewards/prompt_len_reward"][0], |
139 | | - r2, |
140 | | - err_msg="rewards/prompt_len_reward mismatch", |
141 | | - ) |
142 | | - for col_idx in range(rewards_matrix.shape[0]): |
143 | | - npt.assert_allclose( |
144 | | - test_metrics["rewards/min"][0][col_idx], |
145 | | - np.min(rewards_matrix[:, col_idx]), |
| 41 | + self.assertIn("prompts", log_metrics) |
| 42 | + self.assertIn("completions", log_metrics) |
| 43 | + self.assertIn("rewards/sum", log_metrics) |
| 44 | + np.testing.assert_allclose(log_metrics["rewards/sum"][0], [3.0, 3.0]) |
| 45 | + self.assertIn("rewards/reward_fn1", log_metrics) |
| 46 | + np.testing.assert_allclose( |
| 47 | + log_metrics["rewards/reward_fn1"][0], [1.0, 1.0] |
146 | 48 | ) |
147 | | - npt.assert_allclose( |
148 | | - test_metrics["rewards/max"][0][col_idx], |
149 | | - np.max(rewards_matrix[:, col_idx]), |
150 | | - ) |
151 | | - |
152 | | - def test_algo_config_param_passing(self): |
153 | | - # Mock the reward function to spy on its call arguments |
154 | | - mock_fn = mock.Mock(wraps=prompt_len_reward) |
155 | | - mock_fn.__name__ = prompt_len_reward.__name__ |
156 | | - # Restore the signature for introspection |
157 | | - mock_fn.__signature__ = inspect.signature(prompt_len_reward) |
158 | | - |
159 | | - manager = reward_manager.SequenceRewardManager( |
160 | | - reward_fns=[mock_fn], |
161 | | - algo_config=self.test_algo_config, |
162 | | - ) |
163 | | - manager( |
164 | | - self.prompts, |
165 | | - self.completions, |
166 | | - ) |
167 | | - |
168 | | - mock_fn.assert_called_once() |
169 | | - _, kwargs = mock_fn.call_args |
170 | | - self.assertEqual(kwargs["custom_param"], 2.0) |
171 | | - self.assertNotIn( |
172 | | - "another_param", kwargs |
173 | | - ) # Not in prompt_len_reward signature |
174 | | - |
175 | | - def test_nan_handling(self): |
176 | | - manager = reward_manager.SequenceRewardManager( |
177 | | - reward_fns=[len_reward, nan_reward], |
178 | | - algo_config=self.test_algo_config, |
179 | | - ) |
180 | | - rewards_info = manager( |
181 | | - self.prompts, |
182 | | - self.completions, |
183 | | - ) |
184 | | - # np.nansum should treat nan as 0 for summation |
185 | | - expected_rewards = np.array([float(len(c)) for c in self.completions]) |
186 | | - np.testing.assert_array_almost_equal( |
187 | | - rewards_info["rewards"], expected_rewards |
188 | | - ) |
189 | | - # Check logged metrics for NaN |
190 | | - test_metrics = rewards_info["log_metrics"] |
191 | | - self.assertTrue(np.isnan(test_metrics["rewards/nan_reward"][0]).all()) |
192 | | - np.testing.assert_allclose( |
193 | | - test_metrics["rewards/sum"][0], |
194 | | - expected_rewards, |
195 | | - err_msg="rewards/sum mismatch", |
196 | | - ) |
197 | | - |
198 | | - @parameterized.named_parameters( |
199 | | - dict( |
200 | | - testcase_name="reward_fn_returns_none", |
201 | | - reward_fns=[lambda prompts, completions, **kw: None], |
202 | | - expected_regex="Failed to obtain result.*Result is None", |
203 | | - error_type=RuntimeError, |
204 | | - ), |
205 | | - dict( |
206 | | - testcase_name="reward_fn_bad_length", |
207 | | - reward_fns=[ |
208 | | - lambda prompts, completions, **kw: [1.0] * (len(prompts) + 1) |
209 | | - ], |
210 | | - expected_regex="Length mismatch", |
211 | | - error_type=RuntimeError, |
212 | | - ), |
213 | | - ) |
214 | | - def test_errors( |
215 | | - self, expected_regex, error_type, kwargs=None, reward_fns=None |
216 | | - ): |
217 | | - if reward_fns is None: |
218 | | - reward_fns = [len_reward] |
219 | | - for i, fn in enumerate(reward_fns): |
220 | | - if not hasattr(fn, "__name__"): |
221 | | - fn.__name__ = f"test_fn_{i}" |
222 | | - |
223 | | - manager = reward_manager.SequenceRewardManager( |
224 | | - reward_fns=reward_fns, |
225 | | - algo_config=self.test_algo_config, |
226 | | - ) |
227 | | - with self.assertRaisesRegex(error_type, expected_regex): |
228 | | - manager( |
229 | | - self.prompts, |
230 | | - self.completions, |
231 | | - **(kwargs or {}), |
| 49 | + self.assertIn("rewards/reward_fn2", log_metrics) |
| 50 | + np.testing.assert_allclose( |
| 51 | + log_metrics["rewards/reward_fn2"][0], [2.0, 2.0] |
232 | 52 | ) |
233 | 53 |
|
234 | | - def test_no_reward_fns_raises_error(self): |
235 | | - with self.assertRaisesRegex(ValueError, "reward_fns cannot be empty"): |
236 | | - reward_manager.SequenceRewardManager( |
237 | | - reward_fns=[], |
238 | | - algo_config=self.test_algo_config, |
239 | | - ) |
| 54 | + # Check that _log_one_example was called and logged correctly |
| 55 | + mock_log_info.assert_any_call("======= example rewards =======") |
| 56 | + mock_log_info.assert_any_call("%s:\t%s", "prompts", "prompt1") |
| 57 | + mock_log_info.assert_any_call("%s:\t%s", "completions", "completion1") |
| 58 | + mock_log_info.assert_any_call("%s:\t%s", "rewards/sum", "3.0") |
| 59 | + mock_log_info.assert_any_call("%s:\t%s", "rewards/mean", "1.5") |
| 60 | + mock_log_info.assert_any_call("%s:\t%s", "rewards/min", "1.0") |
| 61 | + mock_log_info.assert_any_call("%s:\t%s", "rewards/max", "2.0") |
| 62 | + mock_log_info.assert_any_call("%s:\t%s", "rewards/reward_fn1", "1.0") |
| 63 | + mock_log_info.assert_any_call("%s:\t%s", "rewards/reward_fn2", "2.0") |
| 64 | + mock_log_info.assert_any_call("=======================") |
240 | 65 |
|
241 | 66 |
|
242 | 67 | if __name__ == "__main__": |
243 | | - absltest.main() |
| 68 | + unittest.main() |
0 commit comments