44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import random
78import re
89
910
@@ -57,15 +58,28 @@ def _to_float(self, text: str) -> float | None:
5758
5859
5960class ThinkingReward :
60- """Reward class for evaluating use of <think> tags in reasoning."""
61+ """Reward class for evaluating use of thinking tags in reasoning.
6162
62- def __init__ (self , partial_reward : float = 0.2 , full_reward : float = 1.0 ):
63+ Args:
64+ partial_reward: Reward for partial tag usage (incomplete/malformed)
65+ full_reward: Reward for well-formed thinking blocks with content
66+ tag: Tag name to use (default "think", can use "思考" for Japanese, etc.)
67+ """
68+
69+ def __init__ (
70+ self , partial_reward : float = 0.2 , full_reward : float = 1.0 , tag : str = "think"
71+ ):
6372 self .partial_reward = partial_reward
6473 self .full_reward = full_reward
74+ self .tag = tag
75+ # Build regex patterns for the specified tag
6576 self ._THINK_BLOCK_RE = re .compile (
66- r"<\s*think\s*>(.*?)<\s*/\s*think\s*>" , re .IGNORECASE | re .DOTALL
77+ rf"<\s*{ re .escape (tag )} \s*>(.*?)<\s*/\s*{ re .escape (tag )} \s*>" ,
78+ re .IGNORECASE | re .DOTALL ,
79+ )
80+ self ._THINK_TAG_ATTEMPT_RE = re .compile (
81+ rf"<\s*/?\s*{ re .escape (tag )} \s*>" , re .IGNORECASE
6782 )
68- self ._THINK_TAG_ATTEMPT_RE = re .compile (r"<\s*/?\s*think\s*>" , re .IGNORECASE )
6983
7084 def __call__ (self , prompt : str , response : str , target : str | None = None ) -> float :
7185 """Compute thinking reward."""
@@ -80,3 +94,128 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo
8094 elif has_attempt :
8195 return self .partial_reward
8296 return 0.0
97+
98+
99+ class LanguageReward :
100+ """Reward class for evaluating the language used in responses.
101+
102+ This reward uses langid to detect the language and rewards responses that use
103+ the target language. The detection strategy depends on the format:
104+ - If exactly one thinking block: detect language of the block content
105+ - Otherwise (no blocks or multiple blocks): detect language of whole response
106+
107+ Note: Format enforcement (single vs multiple blocks) is handled by ThinkingReward.
108+ This reward focuses purely on language detection.
109+
110+ Args:
111+ target_language: ISO 639-1 language code (e.g., 'en', 'ja', 'zh', 'es')
112+ match_reward: Reward when detected language matches target (default: 1.0)
113+ no_match_reward: Reward when language doesn't match (default: 0.0)
114+ tag: Tag name to use (default "思考" for multilingual, can use "think", etc.)
115+ debug: If True, print debug samples showing model outputs and detected language
116+ debug_sample_rate: Fraction of calls to debug (e.g., 0.1 = 10% of calls)
117+
118+ Note: Requires langid to be installed. Install with: pip install langid
119+ """
120+
121+ def __init__ (
122+ self ,
123+ target_language : str = "ja" ,
124+ match_reward : float = 1.0 ,
125+ no_match_reward : float = 0.0 ,
126+ tag : str = "思考" ,
127+ debug : bool = False ,
128+ debug_sample_rate : float = 0.1 ,
129+ ):
130+ self .target_language = target_language
131+ self .match_reward = match_reward
132+ self .no_match_reward = no_match_reward
133+ self .tag = tag
134+ self .debug = debug
135+ self .debug_sample_rate = debug_sample_rate
136+ self ._debug_counter = 0
137+ # Build regex pattern for the specified tag
138+ self ._THINK_BLOCK_RE = re .compile (
139+ rf"<\s*{ re .escape (tag )} \s*>(.*?)<\s*/\s*{ re .escape (tag )} \s*>" , re .DOTALL
140+ )
141+
142+ # Lazy import langid with helpful error message
143+ try :
144+ import langid
145+
146+ self ._langid = langid
147+ except ImportError :
148+ raise ImportError (
149+ "langid is required for LanguageReward but is not installed. "
150+ "Please install it with: pip install langid"
151+ ) from None
152+
153+ def __call__ (self , prompt : str , response : str , target : str | None = None ) -> float :
154+ """Compute language reward based on detected language.
155+
156+ Detection strategy:
157+ - If exactly one thinking block: detect language of block content
158+ - Otherwise: detect language of whole response
159+
160+ Args:
161+ prompt: The input prompt (unused but kept for signature consistency)
162+ response: The model response
163+ target: Optional target string (unused but kept for signature consistency)
164+
165+ Returns:
166+ match_reward if detected language matches target, no_match_reward otherwise
167+ """
168+
169+ # TODO: refactor pending https://github.com/meta-pytorch/torchforge/issues/187
170+ should_debug = self .debug and (random .random () < self .debug_sample_rate )
171+
172+ if not response :
173+ if should_debug :
174+ print (
175+ f"\n [LanguageReward] Empty response | Reward: { self .no_match_reward } "
176+ )
177+ return self .no_match_reward
178+
179+ # Extract all thinking blocks
180+ matches = self ._THINK_BLOCK_RE .findall (response )
181+
182+ # Determine what text to analyze
183+ if len (matches ) == 1 :
184+ # Single block: detect language of block content only
185+ text_to_analyze = matches [0 ].strip ()
186+ detection_mode = "single block"
187+ else :
188+ # No blocks or multiple blocks: detect language of whole response
189+ text_to_analyze = response .strip ()
190+ detection_mode = f"{ len (matches )} blocks, using whole response"
191+
192+ # Remove extra whitespace
193+ text_to_analyze = re .sub (r"\s+" , " " , text_to_analyze ).strip ()
194+
195+ if not text_to_analyze :
196+ if should_debug :
197+ print (f"\n [LanguageReward] Empty text | Reward: { self .no_match_reward } " )
198+ return self .no_match_reward
199+
200+ # Detect language using langid
201+ detected_lang , confidence = self ._langid .classify (text_to_analyze )
202+
203+ # Check if language matches target
204+ reward = (
205+ self .match_reward
206+ if detected_lang == self .target_language
207+ else self .no_match_reward
208+ )
209+
210+ if should_debug :
211+ sample = text_to_analyze [:1000 ].replace ("\n " , " " )
212+ match_symbol = "✓" if detected_lang == self .target_language else "✗"
213+ print (
214+ f"\n [LanguageReward] Detection mode: { detection_mode } "
215+ f"\n Target: { self .target_language } | Detected: { detected_lang } | "
216+ f"Confidence: { confidence :.2f} "
217+ f"\n Sample: { sample } ..."
218+ f"\n → Reward: { reward } { match_symbol } "
219+ )
220+
221+ return reward
0 commit comments