Skip to content

Commit 403e417

Browse files
ds-hwangchanglan
authored andcommitted
Introduce MinSummary and MaxSummary.
GitOrigin-RevId: 8d3511d
1 parent 3f88f22 commit 403e417

File tree

2 files changed

+110
-1
lines changed

2 files changed

+110
-1
lines changed

axlearn/common/metrics.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,48 @@ def accumulate(self, other: Summary) -> Summary:
5050
return self + other
5151

5252

53+
class MinSummary(Summary):
54+
"""A summary that computes the minimum value across tensor elements."""
55+
56+
_value: Tensor
57+
58+
def value(self) -> Tensor:
59+
return self._value
60+
61+
def validate(self):
62+
val = self._value
63+
if not isinstance(val, Tensor):
64+
raise ValueError(f"MinSummary value must be a Tensor, but got {str(type(val))}.")
65+
if val.ndim >= 1:
66+
raise ValueError(f"MinSummary value must be a scalar, but got {val.ndim=}.")
67+
68+
def accumulate(self, other: Summary) -> Summary:
69+
if not isinstance(other, MinSummary):
70+
raise TypeError(f"Expected MinSummary, got {type(other)}.")
71+
return MinSummary(jnp.minimum(self.value(), other.value()))
72+
73+
74+
class MaxSummary(Summary):
75+
"""A summary that computes the maximum value across tensor elements."""
76+
77+
_value: Tensor
78+
79+
def value(self) -> Tensor:
80+
return self._value
81+
82+
def validate(self):
83+
val = self._value
84+
if not isinstance(val, Tensor):
85+
raise ValueError(f"MaxSummary value must be a Tensor, but got {str(type(val))}.")
86+
if val.ndim >= 1:
87+
raise ValueError(f"MaxSummary value must be a scalar, but got {val.ndim=}.")
88+
89+
def accumulate(self, other: Summary) -> Summary:
90+
if not isinstance(other, MaxSummary):
91+
raise TypeError(f"Expected MaxSummary, got {type(other)}.")
92+
return MaxSummary(jnp.maximum(self.value(), other.value()))
93+
94+
5395
class MetricAccumulator(Configurable):
5496
"""A MetricAccumulator is used during evaluation to accumulate metrics across batches."""
5597

axlearn/common/metrics_test.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22

33
"""Tests for metrics.py."""
44

5+
import contextlib
56
from typing import Optional
67

78
import chex
89
import jax
910
import jax.numpy as jnp
10-
from absl.testing import parameterized
11+
from absl.testing import absltest, parameterized
1112

1213
from axlearn.common import metrics, summary, test_utils, utils
14+
from axlearn.common.metrics import MaxSummary, MinSummary
1315
from axlearn.common.module import Summable
1416

1517

@@ -55,6 +57,29 @@ def test_metric_accumulator(self):
5557
expected = jax.tree_util.tree_leaves(expected)
5658
chex.assert_trees_all_close(result, expected)
5759

60+
@parameterized.parameters(
61+
dict(cls=metrics.MinSummary, expected=-10),
62+
dict(cls=metrics.MaxSummary, expected=10),
63+
)
64+
def test_metric_min_max_accumulator(self, cls, expected):
65+
acc = metrics.MetricAccumulator.default_config().instantiate()
66+
summaries = [
67+
dict(foo=cls(jnp.array(5))),
68+
dict(foo=cls(jnp.array(-10))),
69+
dict(foo=cls(jnp.array(10))),
70+
]
71+
72+
summaries_copy = jax.tree.map(lambda x: x, summaries)
73+
for s in summaries_copy:
74+
acc.update(s)
75+
result = acc.summaries()
76+
expected = dict(foo=cls(jnp.array(expected)))
77+
78+
chex.assert_trees_all_equal_structs(result, expected)
79+
result = jax.tree_util.tree_leaves(result)
80+
expected = jax.tree_util.tree_leaves(expected)
81+
self.assertEqual(result, expected)
82+
5883
def test_flatten_unflatten_metric_accumulator(self):
5984
acc = metrics.MetricAccumulator.default_config().instantiate()
6085
summaries = [
@@ -128,3 +153,45 @@ def add(weight):
128153

129154
# Test isinstance check.
130155
self.assertIsInstance(metrics.WeightedScalar(1.0, 1.0), Summable)
156+
157+
158+
class MinSummaryTest(test_utils.TestCase):
159+
@parameterized.parameters(
160+
(jnp.array(1), jnp.array(1)),
161+
(jnp.array([1, 2]), ValueError("MinSummary value must be a scalar, but got val.ndim=1.")),
162+
(jnp.array([[1, 2]]), ValueError("MinSummary value must be a scalar, but got val.ndim=2.")),
163+
(1, ValueError("MinSummary value must be a Tensor, but got <class 'int'>.")),
164+
)
165+
def test_min_summary(self, value, expected):
166+
min_summary = MinSummary(value)
167+
if isinstance(expected, ValueError):
168+
ctx = self.assertRaisesRegex(ValueError, expected.args[0])
169+
else:
170+
ctx = contextlib.nullcontext()
171+
with ctx:
172+
min_summary.validate()
173+
new_summary = min_summary.accumulate(MinSummary(jnp.array(10)))
174+
self.assertEqual(new_summary.value(), expected)
175+
176+
177+
class MaxSummaryTest(test_utils.TestCase):
178+
@parameterized.parameters(
179+
(jnp.array(1), jnp.array(1)),
180+
(jnp.array([1, 2]), ValueError("MaxSummary value must be a scalar, but got val.ndim=1.")),
181+
(jnp.array([[1, 2]]), ValueError("MaxSummary value must be a scalar, but got val.ndim=2.")),
182+
(1, ValueError("MaxSummary value must be a Tensor, but got <class 'int'>.")),
183+
)
184+
def test_min_summary(self, value, expected):
185+
max_summary = MaxSummary(value)
186+
if isinstance(expected, ValueError):
187+
ctx = self.assertRaisesRegex(ValueError, expected.args[0])
188+
else:
189+
ctx = contextlib.nullcontext()
190+
with ctx:
191+
max_summary.validate()
192+
new_summary = max_summary.accumulate(MaxSummary(jnp.array(-10)))
193+
self.assertEqual(new_summary.value(), expected)
194+
195+
196+
if __name__ == "__main__":
197+
absltest.main()

0 commit comments

Comments
 (0)