55# LICENSE file in the root directory of this source tree.
66
77import heapq
8+ import itertools
89import logging
910
1011import os
@@ -145,37 +146,54 @@ def record_episode_sample(key: str, episode):
145146 record_metric (key , sample , Reduce .SAMPLE )
146147
147148
148- def reduce_metrics_states (states : List [Dict [str , Dict [str , Any ]]]) -> Dict [str , Any ]:
149- """Reduce metric accumulators states to a single value per metric.
149+ def reduce_metrics_states (
150+ states : List [Dict [str , Dict [str , Any ]]]
151+ ) -> tuple [Dict [str , Any ], Dict [str , list [dict ]]]:
152+ """
153+ Reduce metric accumulator states across ranks into two groups:
154+ - scalar metrics (mean/sum/etc.)
155+ - sample metrics (list[dict])
150156
151- Can be used when reducing metrics across ranks or services, as merging
152- states is more precise than merging locally reduced metrics.
157+ This function merges metric accumulator states from multiple ranks or processes
158+ into final reduced values. It automatically distinguishes between scalar reductions
159+ (e.g., MEAN, SUM) and structured SAMPLE-type reductions (e.g., per-example dicts).
153160
154161 Args:
155162 states (List[Dict[str, Dict[str, Any]]]): List of state of one or more metrics,
156163 normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`.
157164
158165 Returns:
159- Dict[str, Any]: Dictionary with format {metric_key: reduced_value}
166+ metrics: Dict[str, Any], {metric_key: reduced_scalar_value}
167+ samples: Dict[str, list[dict]], {metric_key: merged_list_of_samples}
160168
161169 Example:
162- states = [
163- {"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}},
164- {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
165- ]
166- reduce_metrics_states(states)
167- >>> {"loss": 2.0}
170+ >>> states = [
171+ ... {
172+ ... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"},
173+ ... "rollout/sample": {"reduction_type": "sample", "samples": [{"id": 1}]},
174+ ... },
175+ ... {
176+ ... "loss": {"count": 10, "sum": 26, "reduction_type": "mean"},
177+ ... "rollout/sample": {"reduction_type": "sample", "samples": [{"id": 2}]},
178+ ... },
179+ ... ]
180+ >>> metrics, samples = reduce_metrics_states(states)
181+ >>> metrics
182+ {'loss': 2.6666666666666665}
183+ >>> samples
184+ {'rollout/sample': [{'id': 1}, {'id': 2}]}
168185
169186 Raises:
170187 ValueError: on mismatched reduction types for the same metric key.
171188 """
172189 if not states :
173- return {}
190+ return {}, {}
174191
175192 # Collect unique keys across all
176193 all_keys = set (k for state in states for k in state )
194+ metrics : Dict [str , Any ] = {}
195+ samples : Dict [str , list [dict ]] = {}
177196
178- reduced_metrics = {}
179197 for key in all_keys :
180198 metric_states = [state .get (key ) for state in states if key in state ]
181199 if not metric_states :
@@ -194,9 +212,14 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str,
194212
195213 metric_accumulator = Reduce (first_reduction_type ).accumulator_class
196214 reduced_value = metric_accumulator .get_reduced_value_from_states (metric_states )
197- reduced_metrics [key ] = reduced_value
215+ metrics [key ] = reduced_value
198216
199- return reduced_metrics
217+ # separate samples vs normal metrics
218+ if first_reduction_type == Reduce .SAMPLE .value :
219+ samples [key ] = reduced_value
220+ else :
221+ metrics [key ] = reduced_value
222+ return metrics , samples
200223
201224
202225#################
@@ -271,36 +294,39 @@ def __init__(self, top_k=1, bottom_k=1, key="reward"):
271294 self .key = key
272295 self ._top_heap = [] # min-heap for top-k
273296 self ._bottom_heap = [] # max-heap for bottom-k (store -value)
297+ self ._counter = itertools .count () # tie-breaker id generator
274298
275299 def filter_append (self , sample : Dict ) -> bool :
276300 val = sample .get (self .key , 0.0 )
301+ idx = next (self ._counter ) # unique tiebreaker
277302
278303 # If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none).
279304 # maintain top-k
280305 if self .top_k > 0 :
281306 if len (self ._top_heap ) < self .top_k :
282- heapq .heappush (self ._top_heap , (val , sample ))
307+ heapq .heappush (self ._top_heap , (val , idx , sample ))
283308 else :
284- heapq .heappushpop (self ._top_heap , (val , sample ))
309+ heapq .heappushpop (self ._top_heap , (val , idx , sample ))
285310
286311 # maintain bottom-k
287312 if self .bottom_k > 0 :
288313 if len (self ._bottom_heap ) < self .bottom_k :
289- heapq .heappush (self ._bottom_heap , (- val , sample ))
314+ heapq .heappush (self ._bottom_heap , (- val , idx , sample ))
290315 else :
291- heapq .heappushpop (self ._bottom_heap , (- val , sample ))
316+ heapq .heappushpop (self ._bottom_heap , (- val , idx , sample ))
292317
293318 # always return False here because we don't store in samples list
294319 return False
295320
296321 def filter_flush (self , samples : List [Dict ]) -> List [Dict ]:
297- tops = [s for _ , s in self ._top_heap ]
298- bottoms = [s for _ , s in self ._bottom_heap ]
322+ tops = [s for _ , _ , s in self ._top_heap ]
323+ bottoms = [s for _ , _ , s in self ._bottom_heap ]
299324 return bottoms + tops
300325
301326 def reset (self ):
302327 self ._top_heap = []
303328 self ._bottom_heap = []
329+ self ._counter = itertools .count ()
304330
305331
306332################
@@ -670,14 +696,27 @@ async def flush(
670696
671697 # Reduce metrics from states for logging if any per-rank backend
672698 if self .logger_backends :
673- metrics = {}
699+ # Prepare two groups: normal metrics and sample-type metrics
700+ metrics : Dict [str , Any ] = {}
701+ samples : Dict [str , list [dict ]] = {}
674702 for key , state in states .items ():
675- acc_class = Reduce (state ["reduction_type" ]).accumulator_class
676- metrics [key ] = acc_class .get_reduced_value_from_states ([state ])
703+ reduction_type = state ["reduction_type" ]
704+ acc_class = Reduce (reduction_type ).accumulator_class
705+ value = acc_class .get_reduced_value_from_states ([state ])
706+
707+ if reduction_type == Reduce .SAMPLE .value :
708+ # sample-type metrics → list[dict]
709+ samples [key ] = value
710+ else :
711+ # scalar metrics → float/int/etc.
712+ metrics [key ] = value
677713
678714 # Log to local logger_backends
679715 for logger_backend in self .logger_backends :
680- await logger_backend .log (metrics , step )
716+ if metrics :
717+ await logger_backend .log (metrics , step )
718+ if samples :
719+ await logger_backend .log_samples (samples , step )
681720
682721 return states if return_state else {}
683722
@@ -728,6 +767,9 @@ async def init(
728767 async def log (self , metrics : Dict [str , Any ], step : int ) -> None :
729768 pass
730769
770+ async def log_samples (self , samples : Dict [str , List [dict ]], step : int ) -> None :
771+ pass
772+
731773 async def finish (self ) -> None :
732774 pass
733775
@@ -763,13 +805,13 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
763805 """Pretty-print sample-level logs to console."""
764806 if not samples :
765807 return
766- import pprint
808+ import json
767809
768810 logger .info (f"=== [{ self .prefix } ] - SAMPLE LOGS STEP { step } ===" )
769811 for key , rows in samples .items ():
770812 logger .info (f"[{ key } ] ({ len (rows )} samples)" )
771813 for sample in rows :
772- pretty = pprint . pformat (sample , indent = 4 , width = 120 , compact = True )
814+ pretty = json . dumps (sample , indent = 2 , ensure_ascii = False )
773815 logger .info (pretty )
774816 logger .info ("==============================================\n " )
775817
@@ -805,6 +847,7 @@ def __init__(self, logger_backend_config: Dict[str, Any]):
805847 "reduce_across_ranks" , True
806848 )
807849 self .share_run_id = logger_backend_config .get ("share_run_id" , False )
850+ self .tables = {} # keep persistent tables per key
808851
809852 async def init (
810853 self ,
@@ -891,18 +934,25 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
891934
892935 if not self .run or not samples :
893936 return
894-
895937 for key , rows in samples .items ():
896938 if not rows :
897939 continue
898-
899940 # Create a WandB Table dynamically based on keys of first sample
900941 columns = list (rows [0 ].keys ())
901942 table = wandb .Table (columns = columns )
902943 for sample in rows :
903- table .add_data (* [sample .get (c ) for c in columns ])
904-
905- self .run .log ({f"{ key } _table" : table , "global_step" : step })
944+ # table.add_data(*[sample.get(c) for c in columns])
945+ values = [sample .get (c ) for c in columns ]
946+ logger .info (f"Adding row to { key } _table: { values } " )
947+ table .add_data (* values )
948+ self .run .log (
949+ {
950+ f"{ key } _step_{ step } _table" : table ,
951+ "_sample_rows_logged" : len (rows ),
952+ "global_step" : step ,
953+ },
954+ commit = True ,
955+ )
906956 logger .info (
907957 f"WandbBackend: Logged { len (rows )} samples for { key } at step { step } "
908958 )
0 commit comments