Skip to content

Commit 9234389

Browse files
committed
functions, tests
1 parent 37c2ac9 commit 9234389

File tree

5 files changed

+203
-32
lines changed

5 files changed

+203
-32
lines changed

apps/grpo/main.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from forge.data.rewards import MathReward, ThinkingReward
3030
from forge.data_models.completion import Completion
3131
from forge.observability.metric_actors import get_or_create_metric_logger
32-
from forge.observability.metrics import record_metric, Reduce
32+
from forge.observability.metrics import record_episode_sample, record_metric, Reduce
3333
from forge.observability.perf_tracker import Tracer
3434

3535
from forge.types import LauncherConfig, ProvisionerConfig
@@ -51,6 +51,7 @@ class Episode:
5151
completion: Completion | None = None
5252
ref_logprobs: torch.Tensor | None = None
5353
reward: float | None = None
54+
reward_breakdown: dict[str, float] | None = None
5455
advantage: float | None = None
5556

5657
@property
@@ -143,8 +144,11 @@ class RewardActor(ForgeActor):
143144
reward_functions: list[Callable]
144145

145146
@endpoint
146-
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
147+
async def evaluate_response(
148+
self, prompt: str, response: str, target: str
149+
) -> dict[str, float]:
147150
total_rewards = 0.0
151+
reward_breakdown = {} # reward breakdown by function
148152
for reward_fn in self.reward_functions:
149153
reward = reward_fn(prompt, response, target)
150154
total_rewards += reward
@@ -153,6 +157,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
153157
reward_fn_name = getattr(
154158
reward_fn, "__name__", reward_fn.__class__.__name__
155159
)
160+
reward_breakdown[reward_fn_name] = reward
156161
# per function reward
157162
record_metric(
158163
f"reward/evaluate_response/sum_{reward_fn_name}_reward",
@@ -183,7 +188,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
183188
)
184189

185190
avg_reward = total_rewards / len(self.reward_functions)
186-
return avg_reward
191+
reward_breakdown["reward"] = avg_reward
192+
return reward_breakdown
187193

188194

