Skip to content

Commit be5552d

Browse files
committed
Add fallback reward for correct language without thinking blocks
- Add fallback_reward parameter (default 0.2) - If no <think> blocks found, check if response text is in target language - Reward structure: * full_reward (1.0): Single block + correct language * partial_reward (0.5): Multiple blocks + correct language * fallback_reward (0.2): No blocks + correct language in response text * no_match_reward (0.0): Wrong language - Update all tests to reflect new behavior (29 tests passing)
1 parent c47a47e commit be5552d

File tree

2 files changed

+86
-27
lines changed

2 files changed

+86
-27
lines changed

src/forge/data/rewards.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,10 @@ class LanguageReward:
9090
9191
Args:
9292
target_language: ISO 639-1 language code (e.g., 'en', 'ja', 'zh', 'es')
93-
full_reward: Reward when detected language matches target
94-
no_match_reward: Reward when detected language doesn't match target
93+
full_reward: Reward when language matches and format is correct (single block)
94+
partial_reward: Reward when language matches but format is wrong (multiple blocks)
95+
fallback_reward: Reward when no valid blocks but response text is in target language
96+
no_match_reward: Reward when language doesn't match
9597
9698
Note: Requires langid to be installed. Install with: pip install langid
9799
"""
@@ -100,10 +102,14 @@ def __init__(
100102
self,
101103
target_language: str = "en",
102104
full_reward: float = 1.0,
105+
partial_reward: float = 0.5,
106+
fallback_reward: float = 0.2,
103107
no_match_reward: float = 0.0,
104108
):
105109
self.target_language = target_language
106110
self.full_reward = full_reward
111+
self.partial_reward = partial_reward
112+
self.fallback_reward = fallback_reward
107113
self.no_match_reward = no_match_reward
108114
self._THINK_BLOCK_RE = re.compile(
109115
r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL
@@ -129,21 +135,38 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo
129135
target: Optional target string (unused but kept for signature consistency)
130136
131137
Returns:
132-
full_reward if detected language matches target_language and format is correct,
133-
no_match_reward otherwise (including when format is wrong or no thinking block)
138+
full_reward if language matches and exactly one thinking block is found,
139+
partial_reward if language matches but multiple thinking blocks found,
140+
fallback_reward if no valid blocks but response text is in target language,
141+
no_match_reward otherwise (wrong language)
134142
"""
135143
if not response:
136144
return self.no_match_reward
137145

138146
# Extract all thinking blocks
139147
matches = self._THINK_BLOCK_RE.findall(response)
140148

141-
# Return 0 reward if format is wrong (0 or multiple thinking blocks)
142-
if len(matches) != 1:
149+
# If no thinking blocks found, check if response text is in target language
150+
if len(matches) == 0:
151+
# Remove any partial tags that might exist
152+
response_text = re.sub(
153+
r"<\s*/?\s*think\s*>", "", response, flags=re.IGNORECASE
154+
).strip()
155+
156+
if not response_text:
157+
return self.no_match_reward
158+
159+
# Detect language of general response
160+
detected_lang, confidence = self._langid.classify(response_text)
161+
162+
# Give fallback reward if response is in target language
163+
if detected_lang == self.target_language:
164+
return self.fallback_reward
165+
143166
return self.no_match_reward
144167

145-
# Get the single thinking block content
146-
thinking_content = matches[0]
168+
# Concatenate all thinking blocks for language detection
169+
thinking_content = " ".join(matches)
147170

148171
# Remove extra whitespace
149172
thinking_content = re.sub(r"\s+", " ", thinking_content).strip()
@@ -154,8 +177,13 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo
154177
# Detect language using langid
155178
detected_lang, confidence = self._langid.classify(thinking_content)
156179

157-
# Return full reward if language matches target
180+
# Check if language matches target
158181
if detected_lang == self.target_language:
159-
return self.full_reward
182+
# Full reward for correct format (single block)
183+
if len(matches) == 1:
184+
return self.full_reward
185+
# Partial reward for wrong format (multiple blocks) but correct language
186+
else:
187+
return self.partial_reward
160188

161189
return self.no_match_reward

tests/unit_tests/rl/test_language_reward.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,35 @@ def setUp(self):
1919
self.reward_en = LanguageReward(target_language="en")
2020
self.reward_ja = LanguageReward(target_language="ja")
2121
self.custom_reward = LanguageReward(
22-
target_language="ja", full_reward=0.9, no_match_reward=0.1
22+
target_language="ja",
23+
full_reward=0.9,
24+
partial_reward=0.6,
25+
fallback_reward=0.3,
26+
no_match_reward=0.1,
2327
)
2428

2529
def test_init_default_values(self):
2630
"""Test LanguageReward initialization with default values."""
2731
reward = self.LanguageReward()
2832
self.assertEqual(reward.target_language, "en")
2933
self.assertEqual(reward.full_reward, 1.0)
34+
self.assertEqual(reward.partial_reward, 0.5)
35+
self.assertEqual(reward.fallback_reward, 0.2)
3036
self.assertEqual(reward.no_match_reward, 0.0)
3137

