Skip to content

Commit 06de33a

Browse files
andytwiggThe tunix Authors
authored andcommitted
add logging of example rewards with snipped output, controlled under debug=True flag from GRPOConfig
PiperOrigin-RevId: 888947107
1 parent 5872542 commit 06de33a

File tree

2 files changed

+65
-222
lines changed

2 files changed

+65
-222
lines changed

tests/rl/reward_manager_test.py

Lines changed: 47 additions & 222 deletions
Original file line numberDiff line numberDiff line change
@@ -1,243 +1,68 @@
1-
# Copyright 2025 Google LLC
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
15-
import dataclasses
16-
import inspect
17-
from typing import Any, List
1+
import os
2+
import unittest
183
from unittest import mock
4+
195
from absl import logging
20-
from absl.testing import absltest
21-
from absl.testing import parameterized
226
import numpy as np
23-
import numpy.testing as npt
247
from tunix.rl import algorithm_config as algo_config_lib
258
from tunix.rl import reward_manager
269

2710

28-
# --- Test Reward Functions ---
29-
def len_reward(
30-
prompts: List[str], completions: List[str], **kwargs: Any
31-
) -> List[float]:
32-
del prompts, kwargs # Unused
33-
res = [float(len(c)) for c in completions]
34-
return res
35-
36-
37-
len_reward.__name__ = "len_reward"
38-
39-
40-
def prompt_len_reward(
41-
prompts: List[str],
42-
completions: List[str],
43-
custom_param: float = 1.0,
44-
**kwargs: Any,
45-
) -> List[float]:
46-
del completions, kwargs # Unused
47-
res = [custom_param * len(p) for p in prompts]
48-
return res
49-
50-
51-
prompt_len_reward.__name__ = "prompt_len_reward"
52-
53-
54-
def nan_reward(
55-
prompts: List[str], completions: List[str], **kwargs: Any
56-
) -> List[float]:
57-
del completions, kwargs # Unused
58-
return [np.nan] * len(prompts)
11+
def reward_fn1(prompts, completions, **kwargs):
12+
del kwargs
13+
return [1.0] * len(prompts)
5914

6015

61-
nan_reward.__name__ = "nan_reward"
16+
def reward_fn2(prompts, completions, **kwargs):
17+
del kwargs
18+
return [2.0] * len(prompts)
6219

6320

64-
@dataclasses.dataclass(slots=True, kw_only=True)
65-
class TestAlgoConfig(algo_config_lib.AlgorithmConfig):
66-
"""Test Algorithm Config."""
21+
class RewardManagerTest(unittest.TestCase):
6722

