Skip to content

Commit 2d52ebf

Browse files
committed
merge filter into sampler
1 parent 0c52ea5 commit 2d52ebf

File tree

4 files changed

+40
-67
lines changed

4 files changed

+40
-67
lines changed

src/forge/observability/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
SampleAccumulator,
2929
StdAccumulator,
3030
SumAccumulator,
31-
TopBottomKFilter,
3231
WandbBackend,
3332
)
3433
from .perf_tracker import trace, Tracer
@@ -69,6 +68,4 @@
6968
"MinAccumulator",
7069
"StdAccumulator",
7170
"SampleAccumulator",
72-
# Filter classes
73-
"TopBottomKFilter",
7471
]

src/forge/observability/metric_actors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,9 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
428428
scalar_metrics = [
429429
m for m in reduced_metrics if m.reduction != Reduce.SAMPLE
430430
]
431-
sample_metrics = {
431+
sample_metrics = [
432432
m for m in reduced_metrics if m.reduction == Reduce.SAMPLE
433-
}
433+
]
434434

435435
# Log to global backends
436436
for backend_name, backend in self.global_logger_backends.items():

src/forge/observability/metrics.py

Lines changed: 37 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -199,55 +199,6 @@ def record_episode_sample(table_name: str, episode):
199199
record_metric(table_name, sample, Reduce.SAMPLE)
200200

201201

202-
#################
203-
# SampleFilters #
204-
#################
205-
206-
207-
class TopBottomKFilter:
208-
"""Keep the top-k and bottom-k samples by a given key (e.g., reward)."""
209-
210-
def __init__(self, top_k=1, bottom_k=1, key="reward"):
211-
self.top_k = top_k
212-
self.bottom_k = bottom_k
213-
self.key = key
214-
self._top_heap = [] # min-heap for top-k
215-
self._bottom_heap = [] # max-heap for bottom-k (store -value)
216-
self._counter = itertools.count() # tie-breaker id generator
217-
218-
def filter_append(self, sample: Dict) -> bool:
219-
val = sample.get(self.key, 0.0)
220-
idx = next(self._counter) # unique tiebreaker
221-
222-
# If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none).
223-
# maintain top-k
224-
if self.top_k > 0:
225-
if len(self._top_heap) < self.top_k:
226-
heapq.heappush(self._top_heap, (val, idx, sample))
227-
else:
228-
heapq.heappushpop(self._top_heap, (val, idx, sample))
229-
230-
# maintain bottom-k
231-
if self.bottom_k > 0:
232-
if len(self._bottom_heap) < self.bottom_k:
233-
heapq.heappush(self._bottom_heap, (-val, idx, sample))
234-
else:
235-
heapq.heappushpop(self._bottom_heap, (-val, idx, sample))
236-
237-
# always return False here because we don't store in samples list
238-
return False
239-
240-
def filter_flush(self, samples: List[Dict]) -> List[Dict]:
241-
tops = [s for _, _, s in self._top_heap]
242-
bottoms = [s for _, _, s in self._bottom_heap]
243-
return bottoms + tops
244-
245-
def reset(self):
246-
self._top_heap = []
247-
self._bottom_heap = []
248-
self._counter = itertools.count()
249-
250-
251202
################
252203
# Accumulators #
253204
################
@@ -459,30 +410,53 @@ def reset(self) -> None:
459410

460411

461412
class SampleAccumulator(MetricAccumulator):
462-
"""Accumulator for sample-level metrics (e.g., prompt/response/reward dicts).
463-
Optionally uses a sample filter to decide what to keep at append/flush time.
413+
"""Accumulator for sample-level metrics with top-k and bottom-k filtering.
414+
415+
Keeps the top-k and bottom-k samples by a given key (e.g., reward).
416+
Useful for logging only the best and worst samples from a batch.
464417
"""
465418

466-
def __init__(self, reduction: Reduce):
419+
def __init__(
420+
self, reduction: Reduce, top_k: int = 1, bottom_k: int = 1, key: str = "reward"
421+
):
467422
super().__init__(reduction)
468423
self.samples: List[Dict[str, Any]] = []
469-
self.filter = TopBottomKFilter()
424+
self.top_k = top_k
425+
self.bottom_k = bottom_k
426+
self.key = key
427+
self._top_heap = [] # min-heap for top-k
428+
self._bottom_heap = [] # max-heap for bottom-k (store -value)
429+
self._counter = itertools.count() # tie-breaker id generator
470430
self.is_reset = True
471431

472432
def append(self, value: dict) -> None:
473433
if not isinstance(value, dict):
474434
raise ValueError(f"Expected dict, got {type(value)}")
475435

476436
self.is_reset = False
477-
# Only keep the sample if filter_append returns True
478-
if self.filter.filter_append(value):
479-
self.samples.append(value)
437+
val = value.get(self.key, 0.0)
438+
idx = next(self._counter) # unique tiebreaker
439+
440+
# If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none).
441+
# maintain top-k
442+
if self.top_k > 0:
443+
if len(self._top_heap) < self.top_k:
444+
heapq.heappush(self._top_heap, (val, idx, value))
445+
else:
446+
heapq.heappushpop(self._top_heap, (val, idx, value))
447+
448+
# maintain bottom-k
449+
if self.bottom_k > 0:
450+
if len(self._bottom_heap) < self.bottom_k:
451+
heapq.heappush(self._bottom_heap, (-val, idx, value))
452+
else:
453+
heapq.heappushpop(self._bottom_heap, (-val, idx, value))
480454

481455
def get_value(self) -> list[dict]:
482-
"""Return locally collected (and optionally filtered) samples."""
483-
# Apply flush-time filter (e.g. heap selection, threshold trimming)
484-
results = self.filter.filter_flush(self.samples)
485-
return results
456+
"""Return top-k and bottom-k filtered samples."""
457+
tops = [s for _, _, s in self._top_heap]
458+
bottoms = [s for _, _, s in self._bottom_heap]
459+
return bottoms + tops
486460

487461
def get_state(self) -> Dict[str, Any]:
488462
"""Serialize accumulator state for cross-rank reduction."""
@@ -503,7 +477,9 @@ def reset(self) -> None:
503477
"""Clear local samples and reset filter state."""
504478
self.is_reset = True
505479
self.samples.clear()
506-
self.filter.reset()
480+
self._top_heap = []
481+
self._bottom_heap = []
482+
self._counter = itertools.count()
507483

508484

509485
#############

tests/unit_tests/data/test_metrics_aggregator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_handler_replacement_warning(self, caplog):
247247
assert "Replacing handler for AggregationType.SUM" in caplog.records[0].message
248248

249249
def test_sample_accumulator_with_topbottom_filter(self):
250-
"""Ensure SampleAccumulator integrates with TopBottomKFilter correctly."""
250+
"""Ensure SampleAccumulator samples top and bottom correctly."""
251251
from forge.observability.metrics import Reduce, SampleAccumulator
252252

253253
acc = SampleAccumulator(Reduce.SAMPLE)

0 commit comments

Comments
 (0)