@@ -140,32 +140,12 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
140140 list[Metric]: List of reduced metrics
141141
142142 Example:
143- >>> states = [
144- ... {
145- ... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"},
146- ... "reward/sample": {
147- ... "reduction_type": "sample",
148- ... "samples": [{"episode_id": 1, "reward": 0.5}],
149- ... },
150- ... },
151- ... {
152- ... "loss": {"count": 10, "sum": 16, "reduction_type": "mean"},
153- ... "reward/sample": {
154- ... "reduction_type": "sample",
155- ... "samples": [{"episode_id": 2, "reward": 1.0}],
156- ... },
157- ... },
158- ... ]
159- >>> metrics = reduce_metrics_states(states)
160- >>> for m in metrics:
161- ... print(m)
162- Metric(key='loss', value=2.0, reduction=Reduce.MEAN)
163- Metric(
164- key='reward/sample',
165- value=[{'episode_id': 1, 'reward': 0.5},
166- {'episode_id': 2, 'reward': 1.0}],
167- reduction=Reduce.SAMPLE,
168- )
143+ states = [
144+ {"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}},
145+ {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
146+ ]
147+ reduce_metrics_states(states)
148+ >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)]
169149
170150 Raises:
171151 ValueError: on mismatched reduction types for the same metric key.
0 commit comments