Skip to content

Commit 2b0496e

Browse files
committed
log_samples take list of metirc
1 parent 9d2a0cb commit 2b0496e

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

apps/grpo/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,11 @@ def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]:
8989
"request_len": self.request_len,
9090
"response_len": self.response_len,
9191
"pad_id": self.pad_id,
92+
"ref_logprobs": self.ref_logprobs,
93+
"completion": self.completion,
9294
}
9395

94-
if self.reward_breakdown is not None:
96+
if self.reward_breakdown is not None and "reward_breakdown" not in exclude:
9597
result.update(self.reward_breakdown)
9698

9799
if exclude:

src/forge/observability/metric_actors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
438438
m for m in reduced_metrics if m.reduction != Reduce.SAMPLE
439439
]
440440
sample_metrics = {
441-
m.key: m.value for m in reduced_metrics if m.reduction == Reduce.SAMPLE
441+
m for m in reduced_metrics if m.reduction == Reduce.SAMPLE
442442
}
443443

444444
# Log to global backends

src/forge/observability/metrics.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def record_episode_sample(table_name: str, episode):
195195
table_name (str): logging prefix (e.g. "rollout/sample").
196196
episode (Episode): episode object with filled attributes.
197197
"""
198-
sample = episode.to_dict()
198+
sample = episode.to_dict(exclude=["ref_logprobs", "completion"])
199199
record_metric(table_name, sample, Reduce.SAMPLE)
200200

201201

@@ -675,9 +675,7 @@ def push(self, metric: Metric) -> None:
675675
for backend in self.per_rank_no_reduce_backends:
676676

677677
if metric.reduction == Reduce.SAMPLE:
678-
# Wrap singleton Metric into expected {key: [list_of_dicts]} format
679-
sample = {metric.key: [metric.value]}
680-
asyncio.create_task(backend.log_samples(sample, self.global_step))
678+
asyncio.create_task(backend.log_samples([metric], self.global_step))
681679
else:
682680
backend.log_stream(metric=metric, global_step=self.global_step)
683681

@@ -882,11 +880,12 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
882880
async def finish(self) -> None:
883881
pass
884882

885-
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
883+
async def log_samples(self, samples: List[Metric], step: int) -> None:
886884
"""Pretty-print sample-level logs to console."""
887885

888886
logger.info(f"========== SAMPLE LOGS STEP {step} ==========")
889-
for table_name, table_rows in samples.items():
887+
for sample in samples:
888+
table_name, table_rows = sample.key, sample.value
890889
logger.info(f"[{table_name}] ({len(table_rows)} samples)")
891890
logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False))
892891
logger.info("==============================================\n")
@@ -1038,14 +1037,15 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
10381037
# note: here we dont use step since wandb keeps only the latest value for each step
10391038
self.run.log(log_data)
10401039

1041-
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
1040+
async def log_samples(self, samples: List[Metric], step: int) -> None:
10421041
"""Log sample-level data incrementally to persistent WandB Tables."""
10431042
import wandb
10441043

10451044
if not self.run:
10461045
return
10471046

1048-
for table_name, table_rows in samples.items():
1047+
for sample in samples:
1048+
table_name, table_rows = sample.key, sample.value
10491049
if not table_rows:
10501050
continue
10511051

0 commit comments

Comments
 (0)