@@ -46,10 +46,13 @@ class Episode:
4646 request_len : int
4747 response_len : int
4848 target : Any | None = None
49+ request : str | None = None
50+ response : str | None = None
4951 # Processed data
5052 completion : Completion | None = None
5153 ref_logprobs : torch .Tensor | None = None
5254 reward : float | None = None
55+ reward_breakdown : dict [str , float ] | None = None
5356 advantage : float | None = None
5457
5558 @property
@@ -72,6 +75,32 @@ def response_tensor(self) -> torch.Tensor:
7275 tensor = F .pad (tensor , (0 , diff ), value = self .pad_id )
7376 return tensor
7477
78+ def to_dict (self , exclude : list [str ] | None = None ) -> dict [str , Any ]:
79+ """Convert episode to dict, optionally excluding specified fields."""
80+ result = {
81+ "episode_id" : self .episode_id ,
82+ "policy_version" : self .policy_version ,
83+ "prompt" : self .request ,
84+ "response" : self .response ,
85+ "target" : str (self .target ),
86+ "reward" : self .reward ,
87+ "advantage" : self .advantage ,
88+ "request_len" : self .request_len ,
89+ "response_len" : self .response_len ,
90+ "pad_id" : self .pad_id ,
91+ "ref_logprobs" : self .ref_logprobs ,
92+ "completion" : self .completion ,
93+ }
94+
95+ if self .reward_breakdown is not None and "reward_breakdown" not in exclude :
96+ result .update (self .reward_breakdown )
97+
98+ if exclude :
99+ for key in exclude :
100+ result .pop (key , None )
101+
102+ return result
103+
75104
76105# Represents the group (G) of episodes in GRPO
77106Group = list [Episode ]
@@ -166,8 +195,11 @@ class RewardActor(ForgeActor):
166195 reward_functions : list [Callable ]
167196
168197 @endpoint
169- async def evaluate_response (self , prompt : str , response : str , target : str ) -> float :
198+ async def evaluate_response (
199+ self , prompt : str , response : str , target : str
200+ ) -> (dict [str , float ], float ):
170201 total_rewards = 0.0
202+ reward_breakdown = {} # reward breakdown by function
171203 for reward_fn in self .reward_functions :
172204 reward = reward_fn (prompt , response , target )
173205 total_rewards += reward
@@ -176,6 +208,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
176208 reward_fn_name = getattr (
177209 reward_fn , "__name__" , reward_fn .__class__ .__name__
178210 )
211+ reward_breakdown [reward_fn_name ] = reward
179212 # per function reward
180213 record_metric (
181214 f"reward/evaluate_response/sum_{ reward_fn_name } _reward" ,
@@ -205,8 +238,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
205238 Reduce .SUM ,
206239 )
207240
208- avg_reward = total_rewards / len (self .reward_functions )
209- return avg_reward
241+ avg_reward : float = total_rewards / len (self .reward_functions )
242+ return reward_breakdown , avg_reward
210243
211244
212245@dataclass
@@ -428,9 +461,14 @@ async def continuous_rollouts():
428461 request_len = max_req_tokens ,
429462 response_len = max_res_tokens ,
430463 target = target ,
464+ request = prompt ,
465+ response = response .text ,
431466 completion = response ,
432467 )
433- episode .reward = await reward_actor .evaluate_response .route (
468+ (
469+ episode .reward_breakdown ,
470+ episode .reward ,
471+ ) = await reward_actor .evaluate_response .route (
434472 prompt = prompt , response = response .text , target = target
435473 )
436474 episodes .append (episode )
@@ -471,6 +509,14 @@ async def continuous_rollouts():
471509 episode .advantage = advantage
472510 await replay_buffer .add .call_one (episode )
473511
512+ sample = episode .to_dict (exclude = ["ref_logprobs" , "completion" ])
513+ sample ["score" ] = sample ["reward" ]
514+ record_metric (
515+ "main_samples/continuous_rollouts/sample_table" ,
516+ sample ,
517+ Reduce .SAMPLE ,
518+ )
519+
474520 rollout_count += 1
475521 record_metric (
476522 "main/continuous_rollouts/count_rollout_iterations" , 1 , Reduce .SUM
0 commit comments