diff --git a/flax/nnx/training/metrics.py b/flax/nnx/training/metrics.py index 51a341609..8000f4320 100644 --- a/flax/nnx/training/metrics.py +++ b/flax/nnx/training/metrics.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import inspect import typing as tp import numpy as np @@ -23,7 +24,11 @@ from flax.nnx.variablelib import Variable import jax, jax.numpy as jnp -# TODO: add tests and docstrings + +_MULTIMETRIC_RESERVED_NAMES = frozenset({ + 'reset', 'update', 'compute', 'split', + '_metric_names', '_expected_kwargs', +}) class MetricState(Variable): @@ -123,6 +128,15 @@ def compute(self) -> jax.Array: @struct.dataclass class Statistics: + """Running statistics computed by the Welford algorithm. + + Attributes: + mean: the running mean of the data. + standard_error_of_mean: the standard error of the mean. + standard_deviation: the population standard deviation + (ddof=0) of the data. + """ + mean: jnp.float32 standard_error_of_mean: jnp.float32 standard_deviation: jnp.float32 @@ -386,11 +400,52 @@ def __init__(self, **metrics): **metrics: the key-word arguments that will be used to access the corresponding ``Metric``. """ - # TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods - self._metric_names = [] + # Validate metric names before any mutation. + for name in metrics: + if name in _MULTIMETRIC_RESERVED_NAMES: + raise ValueError( + f"Metric name '{name}' conflicts with a reserved " + f'name. Reserved names: ' + f'{sorted(_MULTIMETRIC_RESERVED_NAMES)}' + ) + + self._metric_names: list[str] = [] + self._expected_kwargs: set[str] | None = set() for metric_name, metric in metrics.items(): self._metric_names.append(metric_name) setattr(self, metric_name, metric) + # Collect expected kwargs for validation in update(). + if self._expected_kwargs is None: + continue + sig = inspect.signature(metric.update) + has_named_params = False + has_var_keyword = False + named_param_names: set[str] = set() + for pname, param in sig.parameters.items(): + if pname == 'self': + continue + if param.kind in ( + param.POSITIONAL_OR_KEYWORD, + param.KEYWORD_ONLY, + ): + named_param_names.add(pname) + has_named_params = True + elif param.kind == param.VAR_KEYWORD: + has_var_keyword = True + if has_named_params and has_var_keyword: + # Metric declares specific params but also absorbs + # extras (e.g. Accuracy's **_); can't validate + # without false positives. + self._expected_kwargs = None + elif has_named_params: + self._expected_kwargs.update(named_param_names) + elif hasattr(metric, 'argname'): + # Use argname convention (e.g. Average, Welford). + self._expected_kwargs.add(metric.argname) + elif has_var_keyword: + # Pure **kwargs with no specific params; can't + # validate. + self._expected_kwargs = None def reset(self) -> None: """Reset all underlying ``Metric``'s.""" @@ -398,16 +453,29 @@ def reset(self) -> None: getattr(self, metric_name).reset() def update(self, **updates) -> None: - """In-place update all underlying ``Metric``'s in this ``MultiMetric``. All - ``**updates`` will be passed to the ``update`` method of all underlying - ``Metric``'s. + """In-place update all underlying ``Metric``'s. + + All ``**updates`` are forwarded to each metric's + ``update`` method. Args: - **updates: the key-word arguments that will be passed to the underlying ``Metric``'s - ``update`` method. + **updates: keyword arguments forwarded to each + underlying metric's ``update`` method. + + Raises: + TypeError: if an unexpected keyword argument is + provided and the expected set can be statically + determined from the underlying metrics. """ # TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update - # TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo + if self._expected_kwargs is not None: + unexpected = set(updates) - self._expected_kwargs + if unexpected: + raise TypeError( + f'Unexpected keyword argument(s): ' + f'{sorted(unexpected)}. ' + f'Expected: {sorted(self._expected_kwargs)}' + ) for metric_name in self._metric_names: getattr(self, metric_name).update(**updates) diff --git a/tests/nnx/metrics_test.py b/tests/nnx/metrics_test.py index 5d4a1229a..966cdff14 100644 --- a/tests/nnx/metrics_test.py +++ b/tests/nnx/metrics_test.py @@ -156,5 +156,177 @@ def test_accuracy_dims(self, logits, labels, threshold, error_msg): accuracy.update(logits=logits, labels=labels) +class TestAverage(parameterized.TestCase): + + def test_initial_compute_nan(self): + avg = nnx.metrics.Average() + self.assertTrue(jnp.isnan(avg.compute())) + + def test_single_batch(self): + avg = nnx.metrics.Average() + avg.update(values=jnp.array([1, 2, 3, 4])) + np.testing.assert_allclose(avg.compute(), 2.5, rtol=1e-6) + + def test_multiple_batches(self): + avg = nnx.metrics.Average() + avg.update(values=jnp.array([1, 2, 3, 4])) + avg.update(values=jnp.array([3, 2, 1, 0])) + np.testing.assert_allclose(avg.compute(), 2.0, rtol=1e-6) + + def test_reset(self): + avg = nnx.metrics.Average() + avg.update(values=jnp.array([1, 2, 3])) + avg.reset() + self.assertTrue(jnp.isnan(avg.compute())) + + def test_custom_argname(self): + avg = nnx.metrics.Average('loss') + avg.update(loss=jnp.array([10, 20])) + np.testing.assert_allclose(avg.compute(), 15.0, rtol=1e-6) + + def test_missing_argname(self): + avg = nnx.metrics.Average('loss') + with self.assertRaisesRegex(TypeError, "Expected keyword argument 'loss'"): + avg.update(values=jnp.array([1, 2])) + + def test_scalar_float(self): + avg = nnx.metrics.Average() + avg.update(values=5.0) + np.testing.assert_allclose(avg.compute(), 5.0, rtol=1e-6) + + def test_scalar_int(self): + avg = nnx.metrics.Average() + avg.update(values=3) + np.testing.assert_allclose(avg.compute(), 3.0, rtol=1e-6) + + +class TestWelford(parameterized.TestCase): + + def test_multiple_batches(self): + wf = nnx.metrics.Welford() + batch1 = jnp.array([1.0, 2.0, 3.0, 4.0]) + batch2 = jnp.array([3.0, 2.0, 1.0, 0.0]) + wf.update(values=batch1) + wf.update(values=batch2) + all_values = jnp.concatenate([batch1, batch2]) + stats = wf.compute() + np.testing.assert_allclose(stats.mean, all_values.mean(), rtol=1e-6) + np.testing.assert_allclose( + stats.standard_deviation, all_values.std(), rtol=1e-5 + ) + expected_sem = all_values.std() / jnp.sqrt(all_values.size) + np.testing.assert_allclose( + stats.standard_error_of_mean, expected_sem, rtol=1e-5 + ) + + def test_reset(self): + wf = nnx.metrics.Welford() + wf.update(values=jnp.array([1.0, 2.0, 3.0])) + wf.reset() + stats = wf.compute() + np.testing.assert_allclose(stats.mean, 0.0, atol=0) + self.assertTrue(jnp.isnan(stats.standard_error_of_mean)) + self.assertTrue(jnp.isnan(stats.standard_deviation)) + + def test_custom_argname(self): + wf = nnx.metrics.Welford('loss') + wf.update(loss=jnp.array([1.0, 2.0, 3.0])) + stats = wf.compute() + np.testing.assert_allclose(stats.mean, 2.0, rtol=1e-6) + + def test_missing_argname(self): + wf = nnx.metrics.Welford('loss') + with self.assertRaisesRegex(TypeError, "Expected keyword argument 'loss'"): + wf.update(values=jnp.array([1.0])) + + +class TestAccuracy(parameterized.TestCase): + + def test_multiclass_int64_labels(self): + logits = jnp.array([[0.0, 1.0], [1.0, 0.0]]) + labels = np.array([1, 0], dtype=np.int64) + labels = jnp.asarray(labels) + acc = nnx.metrics.Accuracy() + acc.update(logits=logits, labels=labels) + np.testing.assert_allclose(acc.compute(), 1.0, rtol=1e-6) + + def test_invalid_label_dtype(self): + logits = jnp.array([[0.0, 1.0]]) + labels = jnp.array([1.0]) + acc = nnx.metrics.Accuracy() + with self.assertRaisesRegex(ValueError, 'labels.dtype'): + acc.update(logits=logits, labels=labels) + + def test_threshold_type_error(self): + with self.assertRaisesRegex(TypeError, 'float'): + nnx.metrics.Accuracy(threshold=1) + + +class TestMultiMetric(parameterized.TestCase): + + @parameterized.parameters('reset', 'update', 'compute', 'split') + def test_reserved_name_error(self, name): + with self.assertRaisesRegex(ValueError, 'reserved'): + nnx.MultiMetric(**{name: nnx.metrics.Average()}) + + @parameterized.parameters('_metric_names', '_expected_kwargs') + def test_internal_name_error(self, name): + with self.assertRaisesRegex(ValueError, 'reserved'): + nnx.MultiMetric(**{name: nnx.metrics.Average()}) + + def test_unmatched_kwarg_error(self): + metrics = nnx.MultiMetric( + loss=nnx.metrics.Average(), + score=nnx.metrics.Average('score'), + ) + # Guard: validation must be active for this test. + self.assertEqual( + metrics._expected_kwargs, {'values', 'score'} + ) + with self.assertRaisesRegex( + TypeError, 'Unexpected keyword argument' + ): + metrics.update( + values=jnp.array([1.0]), + score=jnp.array([2.0]), + typo_kwarg=jnp.array([3.0]), + ) + + def test_compute_returns_dict(self): + metrics = nnx.MultiMetric( + loss=nnx.metrics.Average(), + score=nnx.metrics.Average('score'), + ) + metrics.update(values=jnp.array([1, 2, 3]), score=jnp.array([4, 5, 6])) + result = metrics.compute() + self.assertIsInstance(result, dict) + self.assertEqual(set(result.keys()), {'loss', 'score'}) + np.testing.assert_allclose(result['loss'], 2.0, rtol=1e-6) + np.testing.assert_allclose(result['score'], 5.0, rtol=1e-6) + + def test_split_merge(self): + metrics = nnx.MultiMetric( + loss=nnx.metrics.Average(), + ) + metrics.update(values=jnp.array([1.0, 2.0, 3.0])) + graphdef, state = metrics.split() + restored = nnx.merge(graphdef, state) + self.assertEqual(restored._metric_names, ['loss']) + self.assertEqual( + restored._expected_kwargs, {'values'} + ) + np.testing.assert_allclose( + restored.compute()['loss'], 2.0, rtol=1e-6 + ) + + def test_validation_disabled_with_var_keyword(self): + metrics = nnx.MultiMetric( + accuracy=nnx.metrics.Accuracy(), + loss=nnx.metrics.Average(), + ) + # Accuracy.update has **_, so validation is disabled. + self.assertIsNone(metrics._expected_kwargs) + + if __name__ == '__main__': absltest.main()