3238
def test_init_custom_values(self):
3339
"""Test LanguageReward initialization with custom values."""
3440
reward = self.LanguageReward(
35-
target_language="ja", full_reward=0.9, no_match_reward=0.1
41+
target_language="ja",
42+
full_reward=0.9,
43+
partial_reward=0.6,
44+
fallback_reward=0.3,
45+
no_match_reward=0.1,
3646
)
3747
self.assertEqual(reward.target_language, "ja")
3848
self.assertEqual(reward.full_reward, 0.9)
49+
self.assertEqual(reward.partial_reward, 0.6)
50+
self.assertEqual(reward.fallback_reward, 0.3)
3951
self.assertEqual(reward.no_match_reward, 0.1)
4052

4153
def test_init_missing_langid(self):
@@ -112,10 +124,17 @@ def test_call_language_mismatch(self):
112124
self.assertEqual(result, 0.0)
113125

114126
def test_call_with_no_thinking_tags(self):
115-
"""Test __call__ with response containing no thinking tags."""
127+
"""Test __call__ with response containing no thinking tags but correct language."""
116128
result = self.reward_en(
117129
"prompt", "This is just a regular response without any thinking tags."
118130
)
131+
# No thinking blocks but response is in English, should get fallback reward
132+
self.assertEqual(result, 0.2)
133+
134+
def test_call_with_no_thinking_tags_wrong_language(self):
135+
"""Test __call__ with response containing no thinking tags and wrong language."""
136+
result = self.reward_en("prompt", "これは日本語の応答です。タグはありません。")
137+
# No thinking blocks and wrong language, should get no_match_reward
119138
self.assertEqual(result, 0.0)
120139

121140
def test_call_with_empty_thinking_block(self):
@@ -145,15 +164,15 @@ def test_call_with_whitespace_in_tags(self):
145164
self.assertEqual(result, 1.0)
146165

147166
def test_call_multiple_thinking_blocks(self):
148-
"""Test __call__ with multiple thinking blocks (wrong format)."""
167+
"""Test __call__ with multiple thinking blocks (wrong format but correct language)."""
149168
response = """
150169
<think>First thought in English.</think>
151170
Some text in between.
152171
<think>Second thought also in English.</think>
153172
"""
154173
result = self.reward_en("prompt", response)
155-
# Multiple blocks = wrong format, should return 0
156-
self.assertEqual(result, 0.0)
174+
# Multiple blocks = wrong format, but language is correct, should return partial_reward
175+
self.assertEqual(result, 0.5)
157176

158177
def test_call_multiple_thinking_blocks_mixed_languages(self):
159178
"""Test __call__ with multiple thinking blocks in different languages (wrong format)."""
@@ -162,8 +181,9 @@ def test_call_multiple_thinking_blocks_mixed_languages(self):
162181
<think>これは短い日本語。</think>
163182
"""
164183
result = self.reward_en("prompt", response)
165-
# Multiple blocks = wrong format, should return 0
166-
self.assertEqual(result, 0.0)
184+
# Multiple blocks with mixed languages - langid will detect dominant language
185+
# Should return either partial_reward (if detects English) or no_match_reward (if detects Japanese)
186+
self.assertIn(result, [0.0, 0.5])
167187

168188
def test_call_multiline_thinking_block(self):
169189
"""Test __call__ with multiline thinking blocks."""
@@ -192,20 +212,31 @@ def test_call_with_target_parameter(self):
192212
result = self.reward_en("prompt", response, target="some target")
193213
self.assertEqual(result, 1.0)
194214

195-
result = self.reward_en("prompt", "no tags", target="some target")
196-
self.assertEqual(result, 0.0)
215+
# Longer English text without tags should get fallback reward
216+
result = self.reward_en(
217+
"prompt",
218+
"This is a response without thinking tags but in English language.",
219+
target="some target",
220+
)
221+
self.assertEqual(result, 0.2)
197222

198223
def test_call_custom_reward_values(self):
199224
"""Test __call__ with custom reward values."""
200-
response_ja = "<think>これは日本語です。</think>"
225+
response_ja_single = "<think>これは日本語です。</think>"
226+
response_ja_multiple = "<think>最初の考え。</think><think>次の考え。</think>"
227+
response_ja_no_tags = "これはタグなしの日本語です。"
201228
response_en = "<think>This is English.</think>"
202-
response_none = "no thinking tags"
203-
204-
# Test custom full reward
205-
self.assertEqual(self.custom_reward("prompt", response_ja), 0.9)
206-
# Test custom no_match reward
229+
response_none = ""
230+
231+
# Test custom full reward (single block, correct language)
232+
self.assertEqual(self.custom_reward("prompt", response_ja_single), 0.9)
233+
# Test custom partial reward (multiple blocks, correct language)
234+
self.assertEqual(self.custom_reward("prompt", response_ja_multiple), 0.6)
235+
# Test custom fallback reward (no blocks, correct language)
236+
self.assertEqual(self.custom_reward("prompt", response_ja_no_tags), 0.3)
237+
# Test custom no_match reward (wrong language)
207238
self.assertEqual(self.custom_reward("prompt", response_en), 0.1)
208-
# Test no tags
239+
# Test empty response
209240
self.assertEqual(self.custom_reward("prompt", response_none), 0.1)
210241

211242
def test_call_zero_custom_values(self):

0 commit comments

Comments
 (0)