Skip to content

Commit e926707

Browse files
Add LanguageReward for training models to think in target language (#515)
Co-authored-by: Jiyue Wang <[email protected]>
1 parent 7381c3f commit e926707

File tree

7 files changed

+449
-15
lines changed

7 files changed

+449
-15
lines changed

apps/grpo/main.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from forge.actors.trainer import TitanTrainer
2727
from forge.controller.actor import ForgeActor
2828
from forge.controller.provisioner import init_provisioner, shutdown
29-
from forge.data.rewards import MathReward, ThinkingReward
29+
from forge.data.rewards import LanguageReward, MathReward, ThinkingReward
3030
from forge.data_models.completion import Completion
3131
from forge.observability.metric_actors import get_or_create_metric_logger
3232
from forge.observability.metrics import record_metric, Reduce
@@ -129,7 +129,7 @@ def simple_grpo_loss(
129129
ref_logprobs: torch.Tensor,
130130
advantages: torch.Tensor,
131131
padding_mask: torch.Tensor,
132-
beta: float = 0.1,
132+
beta: float = 1e-6,
133133
) -> torch.Tensor:
134134
logprobs: torch.Tensor = compute_logprobs(logits, response)
135135
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
@@ -237,10 +237,15 @@ async def setup(self):
237237
self._epoch = 0
238238

239239
def gsm8k_transform(sample):
240-
system_prompt = """
241-
Put all your scratchpad work between <think> and </think> tags.
242-
Your final answer should be between <answer> and </answer> tags otherwise it will not be scored.
243-
"""
240+
system_prompt = """You are a helpful AI assistant that solves math problems.
241+
242+
Please show your reasoning inside <思考></思考> tags, then provide your final numerical answer inside <answer></answer> tags.
243+
244+
Example:
245+
Question: What is 12 + 5?
246+
<思考>12と5を足します。12 + 5 = 17です。</思考>
247+
<answer>17</answer>
248+
"""
244249
request: str = sample["question"]
245250
as_chat = [
246251
{"role": "system", "content": system_prompt},
@@ -359,7 +364,17 @@ async def main(cfg: DictConfig):
359364
ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),
360365
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
361366
RewardActor.options(**cfg.services.reward_actor).as_service(
362-
reward_functions=[MathReward(), ThinkingReward()]
367+
reward_functions=[
368+
MathReward(),
369+
ThinkingReward(tag="思考"), # Use Japanese tag
370+
LanguageReward(
371+
target_language="ja",
372+
tag="思考",
373+
match_reward=2.0,
374+
debug=True,
375+
debug_sample_rate=0.1,
376+
), # Japanese language reward with debug
377+
]
363378
),
364379
)
365380

apps/grpo/qwen3_1_7b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
group_size: 8
66
local_batch_size: 16 # per-device batch size
77
max_req_tokens: 1024
8-
max_res_tokens: 1024
8+
max_res_tokens: 2048
99
model: "Qwen/Qwen3-1.7B"
1010
off_by_n: 1 # Off by one by default
1111

apps/grpo/qwen3_8b.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# >>> python -m apps.grpo.main --config apps/grpo/qwen3_8b.yaml
33

44
# Global configuration
5-
group_size: 8
6-
local_batch_size: 12 # per-device batch size
5+
group_size: 16
6+
local_batch_size: 4 # per-device batch size
77
max_req_tokens: 1024
8-
max_res_tokens: 1024
8+
max_res_tokens: 2048
99
model: "Qwen/Qwen3-8B"
1010
off_by_n: 1 # Off by one by default
1111

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dev = [
4747
"anyio",
4848
"pytest-asyncio",
4949
"multiprocess",
50+
"langid",
5051
]
5152
docs = [
5253
"sphinx==7.2.6",

src/forge/data/rewards.py

Lines changed: 143 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import random
78
import re
89

910

@@ -57,15 +58,28 @@ def _to_float(self, text: str) -> float | None:
5758

5859

5960
class ThinkingReward:
60-
"""Reward class for evaluating use of <think> tags in reasoning."""
61+
"""Reward class for evaluating use of thinking tags in reasoning.
6162
62-
def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0):
63+
Args:
64+
partial_reward: Reward for partial tag usage (incomplete/malformed)
65+
full_reward: Reward for well-formed thinking blocks with content
66+
tag: Tag name to use (default "think", can use "思考" for Japanese, etc.)
67+
"""
68+
69+
def __init__(
70+
self, partial_reward: float = 0.2, full_reward: float = 1.0, tag: str = "think"
71+
):
6372
self.partial_reward = partial_reward
6473
self.full_reward = full_reward
74+
self.tag = tag
75+
# Build regex patterns for the specified tag
6576
self._THINK_BLOCK_RE = re.compile(
66-
r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL
77+
rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>",
78+
re.IGNORECASE | re.DOTALL,
79+
)
80+
self._THINK_TAG_ATTEMPT_RE = re.compile(
81+
rf"<\s*/?\s*{re.escape(tag)}\s*>", re.IGNORECASE
6782
)
68-
self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE)
6983

