Skip to content

Commit 1171f2e

Browse files
committed
integrate sampling
1 parent 58ce06b commit 1171f2e

File tree

2 files changed

+85
-5
lines changed

2 files changed

+85
-5
lines changed

apps/grpo/main.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from forge.controller.provisioner import shutdown
3131
from forge.data.rewards import MathReward, ThinkingReward
3232
from forge.observability.metric_actors import get_or_create_metric_logger
33-
from forge.observability.metrics import record_metric, Reduce
33+
from forge.observability.metrics import record_episode_sample, record_metric, Reduce
3434
from forge.observability.perf_tracker import Tracer
3535
from forge.util.ops import compute_logprobs
3636
from monarch.actor import endpoint
@@ -54,6 +54,7 @@ class Episode:
5454
response_tokens: list[int] | None = None
5555
ref_logprobs: torch.Tensor | None = None
5656
reward: float | None = None
57+
reward_breakdown: dict[str, float] | None = None
5758
advantage: float | None = None
5859

5960
@property
@@ -168,8 +169,11 @@ class RewardActor(ForgeActor):
168169
reward_functions: list[Callable]
169170

170171
@endpoint
171-
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
172+
async def evaluate_response(
173+
self, prompt: str, response: str, target: str
174+
) -> dict[str, float]:
172175
total_rewards = 0.0
176+
reward_breakdown = {} # reward breakdown by function
173177
for reward_fn in self.reward_functions:
174178
reward = reward_fn(prompt, response, target)
175179
total_rewards += reward
@@ -178,6 +182,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
178182
reward_fn_name = getattr(
179183
reward_fn, "__name__", reward_fn.__class__.__name__
180184
)
185+
reward_breakdown[reward_fn_name] = reward
181186
# per function reward
182187
record_metric(
183188
f"reward/evaluate_response/sum_{reward_fn_name}_reward",
@@ -210,7 +215,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
210215
)
211216

212217
avg_reward = total_rewards / len(self.reward_functions)
213-
return avg_reward
218+
reward_breakdown["reward"] = avg_reward
219+
return reward_breakdown
214220

215221

216222
@dataclass
@@ -395,9 +401,10 @@ async def continuous_rollouts():
395401
episode.response = response.text
396402
input_ids[i, :max_req_tokens] = episode.request_tensor
397403
input_ids[i, max_req_tokens:] = episode.response_tensor
398-
episode.reward = await reward_actor.evaluate_response.route(
404+
episode.reward_breakdown = await reward_actor.evaluate_response.route(
399405
prompt=prompt, response=response.text, target=target
400406
)
407+
episode.reward = episode.reward_breakdown["reward"]
401408

402409
t.step("reward_evaluation")
403410

@@ -416,7 +423,9 @@ async def continuous_rollouts():
416423
for episode, advantage in zip(group.episodes, advantages):
417424
episode.advantage = advantage
418425
await replay_buffer.add.call_one(episode)
426+
record_episode_sample("rollout/sample", episode)
419427

428+
record_metric("sample/", {}, Reduce.SAMPLE)
420429
# Log metrics
421430
rollout_count += 1
422431
record_metric(

src/forge/observability/metrics.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,38 @@ def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None
113113
collector.push(key, value, reduction)
114114

115115

116+
def record_episode_sample(key: str, episode):
117+
"""
118+
Record a structured sample-level log for a single episode.
119+
120+
Args:
121+
key (str): logging prefix (e.g. "rollout/sample").
122+
episode (Episode): episode object with filled attributes.
123+
reward_breakdown (dict[str, float]): per-function rewards, e.g. {"MathReward": 0.8, "FormatReward": 1.0}.
124+
"""
125+
sample = {
126+
"episode_id": episode.episode_id,
127+
"policy_version": episode.policy_version,
128+
"prompt": episode.request,
129+
"response": episode.response,
130+
"target": episode.target,
131+
**(
132+
episode.reward_breakdown or {}
133+
), # per-fn breakdown including the average reward
134+
"advantage": episode.advantage,
135+
"ref_logprobs": (
136+
episode.ref_logprobs.mean().item()
137+
if episode.ref_logprobs is not None
138+
else None
139+
),
140+
"request_len": episode.request_len,
141+
"response_len": episode.response_len,
142+
"pad_id": episode.pad_id,
143+
}
144+
145+
record_metric(key, sample, Reduce.SAMPLE)
146+
147+
116148
def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, Any]:
117149
"""Reduce metric accumulators states to a single value per metric.
118150
@@ -465,7 +497,9 @@ class SampleAccumulator(MetricAccumulator):
465497
Optionally uses a SampleFilter to decide what to keep at append/flush time.
466498
"""
467499

468-
def __init__(self, reduction: Reduce, filter: SampleFilter | None = None):
500+
def __init__(
501+
self, reduction: Reduce, filter: SampleFilter | None = TopBottomKFilter()
502+
):
469503
super().__init__(reduction)
470504
self.samples: List[Dict[str, Any]] = []
471505
self.filter = filter
@@ -598,6 +632,7 @@ def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None:
598632
raise ValueError("Collector not initialized—call init first")
599633

600634
if key not in self.accumulators:
635+
# TODO: make sample filter configurable
601636
self.accumulators[key] = reduction.accumulator_class(reduction)
602637

603638
self.accumulators[key].append(value)
@@ -724,6 +759,20 @@ async def log(self, metrics: Dict[str, Any], step: int) -> None:
724759
logger.info(f" {key}: {value}")
725760
logger.info("==============================\n")
726761

762+
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
763+
"""Pretty-print sample-level logs to console."""
764+
if not samples:
765+
return
766+
import pprint
767+
768+
logger.info(f"=== [{self.prefix}] - SAMPLE LOGS STEP {step} ===")
769+
for key, rows in samples.items():
770+
logger.info(f"[{key}] ({len(rows)} samples)")
771+
for sample in rows:
772+
pretty = pprint.pformat(sample, indent=4, width=120, compact=True)
773+
logger.info(pretty)
774+
logger.info("==============================================\n")
775+
727776
async def finish(self) -> None:
728777
pass
729778

@@ -836,6 +885,28 @@ async def log(self, metrics: Dict[str, Any], step: int) -> None:
836885
else:
837886
logger.debug(f"WandbBackend: No run started, skipping log for {self.name}")
838887

888+
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
889+
"""Log sample-level data to WandB Tables."""
890+
import wandb
891+
892+
if not self.run or not samples:
893+
return
894+
895+
for key, rows in samples.items():
896+
if not rows:
897+
continue
898+
899+
# Create a WandB Table dynamically based on keys of first sample
900+
columns = list(rows[0].keys())
901+
table = wandb.Table(columns=columns)
902+
for sample in rows:
903+
table.add_data(*[sample.get(c) for c in columns])
904+
905+
self.run.log({f"{key}_table": table, "global_step": step})
906+
logger.info(
907+
f"WandbBackend: Logged {len(rows)} samples for {key} at step {step}"
908+
)
909+
839910
def get_metadata_for_secondary_ranks(self) -> Dict[str, Any]:
840911
if self.run and not self.reduce_across_ranks and self.share_run_id:
841912
return {"shared_run_id": self.run.id}

0 commit comments

Comments
 (0)