Skip to content

Commit 0635c9a

Browse files
committed
add to_dict
1 parent b413d7b commit 0635c9a

File tree

2 files changed

+25
-15
lines changed

2 files changed

+25
-15
lines changed

apps/grpo/main.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,30 @@ def response_tensor(self) -> torch.Tensor:
7676
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
7777
return tensor
7878

79+
def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]:
80+
"""Convert episode to dict, optionally excluding specified fields."""
81+
result = {
82+
"episode_id": self.episode_id,
83+
"policy_version": self.policy_version,
84+
"prompt": self.request,
85+
"response": self.response,
86+
"target": str(self.target),
87+
"reward": self.reward,
88+
"advantage": self.advantage,
89+
"request_len": self.request_len,
90+
"response_len": self.response_len,
91+
"pad_id": self.pad_id,
92+
}
93+
94+
if self.reward_breakdown is not None:
95+
result.update(self.reward_breakdown)
96+
97+
if exclude:
98+
for key in exclude:
99+
result.pop(key, None)
100+
101+
return result
102+
79103

80104
# Represents the group (G) of episodes in GRPO
81105
Group = list[Episode]

src/forge/observability/metrics.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -194,21 +194,7 @@ def record_episode_sample(table_name: str, episode):
194194
table_name (str): logging prefix (e.g. "rollout/sample").
195195
episode (Episode): episode object with filled attributes.
196196
"""
197-
sample = {
198-
"episode_id": episode.episode_id,
199-
"policy_version": episode.policy_version,
200-
"prompt": episode.request,
201-
"response": episode.response,
202-
"target": str(episode.target),
203-
**(
204-
episode.reward_breakdown or {}
205-
), # per-fn breakdown including the average reward
206-
"reward": episode.reward,
207-
"advantage": episode.advantage,
208-
"request_len": episode.request_len,
209-
"response_len": episode.response_len,
210-
"pad_id": episode.pad_id,
211-
}
197+
sample = episode.to_dict()
212198
record_metric(table_name, sample, Reduce.SAMPLE)
213199

214200

0 commit comments

Comments
 (0)