Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
python -m pip install -r ./tests/requirements.txt

- name: Checks with pre-commit
uses: pre-commit/action@v2.0.3
uses: pre-commit/action@v3.0.1

- name: Test with pytest
run: |
Expand Down
158 changes: 124 additions & 34 deletions equinox/nn/_batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import warnings
from collections.abc import Hashable, Sequence
from typing import Literal

import jax
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import Array, Bool, Float, PRNGKeyArray
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray

from .._misc import default_floating_dtype
from .._module import field
Expand Down Expand Up @@ -40,20 +42,57 @@ class BatchNorm(StatefulLayer, strict=True):
statistics updated. During inference then just the running statistics are used.
Whether the model is in training or inference mode should be toggled using
[`equinox.nn.inference_mode`][].

With `mode = "batch"` during training the batch mean and variance are used
for normalization. For inference the exponential running mean and unbiased
variance are used for normalization. This is in line with how other machine
learning packages (e.g. PyTorch, flax, haiku) implement batch norm.

With `mode = "ema"` exponential running means and variances are kept. During
training the batch statistics are used to fill in the running statistics until
they are populated. During inference the running statistics are used for
normalization.

??? cite

[Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/abs/1502.03167)

```bibtex
@article{DBLP:journals/corr/IoffeS15,
author = {Sergey Ioffe and Christian Szegedy},
title = {Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift},
journal = {CoRR},
volume = {abs/1502.03167},
year = {2015},
url = {http://arxiv.org/abs/1502.03167},
eprinttype = {arXiv},
eprint = {1502.03167},
timestamp = {Mon, 13 Aug 2018 16:47:06 +0200},
biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
""" # noqa: E501

weight: Float[Array, "input_size"] | None
bias: Float[Array, "input_size"] | None
first_time_index: StateIndex[Bool[Array, ""]]
state_index: StateIndex[
tuple[Float[Array, "input_size"], Float[Array, "input_size"]]
]
ema_first_time_index: None | StateIndex[Bool[Array, ""]]
ema_state_index: (
None | StateIndex[tuple[Float[Array, "input_size"], Float[Array, "input_size"]]]
)
batch_counter: None | StateIndex[Int[Array, ""]]
batch_state_index: (
None | StateIndex[tuple[Float[Array, "input_size"], Float[Array, "input_size"]]]
)
axis_name: Hashable | Sequence[Hashable]
inference: bool
input_size: int = field(static=True)
eps: float = field(static=True)
channelwise_affine: bool = field(static=True)
momentum: float = field(static=True)
mode: Literal["ema", "batch"] = field(static=True)

def __init__(
self,
Expand All @@ -64,6 +103,7 @@ def __init__(
momentum: float = 0.99,
inference: bool = False,
dtype=None,
mode: Literal["ema", "batch", "legacy"] = "legacy",
):
"""**Arguments:**

Expand All @@ -85,20 +125,44 @@ def __init__(
if `channelwise_affine` is `True`. Defaults to either
`jax.numpy.float32` or `jax.numpy.float64` depending on whether JAX is in
64-bit mode.
- `mode`: The variant of batch norm to use, either 'ema' or 'batch'.
"""
if mode == "legacy":
mode = "ema"
warnings.warn(
"When `eqx.nn.BatchNorm(..., mode=...)` is unspecified it defaults to "
"'ema', for backward compatibility. This typically has a performance "
"impact, and for new code the user is encouraged to use 'batch' "
"instead. See `https://github.com/patrick-kidger/equinox/issues/659`."
)
if mode not in {"ema", "batch"}:
raise ValueError("Invalid mode, must be 'ema' or 'batch'.")
self.mode = mode
dtype = default_floating_dtype() if dtype is None else dtype
if channelwise_affine:
self.weight = jnp.ones((input_size,), dtype=dtype)
self.bias = jnp.zeros((input_size,), dtype=dtype)
else:
self.weight = None
self.bias = None
self.first_time_index = StateIndex(jnp.array(True))
init_buffers = (
jnp.empty((input_size,), dtype=dtype),
jnp.empty((input_size,), dtype=dtype),
)
self.state_index = StateIndex(init_buffers)
if mode == "ema":
self.ema_first_time_index = StateIndex(jnp.array(True))
init_buffers = (
jnp.empty((input_size,), dtype=dtype),
jnp.empty((input_size,), dtype=dtype),
)
self.ema_state_index = StateIndex(init_buffers)
self.batch_counter = None
self.batch_state_index = None
else:
self.batch_counter = StateIndex(jnp.array(0))
init_hidden = (
jnp.zeros((input_size,), dtype=dtype),
jnp.ones((input_size,), dtype=dtype),
)
self.batch_state_index = StateIndex(init_hidden)
self.ema_first_time_index = None
self.ema_state_index = None
self.inference = inference
self.axis_name = axis_name
self.input_size = input_size
Expand Down Expand Up @@ -138,38 +202,64 @@ def __call__(
A `NameError` if no `vmap`s are placed around this operation, or if this vmap
does not have a matching `axis_name`.
"""
del key

if inference is None:
inference = self.inference
if inference:
running_mean, running_var = state.get(self.state_index)
else:

def _stats(y):
mean = jnp.mean(y)
mean = lax.pmean(mean, self.axis_name)
var = jnp.mean((y - mean) * jnp.conj(y - mean))
var = lax.pmean(var, self.axis_name)
var = jnp.maximum(0.0, var)
return mean, var

first_time = state.get(self.first_time_index)
state = state.set(self.first_time_index, jnp.array(False))

batch_mean, batch_var = jax.vmap(_stats)(x)
running_mean, running_var = state.get(self.state_index)
momentum = self.momentum
running_mean = (1 - momentum) * batch_mean + momentum * running_mean
running_var = (1 - momentum) * batch_var + momentum * running_var
running_mean = lax.select(first_time, batch_mean, running_mean)
running_var = lax.select(first_time, batch_var, running_var)
state = state.set(self.state_index, (running_mean, running_var))
def _stats(y):
mean = jnp.mean(y)
mean = lax.pmean(mean, self.axis_name)
var = jnp.mean((y - mean) * jnp.conj(y - mean))
var = lax.pmean(var, self.axis_name)
var = jnp.maximum(0.0, var)
return mean, var

def _norm(y, m, v, w, b):
out = (y - m) / jnp.sqrt(v + self.eps)
if self.channelwise_affine:
out = out * w + b
return out

out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias)
if self.mode == "ema":
assert self.ema_first_time_index is not None
assert self.ema_state_index is not None
if inference:
mean, var = state.get(self.ema_state_index)
else:
first_time = state.get(self.ema_first_time_index)
state = state.set(self.ema_first_time_index, jnp.array(False))
batch_mean, batch_var = jax.vmap(_stats)(x)
running_mean, running_var = state.get(self.ema_state_index)
momentum = self.momentum
mean = (1 - momentum) * batch_mean + momentum * running_mean
var = (1 - momentum) * batch_var + momentum * running_var
# since jnp.array(0) == False
mean = lax.select(first_time, batch_mean, mean)
var = lax.select(first_time, batch_var, var)
state = state.set(self.ema_state_index, (mean, var))
else:
assert self.batch_state_index is not None
assert self.batch_counter is not None
counter = state.get(self.batch_counter)
hidden_mean, hidden_var = state.get(self.batch_state_index)
if inference:
# Zero-debias approach: mean = hidden_mean / (1 - momentum^counter)
# For simplicity we do the minimal version here (no warmup).
scale = 1 - self.momentum**counter
mean = hidden_mean / scale
var = hidden_var / scale
else:
mean, var = jax.vmap(_stats)(x)
new_counter = counter + 1
new_hidden_mean = hidden_mean * self.momentum + mean * (
1 - self.momentum
)
new_hidden_var = hidden_var * self.momentum + var * (1 - self.momentum)
state = state.set(self.batch_counter, new_counter)
state = state.set(
self.batch_state_index, (new_hidden_mean, new_hidden_var)
)

out = jax.vmap(_norm)(x, mean, var, self.weight, self.bias)
return out, state
57 changes: 44 additions & 13 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_sequential(getkey):
[
eqx.nn.Linear(2, 4, key=getkey()),
eqx.nn.Linear(4, 1, key=getkey()),
eqx.nn.BatchNorm(1, axis_name="batch"),
eqx.nn.BatchNorm(1, axis_name="batch", mode="ema"),
eqx.nn.Linear(1, 3, key=getkey()),
]
)
Expand Down Expand Up @@ -176,7 +176,7 @@ def make():
inner_seq = eqx.nn.Sequential(
[
eqx.nn.Linear(2, 4, key=getkey()),
eqx.nn.BatchNorm(4, axis_name="batch")
eqx.nn.BatchNorm(4, axis_name="batch", mode="ema")
if inner_stateful
else eqx.nn.Identity(),
eqx.nn.Linear(4, 3, key=getkey()),
Expand All @@ -186,7 +186,7 @@ def make():
[
eqx.nn.Linear(5, 2, key=getkey()),
inner_seq,
eqx.nn.BatchNorm(3, axis_name="batch")
eqx.nn.BatchNorm(3, axis_name="batch", mode="ema")
if outer_stateful
else eqx.nn.Identity(),
eqx.nn.Linear(3, 6, key=getkey()),
Expand Down Expand Up @@ -949,22 +949,28 @@ def test_group_norm(getkey):
gn = eqx.nn.GroupNorm(groups=4, channels=None, channelwise_affine=True)


def test_batch_norm(getkey):
@pytest.mark.parametrize("mode", ("ema", "batch"))
def test_batch_norm(getkey, mode):
x0 = jrandom.uniform(getkey(), (5,))
x1 = jrandom.uniform(getkey(), (10, 5))
x2 = jrandom.uniform(getkey(), (10, 5, 6))
x3 = jrandom.uniform(getkey(), (10, 5, 7, 8))

# Test that it works with a single vmap'd axis_name

bn = eqx.nn.BatchNorm(5, "batch")
bn = eqx.nn.BatchNorm(5, "batch", mode=mode)
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))

for x in (x1, x2, x3):
out, state = vbn(x, state)
assert out.shape == x.shape
running_mean, running_var = state.get(bn.state_index)
if mode == "ema":
assert bn.ema_state_index is not None
running_mean, running_var = state.get(bn.ema_state_index)
else:
assert bn.batch_state_index is not None
running_mean, running_var = state.get(bn.batch_state_index)
assert running_mean.shape == (5,)
assert running_var.shape == (5,)

Expand All @@ -985,28 +991,38 @@ def test_batch_norm(getkey):
in_axes=(0, None),
)(x2, state)
assert out.shape == x2.shape
running_mean, running_var = state.get(bn.state_index)
if mode == "ema":
assert bn.ema_state_index is not None
running_mean, running_var = state.get(bn.ema_state_index)
else:
assert bn.batch_state_index is not None
running_mean, running_var = state.get(bn.batch_state_index)
assert running_mean.shape == (10, 5)
assert running_var.shape == (10, 5)

# Test that it handles multiple axis_names

vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2"))
vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2"), mode=mode)
vvstate = eqx.nn.State(vvbn)
for axis_name in ("batch1", "batch2"):
vvbn = jax.vmap(
vvbn, axis_name=axis_name, in_axes=(0, None), out_axes=(0, None)
)
out, out_vvstate = vvbn(x2, vvstate)
assert out.shape == x2.shape
running_mean, running_var = out_vvstate.get(vvbn.state_index)
if mode == "ema":
assert vvbn.ema_state_index is not None
running_mean, running_var = out_vvstate.get(vvbn.ema_state_index)
else:
assert vvbn.batch_state_index is not None
running_mean, running_var = out_vvstate.get(vvbn.batch_state_index)
assert running_mean.shape == (6,)
assert running_var.shape == (6,)

# Test that it normalises

x1alt = jrandom.normal(jrandom.PRNGKey(5678), (10, 5)) # avoid flakey test
bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False)
bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False, mode=mode)
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vbn(x1alt, state)
Expand All @@ -1017,9 +1033,19 @@ def test_batch_norm(getkey):

# Test that the statistics update during training
out, state = vbn(x1, state)
running_mean, running_var = state.get(bn.state_index)
if mode == "ema":
assert bn.ema_state_index is not None
running_mean, running_var = state.get(bn.ema_state_index)
else:
assert bn.batch_state_index is not None
running_mean, running_var = state.get(bn.batch_state_index)
out, state = vbn(3 * x1 + 10, state)
running_mean2, running_var2 = state.get(bn.state_index)
if mode == "ema":
assert bn.ema_state_index is not None
running_mean2, running_var2 = state.get(bn.ema_state_index)
else:
assert bn.batch_state_index is not None
running_mean2, running_var2 = state.get(bn.batch_state_index)
assert not jnp.allclose(running_mean, running_mean2)
assert not jnp.allclose(running_var, running_var2)

Expand All @@ -1028,7 +1054,12 @@ def test_batch_norm(getkey):
ibn = eqx.nn.inference_mode(bn, value=True)
vibn = jax.vmap(ibn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vibn(4 * x1 + 20, state)
running_mean3, running_var3 = state.get(bn.state_index)
if mode == "ema":
assert bn.ema_state_index is not None
running_mean3, running_var3 = state.get(bn.ema_state_index)
else:
assert bn.batch_state_index is not None
running_mean3, running_var3 = state.get(bn.batch_state_index)
assert jnp.array_equal(running_mean2, running_mean3)
assert jnp.array_equal(running_var2, running_var3)

Expand Down
8 changes: 6 additions & 2 deletions tests/test_serialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,16 @@ class Model(eqx.Module):
norm1: eqx.nn.BatchNorm
norm2: eqx.nn.BatchNorm

model = Model(eqx.nn.BatchNorm(3, "hi"), eqx.nn.BatchNorm(4, "bye"))
model = Model(
eqx.nn.BatchNorm(3, "hi", mode="ema"), eqx.nn.BatchNorm(4, "bye", mode="ema")
)
state = eqx.nn.State(model)

eqx.tree_serialise_leaves(tmp_path, (model, state))

model2 = Model(eqx.nn.BatchNorm(3, "hi"), eqx.nn.BatchNorm(4, "bye"))
model2 = Model(
eqx.nn.BatchNorm(3, "hi", mode="ema"), eqx.nn.BatchNorm(4, "bye", mode="ema")
)
state2 = eqx.nn.State(model2)

eqx.tree_deserialise_leaves(tmp_path, (model2, state2))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def test_delete_init_state():
model = eqx.nn.BatchNorm(3, "batch")
model = eqx.nn.BatchNorm(3, "batch", mode="ema")
eqx.nn.State(model)
model2 = eqx.nn.delete_init_state(model)

Expand Down