@@ -24,6 +24,7 @@ class Reduce(Enum):
2424    MAX  =  "max" 
2525    MIN  =  "min" 
2626    STD  =  "std" 
27+     SAMPLE  =  "sample" 
2728
2829    @property  
2930    def  accumulator_class (self ):
@@ -33,6 +34,7 @@ def accumulator_class(self):
3334            Reduce .MAX : MaxAccumulator ,
3435            Reduce .MIN : MinAccumulator ,
3536            Reduce .STD : StdAccumulator ,
37+             Reduce .SAMPLE : SampleAccumulator ,
3638        }
3739        return  mapping [self ]
3840
@@ -188,6 +190,10 @@ def filter_flush(self, samples: List[Dict]) -> List[Dict]:
188190        """ 
189191        return  samples 
190192
193+     def  reset (self ) ->  None :
194+         """Clears for next accumulation cycle.""" 
195+         pass 
196+ 
191197
192198class  RandomRatioFilter :
193199    """Randomly keep a fraction of samples.""" 
@@ -260,6 +266,10 @@ def filter_flush(self, samples: List[Dict]) -> List[Dict]:
260266        bottoms  =  [s  for  _ , s  in  self ._bottom_heap ]
261267        return  bottoms  +  tops 
262268
269+     def  reset (self ):
270+         self ._top_heap  =  []
271+         self ._bottom_heap  =  []
272+ 
263273
264274################ 
265275# Accumulators # 
@@ -449,6 +459,56 @@ def reset(self) -> None:
449459        self .count  =  0 
450460
451461
462+ class  SampleAccumulator (MetricAccumulator ):
463+     """Accumulator for sample-level metrics (e.g., prompt/response/reward dicts). 
464+ 
465+     Optionally uses a SampleFilter to decide what to keep at append/flush time. 
466+     """ 
467+ 
468+     def  __init__ (self , reduction : Reduce , filter : SampleFilter  |  None  =  None ):
469+         super ().__init__ (reduction )
470+         self .samples : List [Dict [str , Any ]] =  []
471+         self .filter  =  filter 
472+ 
473+     def  append (self , value : dict ) ->  None :
474+         assert  isinstance (value , dict )
475+ 
476+         # If filter is provided, only keep the sample if filter_append returns True 
477+         if  self .filter :
478+             if  self .filter .filter_append (value ):
479+                 self .samples .append (value )
480+         else :
481+             self .samples .append (value )
482+ 
483+     def  get_value (self ) ->  list [dict ]:
484+         """Return locally collected (and optionally filtered) samples.""" 
485+         if  self .filter :
486+             # Apply flush-time filter (e.g. heap selection, threshold trimming) 
487+             return  self .filter .filter_flush (self .samples )
488+         return  self .samples 
489+ 
490+     def  get_state (self ) ->  Dict [str , Any ]:
491+         """Serialize accumulator state for cross-rank reduction.""" 
492+         return  {
493+             "reduction_type" : self .reduction_type .value ,
494+             "samples" : self .get_value (),
495+         }
496+ 
497+     @classmethod  
498+     def  get_reduced_value_from_states (cls , states : List [Dict [str , Any ]]) ->  list [dict ]:
499+         """Merge sample states across ranks.""" 
500+         merged  =  []
501+         for  s  in  states :
502+             merged .extend (s .get ("samples" , []))
503+         return  merged 
504+ 
505+     def  reset (self ) ->  None :
506+         """Clear local samples and reset filter state.""" 
507+         self .samples .clear ()
508+         if  self .filter :
509+             self .filter .reset ()
510+ 
511+ 
452512############# 
453513# Collector # 
454514############# 
0 commit comments