Skip to content

Commit 98054fc

Browse files
committed
per_rank_no_reduce mode
1 parent b2c8d88 commit 98054fc

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/forge/observability/metrics.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import asyncio
78
import heapq
89
import itertools
910
import logging
@@ -715,12 +716,13 @@ def push(self, metric: Metric) -> None:
715716

716717
# For PER_RANK_NO_REDUCE backends: stream without reduce
717718
for backend in self.per_rank_no_reduce_backends:
718-
# if metric.reduction == Reduce.SAMPLE:
719-
# # Wrap singleton Metric into expected {key: [list_of_dicts]} format
720-
# sample = {metric.key: [metric.value]}
721-
# asyncio.create_task(backend.log_samples(sample, self.global_step))
722-
# else:
723-
backend.log_stream(metric=metric, global_step=self.global_step)
719+
720+
if metric.reduction == Reduce.SAMPLE:
721+
# Wrap singleton Metric into expected {key: [list_of_dicts]} format
722+
sample = {metric.key: [metric.value]}
723+
asyncio.create_task(backend.log_samples(sample, self.global_step))
724+
else:
725+
backend.log_stream(metric=metric, global_step=self.global_step)
724726

725727
# Always accumulate for reduction and state return
726728
key = metric.key

0 commit comments

Comments
 (0)