189195
@dataclass
@@ -387,9 +393,10 @@ async def continuous_rollouts():
387393
target=target,
388394
completion=response,
389395
)
390-
episode.reward = await reward_actor.evaluate_response.route(
396+
episode.reward_breakdown = await reward_actor.evaluate_response.route(
391397
prompt=prompt, response=response.text, target=target
392398
)
399+
episode.reward = episode.reward_breakdown["reward"]
393400
episodes.append(episode)
394401

395402
# Build input_ids for reference logprobs
@@ -411,6 +418,7 @@ async def continuous_rollouts():
411418
for episode, advantage in zip(episodes, advantages):
412419
episode.advantage = advantage
413420
await replay_buffer.add.call_one(episode)
421+
record_episode_sample("rollout/sample", episode)
414422

415423
rollout_count += 1
416424
record_metric(

src/forge/observability/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
MetricAccumulator,
2222
MetricCollector,
2323
MinAccumulator,
24+
record_episode_sample,
2425
record_metric,
2526
Reduce,
2627
reduce_metrics_states,
@@ -37,6 +38,7 @@
3738
# Main API functions
3839
"record_metric",
3940
"reduce_metrics_states",
41+
"record_episode_sample",
4042
"get_logger_backend_class",
4143
"get_or_create_metric_logger",
4244
# Performance tracking

src/forge/observability/metric_actors.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,20 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
432432
# Reduce metrics from states
433433
reduced_metrics = reduce_metrics_states(all_local_states)
434434

435+
# Split into scalar metrics and sample metrics
436+
scalar_metrics = [
437+
m for m in reduced_metrics if m.reduction != Reduce.SAMPLE
438+
]
439+
sample_metrics = {
440+
m.key: m.value for m in reduced_metrics if m.reduction == Reduce.SAMPLE
441+
}
442+
435443
# Log to global backends
436444
for backend_name, backend in self.global_logger_backends.items():
437-
await backend.log_batch(reduced_metrics, global_step)
445+
if scalar_metrics:
446+
await backend.log_batch(scalar_metrics, global_step)
447+
if sample_metrics:
448+
await backend.log_samples(sample_metrics, global_step)
438449

439450
@endpoint
440451
async def has_fetcher(self, proc_id: str) -> bool:

src/forge/observability/metrics.py

Lines changed: 127 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,32 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
139139
list[Metric]: List of reduced metrics
140140
141141
Example:
142-
states = [
143-
{"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}},
144-
{"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
145-
]
146-
reduce_metrics_states(states)
147-
>>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)]
142+
>>> states = [
143+
... {
144+
... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"},
145+
... "reward/sample": {
146+
... "reduction_type": "sample",
147+
... "samples": [{"episode_id": 1, "reward": 0.5}],
148+
... },
149+
... },
150+
... {
151+
... "loss": {"count": 10, "sum": 16, "reduction_type": "mean"},
152+
... "reward/sample": {
153+
... "reduction_type": "sample",
154+
... "samples": [{"episode_id": 2, "reward": 1.0}],
155+
... },
156+
... },
157+
... ]
158+
>>> metrics = reduce_metrics_states(states)
159+
>>> for m in metrics:
160+
... print(m)
161+
Metric(key='loss', value=2.0, reduction=Reduce.MEAN)
162+
Metric(
163+
key='reward/sample',
164+
value=[{'episode_id': 1, 'reward': 0.5},
165+
{'episode_id': 2, 'reward': 1.0}],
166+
reduction=Reduce.SAMPLE,
167+
)
148168
149169
Raises:
150170
ValueError: on mismatched reduction types for the same metric key.
@@ -186,6 +206,31 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
186206
return reduced_metrics
187207

188208

209+
def record_episode_sample(table_name: str, episode):
210+
"""
211+
Record a structured sample-level log for a single episode.
212+
Args:
213+
table_name (str): logging prefix (e.g. "rollout/sample").
214+
episode (Episode): episode object with filled attributes.
215+
"""
216+
sample = {
217+
"episode_id": episode.episode_id,
218+
"policy_version": episode.policy_version,
219+
"prompt": episode.request,
220+
"response": episode.response,
221+
"target": str(episode.target),
222+
**(
223+
episode.reward_breakdown or {}
224+
), # per-fn breakdown including the average reward
225+
"advantage": episode.advantage,
226+
"request_len": episode.request_len,
227+
"response_len": episode.response_len,
228+
"pad_id": episode.pad_id,
229+
}
230+
231+
record_metric(table_name, sample, Reduce.SAMPLE)
232+
233+
189234
#################
190235
# SampleFilters #
191236
#################
@@ -656,7 +701,12 @@ def push(self, metric: Metric) -> None:
656701

657702
# For PER_RANK_NO_REDUCE backends: stream without reduce
658703
for backend in self.per_rank_no_reduce_backends:
659-
backend.log_stream(metric=metric, global_step=self.global_step)
704+
if metric.reduction == Reduce.SAMPLE:
705+
# Wrap singleton Metric into expected {key: [list_of_dicts]} format
706+
sample = {metric.key: [metric.value]}
707+
asyncio.create_task(backend.log_samples(sample, self.global_step))
708+
else:
709+
backend.log_stream(metric=metric, global_step=self.global_step)
660710

661711
# Always accumulate for reduction and state return
662712
key = metric.key
@@ -711,8 +761,21 @@ async def flush(
711761
if self.per_rank_reduce_backends:
712762
metrics_for_backends = reduce_metrics_states([states])
713763

764+
# Split into scalar metrics and sample metrics
765+
scalar_metrics = [
766+
m for m in metrics_for_backends if m.reduction != Reduce.SAMPLE
767+
]
768+
sample_metrics = {
769+
m.key: m.value
770+
for m in metrics_for_backends
771+
if m.reduction == Reduce.SAMPLE
772+
}
773+
714774
for backend in self.per_rank_reduce_backends:
715-
await backend.log_batch(metrics_for_backends, global_step)
775+
if scalar_metrics:
776+
await backend.log_batch(scalar_metrics, global_step)
777+
if sample_metrics:
778+
await backend.log_samples(sample_metrics, global_step)
716779

717780
# Update step counter for streaming backends
718781
# Note: This is incremented AFTER flush completes, so metrics recorded between
@@ -846,6 +909,16 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
846909
async def finish(self) -> None:
847910
pass
848911

912+
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
913+
"""Pretty-print sample-level logs to console."""
914+
import json
915+
916+
logger.info(f"========== SAMPLE LOGS STEP {step} ==========")
917+
for table_name, table_rows in samples.items():
918+
logger.info(f"[{table_name}] ({len(table_rows)} samples)")
919+
logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False))
920+
logger.info("==============================================\n")
921+
849922

850923
class WandbBackend(LoggerBackend):
851924
"""
@@ -882,6 +955,7 @@ def __init__(
882955
)
883956
self.run = None
884957
self.process_name = None
958+
self._tables: dict[str, "wandb.Table"] = {}
885959

886960
async def init(
887961
self,
@@ -992,13 +1066,58 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
9921066
# note: here we dont use step since wandb keeps only the latest value for each step
9931067
self.run.log(log_data)
9941068

1069+
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
1070+
"""Log sample-level data incrementally to persistent WandB Tables."""
1071+
import wandb
1072+
1073+
if not self.run:
1074+
return
1075+
1076+
for table_name, table_rows in samples.items():
1077+
if not table_rows:
1078+
continue
1079+
1080+
# If table doesn't exist yet, create it in INCREMENTAL mode
1081+
if table_name not in self._tables:
1082+
columns = list(table_rows[0].keys())
1083+
table = wandb.Table(columns=columns, log_mode="INCREMENTAL")
1084+
self._tables[table_name] = table
1085+
logger.info(
1086+
f"WandbBackend: Created new incremental table: {table_name}"
1087+
)
1088+
else:
1089+
table = self._tables[table_name]
1090+
1091+
# Add rows (fill missing columns with None)
1092+
for s in table_rows:
1093+
values = [s.get(c) for c in table.columns]
1094+
table.add_data(*values)
1095+
1096+
# Log the same table object (INCREMENTAL update)
1097+
self.run.log({f"{table_name}_table": table})
1098+
logger.info(
1099+
f"WandbBackend: Appended {len(table_rows)} rows to incremental table '{table_name}' at step {step}"
1100+
)
1101+
9951102
def get_metadata_for_secondary_ranks(self) -> dict[str, Any]:
9961103
if self.run and self.per_rank_share_run:
9971104
return {"shared_run_id": self.run.id}
9981105
return {}
9991106

10001107
async def finish(self) -> None:
1108+
import wandb
1109+
10011110
if self.run:
1111+
# Convert each incremental table to immutable before finishing
1112+
for table_name, incr_table in self._tables.items():
1113+
final_table = wandb.Table(
1114+
columns=incr_table.columns,
1115+
data=incr_table.data,
1116+
log_mode="IMMUTABLE",
1117+
)
1118+
self.run.log({table_name: final_table})
1119+
logger.info(f"WandbBackend: Finalized table {table_name}")
1120+
10021121
self.run.finish()
10031122
logger.info(f"WandbBackend {self.process_name}: Finished run")
10041123

tests/unit_tests/observability/test_metrics.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -115,33 +115,64 @@ def test_empty_states(self):
115115

116116
def test_single_state(self):
117117
"""Test reduce_metrics_states with single state."""
118-
states = [{"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}]
119-
result = reduce_metrics_states(states)
120-
assert len(result) == 1
121-
assert result[0].key == "loss"
122-
assert result[0].value == 5.0
123-
assert result[0].reduction == Reduce.MEAN
118+
states = [
119+
{
120+
"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2},
121+
"rollout/sample": {
122+
"reduction_type": "sample",
123+
"samples": [{"id": 1, "reward": 0.5}],
124+
},
125+
}
126+
]
127+
metrics = reduce_metrics_states(states)
128+
assert len(metrics) == 2
129+
# Convert to dict for easier testing
130+
result_dict = {m.key: (m.value, m.reduction) for m in metrics}
131+
132+
assert result_dict["loss"][0] == 5.0
133+
assert result_dict["loss"][1] == Reduce.MEAN
134+
135+
assert result_dict["rollout/sample"][0] == [{"id": 1, "reward": 0.5}]
136+
assert result_dict["rollout/sample"][1] == Reduce.SAMPLE
124137

125138
def test_multiple_states(self):
126139
"""Test reduce_metrics_states with multiple states."""
127140
states = [
128-
{"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}},
129-
{"loss": {"reduction_type": "mean", "sum": 20.0, "count": 3}},
141+
{
142+
"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2},
143+
"rollout/sample": {
144+
"reduction_type": "sample",
145+
"samples": [{"id": 1, "reward": 0.5}],
146+
},
147+
},
148+
{
149+
"loss": {"reduction_type": "mean", "sum": 20.0, "count": 3},
150+
"rollout/sample": {
151+
"reduction_type": "sample",
152+
"samples": [{"id": 2, "reward": 0.8}],
153+
},
154+
},
130155
{"accuracy": {"reduction_type": "sum", "total": 15.0}},
131156
]
132-
result = reduce_metrics_states(states)
157+
metrics = reduce_metrics_states(states)
158+
159+
assert len(metrics) == 3
133160

134161
# Convert to dict for easier testing
135-
result_dict = {metric.key: metric.value for metric in result}
136-
assert result_dict["loss"] == 30.0 / 5.0 # 6.0
137-
assert result_dict["accuracy"] == 15.0
138-
139-
# Also check reduction types
140-
for metric in result:
141-
if metric.key == "loss":
142-
assert metric.reduction == Reduce.MEAN
143-
elif metric.key == "accuracy":
144-
assert metric.reduction == Reduce.SUM
162+
result_dict = {m.key: (m.value, m.reduction) for m in metrics}
163+
164+
# Check scalar reductions
165+
assert result_dict["loss"][0] == 30.0 / 5.0 # 6.0
166+
assert result_dict["loss"][1] == Reduce.MEAN
167+
assert result_dict["accuracy"][0] == 15.0
168+
assert result_dict["accuracy"][1] == Reduce.SUM
169+
170+
# Check sample concatenation
171+
assert result_dict["rollout/sample"][0] == [
172+
{"id": 1, "reward": 0.5},
173+
{"id": 2, "reward": 0.8},
174+
]
175+
assert result_dict["rollout/sample"][1] == Reduce.SAMPLE
145176

146177
def test_mismatched_reduction_types_raises_error(self):
147178
"""Test reduce_metrics_states raises error for mismatched reduction types."""

0 commit comments

Comments
 (0)