Skip to content

Commit 47c2333

Browse files
DNXieFelipe Mello
andauthored
Add Sample-level Logging API (#486)
Co-authored-by: Felipe Mello <[email protected]>
1 parent e926707 commit 47c2333

File tree

5 files changed

+321
-33
lines changed

5 files changed

+321
-33
lines changed

apps/grpo/main.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,13 @@ class Episode:
4646
request_len: int
4747
response_len: int
4848
target: Any | None = None
49+
request: str | None = None
50+
response: str | None = None
4951
# Processed data
5052
completion: Completion | None = None
5153
ref_logprobs: torch.Tensor | None = None
5254
reward: float | None = None
55+
reward_breakdown: dict[str, float] | None = None
5356
advantage: float | None = None
5457

5558
@property
@@ -72,6 +75,32 @@ def response_tensor(self) -> torch.Tensor:
7275
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
7376
return tensor
7477

78+
def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]:
79+
"""Convert episode to dict, optionally excluding specified fields."""
80+
result = {
81+
"episode_id": self.episode_id,
82+
"policy_version": self.policy_version,
83+
"prompt": self.request,
84+
"response": self.response,
85+
"target": str(self.target),
86+
"reward": self.reward,
87+
"advantage": self.advantage,
88+
"request_len": self.request_len,
89+
"response_len": self.response_len,
90+
"pad_id": self.pad_id,
91+
"ref_logprobs": self.ref_logprobs,
92+
"completion": self.completion,
93+
}
94+
95+
if self.reward_breakdown is not None and "reward_breakdown" not in exclude:
96+
result.update(self.reward_breakdown)
97+
98+
if exclude:
99+
for key in exclude:
100+
result.pop(key, None)
101+
102+
return result
103+
75104

76105
# Represents the group (G) of episodes in GRPO
77106
Group = list[Episode]
@@ -166,8 +195,11 @@ class RewardActor(ForgeActor):
166195
reward_functions: list[Callable]
167196

168197
@endpoint
169-
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
198+
async def evaluate_response(
199+
self, prompt: str, response: str, target: str
200+
) -> (dict[str, float], float):
170201
total_rewards = 0.0
202+
reward_breakdown = {} # reward breakdown by function
171203
for reward_fn in self.reward_functions:
172204
reward = reward_fn(prompt, response, target)
173205
total_rewards += reward
@@ -176,6 +208,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
176208
reward_fn_name = getattr(
177209
reward_fn, "__name__", reward_fn.__class__.__name__
178210
)
211+
reward_breakdown[reward_fn_name] = reward
179212
# per function reward
180213
record_metric(
181214
f"reward/evaluate_response/sum_{reward_fn_name}_reward",
@@ -205,8 +238,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
205238
Reduce.SUM,
206239
)
207240

208-
avg_reward = total_rewards / len(self.reward_functions)
209-
return avg_reward
241+
avg_reward: float = total_rewards / len(self.reward_functions)
242+
return reward_breakdown, avg_reward
210243

211244

212245
@dataclass
@@ -428,9 +461,14 @@ async def continuous_rollouts():
428461
request_len=max_req_tokens,
429462
response_len=max_res_tokens,
430463
target=target,
464+
request=prompt,
465+
response=response.text,
431466
completion=response,
432467
)
433-
episode.reward = await reward_actor.evaluate_response.route(
468+
(
469+
episode.reward_breakdown,
470+
episode.reward,
471+
) = await reward_actor.evaluate_response.route(
434472
prompt=prompt, response=response.text, target=target
435473
)
436474
episodes.append(episode)
@@ -471,6 +509,14 @@ async def continuous_rollouts():
471509
episode.advantage = advantage
472510
await replay_buffer.add.call_one(episode)
473511

512+
sample = episode.to_dict(exclude=["ref_logprobs", "completion"])
513+
sample["score"] = sample["reward"]
514+
record_metric(
515+
"main_samples/continuous_rollouts/sample_table",
516+
sample,
517+
Reduce.SAMPLE,
518+
)
519+
474520
rollout_count += 1
475521
record_metric(
476522
"main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM

src/forge/observability/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
record_metric,
2525
Reduce,
2626
reduce_metrics_states,
27+
SampleAccumulator,
2728
StdAccumulator,
2829
SumAccumulator,
2930
WandbBackend,
@@ -64,4 +65,5 @@
6465
"MaxAccumulator",
6566
"MinAccumulator",
6667
"StdAccumulator",
68+
"SampleAccumulator",
6769
]

src/forge/observability/metric_actors.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
LoggerBackend,
1919
LoggingMode,
2020
MetricCollector,
21+
Reduce,
2122
reduce_metrics_states,
2223
)
2324

@@ -432,9 +433,20 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
432433
# Reduce metrics from states
433434
reduced_metrics = reduce_metrics_states(all_local_states)
434435

436+
# Split into scalar metrics and sample metrics
437+
scalar_metrics = [
438+
m for m in reduced_metrics if m.reduction != Reduce.SAMPLE
439+
]
440+
sample_metrics = [
441+
m for m in reduced_metrics if m.reduction == Reduce.SAMPLE
442+
]
443+
435444
# Log to global backends
436445
for backend_name, backend in self.global_logger_backends.items():
437-
await backend.log_batch(reduced_metrics, global_step)
446+
if scalar_metrics:
447+
await backend.log_batch(scalar_metrics, global_step)
448+
if sample_metrics:
449+
await backend.log_samples(sample_metrics, global_step)
438450

439451
@endpoint
440452
async def has_fetcher(self, proc_id: str) -> bool:

0 commit comments

Comments
 (0)