Skip to content

Commit 7b4829c

Browse files
committed
Simplify LanguageReward logic to focus on language detection only
Since ThinkingReward already enforces format (single vs multiple blocks), LanguageReward now focuses purely on language detection with simplified logic: Detection strategy: - If exactly one thinking block: detect language of block content only - Otherwise (no blocks or multiple blocks): detect language of whole response - Returns match_reward (1.0) if language matches, no_match_reward (0.0) otherwise Changes: - Removed partial_reward and fallback_reward parameters (now just match/no-match) - Renamed full_reward to match_reward for clarity - Updated all 29 tests to match new behavior (all passing) - Updated README with clearer explanation of reward separation - Updated debug script with new expected rewards This separation of concerns allows each reward to specialize: - ThinkingReward: format enforcement - LanguageReward: language detection
1 parent abb653e commit 7b4829c

File tree

4 files changed

+85
-125
lines changed

4 files changed

+85
-125
lines changed

sandbox/grpo_language/README.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,16 @@ You can use any of the config files from `apps/grpo/` (e.g., `qwen3_1_7b.yaml`,
4444
1. The model receives a math problem and is instructed to use `<思考>` tags for reasoning
4545
2. During training, the model generates responses with thinking blocks
4646
3. Three rewards are computed:
47-
- Math correctness (did it get the right answer?)
48-
- Thinking usage (did it use `<思考>` tags properly?)
49-
- Language usage (did it think in Japanese?)
47+
- **MathReward**: Did it get the right answer?
48+
- **ThinkingReward**: Did it use `<思考>` tags properly? (single block = full reward, multiple blocks = partial reward)
49+
- **LanguageReward**: Did it use the target language? Detection strategy:
50+
- If exactly one thinking block: detect language of block content only
51+
- Otherwise (no blocks or multiple blocks): detect language of whole response
52+
- Returns match_reward (1.0) if detected language matches target, no_match_reward (0.0) otherwise
5053
4. The model is trained to maximize all three rewards
5154

55+
**Note**: ThinkingReward enforces format (single vs multiple blocks), while LanguageReward focuses purely on language detection. This separation of concerns allows each reward to specialize in one aspect of the desired behavior.
56+
5257
## Configuration
5358

5459
### Target Language

sandbox/grpo_language/debug_reward.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,10 @@
8080
print(" No content to analyze")
8181

8282
print("\n" + "=" * 80)
83-
print("Expected rewards:")
84-
print(" full_reward (1.0): Single Japanese thinking block")
85-
print(" partial_reward (0.5): Multiple Japanese thinking blocks")
86-
print(" fallback_reward (0.2): No blocks but Japanese response text")
87-
print(" no_match_reward (0.0): Wrong language")
83+
print("Expected rewards (simplified logic):")
84+
print(" match_reward (1.0): Detected language matches target (ja)")
85+
print(" no_match_reward (0.0): Detected language doesn't match target")
86+
print("\nDetection strategy:")
87+
print(" - Single thinking block: detect language of block content only")
88+
print(" - Multiple blocks or no blocks: detect language of whole response")
8889
print("=" * 80)

src/forge/data/rewards.py

Lines changed: 49 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,20 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo
9696

9797

9898
class LanguageReward:
99-
"""Reward class for evaluating the language used in thinking tags.
99+
"""Reward class for evaluating the language used in responses.
100100
101-
This reward uses langid to detect the language of text within thinking blocks
102-
and rewards responses that use the target language.
101+
This reward uses langid to detect the language and rewards responses that use
102+
the target language. The detection strategy depends on the format:
103+
- If exactly one thinking block: detect language of the block content
104+
- Otherwise (no blocks or multiple blocks): detect language of whole response
105+
106+
Note: Format enforcement (single vs multiple blocks) is handled by ThinkingReward.
107+
This reward focuses purely on language detection.
103108
104109
Args:
105110
target_language: ISO 639-1 language code (e.g., 'en', 'ja', 'zh', 'es')
106-
full_reward: Reward when language matches and format is correct (single block)
107-
partial_reward: Reward when language matches but format is wrong (multiple blocks)
108-
fallback_reward: Reward when no valid blocks but response text is in target language
109-
no_match_reward: Reward when language doesn't match
111+
match_reward: Reward when detected language matches target (default: 1.0)
112+
no_match_reward: Reward when language doesn't match (default: 0.0)
110113
tag: Tag name to use (default "思考" for multilingual, can use "think", etc.)
111114
debug: If True, print debug samples showing model outputs and detected language
112115
debug_sample_rate: Fraction of calls to debug (e.g., 0.1 = 10% of calls)
@@ -117,18 +120,14 @@ class LanguageReward:
117120
def __init__(
118121
self,
119122
target_language: str = "en",
120-
full_reward: float = 1.0,
121-
partial_reward: float = 0.5,
122-
fallback_reward: float = 0.2,
123+
match_reward: float = 1.0,
123124
no_match_reward: float = 0.0,
124125
tag: str = "思考",
125126
debug: bool = False,
126127
debug_sample_rate: float = 0.1,
127128
):
128129
self.target_language = target_language
129-
self.full_reward = full_reward
130-
self.partial_reward = partial_reward
131-
self.fallback_reward = fallback_reward
130+
self.match_reward = match_reward
132131
self.no_match_reward = no_match_reward
133132
self.tag = tag
134133
self.debug = debug
@@ -138,7 +137,6 @@ def __init__(
138137
self._THINK_BLOCK_RE = re.compile(
139138
rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>", re.DOTALL
140139
)
141-
self._TAG_PATTERN = rf"<\s*/?\s*{re.escape(tag)}\s*>"
142140

143141
# Lazy import langid with helpful error message
144142
try:
@@ -152,18 +150,19 @@ def __init__(
152150
) from None
153151

154152
def __call__(self, prompt: str, response: str, target: str | None = None) -> float:
155-
"""Compute language reward based on thinking block content.
153+
"""Compute language reward based on detected language.
154+
155+
Detection strategy:
156+
- If exactly one thinking block: detect language of block content
157+
- Otherwise: detect language of whole response
156158
157159
Args:
158160
prompt: The input prompt (unused but kept for signature consistency)
159-
response: The model response containing <think> tags
161+
response: The model response
160162
target: Optional target string (unused but kept for signature consistency)
161163
162164
Returns:
163-
full_reward if language matches and exactly one thinking block is found,
164-
partial_reward if language matches but multiple thinking blocks found,
165-
fallback_reward if no valid blocks but response text is in target language,
166-
no_match_reward otherwise (wrong language)
165+
match_reward if detected language matches target, no_match_reward otherwise
167166
"""
168167
# Increment counter for sampling
169168
self._debug_counter += 1
@@ -174,89 +173,52 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo
174173
)
175174

176175
if not response:
177-
return self.no_match_reward
178-
179-
# Extract all thinking blocks
180-
matches = self._THINK_BLOCK_RE.findall(response)
181-
182-
# If no thinking blocks found, check if response text is in target language
183-
if len(matches) == 0:
184-
# Remove any partial tags that might exist
185-
response_text = re.sub(self._TAG_PATTERN, "", response).strip()
186-
187-
if not response_text:
188-
if should_debug:
189-
print(
190-
f"\n[LanguageReward] Empty response | Reward: {self.no_match_reward}"
191-
)
192-
return self.no_match_reward
193-
194-
# Detect language of general response
195-
detected_lang, confidence = self._langid.classify(response_text)
196-
197176
if should_debug:
198-
sample = response[:150].replace("\n", " ")
199177
print(
200-
f"\n[LanguageReward] No thinking blocks found (FALLBACK mode)"
201-
f"\n Target: {self.target_language} | Detected: {detected_lang} | "
202-
f"Confidence: {confidence:.2f}"
203-
f"\n Sample: {sample}..."
178+
f"\n[LanguageReward] Empty response | Reward: {self.no_match_reward}"
204179
)
205-
206-
# Give fallback reward if response is in target language
207-
if detected_lang == self.target_language:
208-
if should_debug:
209-
print(
210-
f" → Reward: {self.fallback_reward} (fallback, correct language)"
211-
)
212-
return self.fallback_reward
213-
214-
if should_debug:
215-
print(f" → Reward: {self.no_match_reward} (wrong language)")
216180
return self.no_match_reward
217181

218-
# Concatenate all thinking blocks for language detection
219-
thinking_content = " ".join(matches)
182+
# Extract all thinking blocks
183+
matches = self._THINK_BLOCK_RE.findall(response)
184+
185+
# Determine what text to analyze
186+
if len(matches) == 1:
187+
# Single block: detect language of block content only
188+
text_to_analyze = matches[0].strip()
189+
detection_mode = "single block"
190+
else:
191+
# No blocks or multiple blocks: detect language of whole response
192+
text_to_analyze = response.strip()
193+
detection_mode = f"{len(matches)} blocks, using whole response"
220194

221195
# Remove extra whitespace
222-
thinking_content = re.sub(r"\s+", " ", thinking_content).strip()
196+
text_to_analyze = re.sub(r"\s+", " ", text_to_analyze).strip()
223197

224-
if not thinking_content:
198+
if not text_to_analyze:
225199
if should_debug:
226-
print(
227-
f"\n[LanguageReward] Empty thinking blocks | Reward: {self.no_match_reward}"
228-
)
200+
print(f"\n[LanguageReward] Empty text | Reward: {self.no_match_reward}")
229201
return self.no_match_reward
230202

231203
# Detect language using langid
232-
detected_lang, confidence = self._langid.classify(thinking_content)
204+
detected_lang, confidence = self._langid.classify(text_to_analyze)
205+
206+
# Check if language matches target
207+
reward = (
208+
self.match_reward
209+
if detected_lang == self.target_language
210+
else self.no_match_reward
211+
)
233212

234213
if should_debug:
235-
sample = thinking_content[:150].replace("\n", " ")
214+
sample = text_to_analyze[:150].replace("\n", " ")
215+
match_symbol = "✓" if detected_lang == self.target_language else "✗"
236216
print(
237-
f"\n[LanguageReward] Found {len(matches)} thinking block(s)"
217+
f"\n[LanguageReward] Detection mode: {detection_mode}"
238218
f"\n Target: {self.target_language} | Detected: {detected_lang} | "
239219
f"Confidence: {confidence:.2f}"
240-
f"\n Thinking sample: {sample}..."
220+
f"\n Sample: {sample}..."
221+
f"\n → Reward: {reward} {match_symbol}"
241222
)
242223

243-
# Check if language matches target
244-
if detected_lang == self.target_language:
245-
# Full reward for correct format (single block)
246-
if len(matches) == 1:
247-
if should_debug:
248-
print(
249-
f" → Reward: {self.full_reward} (single block, correct language) ✓"
250-
)
251-
return self.full_reward
252-
# Partial reward for wrong format (multiple blocks) but correct language
253-
else:
254-
if should_debug:
255-
print(
256-
f" → Reward: {self.partial_reward} (multiple blocks, correct language)"
257-
)
258-
return self.partial_reward
259-
260-
if should_debug:
261-
print(f" → Reward: {self.no_match_reward} (wrong language) ✗")
262-
return self.no_match_reward
224+
return reward

tests/unit_tests/rl/test_language_reward.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,26 @@ def setUp(self):
2020
self.reward_ja = LanguageReward(target_language="ja")
2121
self.custom_reward = LanguageReward(
2222
target_language="ja",
23-
full_reward=0.9,
24-
partial_reward=0.6,
25-
fallback_reward=0.3,
23+
match_reward=0.9,
2624
no_match_reward=0.1,
2725
)
2826

2927
def test_init_default_values(self):
3028
"""Test LanguageReward initialization with default values."""
3129
reward = self.LanguageReward()
3230
self.assertEqual(reward.target_language, "en")
33-
self.assertEqual(reward.full_reward, 1.0)
34-
self.assertEqual(reward.partial_reward, 0.5)
35-
self.assertEqual(reward.fallback_reward, 0.2)
31+
self.assertEqual(reward.match_reward, 1.0)
3632
self.assertEqual(reward.no_match_reward, 0.0)
3733

3834
def test_init_custom_values(self):
3935
"""Test LanguageReward initialization with custom values."""
4036
reward = self.LanguageReward(
4137
target_language="ja",
42-
full_reward=0.9,
43-
partial_reward=0.6,
44-
fallback_reward=0.3,
38+
match_reward=0.9,
4539
no_match_reward=0.1,
4640
)
4741
self.assertEqual(reward.target_language, "ja")
48-
self.assertEqual(reward.full_reward, 0.9)
49-
self.assertEqual(reward.partial_reward, 0.6)
50-
self.assertEqual(reward.fallback_reward, 0.3)
42+
self.assertEqual(reward.match_reward, 0.9)
5143
self.assertEqual(reward.no_match_reward, 0.1)
5244

5345
def test_init_missing_langid(self):
@@ -130,13 +122,13 @@ def test_call_with_no_thinking_tags(self):
130122
result = self.reward_en(
131123
"prompt", "This is just a regular response without any thinking tags."
132124
)
133-
# No thinking blocks but response is in English, should get fallback reward
134-
self.assertEqual(result, 0.2)
125+
# No thinking blocks -> detect whole response, English detected -> match_reward
126+
self.assertEqual(result, 1.0)
135127

136128
def test_call_with_no_thinking_tags_wrong_language(self):
137129
"""Test __call__ with response containing no thinking tags and wrong language."""
138130
result = self.reward_en("prompt", "これは日本語の応答です。タグはありません。")
139-
# No thinking blocks and wrong language, should get no_match_reward
131+
# No thinking blocks -> detect whole response, Japanese detected -> no_match_reward
140132
self.assertEqual(result, 0.0)
141133

142134
def test_call_with_empty_thinking_block(self):
@@ -167,26 +159,26 @@ def test_call_with_whitespace_in_tags(self):
167159
self.assertEqual(result, 1.0)
168160

169161
def test_call_multiple_thinking_blocks(self):
170-
"""Test __call__ with multiple thinking blocks (wrong format but correct language)."""
162+
"""Test __call__ with multiple thinking blocks - detects whole response language."""
171163
response = """
172164
<思考>First thought in English.</思考>
173165
Some text in between.
174166
<思考>Second thought also in English.</思考>
175167
"""
176168
result = self.reward_en("prompt", response)
177-
# Multiple blocks = wrong format, but language is correct, should return partial_reward
178-
self.assertEqual(result, 0.5)
169+
# Multiple blocks -> detect whole response, English detected -> match_reward
170+
self.assertEqual(result, 1.0)
179171

180172
def test_call_multiple_thinking_blocks_mixed_languages(self):
181-
"""Test __call__ with multiple thinking blocks in different languages (wrong format)."""
173+
"""Test __call__ with multiple thinking blocks in different languages."""
182174
response = """
183175
<思考>First thought in English with lots of content here.</思考>
184176
<思考>これは短い日本語。</思考>
185177
"""
186178
result = self.reward_en("prompt", response)
187-
# Multiple blocks with mixed languages - langid will detect dominant language
188-
# Should return either partial_reward (if detects English) or no_match_reward (if detects Japanese)
189-
self.assertIn(result, [0.0, 0.5])
179+
# Multiple blocks -> detect whole response, langid will detect dominant language
180+
# Should return match_reward (1.0) if English dominant, or no_match_reward (0.0) if not
181+
self.assertIn(result, [0.0, 1.0])
190182

191183
def test_call_multiline_thinking_block(self):
192184
"""Test __call__ with multiline thinking blocks."""
@@ -215,13 +207,13 @@ def test_call_with_target_parameter(self):
215207
result = self.reward_en("prompt", response, target="some target")
216208
self.assertEqual(result, 1.0)
217209

218-
# Longer English text without tags should get fallback reward
210+
# English text without tags -> detect whole response -> match_reward
219211
result = self.reward_en(
220212
"prompt",
221213
"This is a response without thinking tags but in English language.",
222214
target="some target",
223215
)
224-
self.assertEqual(result, 0.2)
216+
self.assertEqual(result, 1.0)
225217

226218
def test_call_custom_reward_values(self):
227219
"""Test __call__ with custom reward values."""
@@ -231,12 +223,12 @@ def test_call_custom_reward_values(self):
231223
response_en = "<思考>This is English.</思考>"
232224
response_none = ""
233225

234-
# Test custom full reward (single block, correct language)
226+
# Test custom match reward (single block, correct language)
235227
self.assertEqual(self.custom_reward("prompt", response_ja_single), 0.9)
236-
# Test custom partial reward (multiple blocks, correct language)
237-
self.assertEqual(self.custom_reward("prompt", response_ja_multiple), 0.6)
238-
# Test custom fallback reward (no blocks, correct language)
239-
self.assertEqual(self.custom_reward("prompt", response_ja_no_tags), 0.3)
228+
# Test custom match reward (multiple blocks -> whole response, correct language)
229+
self.assertEqual(self.custom_reward("prompt", response_ja_multiple), 0.9)
230+
# Test custom match reward (no blocks -> whole response, correct language)
231+
self.assertEqual(self.custom_reward("prompt", response_ja_no_tags), 0.9)
240232
# Test custom no_match reward (wrong language)
241233
self.assertEqual(self.custom_reward("prompt", response_en), 0.1)
242234
# Test empty response
@@ -245,7 +237,7 @@ def test_call_custom_reward_values(self):
245237
def test_call_zero_custom_values(self):
246238
"""Test __call__ with zero custom values."""
247239
zero_reward = self.LanguageReward(
248-
target_language="en", full_reward=0.0, no_match_reward=0.0
240+
target_language="en", match_reward=0.0, no_match_reward=0.0
249241
)
250242
result = zero_reward("prompt", "<思考>This is English.</思考>")
251243
self.assertEqual(result, 0.0)

0 commit comments

Comments
 (0)