Skip to content

Commit 37c2ac9

Browse files
committed
add accumulator and test
1 parent 6e77f0b commit 37c2ac9

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

src/forge/observability/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
record_metric,
2525
Reduce,
2626
reduce_metrics_states,
27+
SampleAccumulator,
2728
StdAccumulator,
2829
SumAccumulator,
30+
TopBottomKFilter,
2931
WandbBackend,
3032
)
3133
from .perf_tracker import trace, Tracer
@@ -64,4 +66,7 @@
6466
"MaxAccumulator",
6567
"MinAccumulator",
6668
"StdAccumulator",
69+
"SampleAccumulator",
70+
# Filter classes
71+
"TopBottomKFilter",
6772
]

src/forge/observability/metrics.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import heapq
8+
import itertools
79
import logging
810
import os
911
from abc import ABC, abstractmethod
1012
from dataclasses import dataclass
1113
from datetime import datetime
1214
from enum import Enum
13-
from typing import Any
15+
from typing import Any, Dict, List
1416

1517
import pytz
1618

@@ -68,6 +70,7 @@ class Reduce(Enum):
6870
MAX = "max"
6971
MIN = "min"
7072
STD = "std"
73+
SAMPLE = "sample"
7174

7275
@property
7376
def accumulator_class(self):
@@ -77,6 +80,7 @@ def accumulator_class(self):
7780
Reduce.MAX: MaxAccumulator,
7881
Reduce.MIN: MinAccumulator,
7982
Reduce.STD: StdAccumulator,
83+
Reduce.SAMPLE: SampleAccumulator,
8084
}
8185
return mapping[self]
8286

@@ -182,6 +186,55 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
182186
return reduced_metrics
183187

184188

189+
#################
190+
# SampleFilters #
191+
#################
192+
193+
194+
class TopBottomKFilter:
195+
"""Keep the top-k and bottom-k samples by a given key (e.g., reward)."""
196+
197+
def __init__(self, top_k=1, bottom_k=1, key="reward"):
198+
self.top_k = top_k
199+
self.bottom_k = bottom_k
200+
self.key = key
201+
self._top_heap = [] # min-heap for top-k
202+
self._bottom_heap = [] # max-heap for bottom-k (store -value)
203+
self._counter = itertools.count() # tie-breaker id generator
204+
205+
def filter_append(self, sample: Dict) -> bool:
206+
val = sample.get(self.key, 0.0)
207+
idx = next(self._counter) # unique tiebreaker
208+
209+
# If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none).
210+
# maintain top-k
211+
if self.top_k > 0:
212+
if len(self._top_heap) < self.top_k:
213+
heapq.heappush(self._top_heap, (val, idx, sample))
214+
else:
215+
heapq.heappushpop(self._top_heap, (val, idx, sample))
216+
217+
# maintain bottom-k
218+
if self.bottom_k > 0:
219+
if len(self._bottom_heap) < self.bottom_k:
220+
heapq.heappush(self._bottom_heap, (-val, idx, sample))
221+
else:
222+
heapq.heappushpop(self._bottom_heap, (-val, idx, sample))
223+
224+
# always return False here because we don't store in samples list
225+
return False
226+
227+
def filter_flush(self, samples: List[Dict]) -> List[Dict]:
228+
tops = [s for _, _, s in self._top_heap]
229+
bottoms = [s for _, _, s in self._bottom_heap]
230+
return bottoms + tops
231+
232+
def reset(self):
233+
self._top_heap = []
234+
self._bottom_heap = []
235+
self._counter = itertools.count()
236+
237+
185238
################
186239
# Accumulators #
187240
################
@@ -392,6 +445,50 @@ def reset(self) -> None:
392445
self.count = 0
393446

394447

448+
class SampleAccumulator(MetricAccumulator):
449+
"""Accumulator for sample-level metrics (e.g., prompt/response/reward dicts).
450+
Optionally uses a sample filter to decide what to keep at append/flush time.
451+
"""
452+
453+
def __init__(self, reduction: Reduce):
454+
super().__init__(reduction)
455+
self.samples: List[Dict[str, Any]] = []
456+
self.filter = TopBottomKFilter()
457+
458+
def append(self, value: dict) -> None:
459+
if not isinstance(value, dict):
460+
raise ValueError(f"Expected dict, got {type(value)}")
461+
462+
# Only keep the sample if filter_append returns True
463+
if self.filter.filter_append(value):
464+
self.samples.append(value)
465+
466+
def get_value(self) -> list[dict]:
467+
"""Return locally collected (and optionally filtered) samples."""
468+
# Apply flush-time filter (e.g. heap selection, threshold trimming)
469+
return self.filter.filter_flush(self.samples)
470+
471+
def get_state(self) -> Dict[str, Any]:
472+
"""Serialize accumulator state for cross-rank reduction."""
473+
return {
474+
"reduction_type": self.reduction_type.value,
475+
"samples": self.get_value(),
476+
}
477+
478+
@classmethod
479+
def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> list[dict]:
480+
"""Merge sample states across ranks."""
481+
merged = []
482+
for s in states:
483+
merged.extend(s.get("samples", []))
484+
return merged
485+
486+
def reset(self) -> None:
487+
"""Clear local samples and reset filter state."""
488+
self.samples.clear()
489+
self.filter.reset()
490+
491+
395492
#############
396493
# Collector #
397494
#############

0 commit comments

Comments
 (0)