diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 0934f36b..806ae013 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -27,7 +27,7 @@ jobs: python -m pip install -r ./tests/requirements.txt - name: Checks with pre-commit - uses: pre-commit/action@v2.0.3 + uses: pre-commit/action@v3.0.1 - name: Test with pytest run: | diff --git a/equinox/nn/_batch_norm.py b/equinox/nn/_batch_norm.py index eec94a73..35ee733e 100644 --- a/equinox/nn/_batch_norm.py +++ b/equinox/nn/_batch_norm.py @@ -1,9 +1,11 @@ +import warnings from collections.abc import Hashable, Sequence +from typing import Literal import jax import jax.lax as lax import jax.numpy as jnp -from jaxtyping import Array, Bool, Float, PRNGKeyArray +from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray from .._misc import default_floating_dtype from .._module import field @@ -40,20 +42,57 @@ class BatchNorm(StatefulLayer, strict=True): statistics updated. During inference then just the running statistics are used. Whether the model is in training or inference mode should be toggled using [`equinox.nn.inference_mode`][]. + + With `mode = "batch"` during training the batch mean and variance are used + for normalization. For inference the exponential running mean and unbiased + variance are used for normalization. This is in line with how other machine + learning packages (e.g. PyTorch, flax, haiku) implement batch norm. + + With `mode = "ema"` exponential running means and variances are kept. During + training the batch statistics are used to fill in the running statistics until + they are populated. During inference the running statistics are used for + normalization. + + ??? cite + + [Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift](https://arxiv.org/abs/1502.03167) + + ```bibtex + @article{DBLP:journals/corr/IoffeS15, + author = {Sergey Ioffe and Christian Szegedy}, + title = {Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift}, + journal = {CoRR}, + volume = {abs/1502.03167}, + year = {2015}, + url = {http://arxiv.org/abs/1502.03167}, + eprinttype = {arXiv}, + eprint = {1502.03167}, + timestamp = {Mon, 13 Aug 2018 16:47:06 +0200}, + biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} + } + ``` """ # noqa: E501 weight: Float[Array, "input_size"] | None bias: Float[Array, "input_size"] | None - first_time_index: StateIndex[Bool[Array, ""]] - state_index: StateIndex[ - tuple[Float[Array, "input_size"], Float[Array, "input_size"]] - ] + ema_first_time_index: None | StateIndex[Bool[Array, ""]] + ema_state_index: ( + None | StateIndex[tuple[Float[Array, "input_size"], Float[Array, "input_size"]]] + ) + batch_counter: None | StateIndex[Int[Array, ""]] + batch_state_index: ( + None | StateIndex[tuple[Float[Array, "input_size"], Float[Array, "input_size"]]] + ) axis_name: Hashable | Sequence[Hashable] inference: bool input_size: int = field(static=True) eps: float = field(static=True) channelwise_affine: bool = field(static=True) momentum: float = field(static=True) + mode: Literal["ema", "batch"] = field(static=True) def __init__( self, @@ -64,6 +103,7 @@ def __init__( momentum: float = 0.99, inference: bool = False, dtype=None, + mode: Literal["ema", "batch", "legacy"] = "legacy", ): """**Arguments:** @@ -85,7 +125,19 @@ def __init__( if `channelwise_affine` is `True`. Defaults to either `jax.numpy.float32` or `jax.numpy.float64` depending on whether JAX is in 64-bit mode. + - `mode`: The variant of batch norm to use, either 'ema' or 'batch'. """ + if mode == "legacy": + mode = "ema" + warnings.warn( + "When `eqx.nn.BatchNorm(..., mode=...)` is unspecified it defaults to " + "'ema', for backward compatibility. This typically has a performance " + "impact, and for new code the user is encouraged to use 'batch' " + "instead. See `https://github.com/patrick-kidger/equinox/issues/659`." + ) + if mode not in {"ema", "batch"}: + raise ValueError("Invalid mode, must be 'ema' or 'batch'.") + self.mode = mode dtype = default_floating_dtype() if dtype is None else dtype if channelwise_affine: self.weight = jnp.ones((input_size,), dtype=dtype) @@ -93,12 +145,24 @@ def __init__( else: self.weight = None self.bias = None - self.first_time_index = StateIndex(jnp.array(True)) - init_buffers = ( - jnp.empty((input_size,), dtype=dtype), - jnp.empty((input_size,), dtype=dtype), - ) - self.state_index = StateIndex(init_buffers) + if mode == "ema": + self.ema_first_time_index = StateIndex(jnp.array(True)) + init_buffers = ( + jnp.empty((input_size,), dtype=dtype), + jnp.empty((input_size,), dtype=dtype), + ) + self.ema_state_index = StateIndex(init_buffers) + self.batch_counter = None + self.batch_state_index = None + else: + self.batch_counter = StateIndex(jnp.array(0)) + init_hidden = ( + jnp.zeros((input_size,), dtype=dtype), + jnp.ones((input_size,), dtype=dtype), + ) + self.batch_state_index = StateIndex(init_hidden) + self.ema_first_time_index = None + self.ema_state_index = None self.inference = inference self.axis_name = axis_name self.input_size = input_size @@ -138,32 +202,18 @@ def __call__( A `NameError` if no `vmap`s are placed around this operation, or if this vmap does not have a matching `axis_name`. """ + del key if inference is None: inference = self.inference - if inference: - running_mean, running_var = state.get(self.state_index) - else: - def _stats(y): - mean = jnp.mean(y) - mean = lax.pmean(mean, self.axis_name) - var = jnp.mean((y - mean) * jnp.conj(y - mean)) - var = lax.pmean(var, self.axis_name) - var = jnp.maximum(0.0, var) - return mean, var - - first_time = state.get(self.first_time_index) - state = state.set(self.first_time_index, jnp.array(False)) - - batch_mean, batch_var = jax.vmap(_stats)(x) - running_mean, running_var = state.get(self.state_index) - momentum = self.momentum - running_mean = (1 - momentum) * batch_mean + momentum * running_mean - running_var = (1 - momentum) * batch_var + momentum * running_var - running_mean = lax.select(first_time, batch_mean, running_mean) - running_var = lax.select(first_time, batch_var, running_var) - state = state.set(self.state_index, (running_mean, running_var)) + def _stats(y): + mean = jnp.mean(y) + mean = lax.pmean(mean, self.axis_name) + var = jnp.mean((y - mean) * jnp.conj(y - mean)) + var = lax.pmean(var, self.axis_name) + var = jnp.maximum(0.0, var) + return mean, var def _norm(y, m, v, w, b): out = (y - m) / jnp.sqrt(v + self.eps) @@ -171,5 +221,45 @@ def _norm(y, m, v, w, b): out = out * w + b return out - out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias) + if self.mode == "ema": + assert self.ema_first_time_index is not None + assert self.ema_state_index is not None + if inference: + mean, var = state.get(self.ema_state_index) + else: + first_time = state.get(self.ema_first_time_index) + state = state.set(self.ema_first_time_index, jnp.array(False)) + batch_mean, batch_var = jax.vmap(_stats)(x) + running_mean, running_var = state.get(self.ema_state_index) + momentum = self.momentum + mean = (1 - momentum) * batch_mean + momentum * running_mean + var = (1 - momentum) * batch_var + momentum * running_var + # since jnp.array(0) == False + mean = lax.select(first_time, batch_mean, mean) + var = lax.select(first_time, batch_var, var) + state = state.set(self.ema_state_index, (mean, var)) + else: + assert self.batch_state_index is not None + assert self.batch_counter is not None + counter = state.get(self.batch_counter) + hidden_mean, hidden_var = state.get(self.batch_state_index) + if inference: + # Zero-debias approach: mean = hidden_mean / (1 - momentum^counter) + # For simplicity we do the minimal version here (no warmup). + scale = 1 - self.momentum**counter + mean = hidden_mean / scale + var = hidden_var / scale + else: + mean, var = jax.vmap(_stats)(x) + new_counter = counter + 1 + new_hidden_mean = hidden_mean * self.momentum + mean * ( + 1 - self.momentum + ) + new_hidden_var = hidden_var * self.momentum + var * (1 - self.momentum) + state = state.set(self.batch_counter, new_counter) + state = state.set( + self.batch_state_index, (new_hidden_mean, new_hidden_var) + ) + + out = jax.vmap(_norm)(x, mean, var, self.weight, self.bias) return out, state diff --git a/tests/test_nn.py b/tests/test_nn.py index a4b766cd..210f39f2 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -142,7 +142,7 @@ def test_sequential(getkey): [ eqx.nn.Linear(2, 4, key=getkey()), eqx.nn.Linear(4, 1, key=getkey()), - eqx.nn.BatchNorm(1, axis_name="batch"), + eqx.nn.BatchNorm(1, axis_name="batch", mode="ema"), eqx.nn.Linear(1, 3, key=getkey()), ] ) @@ -176,7 +176,7 @@ def make(): inner_seq = eqx.nn.Sequential( [ eqx.nn.Linear(2, 4, key=getkey()), - eqx.nn.BatchNorm(4, axis_name="batch") + eqx.nn.BatchNorm(4, axis_name="batch", mode="ema") if inner_stateful else eqx.nn.Identity(), eqx.nn.Linear(4, 3, key=getkey()), @@ -186,7 +186,7 @@ def make(): [ eqx.nn.Linear(5, 2, key=getkey()), inner_seq, - eqx.nn.BatchNorm(3, axis_name="batch") + eqx.nn.BatchNorm(3, axis_name="batch", mode="ema") if outer_stateful else eqx.nn.Identity(), eqx.nn.Linear(3, 6, key=getkey()), @@ -949,7 +949,8 @@ def test_group_norm(getkey): gn = eqx.nn.GroupNorm(groups=4, channels=None, channelwise_affine=True) -def test_batch_norm(getkey): +@pytest.mark.parametrize("mode", ("ema", "batch")) +def test_batch_norm(getkey, mode): x0 = jrandom.uniform(getkey(), (5,)) x1 = jrandom.uniform(getkey(), (10, 5)) x2 = jrandom.uniform(getkey(), (10, 5, 6)) @@ -957,14 +958,19 @@ def test_batch_norm(getkey): # Test that it works with a single vmap'd axis_name - bn = eqx.nn.BatchNorm(5, "batch") + bn = eqx.nn.BatchNorm(5, "batch", mode=mode) state = eqx.nn.State(bn) vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None)) for x in (x1, x2, x3): out, state = vbn(x, state) assert out.shape == x.shape - running_mean, running_var = state.get(bn.state_index) + if mode == "ema": + assert bn.ema_state_index is not None + running_mean, running_var = state.get(bn.ema_state_index) + else: + assert bn.batch_state_index is not None + running_mean, running_var = state.get(bn.batch_state_index) assert running_mean.shape == (5,) assert running_var.shape == (5,) @@ -985,13 +991,18 @@ def test_batch_norm(getkey): in_axes=(0, None), )(x2, state) assert out.shape == x2.shape - running_mean, running_var = state.get(bn.state_index) + if mode == "ema": + assert bn.ema_state_index is not None + running_mean, running_var = state.get(bn.ema_state_index) + else: + assert bn.batch_state_index is not None + running_mean, running_var = state.get(bn.batch_state_index) assert running_mean.shape == (10, 5) assert running_var.shape == (10, 5) # Test that it handles multiple axis_names - vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2")) + vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2"), mode=mode) vvstate = eqx.nn.State(vvbn) for axis_name in ("batch1", "batch2"): vvbn = jax.vmap( @@ -999,14 +1010,19 @@ def test_batch_norm(getkey): ) out, out_vvstate = vvbn(x2, vvstate) assert out.shape == x2.shape - running_mean, running_var = out_vvstate.get(vvbn.state_index) + if mode == "ema": + assert vvbn.ema_state_index is not None + running_mean, running_var = out_vvstate.get(vvbn.ema_state_index) + else: + assert vvbn.batch_state_index is not None + running_mean, running_var = out_vvstate.get(vvbn.batch_state_index) assert running_mean.shape == (6,) assert running_var.shape == (6,) # Test that it normalises x1alt = jrandom.normal(jrandom.PRNGKey(5678), (10, 5)) # avoid flakey test - bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False) + bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False, mode=mode) state = eqx.nn.State(bn) vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None)) out, state = vbn(x1alt, state) @@ -1017,9 +1033,19 @@ def test_batch_norm(getkey): # Test that the statistics update during training out, state = vbn(x1, state) - running_mean, running_var = state.get(bn.state_index) + if mode == "ema": + assert bn.ema_state_index is not None + running_mean, running_var = state.get(bn.ema_state_index) + else: + assert bn.batch_state_index is not None + running_mean, running_var = state.get(bn.batch_state_index) out, state = vbn(3 * x1 + 10, state) - running_mean2, running_var2 = state.get(bn.state_index) + if mode == "ema": + assert bn.ema_state_index is not None + running_mean2, running_var2 = state.get(bn.ema_state_index) + else: + assert bn.batch_state_index is not None + running_mean2, running_var2 = state.get(bn.batch_state_index) assert not jnp.allclose(running_mean, running_mean2) assert not jnp.allclose(running_var, running_var2) @@ -1028,7 +1054,12 @@ def test_batch_norm(getkey): ibn = eqx.nn.inference_mode(bn, value=True) vibn = jax.vmap(ibn, axis_name="batch", in_axes=(0, None), out_axes=(0, None)) out, state = vibn(4 * x1 + 20, state) - running_mean3, running_var3 = state.get(bn.state_index) + if mode == "ema": + assert bn.ema_state_index is not None + running_mean3, running_var3 = state.get(bn.ema_state_index) + else: + assert bn.batch_state_index is not None + running_mean3, running_var3 = state.get(bn.batch_state_index) assert jnp.array_equal(running_mean2, running_mean3) assert jnp.array_equal(running_var2, running_var3) diff --git a/tests/test_serialisation.py b/tests/test_serialisation.py index e89c243a..44b214ba 100644 --- a/tests/test_serialisation.py +++ b/tests/test_serialisation.py @@ -239,12 +239,16 @@ class Model(eqx.Module): norm1: eqx.nn.BatchNorm norm2: eqx.nn.BatchNorm - model = Model(eqx.nn.BatchNorm(3, "hi"), eqx.nn.BatchNorm(4, "bye")) + model = Model( + eqx.nn.BatchNorm(3, "hi", mode="ema"), eqx.nn.BatchNorm(4, "bye", mode="ema") + ) state = eqx.nn.State(model) eqx.tree_serialise_leaves(tmp_path, (model, state)) - model2 = Model(eqx.nn.BatchNorm(3, "hi"), eqx.nn.BatchNorm(4, "bye")) + model2 = Model( + eqx.nn.BatchNorm(3, "hi", mode="ema"), eqx.nn.BatchNorm(4, "bye", mode="ema") + ) state2 = eqx.nn.State(model2) eqx.tree_deserialise_leaves(tmp_path, (model2, state2)) diff --git a/tests/test_stateful.py b/tests/test_stateful.py index d7bc632a..cf57b5cd 100644 --- a/tests/test_stateful.py +++ b/tests/test_stateful.py @@ -7,7 +7,7 @@ def test_delete_init_state(): - model = eqx.nn.BatchNorm(3, "batch") + model = eqx.nn.BatchNorm(3, "batch", mode="ema") eqx.nn.State(model) model2 = eqx.nn.delete_init_state(model)