44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import heapq
8+ import itertools
79import logging
810import os
911from abc import ABC , abstractmethod
1012from dataclasses import dataclass
1113from datetime import datetime
1214from enum import Enum
13- from typing import Any
15+ from typing import Any , Dict , List
1416
1517import pytz
1618
@@ -68,6 +70,7 @@ class Reduce(Enum):
6870 MAX = "max"
6971 MIN = "min"
7072 STD = "std"
73+ SAMPLE = "sample"
7174
7275 @property
7376 def accumulator_class (self ):
@@ -77,6 +80,7 @@ def accumulator_class(self):
7780 Reduce .MAX : MaxAccumulator ,
7881 Reduce .MIN : MinAccumulator ,
7982 Reduce .STD : StdAccumulator ,
83+ Reduce .SAMPLE : SampleAccumulator ,
8084 }
8185 return mapping [self ]
8286
@@ -182,6 +186,55 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
182186 return reduced_metrics
183187
184188
189+ #################
190+ # SampleFilters #
191+ #################
192+
193+
194+ class TopBottomKFilter :
195+ """Keep the top-k and bottom-k samples by a given key (e.g., reward)."""
196+
197+ def __init__ (self , top_k = 1 , bottom_k = 1 , key = "reward" ):
198+ self .top_k = top_k
199+ self .bottom_k = bottom_k
200+ self .key = key
201+ self ._top_heap = [] # min-heap for top-k
202+ self ._bottom_heap = [] # max-heap for bottom-k (store -value)
203+ self ._counter = itertools .count () # tie-breaker id generator
204+
205+ def filter_append (self , sample : Dict ) -> bool :
206+ val = sample .get (self .key , 0.0 )
207+ idx = next (self ._counter ) # unique tiebreaker
208+
209+ # If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none).
210+ # maintain top-k
211+ if self .top_k > 0 :
212+ if len (self ._top_heap ) < self .top_k :
213+ heapq .heappush (self ._top_heap , (val , idx , sample ))
214+ else :
215+ heapq .heappushpop (self ._top_heap , (val , idx , sample ))
216+
217+ # maintain bottom-k
218+ if self .bottom_k > 0 :
219+ if len (self ._bottom_heap ) < self .bottom_k :
220+ heapq .heappush (self ._bottom_heap , (- val , idx , sample ))
221+ else :
222+ heapq .heappushpop (self ._bottom_heap , (- val , idx , sample ))
223+
224+ # always return False here because we don't store in samples list
225+ return False
226+
227+ def filter_flush (self , samples : List [Dict ]) -> List [Dict ]:
228+ tops = [s for _ , _ , s in self ._top_heap ]
229+ bottoms = [s for _ , _ , s in self ._bottom_heap ]
230+ return bottoms + tops
231+
232+ def reset (self ):
233+ self ._top_heap = []
234+ self ._bottom_heap = []
235+ self ._counter = itertools .count ()
236+
237+
185238################
186239# Accumulators #
187240################
@@ -392,6 +445,50 @@ def reset(self) -> None:
392445 self .count = 0
393446
394447
448+ class SampleAccumulator (MetricAccumulator ):
449+ """Accumulator for sample-level metrics (e.g., prompt/response/reward dicts).
450+ Optionally uses a sample filter to decide what to keep at append/flush time.
451+ """
452+
453+ def __init__ (self , reduction : Reduce ):
454+ super ().__init__ (reduction )
455+ self .samples : List [Dict [str , Any ]] = []
456+ self .filter = TopBottomKFilter ()
457+
458+ def append (self , value : dict ) -> None :
459+ if not isinstance (value , dict ):
460+ raise ValueError (f"Expected dict, got { type (value )} " )
461+
462+ # Only keep the sample if filter_append returns True
463+ if self .filter .filter_append (value ):
464+ self .samples .append (value )
465+
466+ def get_value (self ) -> list [dict ]:
467+ """Return locally collected (and optionally filtered) samples."""
468+ # Apply flush-time filter (e.g. heap selection, threshold trimming)
469+ return self .filter .filter_flush (self .samples )
470+
471+ def get_state (self ) -> Dict [str , Any ]:
472+ """Serialize accumulator state for cross-rank reduction."""
473+ return {
474+ "reduction_type" : self .reduction_type .value ,
475+ "samples" : self .get_value (),
476+ }
477+
478+ @classmethod
479+ def get_reduced_value_from_states (cls , states : List [Dict [str , Any ]]) -> list [dict ]:
480+ """Merge sample states across ranks."""
481+ merged = []
482+ for s in states :
483+ merged .extend (s .get ("samples" , []))
484+ return merged
485+
486+ def reset (self ) -> None :
487+ """Clear local samples and reset filter state."""
488+ self .samples .clear ()
489+ self .filter .reset ()
490+
491+
395492#############
396493# Collector #
397494#############
0 commit comments