Skip to content

Commit 14f0a0f

Browse files
committed
resolve comment
1 parent 38326d7 commit 14f0a0f

File tree

3 files changed

+16
-19
lines changed

3 files changed

+16
-19
lines changed

apps/grpo/main.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from forge.data.rewards import MathReward, ThinkingReward
3030
from forge.data_models.completion import Completion
3131
from forge.observability.metric_actors import get_or_create_metric_logger
32-
from forge.observability.metrics import record_episode_sample, record_metric, Reduce
32+
from forge.observability.metrics import record_metric, Reduce
3333
from forge.observability.perf_tracker import Tracer
3434

3535
from forge.types import LauncherConfig, ProvisionerConfig
@@ -449,8 +449,13 @@ async def continuous_rollouts():
449449
for episode, advantage in zip(episodes, advantages):
450450
episode.advantage = advantage
451451
await replay_buffer.add.call_one(episode)
452-
record_episode_sample(
453-
"main_samples/continuous_rollouts/sample_table", episode
452+
453+
sample = episode.to_dict(exclude=["ref_logprobs", "completion"])
454+
sample["score"] = sample["reward"]
455+
record_metric(
456+
"main_samples/continuous_rollouts/sample_table",
457+
sample,
458+
Reduce.SAMPLE,
454459
)
455460

456461
rollout_count += 1

src/forge/observability/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
MetricAccumulator,
2222
MetricCollector,
2323
MinAccumulator,
24-
record_episode_sample,
2524
record_metric,
2625
Reduce,
2726
reduce_metrics_states,
@@ -37,7 +36,6 @@
3736
# Main API functions
3837
"record_metric",
3938
"reduce_metrics_states",
40-
"record_episode_sample",
4139
"get_logger_backend_class",
4240
"get_or_create_metric_logger",
4341
# Performance tracking

src/forge/observability/metrics.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -201,17 +201,6 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
201201
return reduced_metrics
202202

203203

204-
def record_episode_sample(table_name: str, episode):
205-
"""
206-
Record a structured sample-level log for a single episode.
207-
Args:
208-
table_name (str): logging prefix (e.g. "rollout/sample").
209-
episode (Episode): episode object with filled attributes.
210-
"""
211-
sample = episode.to_dict(exclude=["ref_logprobs", "completion"])
212-
record_metric(table_name, sample, Reduce.SAMPLE)
213-
214-
215204
################
216205
# Accumulators #
217206
################
@@ -430,7 +419,7 @@ class SampleAccumulator(MetricAccumulator):
430419
"""
431420

432421
def __init__(
433-
self, reduction: Reduce, top_k: int = 1, bottom_k: int = 1, key: str = "reward"
422+
self, reduction: Reduce, top_k: int = 1, bottom_k: int = 1, key: str = "score"
434423
):
435424
super().__init__(reduction)
436425
self.samples: List[Dict[str, Any]] = []
@@ -869,12 +858,10 @@ async def finish(self) -> None:
869858
async def log_samples(self, samples: List[Metric], step: int) -> None:
870859
"""Pretty-print sample-level logs to console."""
871860

872-
logger.info(f"========== SAMPLE LOGS STEP {step} ==========")
873861
for sample in samples:
874862
table_name, table_rows = sample.key, sample.value
875863
logger.info(f"[{table_name}] ({len(table_rows)} samples)")
876864
logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False))
877-
logger.info("==============================================\n")
878865

879866

880867
class WandbBackend(LoggerBackend):
@@ -1056,6 +1043,13 @@ async def log_samples(self, samples: List[Metric], step: int) -> None:
10561043

10571044
# Add rows (fill missing columns with None)
10581045
for s in table_rows:
1046+
# Check for extra columns not in the table schema
1047+
extra_columns = set(s.keys()) - set(table.columns)
1048+
if extra_columns:
1049+
logger.warning(
1050+
f"WandbBackend: Row has extra columns not in table '{table_name}': {sorted(extra_columns)}. "
1051+
f"These will be ignored."
1052+
)
10591053
values = [s.get(c) for c in table.columns]
10601054
table.add_data(*values)
10611055

0 commit comments

Comments
 (0)