Skip to content

Commit 675a896

Browse files
committed
debug; blocked by wandb table upload bug
1 parent 066c464 commit 675a896

File tree

3 files changed

+87
-35
lines changed

3 files changed

+87
-35
lines changed

apps/grpo/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,6 @@ async def continuous_rollouts():
425425
await replay_buffer.add.call_one(episode)
426426
record_episode_sample("rollout/sample", episode)
427427

428-
record_metric("sample/", {}, Reduce.SAMPLE)
429428
# Log metrics
430429
rollout_count += 1
431430
record_metric(

src/forge/observability/metric_actors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,17 @@ async def flush(self, step: int):
313313
return
314314

315315
# Reduce
316-
reduced_metrics = reduce_metrics_states(all_local_states)
316+
reduced_metrics, reduced_samples = reduce_metrics_states(all_local_states)
317317

318318
# Log to each global logger_backend
319319
for (
320320
logger_backend_name,
321321
logger_backend,
322322
) in self.global_logger_backends.items():
323-
await logger_backend.log(reduced_metrics, step)
323+
if reduced_metrics:
324+
await logger_backend.log(reduced_metrics, step)
325+
if reduced_samples:
326+
await logger_backend.log_samples(reduced_samples, step)
324327

325328
@endpoint
326329
def has_fetcher(self, name: str | ProcMesh) -> bool:

src/forge/observability/metrics.py

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import heapq
8+
import itertools
89
import logging
910

1011
import os
@@ -145,37 +146,54 @@ def record_episode_sample(key: str, episode):
145146
record_metric(key, sample, Reduce.SAMPLE)
146147

147148

148-
def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, Any]:
149-
"""Reduce metric accumulators states to a single value per metric.
149+
def reduce_metrics_states(
150+
states: List[Dict[str, Dict[str, Any]]]
151+
) -> tuple[Dict[str, Any], Dict[str, list[dict]]]:
152+
"""
153+
Reduce metric accumulator states across ranks into two groups:
154+
- scalar metrics (mean/sum/etc.)
155+
- sample metrics (list[dict])
150156
151-
Can be used when reducing metrics across ranks or services, as merging
152-
states is more precise than merging locally reduced metrics.
157+
This function merges metric accumulator states from multiple ranks or processes
158+
into final reduced values. It automatically distinguishes between scalar reductions
159+
(e.g., MEAN, SUM) and structured SAMPLE-type reductions (e.g., per-example dicts).
153160
154161
Args:
155162
states (List[Dict[str, Dict[str, Any]]]): List of state of one or more metrics,
156163
normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`.
157164
158165
Returns:
159-
Dict[str, Any]: Dictionary with format {metric_key: reduced_value}
166+
metrics: Dict[str, Any], {metric_key: reduced_scalar_value}
167+
samples: Dict[str, list[dict]], {metric_key: merged_list_of_samples}
160168
161169
Example:
162-
states = [
163-
{"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}},
164-
{"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
165-
]
166-
reduce_metrics_states(states)
167-
>>> {"loss": 2.0}
170+
>>> states = [
171+
... {
172+
... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"},
173+
... "rollout/sample": {"reduction_type": "sample", "samples": [{"id": 1}]},
174+
... },
175+
... {
176+
... "loss": {"count": 10, "sum": 26, "reduction_type": "mean"},
177+
... "rollout/sample": {"reduction_type": "sample", "samples": [{"id": 2}]},
178+
... },
179+
... ]
180+
>>> metrics, samples = reduce_metrics_states(states)
181+
>>> metrics
182+
{'loss': 2.6666666666666665}
183+
>>> samples
184+
{'rollout/sample': [{'id': 1}, {'id': 2}]}
168185
169186
Raises:
170187
ValueError: on mismatched reduction types for the same metric key.
171188
"""
172189
if not states:
173-
return {}
190+
return {}, {}
174191

175192
# Collect unique keys across all
176193
all_keys = set(k for state in states for k in state)
194+
metrics: Dict[str, Any] = {}
195+
samples: Dict[str, list[dict]] = {}
177196

178-
reduced_metrics = {}
179197
for key in all_keys:
180198
metric_states = [state.get(key) for state in states if key in state]
181199
if not metric_states:
@@ -194,9 +212,14 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str,
194212

195213
metric_accumulator = Reduce(first_reduction_type).accumulator_class
196214
reduced_value = metric_accumulator.get_reduced_value_from_states(metric_states)
197-
reduced_metrics[key] = reduced_value
215+
metrics[key] = reduced_value
198216

199-
return reduced_metrics
217+
# separate samples vs normal metrics
218+
if first_reduction_type == Reduce.SAMPLE.value:
219+
samples[key] = reduced_value
220+
else:
221+
metrics[key] = reduced_value
222+
return metrics, samples
200223

201224

202225
#################
@@ -271,36 +294,39 @@ def __init__(self, top_k=1, bottom_k=1, key="reward"):
271294
self.key = key
272295
self._top_heap = [] # min-heap for top-k
273296
self._bottom_heap = [] # max-heap for bottom-k (store -value)
297+
self._counter = itertools.count() # tie-breaker id generator
274298

275299
def filter_append(self, sample: Dict) -> bool:
276300
val = sample.get(self.key, 0.0)
301+
idx = next(self._counter) # unique tiebreaker
277302

278303
# If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none).
279304
# maintain top-k
280305
if self.top_k > 0:
281306
if len(self._top_heap) < self.top_k:
282-
heapq.heappush(self._top_heap, (val, sample))
307+
heapq.heappush(self._top_heap, (val, idx, sample))
283308
else:
284-
heapq.heappushpop(self._top_heap, (val, sample))
309+
heapq.heappushpop(self._top_heap, (val, idx, sample))
285310

286311
# maintain bottom-k
287312
if self.bottom_k > 0:
288313
if len(self._bottom_heap) < self.bottom_k:
289-
heapq.heappush(self._bottom_heap, (-val, sample))
314+
heapq.heappush(self._bottom_heap, (-val, idx, sample))
290315
else:
291-
heapq.heappushpop(self._bottom_heap, (-val, sample))
316+
heapq.heappushpop(self._bottom_heap, (-val, idx, sample))
292317

293318
# always return False here because we don't store in samples list
294319
return False
295320

296321
def filter_flush(self, samples: List[Dict]) -> List[Dict]:
297-
tops = [s for _, s in self._top_heap]
298-
bottoms = [s for _, s in self._bottom_heap]
322+
tops = [s for _, _, s in self._top_heap]
323+
bottoms = [s for _, _, s in self._bottom_heap]
299324
return bottoms + tops
300325

301326
def reset(self):
302327
self._top_heap = []
303328
self._bottom_heap = []
329+
self._counter = itertools.count()
304330

305331

306332
################
@@ -670,14 +696,27 @@ async def flush(
670696

671697
# Reduce metrics from states for logging if any per-rank backend
672698
if self.logger_backends:
673-
metrics = {}
699+
# Prepare two groups: normal metrics and sample-type metrics
700+
metrics: Dict[str, Any] = {}
701+
samples: Dict[str, list[dict]] = {}
674702
for key, state in states.items():
675-
acc_class = Reduce(state["reduction_type"]).accumulator_class
676-
metrics[key] = acc_class.get_reduced_value_from_states([state])
703+
reduction_type = state["reduction_type"]
704+
acc_class = Reduce(reduction_type).accumulator_class
705+
value = acc_class.get_reduced_value_from_states([state])
706+
707+
if reduction_type == Reduce.SAMPLE.value:
708+
# sample-type metrics → list[dict]
709+
samples[key] = value
710+
else:
711+
# scalar metrics → float/int/etc.
712+
metrics[key] = value
677713

678714
# Log to local logger_backends
679715
for logger_backend in self.logger_backends:
680-
await logger_backend.log(metrics, step)
716+
if metrics:
717+
await logger_backend.log(metrics, step)
718+
if samples:
719+
await logger_backend.log_samples(samples, step)
681720

682721
return states if return_state else {}
683722

@@ -728,6 +767,9 @@ async def init(
728767
async def log(self, metrics: Dict[str, Any], step: int) -> None:
729768
pass
730769

770+
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
771+
pass
772+
731773
async def finish(self) -> None:
732774
pass
733775

@@ -763,13 +805,13 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
763805
"""Pretty-print sample-level logs to console."""
764806
if not samples:
765807
return
766-
import pprint
808+
import json
767809

768810
logger.info(f"=== [{self.prefix}] - SAMPLE LOGS STEP {step} ===")
769811
for key, rows in samples.items():
770812
logger.info(f"[{key}] ({len(rows)} samples)")
771813
for sample in rows:
772-
pretty = pprint.pformat(sample, indent=4, width=120, compact=True)
814+
pretty = json.dumps(sample, indent=2, ensure_ascii=False)
773815
logger.info(pretty)
774816
logger.info("==============================================\n")
775817

@@ -805,6 +847,7 @@ def __init__(self, logger_backend_config: Dict[str, Any]):
805847
"reduce_across_ranks", True
806848
)
807849
self.share_run_id = logger_backend_config.get("share_run_id", False)
850+
self.tables = {} # keep persistent tables per key
808851

809852
async def init(
810853
self,
@@ -891,18 +934,25 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
891934

892935
if not self.run or not samples:
893936
return
894-
895937
for key, rows in samples.items():
896938
if not rows:
897939
continue
898-
899940
# Create a WandB Table dynamically based on keys of first sample
900941
columns = list(rows[0].keys())
901942
table = wandb.Table(columns=columns)
902943
for sample in rows:
903-
table.add_data(*[sample.get(c) for c in columns])
904-
905-
self.run.log({f"{key}_table": table, "global_step": step})
944+
# table.add_data(*[sample.get(c) for c in columns])
945+
values = [sample.get(c) for c in columns]
946+
logger.info(f"Adding row to {key}_table: {values}")
947+
table.add_data(*values)
948+
self.run.log(
949+
{
950+
f"{key}_step_{step}_table": table,
951+
"_sample_rows_logged": len(rows),
952+
"global_step": step,
953+
},
954+
commit=True,
955+
)
906956
logger.info(
907957
f"WandbBackend: Logged {len(rows)} samples for {key} at step {step}"
908958
)

0 commit comments

Comments
 (0)