Skip to content

Commit 0c52ea5

Browse files
committed
log_samples take list of metirc
1 parent 5c67488 commit 0c52ea5

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

apps/grpo/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,11 @@ def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]:
9191
"request_len": self.request_len,
9292
"response_len": self.response_len,
9393
"pad_id": self.pad_id,
94+
"ref_logprobs": self.ref_logprobs,
95+
"completion": self.completion,
9496
}
9597

96-
if self.reward_breakdown is not None:
98+
if self.reward_breakdown is not None and "reward_breakdown" not in exclude:
9799
result.update(self.reward_breakdown)
98100

99101
if exclude:

src/forge/observability/metric_actors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
429429
m for m in reduced_metrics if m.reduction != Reduce.SAMPLE
430430
]
431431
sample_metrics = {
432-
m.key: m.value for m in reduced_metrics if m.reduction == Reduce.SAMPLE
432+
m for m in reduced_metrics if m.reduction == Reduce.SAMPLE
433433
}
434434

435435
# Log to global backends

src/forge/observability/metrics.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def record_episode_sample(table_name: str, episode):
195195
table_name (str): logging prefix (e.g. "rollout/sample").
196196
episode (Episode): episode object with filled attributes.
197197
"""
198-
sample = episode.to_dict()
198+
sample = episode.to_dict(exclude=["ref_logprobs", "completion"])
199199
record_metric(table_name, sample, Reduce.SAMPLE)
200200

201201

@@ -675,9 +675,7 @@ def push(self, metric: Metric) -> None:
675675
for backend in self.per_rank_no_reduce_backends:
676676

677677
if metric.reduction == Reduce.SAMPLE:
678-
# Wrap singleton Metric into expected {key: [list_of_dicts]} format
679-
sample = {metric.key: [metric.value]}
680-
asyncio.create_task(backend.log_samples(sample, self.global_step))
678+
asyncio.create_task(backend.log_samples([metric], self.global_step))
681679
else:
682680
backend.log_stream(metric=metric, global_step=self.global_step)
683681

@@ -867,11 +865,12 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
867865
async def finish(self) -> None:
868866
pass
869867

870-
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
868+
async def log_samples(self, samples: List[Metric], step: int) -> None:
871869
"""Pretty-print sample-level logs to console."""
872870

873871
logger.info(f"========== SAMPLE LOGS STEP {step} ==========")
874-
for table_name, table_rows in samples.items():
872+
for sample in samples:
873+
table_name, table_rows = sample.key, sample.value
875874
logger.info(f"[{table_name}] ({len(table_rows)} samples)")
876875
logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False))
877876
logger.info("==============================================\n")
@@ -1014,14 +1013,15 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
10141013
# note: here we dont use step since wandb keeps only the latest value for each step
10151014
self.run.log(log_data)
10161015

1017-
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
1016+
async def log_samples(self, samples: List[Metric], step: int) -> None:
10181017
"""Log sample-level data incrementally to persistent WandB Tables."""
10191018
import wandb
10201019

10211020
if not self.run:
10221021
return
10231022

1024-
for table_name, table_rows in samples.items():
1023+
for sample in samples:
1024+
table_name, table_rows = sample.key, sample.value
10251025
if not table_rows:
10261026
continue
10271027

0 commit comments

Comments
 (0)