1515import torch
1616import torch .nn .functional as F
1717import torchstore as ts
18+ import yaml
1819from datasets import load_dataset
1920from forge .actors ._torchstore_utils import (
2021 get_dcp_whole_state_dict_key ,
2627from forge .actors .trainer import TitanTrainer
2728from forge .controller .actor import ForgeActor
2829from forge .controller .provisioner import init_provisioner , shutdown
29- from forge .data .rewards import MathReward , ThinkingReward
30+ from forge .data .rewards import LanguageReward , MathReward , ThinkingReward
3031from forge .data_models .completion import Completion
3132from forge .observability .metric_actors import get_or_create_metric_logger
3233from forge .observability .metrics import record_metric , Reduce
3334from forge .observability .perf_tracker import Tracer
3435from forge .types import LauncherConfig , ProvisionerConfig
3536from forge .util .config import parse
37+ from forge .util .logging import get_logger
3638from forge .util .ops import compute_logprobs
3739from monarch .actor import endpoint
38- from omegaconf import DictConfig
40+ from omegaconf import DictConfig , OmegaConf
3941from vllm .transformers_utils .tokenizer import get_tokenizer
4042
43+ logger = get_logger ("INFO" )
44+
4145
4246@dataclass
4347class Episode :
@@ -46,10 +50,13 @@ class Episode:
4650 request_len : int
4751 response_len : int
4852 target : Any | None = None
53+ request : str | None = None
54+ response : str | None = None
4955 # Processed data
5056 completion : Completion | None = None
5157 ref_logprobs : torch .Tensor | None = None
5258 reward : float | None = None
59+ reward_breakdown : dict [str , float ] | None = None
5360 advantage : float | None = None
5461
5562 @property
@@ -72,6 +79,32 @@ def response_tensor(self) -> torch.Tensor:
7279 tensor = F .pad (tensor , (0 , diff ), value = self .pad_id )
7380 return tensor
7481
82+ def to_dict (self , exclude : list [str ] | None = None ) -> dict [str , Any ]:
83+ """Convert episode to dict, optionally excluding specified fields."""
84+ result = {
85+ "episode_id" : self .episode_id ,
86+ "policy_version" : self .policy_version ,
87+ "prompt" : self .request ,
88+ "response" : self .response ,
89+ "target" : str (self .target ),
90+ "reward" : self .reward ,
91+ "advantage" : self .advantage ,
92+ "request_len" : self .request_len ,
93+ "response_len" : self .response_len ,
94+ "pad_id" : self .pad_id ,
95+ "ref_logprobs" : self .ref_logprobs ,
96+ "completion" : self .completion ,
97+ }
98+
99+ if self .reward_breakdown is not None and "reward_breakdown" not in exclude :
100+ result .update (self .reward_breakdown )
101+
102+ if exclude :
103+ for key in exclude :
104+ result .pop (key , None )
105+
106+ return result
107+
75108
76109# Represents the group (G) of episodes in GRPO
77110Group = list [Episode ]
@@ -129,7 +162,7 @@ def simple_grpo_loss(
129162 ref_logprobs : torch .Tensor ,
130163 advantages : torch .Tensor ,
131164 padding_mask : torch .Tensor ,
132- beta : float = 0.1 ,
165+ beta : float = 1e-6 ,
133166) -> torch .Tensor :
134167 logprobs : torch .Tensor = compute_logprobs (logits , response )
135168 kl = torch .exp (ref_logprobs - logprobs ) - (ref_logprobs - logprobs ) - 1
@@ -166,8 +199,11 @@ class RewardActor(ForgeActor):
166199 reward_functions : list [Callable ]
167200
168201 @endpoint
169- async def evaluate_response (self , prompt : str , response : str , target : str ) -> float :
202+ async def evaluate_response (
203+ self , prompt : str , response : str , target : str
204+ ) -> (dict [str , float ], float ):
170205 total_rewards = 0.0
206+ reward_breakdown = {} # reward breakdown by function
171207 for reward_fn in self .reward_functions :
172208 reward = reward_fn (prompt , response , target )
173209 total_rewards += reward
@@ -176,6 +212,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
176212 reward_fn_name = getattr (
177213 reward_fn , "__name__" , reward_fn .__class__ .__name__
178214 )
215+ reward_breakdown [reward_fn_name ] = reward
179216 # per function reward
180217 record_metric (
181218 f"reward/evaluate_response/sum_{ reward_fn_name } _reward" ,
@@ -205,8 +242,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
205242 Reduce .SUM ,
206243 )
207244
208- avg_reward = total_rewards / len (self .reward_functions )
209- return avg_reward
245+ avg_reward : float = total_rewards / len (self .reward_functions )
246+ return reward_breakdown , avg_reward
210247
211248
212249@dataclass
@@ -237,10 +274,15 @@ async def setup(self):
237274 self ._epoch = 0
238275
239276 def gsm8k_transform (sample ):
240- system_prompt = """
241- Put all your scratchpad work between <think> and </think> tags.
242- Your final answer should be between <answer> and </answer> tags otherwise it will not be scored.
243- """
277+ system_prompt = """You are a helpful AI assistant that solves math problems.
278+
279+ Please show your reasoning inside <思考></思考> tags, then provide your final numerical answer inside <answer></answer> tags.
280+
281+ Example:
282+ Question: What is 12 + 5?
283+ <思考>12と5を足します。12 + 5 = 17です。</思考>
284+ <answer>17</answer>
285+ """
244286 request : str = sample ["question" ]
245287 as_chat = [
246288 {"role" : "system" , "content" : system_prompt },
@@ -320,9 +362,14 @@ async def drop_weights(version: int):
320362
321363async def main (cfg : DictConfig ):
322364 """Main GRPO training loop with rollout and training processes."""
323- group_size = cfg .group_size
324- max_req_tokens = cfg .max_req_tokens
325- max_res_tokens = cfg .max_res_tokens
365+ # Convert OmegaConf config to plain dict
366+ run_config_for_logging = OmegaConf .to_container (cfg , resolve = True )
367+
368+ # Log config
369+ logger .info ("=" * 30 + " CONFIGURATION " + "=" * 30 )
370+ logger .info (
371+ yaml .dump (run_config_for_logging , default_flow_style = False , sort_keys = False )
372+ )
326373
327374 # ---- Global setups ---- #
328375 provisioner = None
@@ -334,8 +381,11 @@ async def main(cfg: DictConfig):
334381 provisioner = await init_provisioner ()
335382
336383 metric_logging_cfg = cfg .get ("metric_logging" , {})
384+
337385 mlogger = await get_or_create_metric_logger (process_name = "Controller" )
338- await mlogger .init_backends .call_one (metric_logging_cfg )
386+ await mlogger .init_backends .call_one (
387+ backend_config = metric_logging_cfg , run_config = run_config_for_logging
388+ )
339389
340390 # ---- Setup services ---- #
341391
@@ -359,10 +409,24 @@ async def main(cfg: DictConfig):
359409 ComputeAdvantages .options (** cfg .actors .compute_advantages ).as_actor (),
360410 ReferenceModel .options (** cfg .services .ref_model ).as_service (** cfg .ref_model ),
361411 RewardActor .options (** cfg .services .reward_actor ).as_service (
362- reward_functions = [MathReward (), ThinkingReward ()]
412+ reward_functions = [
413+ MathReward (),
414+ ThinkingReward (tag = "思考" ), # Use Japanese tag
415+ LanguageReward (
416+ target_language = "ja" ,
417+ tag = "思考" ,
418+ match_reward = 2.0 ,
419+ debug = False , # set to true for verbose logging
420+ debug_sample_rate = 0.1 ,
421+ ), # Japanese language reward with debug
422+ ]
363423 ),
364424 )
365425
426+ group_size = cfg .group_size
427+ max_req_tokens = cfg .max_req_tokens
428+ max_res_tokens = cfg .max_res_tokens
429+
366430 # Set max_steps to the configured value, or -1 if not specified or Null
367431 max_steps = cfg .trainer .training .steps or - 1
368432
@@ -413,9 +477,14 @@ async def continuous_rollouts():
413477 request_len = max_req_tokens ,
414478 response_len = max_res_tokens ,
415479 target = target ,
480+ request = prompt ,
481+ response = response .text ,
416482 completion = response ,
417483 )
418- episode .reward = await reward_actor .evaluate_response .route (
484+ (
485+ episode .reward_breakdown ,
486+ episode .reward ,
487+ ) = await reward_actor .evaluate_response .route (
419488 prompt = prompt , response = response .text , target = target
420489 )
421490 episodes .append (episode )
@@ -456,6 +525,14 @@ async def continuous_rollouts():
456525 episode .advantage = advantage
457526 await replay_buffer .add .call_one (episode )
458527
528+ sample = episode .to_dict (exclude = ["ref_logprobs" , "completion" ])
529+ sample ["score" ] = sample ["reward" ]
530+ record_metric (
531+ "main_samples/continuous_rollouts/sample_table" ,
532+ sample ,
533+ Reduce .SAMPLE ,
534+ )
535+
459536 rollout_count += 1
460537 record_metric (
461538 "main/continuous_rollouts/count_rollout_iterations" , 1 , Reduce .SUM
0 commit comments