@@ -141,12 +141,25 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
141141 list[Metric]: List of reduced metrics
142142
143143 Example:
144- states = [
145- {"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}},
146- {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
144+ >>> states = [
145+ ... {
146+ ... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"},
147+ ... "reward/sample": {
148+ ... "reduction_type": "sample",
149+ ... "samples": [{"episode_id": 1, "reward": 0.5}],
150+ ... },
151+ ... },
152+ ... {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
153+ ... ]
154+ >>> reduce_metrics_states(states)
155+ [
156+ Metric(key='loss', value=2.0, reduction=Reduce.MEAN),
157+ Metric(
158+ key='reward/sample',
159+ value=[{'episode_id': 1, 'reward': 0.5}],
160+ reduction=Reduce.SAMPLE,
161+ )
147162 ]
148- reduce_metrics_states(states)
149- >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)]
150163
151164 Raises:
152165 ValueError: on mismatched reduction types for the same metric key.
@@ -649,7 +662,6 @@ def push(self, metric: Metric) -> None:
649662
650663 # For PER_RANK_NO_REDUCE backends: stream without reduce
651664 for backend in self .per_rank_no_reduce_backends :
652-
653665 if metric .reduction == Reduce .SAMPLE :
654666 asyncio .create_task (backend .log_samples ([metric ], self .global_step ))
655667 else :
@@ -712,11 +724,9 @@ async def flush(
712724 scalar_metrics = [
713725 m for m in metrics_for_backends if m .reduction != Reduce .SAMPLE
714726 ]
715- sample_metrics = {
716- m .key : m .value
717- for m in metrics_for_backends
718- if m .reduction == Reduce .SAMPLE
719- }
727+ sample_metrics = [
728+ m for m in metrics_for_backends if m .reduction == Reduce .SAMPLE
729+ ]
720730
721731 for backend in self .per_rank_reduce_backends :
722732 if scalar_metrics :
@@ -1001,6 +1011,10 @@ async def log_samples(self, samples: List[Metric], step: int) -> None:
10011011 if not table_rows :
10021012 continue
10031013
1014+ # Convert to list if single sample. This happens when logging stream
1015+ if isinstance (table_rows , dict ):
1016+ table_rows = [table_rows ]
1017+
10041018 # If table doesn't exist yet, create it in INCREMENTAL mode
10051019 if table_name not in self ._tables :
10061020 # Collect all unique columns from all rows
0 commit comments