@@ -199,55 +199,6 @@ def record_episode_sample(table_name: str, episode):
199199 record_metric (table_name , sample , Reduce .SAMPLE )
200200
201201
202- #################
203- # SampleFilters #
204- #################
205-
206-
207- class TopBottomKFilter :
208- """Keep the top-k and bottom-k samples by a given key (e.g., reward)."""
209-
210- def __init__ (self , top_k = 1 , bottom_k = 1 , key = "reward" ):
211- self .top_k = top_k
212- self .bottom_k = bottom_k
213- self .key = key
214- self ._top_heap = [] # min-heap for top-k
215- self ._bottom_heap = [] # max-heap for bottom-k (store -value)
216- self ._counter = itertools .count () # tie-breaker id generator
217-
218- def filter_append (self , sample : Dict ) -> bool :
219- val = sample .get (self .key , 0.0 )
220- idx = next (self ._counter ) # unique tiebreaker
221-
222- # If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none).
223- # maintain top-k
224- if self .top_k > 0 :
225- if len (self ._top_heap ) < self .top_k :
226- heapq .heappush (self ._top_heap , (val , idx , sample ))
227- else :
228- heapq .heappushpop (self ._top_heap , (val , idx , sample ))
229-
230- # maintain bottom-k
231- if self .bottom_k > 0 :
232- if len (self ._bottom_heap ) < self .bottom_k :
233- heapq .heappush (self ._bottom_heap , (- val , idx , sample ))
234- else :
235- heapq .heappushpop (self ._bottom_heap , (- val , idx , sample ))
236-
237- # always return False here because we don't store in samples list
238- return False
239-
240- def filter_flush (self , samples : List [Dict ]) -> List [Dict ]:
241- tops = [s for _ , _ , s in self ._top_heap ]
242- bottoms = [s for _ , _ , s in self ._bottom_heap ]
243- return bottoms + tops
244-
245- def reset (self ):
246- self ._top_heap = []
247- self ._bottom_heap = []
248- self ._counter = itertools .count ()
249-
250-
251202################
252203# Accumulators #
253204################
@@ -459,30 +410,53 @@ def reset(self) -> None:
459410
460411
461412class SampleAccumulator (MetricAccumulator ):
462- """Accumulator for sample-level metrics (e.g., prompt/response/reward dicts).
463- Optionally uses a sample filter to decide what to keep at append/flush time.
413+ """Accumulator for sample-level metrics with top-k and bottom-k filtering.
414+
415+ Keeps the top-k and bottom-k samples by a given key (e.g., reward).
416+ Useful for logging only the best and worst samples from a batch.
464417 """
465418
466- def __init__ (self , reduction : Reduce ):
419+ def __init__ (
420+ self , reduction : Reduce , top_k : int = 1 , bottom_k : int = 1 , key : str = "reward"
421+ ):
467422 super ().__init__ (reduction )
468423 self .samples : List [Dict [str , Any ]] = []
469- self .filter = TopBottomKFilter ()
424+ self .top_k = top_k
425+ self .bottom_k = bottom_k
426+ self .key = key
427+ self ._top_heap = [] # min-heap for top-k
428+ self ._bottom_heap = [] # max-heap for bottom-k (store -value)
429+ self ._counter = itertools .count () # tie-breaker id generator
470430 self .is_reset = True
471431
472432 def append (self , value : dict ) -> None :
473433 if not isinstance (value , dict ):
474434 raise ValueError (f"Expected dict, got { type (value )} " )
475435
476436 self .is_reset = False
477- # Only keep the sample if filter_append returns True
478- if self .filter .filter_append (value ):
479- self .samples .append (value )
437+ val = value .get (self .key , 0.0 )
438+ idx = next (self ._counter ) # unique tiebreaker
439+
440+ # If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none).
441+ # maintain top-k
442+ if self .top_k > 0 :
443+ if len (self ._top_heap ) < self .top_k :
444+ heapq .heappush (self ._top_heap , (val , idx , value ))
445+ else :
446+ heapq .heappushpop (self ._top_heap , (val , idx , value ))
447+
448+ # maintain bottom-k
449+ if self .bottom_k > 0 :
450+ if len (self ._bottom_heap ) < self .bottom_k :
451+ heapq .heappush (self ._bottom_heap , (- val , idx , value ))
452+ else :
453+ heapq .heappushpop (self ._bottom_heap , (- val , idx , value ))
480454
481455 def get_value (self ) -> list [dict ]:
482- """Return locally collected ( and optionally filtered) samples."""
483- # Apply flush-time filter (e.g. heap selection, threshold trimming)
484- results = self .filter . filter_flush ( self . samples )
485- return results
456+ """Return top-k and bottom-k filtered samples."""
457+ tops = [ s for _ , _ , s in self . _top_heap ]
458+ bottoms = [ s for _ , _ , s in self ._bottom_heap ]
459+ return bottoms + tops
486460
487461 def get_state (self ) -> Dict [str , Any ]:
488462 """Serialize accumulator state for cross-rank reduction."""
@@ -503,7 +477,9 @@ def reset(self) -> None:
503477 """Clear local samples and reset filter state."""
504478 self .is_reset = True
505479 self .samples .clear ()
506- self .filter .reset ()
480+ self ._top_heap = []
481+ self ._bottom_heap = []
482+ self ._counter = itertools .count ()
507483
508484
509485#############
0 commit comments