Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 77 additions & 9 deletions flax/nnx/training/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

import inspect
import typing as tp

import numpy as np
Expand All @@ -23,7 +24,11 @@
from flax.nnx.variablelib import Variable
import jax, jax.numpy as jnp

# TODO: add tests and docstrings

_MULTIMETRIC_RESERVED_NAMES = frozenset({
'reset', 'update', 'compute', 'split',
'_metric_names', '_expected_kwargs',
})


class MetricState(Variable):
Expand Down Expand Up @@ -123,6 +128,15 @@ def compute(self) -> jax.Array:

@struct.dataclass
class Statistics:
"""Running statistics computed by the Welford algorithm.

Attributes:
mean: the running mean of the data.
standard_error_of_mean: the standard error of the mean.
standard_deviation: the population standard deviation
(ddof=0) of the data.
"""

mean: jnp.float32
standard_error_of_mean: jnp.float32
standard_deviation: jnp.float32
Expand Down Expand Up @@ -386,28 +400,82 @@ def __init__(self, **metrics):
**metrics: the key-word arguments that will be used to access
the corresponding ``Metric``.
"""
# TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods
self._metric_names = []
# Validate metric names before any mutation.
for name in metrics:
if name in _MULTIMETRIC_RESERVED_NAMES:
raise ValueError(
f"Metric name '{name}' conflicts with a reserved "
f'name. Reserved names: '
f'{sorted(_MULTIMETRIC_RESERVED_NAMES)}'
)

self._metric_names: list[str] = []
self._expected_kwargs: set[str] | None = set()
for metric_name, metric in metrics.items():
self._metric_names.append(metric_name)
setattr(self, metric_name, metric)
# Collect expected kwargs for validation in update().
if self._expected_kwargs is None:
continue
sig = inspect.signature(metric.update)
has_named_params = False
has_var_keyword = False
named_param_names: set[str] = set()
for pname, param in sig.parameters.items():
if pname == 'self':
continue
if param.kind in (
param.POSITIONAL_OR_KEYWORD,
param.KEYWORD_ONLY,
):
named_param_names.add(pname)
has_named_params = True
elif param.kind == param.VAR_KEYWORD:
has_var_keyword = True
if has_named_params and has_var_keyword:
# Metric declares specific params but also absorbs
# extras (e.g. Accuracy's **_); can't validate
# without false positives.
self._expected_kwargs = None
elif has_named_params:
self._expected_kwargs.update(named_param_names)
elif hasattr(metric, 'argname'):
# Use argname convention (e.g. Average, Welford).
self._expected_kwargs.add(metric.argname)
elif has_var_keyword:
# Pure **kwargs with no specific params; can't
# validate.
self._expected_kwargs = None

def reset(self) -> None:
"""Reset all underlying ``Metric``'s."""
for metric_name in self._metric_names:
getattr(self, metric_name).reset()

def update(self, **updates) -> None:
"""In-place update all underlying ``Metric``'s in this ``MultiMetric``. All
``**updates`` will be passed to the ``update`` method of all underlying
``Metric``'s.
"""In-place update all underlying ``Metric``'s.

All ``**updates`` are forwarded to each metric's
``update`` method.

Args:
**updates: the key-word arguments that will be passed to the underlying ``Metric``'s
``update`` method.
**updates: keyword arguments forwarded to each
underlying metric's ``update`` method.

