|
2 | 2 |
|
3 | 3 | """Tests for metrics.py."""
|
4 | 4 |
|
| 5 | +import contextlib |
5 | 6 | from typing import Optional
|
6 | 7 |
|
7 | 8 | import chex
|
8 | 9 | import jax
|
9 | 10 | import jax.numpy as jnp
|
10 |
| -from absl.testing import parameterized |
| 11 | +from absl.testing import absltest, parameterized |
11 | 12 |
|
12 | 13 | from axlearn.common import metrics, summary, test_utils, utils
|
| 14 | +from axlearn.common.metrics import MaxSummary, MinSummary |
13 | 15 | from axlearn.common.module import Summable
|
14 | 16 |
|
15 | 17 |
|
@@ -55,6 +57,29 @@ def test_metric_accumulator(self):
|
55 | 57 | expected = jax.tree_util.tree_leaves(expected)
|
56 | 58 | chex.assert_trees_all_close(result, expected)
|
57 | 59 |
|
| 60 | + @parameterized.parameters( |
| 61 | + dict(cls=metrics.MinSummary, expected=-10), |
| 62 | + dict(cls=metrics.MaxSummary, expected=10), |
| 63 | + ) |
| 64 | + def test_metric_min_max_accumulator(self, cls, expected): |
| 65 | + acc = metrics.MetricAccumulator.default_config().instantiate() |
| 66 | + summaries = [ |
| 67 | + dict(foo=cls(jnp.array(5))), |
| 68 | + dict(foo=cls(jnp.array(-10))), |
| 69 | + dict(foo=cls(jnp.array(10))), |
| 70 | + ] |
| 71 | + |
| 72 | + summaries_copy = jax.tree.map(lambda x: x, summaries) |
| 73 | + for s in summaries_copy: |
| 74 | + acc.update(s) |
| 75 | + result = acc.summaries() |
| 76 | + expected = dict(foo=cls(jnp.array(expected))) |
| 77 | + |
| 78 | + chex.assert_trees_all_equal_structs(result, expected) |
| 79 | + result = jax.tree_util.tree_leaves(result) |
| 80 | + expected = jax.tree_util.tree_leaves(expected) |
| 81 | + self.assertEqual(result, expected) |
| 82 | + |
58 | 83 | def test_flatten_unflatten_metric_accumulator(self):
|
59 | 84 | acc = metrics.MetricAccumulator.default_config().instantiate()
|
60 | 85 | summaries = [
|
@@ -128,3 +153,45 @@ def add(weight):
|
128 | 153 |
|
129 | 154 | # Test isinstance check.
|
130 | 155 | self.assertIsInstance(metrics.WeightedScalar(1.0, 1.0), Summable)
|
| 156 | + |
| 157 | + |
| 158 | +class MinSummaryTest(test_utils.TestCase): |
| 159 | + @parameterized.parameters( |
| 160 | + (jnp.array(1), jnp.array(1)), |
| 161 | + (jnp.array([1, 2]), ValueError("MinSummary value must be a scalar, but got val.ndim=1.")), |
| 162 | + (jnp.array([[1, 2]]), ValueError("MinSummary value must be a scalar, but got val.ndim=2.")), |
| 163 | + (1, ValueError("MinSummary value must be a Tensor, but got <class 'int'>.")), |
| 164 | + ) |
| 165 | + def test_min_summary(self, value, expected): |
| 166 | + min_summary = MinSummary(value) |
| 167 | + if isinstance(expected, ValueError): |
| 168 | + ctx = self.assertRaisesRegex(ValueError, expected.args[0]) |
| 169 | + else: |
| 170 | + ctx = contextlib.nullcontext() |
| 171 | + with ctx: |
| 172 | + min_summary.validate() |
| 173 | + new_summary = min_summary.accumulate(MinSummary(jnp.array(10))) |
| 174 | + self.assertEqual(new_summary.value(), expected) |
| 175 | + |
| 176 | + |
| 177 | +class MaxSummaryTest(test_utils.TestCase): |
| 178 | + @parameterized.parameters( |
| 179 | + (jnp.array(1), jnp.array(1)), |
| 180 | + (jnp.array([1, 2]), ValueError("MaxSummary value must be a scalar, but got val.ndim=1.")), |
| 181 | + (jnp.array([[1, 2]]), ValueError("MaxSummary value must be a scalar, but got val.ndim=2.")), |
| 182 | + (1, ValueError("MaxSummary value must be a Tensor, but got <class 'int'>.")), |
| 183 | + ) |
| 184 | + def test_min_summary(self, value, expected): |
| 185 | + max_summary = MaxSummary(value) |
| 186 | + if isinstance(expected, ValueError): |
| 187 | + ctx = self.assertRaisesRegex(ValueError, expected.args[0]) |
| 188 | + else: |
| 189 | + ctx = contextlib.nullcontext() |
| 190 | + with ctx: |
| 191 | + max_summary.validate() |
| 192 | + new_summary = max_summary.accumulate(MaxSummary(jnp.array(-10))) |
| 193 | + self.assertEqual(new_summary.value(), expected) |
| 194 | + |
| 195 | + |
| 196 | +if __name__ == "__main__": |
| 197 | + absltest.main() |
0 commit comments