7084
def __call__(self, prompt: str, response: str, target: str | None = None) -> float:
7185
"""Compute thinking reward."""
@@ -80,3 +94,128 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo
8094
elif has_attempt:
8195
return self.partial_reward
8296
return 0.0
97+
98+
99+
class LanguageReward:
100+
"""Reward class for evaluating the language used in responses.
101+
102+
This reward uses langid to detect the language and rewards responses that use
103+
the target language. The detection strategy depends on the format:
104+
- If exactly one thinking block: detect language of the block content
105+
- Otherwise (no blocks or multiple blocks): detect language of whole response
106+
107+
Note: Format enforcement (single vs multiple blocks) is handled by ThinkingReward.
108+
This reward focuses purely on language detection.
109+
110+
Args:
111+
target_language: ISO 639-1 language code (e.g., 'en', 'ja', 'zh', 'es')
112+
match_reward: Reward when detected language matches target (default: 1.0)
113+
no_match_reward: Reward when language doesn't match (default: 0.0)
114+
tag: Tag name to use (default "思考" for multilingual, can use "think", etc.)
115+
debug: If True, print debug samples showing model outputs and detected language
116+
debug_sample_rate: Fraction of calls to debug (e.g., 0.1 = 10% of calls)
117+
118+
Note: Requires langid to be installed. Install with: pip install langid
119+
"""
120+
121+
def __init__(
122+
self,
123+
target_language: str = "ja",
124+
match_reward: float = 1.0,
125+
no_match_reward: float = 0.0,
126+
tag: str = "思考",
127+
debug: bool = False,
128+
debug_sample_rate: float = 0.1,
129+
):
130+
self.target_language = target_language
131+
self.match_reward = match_reward
132+
self.no_match_reward = no_match_reward
133+
self.tag = tag
134+
self.debug = debug
135+
self.debug_sample_rate = debug_sample_rate
136+
self._debug_counter = 0
137+
# Build regex pattern for the specified tag
138+
self._THINK_BLOCK_RE = re.compile(
139+
rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>", re.DOTALL
140+
)
141+
142+
# Lazy import langid with helpful error message
143+
try:
144+
import langid
145+
146+
self._langid = langid
147+
except ImportError:
148+
raise ImportError(
149+
"langid is required for LanguageReward but is not installed. "
150+
"Please install it with: pip install langid"
151+
) from None
152+
153+
def __call__(self, prompt: str, response: str, target: str | None = None) -> float:
154+
"""Compute language reward based on detected language.
155+
156+
Detection strategy:
157+
- If exactly one thinking block: detect language of block content
158+
- Otherwise: detect language of whole response
159+
160+
Args:
161+
prompt: The input prompt (unused but kept for signature consistency)
162+
response: The model response
163+
target: Optional target string (unused but kept for signature consistency)
164+
165+
Returns:
166+
match_reward if detected language matches target, no_match_reward otherwise
167+
"""
168+
169+
# TODO: refactor pending https://github.com/meta-pytorch/torchforge/issues/187
170+
should_debug = self.debug and (random.random() < self.debug_sample_rate)
171+
172+
if not response:
173+
if should_debug:
174+
print(
175+
f"\n[LanguageReward] Empty response | Reward: {self.no_match_reward}"
176+
)
177+
return self.no_match_reward
178+
179+
# Extract all thinking blocks
180+
matches = self._THINK_BLOCK_RE.findall(response)
181+
182+
# Determine what text to analyze
183+
if len(matches) == 1:
184+
# Single block: detect language of block content only
185+
text_to_analyze = matches[0].strip()
186+
detection_mode = "single block"
187+
else:
188+
# No blocks or multiple blocks: detect language of whole response
189+
text_to_analyze = response.strip()
190+
detection_mode = f"{len(matches)} blocks, using whole response"
191+
192+
# Remove extra whitespace
193+
text_to_analyze = re.sub(r"\s+", " ", text_to_analyze).strip()
194+
195+
if not text_to_analyze:
196+
if should_debug:
197+
print(f"\n[LanguageReward] Empty text | Reward: {self.no_match_reward}")
198+
return self.no_match_reward
199+
200+
# Detect language using langid
201+
detected_lang, confidence = self._langid.classify(text_to_analyze)
202+
203+
# Check if language matches target
204+
reward = (
205+
self.match_reward
206+
if detected_lang == self.target_language
207+
else self.no_match_reward
208+
)
209+
210+
if should_debug:
211+
sample = text_to_analyze[:1000].replace("\n", " ")
212+
match_symbol = "✓" if detected_lang == self.target_language else "✗"
213+
print(
214+
f"\n[LanguageReward] Detection mode: {detection_mode}"
215+
f"\n Target: {self.target_language} | Detected: {detected_lang} | "
216+
f"Confidence: {confidence:.2f}"
217+
f"\n Sample: {sample}..."
218+
f"\n → Reward: {reward} {match_symbol}"
219+
)
220+
221+
return reward

0 commit comments

Comments
 (0)