Raises:
TypeError: if an unexpected keyword argument is
provided and the expected set can be statically
determined from the underlying metrics.
"""
# TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update
# TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo
if self._expected_kwargs is not None:
unexpected = set(updates) - self._expected_kwargs
if unexpected:
raise TypeError(
f'Unexpected keyword argument(s): '
f'{sorted(unexpected)}. '
f'Expected: {sorted(self._expected_kwargs)}'
)
for metric_name in self._metric_names:
getattr(self, metric_name).update(**updates)

Expand Down
172 changes: 172 additions & 0 deletions tests/nnx/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,177 @@ def test_accuracy_dims(self, logits, labels, threshold, error_msg):
accuracy.update(logits=logits, labels=labels)


class TestAverage(parameterized.TestCase):

def test_initial_compute_nan(self):
avg = nnx.metrics.Average()
self.assertTrue(jnp.isnan(avg.compute()))

def test_single_batch(self):
avg = nnx.metrics.Average()
avg.update(values=jnp.array([1, 2, 3, 4]))
np.testing.assert_allclose(avg.compute(), 2.5, rtol=1e-6)

def test_multiple_batches(self):
avg = nnx.metrics.Average()
avg.update(values=jnp.array([1, 2, 3, 4]))
avg.update(values=jnp.array([3, 2, 1, 0]))
np.testing.assert_allclose(avg.compute(), 2.0, rtol=1e-6)

def test_reset(self):
avg = nnx.metrics.Average()
avg.update(values=jnp.array([1, 2, 3]))
avg.reset()
self.assertTrue(jnp.isnan(avg.compute()))

def test_custom_argname(self):
avg = nnx.metrics.Average('loss')
avg.update(loss=jnp.array([10, 20]))
np.testing.assert_allclose(avg.compute(), 15.0, rtol=1e-6)

def test_missing_argname(self):
avg = nnx.metrics.Average('loss')
with self.assertRaisesRegex(TypeError, "Expected keyword argument 'loss'"):
avg.update(values=jnp.array([1, 2]))

def test_scalar_float(self):
avg = nnx.metrics.Average()
avg.update(values=5.0)
np.testing.assert_allclose(avg.compute(), 5.0, rtol=1e-6)

def test_scalar_int(self):
avg = nnx.metrics.Average()
avg.update(values=3)
np.testing.assert_allclose(avg.compute(), 3.0, rtol=1e-6)


class TestWelford(parameterized.TestCase):

def test_multiple_batches(self):
wf = nnx.metrics.Welford()
batch1 = jnp.array([1.0, 2.0, 3.0, 4.0])
batch2 = jnp.array([3.0, 2.0, 1.0, 0.0])
wf.update(values=batch1)
wf.update(values=batch2)
all_values = jnp.concatenate([batch1, batch2])
stats = wf.compute()
np.testing.assert_allclose(stats.mean, all_values.mean(), rtol=1e-6)
np.testing.assert_allclose(
stats.standard_deviation, all_values.std(), rtol=1e-5
)
expected_sem = all_values.std() / jnp.sqrt(all_values.size)
np.testing.assert_allclose(
stats.standard_error_of_mean, expected_sem, rtol=1e-5
)

def test_reset(self):
wf = nnx.metrics.Welford()
wf.update(values=jnp.array([1.0, 2.0, 3.0]))
wf.reset()
stats = wf.compute()
np.testing.assert_allclose(stats.mean, 0.0, atol=0)
self.assertTrue(jnp.isnan(stats.standard_error_of_mean))
self.assertTrue(jnp.isnan(stats.standard_deviation))

def test_custom_argname(self):
wf = nnx.metrics.Welford('loss')
wf.update(loss=jnp.array([1.0, 2.0, 3.0]))
stats = wf.compute()
np.testing.assert_allclose(stats.mean, 2.0, rtol=1e-6)

def test_missing_argname(self):
wf = nnx.metrics.Welford('loss')
with self.assertRaisesRegex(TypeError, "Expected keyword argument 'loss'"):
wf.update(values=jnp.array([1.0]))


class TestAccuracy(parameterized.TestCase):

def test_multiclass_int64_labels(self):
logits = jnp.array([[0.0, 1.0], [1.0, 0.0]])
labels = np.array([1, 0], dtype=np.int64)
labels = jnp.asarray(labels)
acc = nnx.metrics.Accuracy()
acc.update(logits=logits, labels=labels)
np.testing.assert_allclose(acc.compute(), 1.0, rtol=1e-6)

def test_invalid_label_dtype(self):
logits = jnp.array([[0.0, 1.0]])
labels = jnp.array([1.0])
acc = nnx.metrics.Accuracy()
with self.assertRaisesRegex(ValueError, 'labels.dtype'):
acc.update(logits=logits, labels=labels)

def test_threshold_type_error(self):
with self.assertRaisesRegex(TypeError, 'float'):
nnx.metrics.Accuracy(threshold=1)


class TestMultiMetric(parameterized.TestCase):

@parameterized.parameters('reset', 'update', 'compute', 'split')
def test_reserved_name_error(self, name):
with self.assertRaisesRegex(ValueError, 'reserved'):
nnx.MultiMetric(**{name: nnx.metrics.Average()})

@parameterized.parameters('_metric_names', '_expected_kwargs')
def test_internal_name_error(self, name):
with self.assertRaisesRegex(ValueError, 'reserved'):
nnx.MultiMetric(**{name: nnx.metrics.Average()})

def test_unmatched_kwarg_error(self):
metrics = nnx.MultiMetric(
loss=nnx.metrics.Average(),
score=nnx.metrics.Average('score'),
)
# Guard: validation must be active for this test.
self.assertEqual(
metrics._expected_kwargs, {'values', 'score'}
)
with self.assertRaisesRegex(
TypeError, 'Unexpected keyword argument'
):
metrics.update(
values=jnp.array([1.0]),
score=jnp.array([2.0]),
typo_kwarg=jnp.array([3.0]),
)

def test_compute_returns_dict(self):
metrics = nnx.MultiMetric(
loss=nnx.metrics.Average(),
score=nnx.metrics.Average('score'),
)
metrics.update(values=jnp.array([1, 2, 3]), score=jnp.array([4, 5, 6]))
result = metrics.compute()
self.assertIsInstance(result, dict)
self.assertEqual(set(result.keys()), {'loss', 'score'})
np.testing.assert_allclose(result['loss'], 2.0, rtol=1e-6)
np.testing.assert_allclose(result['score'], 5.0, rtol=1e-6)

def test_split_merge(self):
metrics = nnx.MultiMetric(
loss=nnx.metrics.Average(),
)
metrics.update(values=jnp.array([1.0, 2.0, 3.0]))
graphdef, state = metrics.split()
restored = nnx.merge(graphdef, state)
self.assertEqual(restored._metric_names, ['loss'])
self.assertEqual(
restored._expected_kwargs, {'values'}
)
np.testing.assert_allclose(
restored.compute()['loss'], 2.0, rtol=1e-6
)

def test_validation_disabled_with_var_keyword(self):
metrics = nnx.MultiMetric(
accuracy=nnx.metrics.Accuracy(),
loss=nnx.metrics.Average(),
)
# Accuracy.update has **_, so validation is disabled.
self.assertIsNone(metrics._expected_kwargs)


if __name__ == '__main__':
absltest.main()