File tree Expand file tree Collapse file tree 2 files changed +25
-15
lines changed
Expand file tree Collapse file tree 2 files changed +25
-15
lines changed Original file line number Diff line number Diff line change @@ -76,6 +76,30 @@ def response_tensor(self) -> torch.Tensor:
7676 tensor = F .pad (tensor , (0 , diff ), value = self .pad_id )
7777 return tensor
7878
79+ def to_dict (self , exclude : list [str ] | None = None ) -> dict [str , Any ]:
80+ """Convert episode to dict, optionally excluding specified fields."""
81+ result = {
82+ "episode_id" : self .episode_id ,
83+ "policy_version" : self .policy_version ,
84+ "prompt" : self .request ,
85+ "response" : self .response ,
86+ "target" : str (self .target ),
87+ "reward" : self .reward ,
88+ "advantage" : self .advantage ,
89+ "request_len" : self .request_len ,
90+ "response_len" : self .response_len ,
91+ "pad_id" : self .pad_id ,
92+ }
93+
94+ if self .reward_breakdown is not None :
95+ result .update (self .reward_breakdown )
96+
97+ if exclude :
98+ for key in exclude :
99+ result .pop (key , None )
100+
101+ return result
102+
79103
80104# Represents the group (G) of episodes in GRPO
81105Group = list [Episode ]
Original file line number Diff line number Diff line change @@ -194,21 +194,7 @@ def record_episode_sample(table_name: str, episode):
194194 table_name (str): logging prefix (e.g. "rollout/sample").
195195 episode (Episode): episode object with filled attributes.
196196 """
197- sample = {
198- "episode_id" : episode .episode_id ,
199- "policy_version" : episode .policy_version ,
200- "prompt" : episode .request ,
201- "response" : episode .response ,
202- "target" : str (episode .target ),
203- ** (
204- episode .reward_breakdown or {}
205- ), # per-fn breakdown including the average reward
206- "reward" : episode .reward ,
207- "advantage" : episode .advantage ,
208- "request_len" : episode .request_len ,
209- "response_len" : episode .response_len ,
210- "pad_id" : episode .pad_id ,
211- }
197+ sample = episode .to_dict ()
212198 record_metric (table_name , sample , Reduce .SAMPLE )
213199
214200
You can’t perform that action at this time.
0 commit comments