Skip to content

Commit 33295f1

Browse files
committed
debug
1 parent 2d52ebf commit 33295f1

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

src/forge/observability/metrics.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,25 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
141141
list[Metric]: List of reduced metrics
142142
143143
Example:
144-
states = [
145-
{"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}},
146-
{"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
144+
>>> states = [
145+
... {
146+
... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"},
147+
... "reward/sample": {
148+
... "reduction_type": "sample",
149+
... "samples": [{"episode_id": 1, "reward": 0.5}],
150+
... },
151+
... },
152+
... {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
153+
... ]
154+
>>> reduce_metrics_states(states)
155+
[
156+
Metric(key='loss', value=2.0, reduction=Reduce.MEAN),
157+
Metric(
158+
key='reward/sample',
159+
value=[{'episode_id': 1, 'reward': 0.5}],
160+
reduction=Reduce.SAMPLE,
161+
)
147162
]
148-
reduce_metrics_states(states)
149-
>>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)]
150163
151164
Raises:
152165
ValueError: on mismatched reduction types for the same metric key.
@@ -649,7 +662,6 @@ def push(self, metric: Metric) -> None:
649662

650663
# For PER_RANK_NO_REDUCE backends: stream without reduce
651664
for backend in self.per_rank_no_reduce_backends:
652-
653665
if metric.reduction == Reduce.SAMPLE:
654666
asyncio.create_task(backend.log_samples([metric], self.global_step))
655667
else:
@@ -712,11 +724,9 @@ async def flush(
712724
scalar_metrics = [
713725
m for m in metrics_for_backends if m.reduction != Reduce.SAMPLE
714726
]
715-
sample_metrics = {
716-
m.key: m.value
717-
for m in metrics_for_backends
718-
if m.reduction == Reduce.SAMPLE
719-
}
727+
sample_metrics = [
728+
m for m in metrics_for_backends if m.reduction == Reduce.SAMPLE
729+
]
720730

721731
for backend in self.per_rank_reduce_backends:
722732
if scalar_metrics:
@@ -1001,6 +1011,10 @@ async def log_samples(self, samples: List[Metric], step: int) -> None:
10011011
if not table_rows:
10021012
continue
10031013

1014+
# Convert to list if single sample. This happens when logging stream
1015+
if isinstance(table_rows, dict):
1016+
table_rows = [table_rows]
1017+
10041018
# If table doesn't exist yet, create it in INCREMENTAL mode
10051019
if table_name not in self._tables:
10061020
# Collect all unique columns from all rows

0 commit comments

Comments
 (0)