Skip to content

Commit 39ba6f9

Browse files
committed
add to_dict
1 parent 55c94e8 commit 39ba6f9

File tree

2 files changed

+25
-15
lines changed

2 files changed

+25
-15
lines changed

apps/grpo/main.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,30 @@ def response_tensor(self) -> torch.Tensor:
7878
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
7979
return tensor
8080

81+
def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]:
82+
"""Convert episode to dict, optionally excluding specified fields."""
83+
result = {
84+
"episode_id": self.episode_id,
85+
"policy_version": self.policy_version,
86+
"prompt": self.request,
87+
"response": self.response,
88+
"target": str(self.target),
89+
"reward": self.reward,
90+
"advantage": self.advantage,
91+
"request_len": self.request_len,
92+
"response_len": self.response_len,
93+
"pad_id": self.pad_id,
94+
}
95+
96+
if self.reward_breakdown is not None:
97+
result.update(self.reward_breakdown)
98+
99+
if exclude:
100+
for key in exclude:
101+
result.pop(key, None)
102+
103+
return result
104+
81105

82106
# Represents the group (G) of episodes in GRPO
83107
Group = list[Episode]

src/forge/observability/metrics.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -194,21 +194,7 @@ def record_episode_sample(table_name: str, episode):
194194
table_name (str): logging prefix (e.g. "rollout/sample").
195195
episode (Episode): episode object with filled attributes.
196196
"""
197-
sample = {
198-
"episode_id": episode.episode_id,
199-
"policy_version": episode.policy_version,
200-
"prompt": episode.request,
201-
"response": episode.response,
202-
"target": str(episode.target),
203-
**(
204-
episode.reward_breakdown or {}
205-
), # per-fn breakdown including the average reward
206-
"reward": episode.reward,
207-
"advantage": episode.advantage,
208-
"request_len": episode.request_len,
209-
"response_len": episode.response_len,
210-
"pad_id": episode.pad_id,
211-
}
197+
sample = episode.to_dict()
212198
record_metric(table_name, sample, Reduce.SAMPLE)
213199

214200

0 commit comments

Comments
 (0)