Skip to content

Commit 4126f61

Browse files
casteryhJenniferWang
authored andcommitted
Add LanguageReward for training models to think in target language (meta-pytorch#515)
Co-authored-by: Jiyue Wang <[email protected]>
1 parent c0540f7 commit 4126f61

File tree

4 files changed

+284
-3
lines changed

4 files changed

+284
-3
lines changed

apps/grpo/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ async def setup(self):
238238

239239
def gsm8k_transform(sample):
240240
system_prompt = """You are a helpful AI assistant that solves math problems.
241+
241242
Please show your reasoning inside <思考></思考> tags, then provide your final numerical answer inside <answer></answer> tags.
243+
242244
Example:
243245
Question: What is 12 + 5?
244246
<思考>12と5を足します。12 + 5 = 17です。</思考>

apps/grpo/qwen3_8b.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# >>> python -m apps.grpo.main --config apps/grpo/qwen3_8b.yaml
33

44
# Global configuration
5-
group_size: 8
6-
local_batch_size: 12 # per-device batch size
5+
group_size: 16
6+
local_batch_size: 4 # per-device batch size
77
max_req_tokens: 1024
8-
max_res_tokens: 1024
8+
max_res_tokens: 2048
99
model: "Qwen/Qwen3-8B"
1010
off_by_n: 1 # Off by one by default
1111

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import sys
8+
import unittest
9+
from unittest.mock import patch
10+
11+
12+
class TestLanguageReward(unittest.TestCase):
13+
def setUp(self):
14+
"""Set up test fixtures before each test method."""
15+
# Import after patching to avoid ImportError
16+
from forge.data.rewards import LanguageReward
17+
18+
self.LanguageReward = LanguageReward
19+
self.reward_en = LanguageReward(target_language="en")
20+
self.reward_ja = LanguageReward(target_language="ja")
21+
22+
def test_init_default_values(self):
23+
"""Test LanguageReward initialization with default values."""
24+
reward = self.LanguageReward()
25+
self.assertEqual(reward.target_language, "ja")
26+
self.assertEqual(reward.match_reward, 1.0)
27+
self.assertEqual(reward.no_match_reward, 0.0)
28+
29+
def test_init_custom_values(self):
30+
"""Test LanguageReward initialization with custom values."""
31+
reward = self.LanguageReward(
32+
target_language="ja",
33+
match_reward=0.9,
34+
no_match_reward=0.1,
35+
)
36+
self.assertEqual(reward.target_language, "ja")
37+
self.assertEqual(reward.match_reward, 0.9)
38+
self.assertEqual(reward.no_match_reward, 0.1)
39+
40+
def test_init_missing_langid(self):
41+
"""Test LanguageReward initialization without langid installed."""
42+
# Remove langid from modules if it exists
43+
langid_module = sys.modules.get("langid")
44+
if "langid" in sys.modules:
45+
del sys.modules["langid"]
46+
47+
with patch.dict("sys.modules", {"langid": None}):
48+
with self.assertRaises(ImportError) as context:
49+
# Re-import to trigger the ImportError
50+
import importlib
51+
52+
import forge.data.rewards
53+
54+
importlib.reload(forge.data.rewards)
55+
forge.data.rewards.LanguageReward()
56+
57+
self.assertIn("langid is required", str(context.exception))
58+
self.assertIn("pip install langid", str(context.exception))
59+
60+
# Restore langid module if it existed
61+
if langid_module is not None:
62+
sys.modules["langid"] = langid_module
63+
64+
def test_regex_pattern(self):
65+
"""Test that regex pattern is compiled correctly."""
66+
reward = self.LanguageReward()
67+
self.assertIsNotNone(reward._THINK_BLOCK_RE)
68+
69+
def test_call_with_english_thinking(self):
70+
"""Test __call__ with English text in thinking blocks."""
71+
response = "<思考>This is English reasoning about math problems.</思考>"
72+
result = self.reward_en("prompt", response)
73+
self.assertEqual(result, 1.0)
74+
75+
def test_call_with_japanese_thinking(self):
76+
"""Test __call__ with Japanese text in thinking blocks."""
77+
response = "<思考>これは日本語で考えています。数学の問題を解きます。</思考>"
78+
result = self.reward_ja("prompt", response)
79+
self.assertEqual(result, 1.0)
80+
81+
# English reward should give no_match_reward for Japanese text
82+
result = self.reward_en("prompt", response)
83+
self.assertEqual(result, 0.0)
84+
85+
def test_call_with_chinese_thinking(self):
86+
"""Test __call__ with Chinese text in thinking blocks."""
87+
response = "<思考>这是中文思考。我们需要解决这个数学问题。</思考>"
88+
reward_zh = self.LanguageReward(target_language="zh")
89+
result = reward_zh("prompt", response)
90+
# langid should detect this as Chinese (zh)
91+
self.assertEqual(result, 1.0)
92+
93+
def test_call_with_spanish_thinking(self):
94+
"""Test __call__ with Spanish text in thinking blocks."""
95+
response = "<思考>Este es un razonamiento en español sobre problemas matemáticos.</思考>"
96+
reward_es = self.LanguageReward(target_language="es")
97+
result = reward_es("prompt", response)
98+
# langid should detect this as Spanish (es)
99+
self.assertEqual(result, 1.0)
100+
101+
def test_call_language_mismatch(self):
102+
"""Test __call__ when detected language doesn't match target."""
103+
# Japanese reward with English text
104+
response = "<思考>This is English reasoning.</思考>"
105+
result = self.reward_ja("prompt", response)
106+
self.assertEqual(result, 0.0)
107+
108+
# English reward with Japanese text
109+
response = "<思考>これは日本語です。</思考>"
110+
result = self.reward_en("prompt", response)
111+
self.assertEqual(result, 0.0)
112+
113+
def test_call_with_no_thinking_tags(self):
114+
"""Test __call__ with response containing no thinking tags but correct language."""
115+
result = self.reward_en(
116+
"prompt", "This is just a regular response without any thinking tags."
117+
)
118+
# No thinking blocks -> detect whole response, English detected -> match_reward
119+
self.assertEqual(result, 1.0)
120+
121+
def test_call_with_no_thinking_tags_wrong_language(self):
122+
"""Test __call__ with response containing no thinking tags and wrong language."""
123+
result = self.reward_en("prompt", "これは日本語の応答です。タグはありません。")
124+
# No thinking blocks -> detect whole response, Japanese detected -> no_match_reward
125+
self.assertEqual(result, 0.0)
126+
127+
def test_call_with_empty_thinking_block(self):
128+
"""Test __call__ with empty thinking block."""
129+
result = self.reward_en("prompt", "<思考></思考>")
130+
self.assertEqual(result, 0.0)
131+
132+
def test_call_with_whitespace_only_thinking_block(self):
133+
"""Test __call__ with whitespace-only thinking block."""
134+
result = self.reward_en("prompt", "<思考> \n \t </思考>")
135+
self.assertEqual(result, 0.0)
136+
137+
def test_call_with_proper_tags(self):
138+
"""Test __call__ with properly formatted thinking tags."""
139+
response = "<思考>This is English reasoning.</思考>"
140+
result = self.reward_en("prompt", response)
141+
self.assertEqual(result, 1.0)
142+
143+
# Japanese content should also work
144+
response = "<思考>これは日本語です。</思考>"
145+
result = self.reward_ja("prompt", response)
146+
self.assertEqual(result, 1.0)
147+
148+
def test_call_multiple_thinking_blocks(self):
149+
"""Test __call__ with multiple thinking blocks - detects whole response language."""
150+
response = """
151+
<思考>First thought in English.</思考>
152+
Some text in between.
153+
<思考>Second thought also in English.</思考>
154+
"""
155+
result = self.reward_en("prompt", response)
156+
# Multiple blocks -> detect whole response, English detected -> match_reward
157+
self.assertEqual(result, 1.0)
158+
159+
def test_call_multiline_thinking_block(self):
160+
"""Test __call__ with multiline thinking blocks."""
161+
response = """<思考>
162+
This is a multiline
163+
thinking block with
164+
lots of English content
165+
about solving problems
166+
</思考>"""
167+
result = self.reward_en("prompt", response)
168+
self.assertEqual(result, 1.0)
169+
170+
def test_call_empty_response(self):
171+
"""Test __call__ with empty response."""
172+
result = self.reward_en("prompt", "")
173+
self.assertEqual(result, 0.0)
174+
175+
def test_call_none_response(self):
176+
"""Test __call__ with None response."""
177+
result = self.reward_en("prompt", None)
178+
self.assertEqual(result, 0.0)
179+
180+
def test_call_custom_reward_values(self):
181+
"""Test __call__ with custom reward values."""
182+
response_ja_single = "<思考>これは日本語です。</思考>"
183+
response_ja_multiple = "<思考>最初の考え。</思考><思考>次の考え。</思考>"
184+
response_ja_no_tags = "これはタグなしの日本語です。"
185+
response_en = "<思考>This is English.</思考>"
186+
response_none = ""
187+
188+
custom_reward = self.LanguageReward(
189+
target_language="ja",
190+
match_reward=0.9,
191+
no_match_reward=0.1,
192+
)
193+
# Test custom match reward (single block, correct language)
194+
self.assertEqual(custom_reward("prompt", response_ja_single), 0.9)
195+
# Test custom match reward (multiple blocks -> whole response, correct language)
196+
self.assertEqual(custom_reward("prompt", response_ja_multiple), 0.9)
197+
# Test custom match reward (no blocks -> whole response, correct language)
198+
self.assertEqual(custom_reward("prompt", response_ja_no_tags), 0.9)
199+
# Test custom no_match reward (wrong language)
200+
self.assertEqual(custom_reward("prompt", response_en), 0.1)
201+
# Test empty response
202+
self.assertEqual(custom_reward("prompt", response_none), 0.1)
203+
204+
def test_call_with_special_characters(self):
205+
"""Test __call__ with special characters in thinking blocks."""
206+
response = (
207+
"<思考>English with special chars: @#$%^&*()_+-=[]{}|;':\",./<>?`~</思考>"
208+
)
209+
result = self.reward_en("prompt", response)
210+
self.assertEqual(result, 1.0)
211+
212+
def test_call_with_mixed_content_outside_tags(self):
213+
"""Test __call__ with mixed language content outside thinking tags."""
214+
# Content outside think tags should be ignored
215+
response = """
216+
これは日本語のテキストです。
217+
<思考>But this is English reasoning inside the tags.</思考>
218+
もっと日本語のテキスト。
219+
"""
220+
result = self.reward_en("prompt", response)
221+
# Should detect English from thinking block only
222+
self.assertEqual(result, 1.0)
223+
224+
def test_call_with_numbers_and_symbols(self):
225+
"""Test __call__ with thinking blocks containing mostly numbers."""
226+
response = "<思考>Calculate: 2 + 2 = 4, then 4 * 3 = 12</思考>"
227+
result = self.reward_en("prompt", response)
228+
# Should still detect as English due to words like "Calculate" and "then"
229+
self.assertEqual(result, 1.0)
230+
231+
def test_call_with_code_in_thinking(self):
232+
"""Test __call__ with code snippets in thinking blocks."""
233+
response = """<思考>
234+
Let me write some Python code to solve this:
235+
def calculate(x):
236+
return x * 2
237+
The function doubles the input value.
238+
</思考>"""
239+
result = self.reward_en("prompt", response)
240+
# Should detect as English due to surrounding text
241+
self.assertEqual(result, 1.0)
242+
243+
def test_different_language_codes(self):
244+
"""Test __call__ with various ISO 639-1 language codes."""
245+
# Test a few common languages
246+
languages = {
247+
"fr": "Ceci est un texte en français avec beaucoup de contenu.",
248+
"de": "Dies ist ein deutscher Text mit viel Inhalt.",
249+
"it": "Questo è un testo italiano con molto contenuto.",
250+
"pt": "Este é um texto em português com muito conteúdo.",
251+
}
252+
253+
for lang_code, text in languages.items():
254+
reward = self.LanguageReward(target_language=lang_code)
255+
response = f"<思考>{text}</思考>"
256+
result = reward("prompt", response)
257+
# langid should detect these correctly
258+
self.assertEqual(
259+
result,
260+
1.0,
261+
f"Failed to detect {lang_code} language: '{text[:50]}...'",
262+
)
263+
264+
265+
if __name__ == "__main__":
266+
unittest.main()

tests/unit_tests/rl/test_thinking_reward.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,19 @@ def test_call_very_long_thinking_block(self):
203203
result = self.reward("prompt", f"<think>{long_content}</think>")
204204
self.assertEqual(result, 1.0)
205205

206+
def test_custom_tag(self):
207+
"""Test that ThinkingReward uses the custom tag passed in."""
208+
# Create reward with custom Japanese tag
209+
custom_tag_reward = ThinkingReward(tag="思考")
210+
211+
# Response with custom tag should get full reward
212+
result = custom_tag_reward("prompt", "<思考>This is my reasoning</思考>")
213+
self.assertEqual(result, 1.0)
214+
215+
# Response with default "think" tag should get no reward
216+
result = custom_tag_reward("prompt", "<think>This is my reasoning</think>")
217+
self.assertEqual(result, 0.0)
218+
206219

207220
if __name__ == "__main__":
208221
unittest.main()

0 commit comments

Comments
 (0)