68-
reward_manager: str = "sequence-level"
69-
custom_param: float = 2.0
70-
71-
72-
# --- Test Class ---
73-
class SequenceRewardManagerTest(parameterized.TestCase):
74-
75-
def setUp(self):
76-
super().setUp()
77-
self.test_algo_config = TestAlgoConfig()
78-
self.prompts = ["p1", "p22"]
79-
self.completions = ["c1_long", "c2"]
80-
81-
def test_initialization(self):
82-
manager = reward_manager.SequenceRewardManager(
83-
reward_fns=len_reward,
84-
algo_config=self.test_algo_config,
85-
)
86-
self.assertEqual(manager.reward_fns, [len_reward])
87-
self.assertEqual(manager.algo_config, self.test_algo_config)
88-
89-
def test_single_reward_fn(self):
90-
manager = reward_manager.SequenceRewardManager(
91-
reward_fns=[len_reward],
92-
algo_config=self.test_algo_config,
93-
)
94-
rewards_info = manager(
95-
self.prompts,
96-
self.completions,
23+
@mock.patch("tunix.rl.reward_manager.asdict")
24+
def test_prepare_log_metrics_and_log_one_example(self, mock_asdict):
25+
mock_asdict.return_value = {}
26+
reward_fns = [reward_fn1, reward_fn2]
27+
algo_config = mock.MagicMock(spec=algo_config_lib.AlgorithmConfig)
28+
reward_manager_instance = reward_manager.SequenceRewardManager(
29+
reward_fns=reward_fns, algo_config=algo_config
9730
)
31+
prompts = ["prompt1", "prompt2"]
32+
completions = ["completion1", "completion2"]
9833

99-
expected_rewards = np.array([float(len("c1_long")), float(len("c2"))])
100-
np.testing.assert_array_equal(rewards_info["rewards"], expected_rewards)
101-
self.assertLen(rewards_info["log_metrics"], 7)
102-
103-
def test_multiple_reward_fns(self):
104-
manager = reward_manager.SequenceRewardManager(
105-
reward_fns=[len_reward, prompt_len_reward],
106-
algo_config=self.test_algo_config,
107-
)
108-
rewards_info = manager(
109-
self.prompts,
110-
self.completions,
111-
)
34+
with (
35+
mock.patch.dict(os.environ, {"TUNIX_DEBUG_REWARDS": "1"}),
36+
mock.patch("absl.logging.info") as mock_log_info,
37+
):
38+
rewards_info = reward_manager_instance(prompts, completions)
39+
log_metrics = rewards_info["log_metrics"]
11240

113-
# custom_param is 2.0 from test_algo_config
114-
r1 = np.array(len_reward(self.prompts, self.completions))
115-
r2 = np.array(
116-
prompt_len_reward(self.prompts, self.completions, custom_param=2.0)
117-
)
118-
expected_rewards = r1 + r2
119-
rewards_matrix = np.array([r1, r2])
120-
np.testing.assert_array_almost_equal(
121-
rewards_info["rewards"], expected_rewards
122-
)
123-
test_metrics = rewards_info["log_metrics"]
124-
for metric_name, v in test_metrics.items():
125-
if metric_name.startswith("rewards/"):
126-
self.assertLen(v[0], 2)
127-
npt.assert_allclose(
128-
test_metrics["rewards/sum"][0],
129-
expected_rewards,
130-
err_msg="rewards/sum mismatch",
131-
)
132-
npt.assert_allclose(
133-
test_metrics["rewards/len_reward"][0],
134-
r1,
135-
err_msg="rewards/len_reward mismatch",
136-
)
137-
npt.assert_allclose(
138-
test_metrics["rewards/prompt_len_reward"][0],
139-
r2,
140-
err_msg="rewards/prompt_len_reward mismatch",
141-
)
142-
for col_idx in range(rewards_matrix.shape[0]):
143-
npt.assert_allclose(
144-
test_metrics["rewards/min"][0][col_idx],
145-
np.min(rewards_matrix[:, col_idx]),
41+
self.assertIn("prompts", log_metrics)
42+
self.assertIn("completions", log_metrics)
43+
self.assertIn("rewards/sum", log_metrics)
44+
np.testing.assert_allclose(log_metrics["rewards/sum"][0], [3.0, 3.0])
45+
self.assertIn("rewards/reward_fn1", log_metrics)
46+
np.testing.assert_allclose(
47+
log_metrics["rewards/reward_fn1"][0], [1.0, 1.0]
14648
)
147-
npt.assert_allclose(
148-
test_metrics["rewards/max"][0][col_idx],
149-
np.max(rewards_matrix[:, col_idx]),
150-
)
151-
152-
def test_algo_config_param_passing(self):
153-
# Mock the reward function to spy on its call arguments
154-
mock_fn = mock.Mock(wraps=prompt_len_reward)
155-
mock_fn.__name__ = prompt_len_reward.__name__
156-
# Restore the signature for introspection
157-
mock_fn.__signature__ = inspect.signature(prompt_len_reward)
158-
159-
manager = reward_manager.SequenceRewardManager(
160-
reward_fns=[mock_fn],
161-
algo_config=self.test_algo_config,
162-
)
163-
manager(
164-
self.prompts,
165-
self.completions,
166-
)
167-
168-
mock_fn.assert_called_once()
169-
_, kwargs = mock_fn.call_args
170-
self.assertEqual(kwargs["custom_param"], 2.0)
171-
self.assertNotIn(
172-
"another_param", kwargs
173-
) # Not in prompt_len_reward signature
174-
175-
def test_nan_handling(self):
176-
manager = reward_manager.SequenceRewardManager(
177-
reward_fns=[len_reward, nan_reward],
178-
algo_config=self.test_algo_config,
179-
)
180-
rewards_info = manager(
181-
self.prompts,
182-
self.completions,
183-
)
184-
# np.nansum should treat nan as 0 for summation
185-
expected_rewards = np.array([float(len(c)) for c in self.completions])
186-
np.testing.assert_array_almost_equal(
187-
rewards_info["rewards"], expected_rewards
188-
)
189-
# Check logged metrics for NaN
190-
test_metrics = rewards_info["log_metrics"]
191-
self.assertTrue(np.isnan(test_metrics["rewards/nan_reward"][0]).all())
192-
np.testing.assert_allclose(
193-
test_metrics["rewards/sum"][0],
194-
expected_rewards,
195-
err_msg="rewards/sum mismatch",
196-
)
197-
198-
@parameterized.named_parameters(
199-
dict(
200-
testcase_name="reward_fn_returns_none",
201-
reward_fns=[lambda prompts, completions, **kw: None],
202-
expected_regex="Failed to obtain result.*Result is None",
203-
error_type=RuntimeError,
204-
),
205-
dict(
206-
testcase_name="reward_fn_bad_length",
207-
reward_fns=[
208-
lambda prompts, completions, **kw: [1.0] * (len(prompts) + 1)
209-
],
210-
expected_regex="Length mismatch",
211-
error_type=RuntimeError,
212-
),
213-
)
214-
def test_errors(
215-
self, expected_regex, error_type, kwargs=None, reward_fns=None
216-
):
217-
if reward_fns is None:
218-
reward_fns = [len_reward]
219-
for i, fn in enumerate(reward_fns):
220-
if not hasattr(fn, "__name__"):
221-
fn.__name__ = f"test_fn_{i}"
222-
223-
manager = reward_manager.SequenceRewardManager(
224-
reward_fns=reward_fns,
225-
algo_config=self.test_algo_config,
226-
)
227-
with self.assertRaisesRegex(error_type, expected_regex):
228-
manager(
229-
self.prompts,
230-
self.completions,
231-
**(kwargs or {}),
49+
self.assertIn("rewards/reward_fn2", log_metrics)
50+
np.testing.assert_allclose(
51+
log_metrics["rewards/reward_fn2"][0], [2.0, 2.0]
23252
)
23353

234-
def test_no_reward_fns_raises_error(self):
235-
with self.assertRaisesRegex(ValueError, "reward_fns cannot be empty"):
236-
reward_manager.SequenceRewardManager(
237-
reward_fns=[],
238-
algo_config=self.test_algo_config,
239-
)
54+
# Check that _log_one_example was called and logged correctly
55+
mock_log_info.assert_any_call("======= example rewards =======")
56+
mock_log_info.assert_any_call("%s:\t%s", "prompts", "prompt1")
57+
mock_log_info.assert_any_call("%s:\t%s", "completions", "completion1")
58+
mock_log_info.assert_any_call("%s:\t%s", "rewards/sum", "3.0")
59+
mock_log_info.assert_any_call("%s:\t%s", "rewards/mean", "1.5")
60+
mock_log_info.assert_any_call("%s:\t%s", "rewards/min", "1.0")
61+
mock_log_info.assert_any_call("%s:\t%s", "rewards/max", "2.0")
62+
mock_log_info.assert_any_call("%s:\t%s", "rewards/reward_fn1", "1.0")
63+
mock_log_info.assert_any_call("%s:\t%s", "rewards/reward_fn2", "2.0")
64+
mock_log_info.assert_any_call("=======================")
24065

24166

24267
if __name__ == "__main__":
243-
absltest.main()
68+
unittest.main()

tunix/rl/reward_manager.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import abc
1818
from dataclasses import asdict
1919
import inspect
20+
import os
2021
from typing import Any, Callable, Dict, List, Sequence
2122
from absl import logging
2223
import numpy as np
@@ -170,6 +171,23 @@ def _compute_rewards(
170171
"rewards": sum_rewards,
171172
"log_metrics": log_metrics,
172173
}
174+
175+
def _log_one_example(log_metrics: Dict[str, Any]):
176+
logging.info("======= example rewards =======")
177+
178+
# add a snippet of the prompt, completion, and reward
179+
def snippet(s: str, k: int = 50):
180+
if len(s) <= 2 * k:
181+
return s
182+
return s[:k] + "..." + s[-k:]
183+
184+
for k, v in log_metrics.items():
185+
logging.info("%s:\t%s", k, snippet(str(v[0][0])))
186+
logging.info("=======================")
187+
188+
if os.getenv("TUNIX_DEBUG_REWARDS"):
189+
_log_one_example(log_metrics)
190+
173191
return rewards_info
174192

175193
def _prepare_log_metrics(

0 commit comments

Comments
 (0)