1212import torch
1313from datasets import load_dataset
1414from forge .actors .policy import Policy , PolicyConfig , SamplingOverrides , WorkerConfig
15+ from forge .actors .replay_buffer import ReplayBuffer
1516from forge .controller import ServiceConfig , spawn_service
1617from forge .controller .actor import ForgeActor
18+ from forge .data .rewards import MathReward , ThinkingReward
19+ from forge .util .metric_logging import get_metric_logger
1720from monarch .actor import endpoint
1821from transformers import AutoModelForCausalLM , AutoTokenizer
1922
@@ -209,66 +212,18 @@ async def update_weights(self, policy_actor):
209212 self .logger .info (f"Updating weights took { end_time - start_time :.2f} seconds" )
210213
211214
212- def math_scoring_function (prompt : str , response : str , target : str ) -> float :
213- """Function to score math correctness."""
214- import re
215-
216- # Extract expected answer from target
217- expected_answer = (
218- float (target .strip ())
219- if target .strip ().replace ("." , "" ).replace ("-" , "" ).isdigit ()
220- else None
221- )
222-
223- # Extract model answer from response
224- patterns = [
225- r"####\s*([+-]?\d+(?:\.\d+)?)" , # GSM8K style answer format
226- r"(?:the\s+)?answer\s+is\s*([+-]?\d+(?:\.\d+)?)" ,
227- r"(?:answer:|result:)\s*([+-]?\d+(?:\.\d+)?)" ,
228- r"=\s*([+-]?\d+(?:\.\d+)?)\s*(?:\.|$)" , # equals near end
229- r"\b([+-]?\d+(?:\.\d+)?)\s*(?:\.|$)" , # number at end
230- r"([+-]?\d+(?:\.\d+)?)" , # any number (fallback)
231- ]
232-
233- model_answer = None
234- response_lower = response .lower ().strip ()
235- for pattern in patterns :
236- matches = re .findall (pattern , response_lower )
237- if matches :
238- model_answer = float (matches [- 1 ])
239- break
240-
241- if expected_answer is None or model_answer is None :
242- return 0.1 # Partial credit for attempting
243-
244- # Check if answers match (with some tolerance for floating point)
245- if abs (expected_answer - model_answer ) < 1e-6 :
246- return 1.0 # Correct answer
247- else :
248- return 0.0 # Incorrect answer
249-
250-
251- def thinking_scoring_function (prompt : str , response : str , target : str ) -> float :
252- """Function to score thinking tag usage."""
253- # Check if response contains <think></think> tags
254- if "<think>" in response .lower () and "</think>" in response .lower ():
255- return 0.5
256- else :
257- return 0.0
258-
259-
260215class RewardActor (ForgeActor ):
261216 """Reward actor that uses a list of scoring functions."""
262217
263- def __init__ (self , scoring_functions : list [Callable ]):
218+ def __init__ (self , reward_functions : list [Callable ]):
264219 super ().__init__ ()
265- self .scoring_functions = scoring_functions
220+ self .reward_functions = reward_functions
266221
267222 @endpoint
268223 async def evaluate_response (self , prompt : str , response : str , target : str ) -> float :
269224 total_reward = 0.0
270- for scoring_fn in self .scoring_functions :
271- reward = scoring_fn (prompt , response , target )
225+ for reward_fn in self .reward_functions :
226+ reward = reward_fn (prompt , response , target )
272227 total_reward += reward
273228 return total_reward
274229
@@ -388,6 +343,13 @@ async def main():
388343 group_size = 1
389344 model = "Qwen/Qwen3-1.7B"
390345
346+ # ---- Setup WandB Logger ---- #
347+ logger = get_metric_logger (
348+ "wandb" ,
349+ freq = 1 ,
350+ project = "grpo-training" ,
351+ )
352+
391353 # ---- Setup services ---- #
392354 default_service_cfg = ServiceConfig (
393355 procs_per_replica = 1 ,
@@ -447,7 +409,7 @@ async def main():
447409 reward_actor = await spawn_service (
448410 default_service_cfg ,
449411 RewardActor ,
450- scoring_functions = [ math_scoring_function , thinking_scoring_function ],
412+ reward_functions = [ MathReward (), ThinkingReward () ],
451413 )
452414
453415 print ("All services initialized successfully!" )
@@ -498,6 +460,7 @@ async def continuous_rollouts():
498460 print (
499461 f"Generated { rollout_count } rollouts w/ average reward { avg_reward } "
500462 )
463+ logger .log ("reward/rollout" , avg_reward , rollout_count )
501464
502465 async def continuous_training ():
503466 training_step = 0
@@ -511,7 +474,9 @@ async def continuous_training():
511474 if training_step % 10 == 0 :
512475 print (f"Completed { training_step } training steps" )
513476 if training_result :
514- print (f"Latest loss: { training_result .get ('loss' , 'N/A' )} " )
477+ loss_value = training_result .get ("loss" , 0.0 )
478+ print (f"Latest loss: { loss_value } " )
479+ logger .log ("loss/training_step" , loss_value , training_step )
515480 # await trainer.update_weights(policy)
516481
517482 print ("Starting GRPO training loops..." )
0 commit comments