File tree Expand file tree Collapse file tree 2 files changed +25
-15
lines changed Expand file tree Collapse file tree 2 files changed +25
-15
lines changed Original file line number Diff line number Diff line change @@ -78,6 +78,30 @@ def response_tensor(self) -> torch.Tensor:
7878 tensor = F .pad (tensor , (0 , diff ), value = self .pad_id )
7979 return tensor
8080
81+ def to_dict (self , exclude : list [str ] | None = None ) -> dict [str , Any ]:
82+ """Convert episode to dict, optionally excluding specified fields."""
83+ result = {
84+ "episode_id" : self .episode_id ,
85+ "policy_version" : self .policy_version ,
86+ "prompt" : self .request ,
87+ "response" : self .response ,
88+ "target" : str (self .target ),
89+ "reward" : self .reward ,
90+ "advantage" : self .advantage ,
91+ "request_len" : self .request_len ,
92+ "response_len" : self .response_len ,
93+ "pad_id" : self .pad_id ,
94+ }
95+
96+ if self .reward_breakdown is not None :
97+ result .update (self .reward_breakdown )
98+
99+ if exclude :
100+ for key in exclude :
101+ result .pop (key , None )
102+
103+ return result
104+
81105
82106# Represents the group (G) of episodes in GRPO
83107Group = list [Episode ]
Original file line number Diff line number Diff line change @@ -194,21 +194,7 @@ def record_episode_sample(table_name: str, episode):
194194 table_name (str): logging prefix (e.g. "rollout/sample").
195195 episode (Episode): episode object with filled attributes.
196196 """
197- sample = {
198- "episode_id" : episode .episode_id ,
199- "policy_version" : episode .policy_version ,
200- "prompt" : episode .request ,
201- "response" : episode .response ,
202- "target" : str (episode .target ),
203- ** (
204- episode .reward_breakdown or {}
205- ), # per-fn breakdown including the average reward
206- "reward" : episode .reward ,
207- "advantage" : episode .advantage ,
208- "request_len" : episode .request_len ,
209- "response_len" : episode .response_len ,
210- "pad_id" : episode .pad_id ,
211- }
197+ sample = episode .to_dict ()
212198 record_metric (table_name , sample , Reduce .SAMPLE )
213199
214200
You can’t perform that action at this time.
0 commit comments