Skip to content

Commit 647b002

Browse files
committed
resolve comments
1 parent 3a99663 commit 647b002

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

src/forge/observability/metrics.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,6 @@ def record_episode_sample(key: str, episode):
143143
episode.reward_breakdown or {}
144144
), # per-fn breakdown including the average reward
145145
"advantage": episode.advantage,
146-
"ref_logprobs": float(
147-
episode.ref_logprobs.mean().item()
148-
if episode.ref_logprobs is not None
149-
else None
150-
),
151146
"request_len": episode.request_len,
152147
"response_len": episode.response_len,
153148
"pad_id": episode.pad_id,
@@ -850,16 +845,12 @@ def log_stream(self, metric: Metric, step: int, *args, **kwargs) -> None:
850845

851846
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
852847
"""Pretty-print sample-level logs to console."""
853-
if not samples:
854-
return
855848
import json
856849

857850
logger.info(f"========== SAMPLE LOGS STEP {step} ==========")
858-
for key, rows in samples.items():
859-
logger.info(f"[{key}] ({len(rows)} samples)")
860-
for sample in rows:
861-
pretty = json.dumps(sample, indent=2, ensure_ascii=False)
862-
logger.info(pretty)
851+
for table_name, table_rows in samples.items():
852+
logger.info(f"[{table_name}] ({len(table_rows)} samples)")
853+
logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False))
863854
logger.info("==============================================\n")
864855

865856
async def finish(self) -> None:
@@ -999,24 +990,24 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
999990
if not self.run or not samples:
1000991
return
1001992

1002-
for key, rows in samples.items():
1003-
if not rows:
993+
for table_name, table_rows in samples.items():
994+
if not table_rows:
1004995
continue
1005996

1006997
# Use all keys to avoid dropped fields
1007-
columns = sorted({k for s in rows for k in s.keys()})
998+
columns = sorted({k for s in table_rows for k in s.keys()})
1008999
table = wandb.Table(columns=columns)
10091000

1010-
for s in rows:
1011-
values = [s.get(c) for c in columns]
1001+
for s in table_rows:
1002+
values = [s.get(c) for c in columns] # returns None for missing keys
10121003
table.add_data(*values)
10131004

10141005
# Unique table name avoids overwrite; commit forces sync
1015-
table_name = f"{key}_table_step{step}"
1016-
self.run.log({table_name: table, "_num_rows": len(rows)}, commit=True)
1006+
table_name = f"{table_name}_table_step{step}"
1007+
self.run.log({table_name: table, "_num_rows": len(table_rows)}, commit=True)
10171008

10181009
logger.info(
1019-
f"WandbBackend: Logged {len(rows)} samples for {key} at step {step}"
1010+
f"WandbBackend: Logged {len(table_rows)} samples for {table_name} at step {step}"
10201011
)
10211012

10221013
def get_metadata_for_secondary_ranks(self) -> Dict[str, Any]:

0 commit comments

Comments
 (0)