Skip to content

Batch BatchNorm#948

Merged
patrick-kidger merged 3 commits intopatrick-kidger:devfrom
lockwo:Owen/batch-norm
Apr 14, 2025
Merged

Batch BatchNorm#948
patrick-kidger merged 3 commits intopatrick-kidger:devfrom
lockwo:Owen/batch-norm

Conversation

@lockwo
Copy link
Copy Markdown
Contributor

@lockwo lockwo commented Feb 10, 2025

Revives #675, but ideally a smaller/simpler change (that doesn't require any math) that matches the patterns in flax/haiku (see: https://github.com/google-deepmind/dm-haiku/blob/main/haiku/_src/batch_norm.py#L42%23L206, https://github.com/google-deepmind/dm-haiku/blob/main/haiku/_src/moving_averages.py#L41%23L139). Has some hardcoded defaults (e.g. I don't support warmup iterations at all), but should help the users of batch norm. In addition to the stability in the initial example, here is an example on an AlphaZero model playing 9x9 Go:

Screenshot 2025-02-09 at 9 37 30 PM

This screenshot (and code) also comes from this PR: sotetsuk/pgx#1300

@lockwo lockwo changed the title Batch norm changes Batch BatchNorm Feb 10, 2025
@lockwo
Copy link
Copy Markdown
Contributor Author

lockwo commented Feb 10, 2025

Also, for anyone who wants to try out the comparisons, here's some example code that shows haiku vs equinox

code
import jax
from jax import numpy as jnp
import equinox as eqx
import haiku as hk


class Net(hk.Module):

    def __init__(
        self,
        name="net",
    ):
        super().__init__(name=name)

    def __call__(self, x, is_training):
        x = x.astype(jnp.float32)
        x = hk.BatchNorm(True, True, 0.9)(x, is_training)
        return x


class Neteqx(eqx.Module):
    norm: eqx.nn.BatchNorm

    def __init__(
        self,
        output_channels: int = 64,
    ):
        self.norm = eqx.nn.BatchNorm(
            output_channels, "batch", momentum=0.9, mode="batch"
        )

    def __call__(self, x, state):
        x = x.astype(jnp.float32)
        x = jnp.moveaxis(x, -1, 0)
        x, state = self.norm(x, state)
        x = jnp.moveaxis(x, 0, -1)
        return x, state


def forward_fn(x):
    net = Net()
    v = net(x, True)
    return v


def inf_fn(x):
    net = Net()
    v = net(x, False)
    return v


forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))
inf = hk.without_apply_rng(hk.transform_with_state(inf_fn))

ins = jnp.array([1, 2, 3, 4]).reshape((4, 1, 1, 1)) * jnp.ones((4, 2, 2, 3))
params, state = forward.init(jax.random.key(0), jnp.zeros_like(ins))
# print(params)
# print(hk.data_structures.tree_size(params))
print(state)
jax.tree.map(lambda x: print(x), state)

for i in range(2):
    out, state = forward.apply(params, state, ins[2 * i : 2 * (i + 1)])
    jax.tree.map(lambda x: print(x), state)

inf.apply(params, state, ins)

eq, s = eqx.nn.make_with_state(Neteqx)(3)
eq = eqx.tree_at(
    lambda x: x.norm.weight, eq, params["net/batch_norm"]["scale"].squeeze()
)
eq = eqx.tree_at(
    lambda x: x.norm.bias, eq, params["net/batch_norm"]["offset"].squeeze()
)

# print(sum(x.size for x in jax.tree.leaves(eqx.filter(eq, eqx.is_inexact_array))))
jax.tree.map(lambda x: print(x), s)

for i in range(2):
    out_eqx, s = eqx.filter_vmap(
        eq, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
    )(ins[2 * i : 2 * (i + 1)], s)
    jax.tree.map(lambda x: print(x), s)
    # print(out_eqx)

eqx.filter_vmap(
    eqx.nn.inference_mode(eq), in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(ins, s)

@ZagButNoZig
Copy link
Copy Markdown
Contributor

ZagButNoZig commented Mar 30, 2025

Ok so I don't know how useful this is, but I validated this implementation against the old ema impl and PyTorch's batch norm.

mode="ema" has exactly the same output as the old version and mode="batch" is withing 10e-6 of the PyTorch implementation, which is exactly what we want?

Here is the script if anyone wants to reproduce it:

import equinox as eqx
import jax.numpy as jnp
import jax
import numpy as np
import torch
import old_bn

N = 16
m = 0.1
mode = "ema"

bn_t = torch.nn.BatchNorm1d(N, momentum=m)
bn_e, s = eqx.nn.make_with_state(eqx.nn.BatchNorm)(input_size=N, axis_name="b", mode=mode, momentum=(1-m))
bn_eo, s2 = eqx.nn.make_with_state(old_bn.BatchNorm)(input_size=N, axis_name="b", momentum=(1-m))

for _ in  range(50):
    x = np.random.rand(N*N).reshape((N, N)).astype(dtype=np.float32)

    x_t = torch.tensor(x)
    x_j = jnp.array(x)
    
    o_t = bn_t(x_t)
    o_j, s = eqx.filter_vmap(bn_e, axis_name="b", in_axes=(0, None), out_axes=(0, None))(x_j, s)
    o_eo, s2 = eqx.filter_vmap(bn_eo, axis_name="b", in_axes=(0, None), out_axes=(0, None))(x_j, s2)

    o_t = np.array(o_t.detach().cpu())
    o_j = np.array(o_j)
    o_eo = np.array(o_eo)

    delta = np.abs(o_t - o_j)
    print('Torch vs. New: ', delta.max())
    
    delta = np.abs(o_j - o_eo)
    print('New vs. Old: ',delta.max())

Where old_bn.BatchNorm is BatchNorm@main.

Copy link
Copy Markdown
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay gosh, I know this one has languished for a long time.

I think this basically LGTM! I've left some minor comments but I think I'm happy to merge this.

Ok so I don't know how useful this is, but I validated this implementation against the old ema impl and PyTorch's batch norm.

As for this, this is super useful! So thank you -- it's another proof point that it's doing what we need it to.

self,
input_size: int,
axis_name: Hashable | Sequence[Hashable],
mode: str = "ema",
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make the default "legacy" -- which emits a warning and then sets "ema" -- to encourage picking explicitly?

Also the type annotation here could be a Literal.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sure, I added a warning

@lockwo
Copy link
Copy Markdown
Contributor Author

lockwo commented Apr 1, 2025

Run pre-commit/action@v2.0.3
install pre-commit
Error: getCacheEntry failed: This legacy service is shutting down, effective April 15, 2025. Migrate to the new service ASAP. For more information: https://gh.io/gha-cache-sunset
The brownout dates and times are as follows:

April 1, 2025, 3 p.m. – 7 p.m. UTC

:(

@patrick-kidger
Copy link
Copy Markdown
Owner

For the pre-commit check, this appears to be resolved by bumping it to v3.0.1.

@lockwo lockwo requested a review from patrick-kidger April 2, 2025 17:25
@ZagButNoZig
Copy link
Copy Markdown
Contributor

Don't mean to be pushy, but are there any blockers left on this one?

Thanks again @lockwo for the amazing work!

Copy link
Copy Markdown
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fortunately this is on my list of PRs to review today :)

I have just one question, where I might be missing something?

Comment on lines +258 to +264
_, (mean, var) = state.get(self.batch_state_index)
else:
batch_mean, batch_var = jax.vmap(_stats)(x)
counter = state.get(self.batch_counter)
(hidden_mean, hidden_var), (running_mean, running_var) = state.get(
self.batch_state_index
)
Copy link
Copy Markdown
Owner

@patrick-kidger patrick-kidger Apr 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC then running_mean and running_var don't actually need to be in the state -- they're never consumed on the inference=False branch.

I think we could afford to have just the hidden_mean, hidden_var in the state, and then do the zero-debiasing on the inference=True branch? This would save on both memory and compute.

WDYT?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If your use case is training heavy then yes. But if you mostly run inference then this adds extra compute as we have to recompute the running data every time?

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for inference then it's already pretty common to want additional optimisations: for example, fusing the weights/biases of batch norm with the parameters of adjacent linear layers.

I think this should probably come under the same banner. Ideally such behaviour should be obtained via constant propagation in the XLA compiler (or similar).

Exactly how this is usually done in practice I've not dug into, as I don't typically run inference-heavy workloads, but ideally we should provide something that can be optimized to do that, rather than the current implementation which unconditionally takes the slower training path?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right. Makes sense, thank you!

lockwo and others added 2 commits April 14, 2025 21:54
- Now tracking only the running statistics, not the zero-debiased statistics. These are handled at inference time instead.
- Standarised bibtex formatting.
- Moved `mode` argument to the end for backward compatibility.
@patrick-kidger patrick-kidger changed the base branch from main to dev April 14, 2025 20:13
@patrick-kidger
Copy link
Copy Markdown
Owner

Okay, so following on the from above discussion, I've gone ahead and adjusted the above. (I'm sympathetic that you must be a bit tired of this PR at this point!) I've:

  • Rebased on the new dev and squashed all your commits together.
  • Updated the zero scaling to happen at inference time, rather than training time.
  • Performed various stylistic simplifications.

@lockwo @ZagButNoZig take a look and let me know what you think?

I've also tried running an adapted version of @ZagButNoZig's script, to test agreement between PyTorch and our mode="batch", and between our mode="batch" and the version of BatchNorm currently available on main, and get agreement throughout.

Click for code

import equinox as eqx
import equinox.nn.old_bn
import jax.numpy as jnp
import numpy as np
import torch


N = 16
m = 0.1

bn_torch = torch.nn.BatchNorm1d(N, momentum=m)
bn_batch, s_batch = eqx.nn.make_with_state(eqx.nn.BatchNorm)(
    input_size=N, axis_name="b", mode="batch", momentum=(1 - m)
)
bn_ema, s_ema = eqx.nn.make_with_state(eqx.nn.BatchNorm)(
    input_size=N, axis_name="b", mode="ema", momentum=(1 - m)
)
bn_old, s_old = eqx.nn.make_with_state(equinox.nn.old_bn.BatchNorm)(
    input_size=N, axis_name="b", momentum=(1 - m)
)

for _ in range(50):
    x = np.random.rand(N * N).reshape((N, N)).astype(dtype=np.float32)

    x_t = torch.tensor(x)
    x_j = jnp.array(x)

    o_torch = bn_torch(x_t)
    o_batch, s_batch = eqx.filter_vmap(
        bn_batch, axis_name="b", in_axes=(0, None), out_axes=(0, None)
    )(x_j, s_batch)
    o_ema, s_ema = eqx.filter_vmap(
        bn_ema, axis_name="b", in_axes=(0, None), out_axes=(0, None)
    )(x_j, s_ema)
    o_old, s_old = eqx.filter_vmap(
        bn_old, axis_name="b", in_axes=(0, None), out_axes=(0, None)
    )(x_j, s_old)

    o_torch = np.array(o_torch.detach().cpu())
    o_batch = np.array(o_batch)
    o_ema = np.array(o_ema)
    o_old = np.array(o_old)

    delta = np.abs(o_torch - o_batch)
    print("Torch vs. batch: ", delta.max())

    delta = np.abs(o_ema - o_old)
    print("ema vs. old: ", delta.max())

@lockwo
Copy link
Copy Markdown
Contributor Author

lockwo commented Apr 14, 2025

Yea I think that all looks good. Moving to inference I think is totally reasonable (and cleaning things up is always good, like deleting key and what not). I just made a minor tweak to update one of the comments to match the language..

I can resquash things if you want too.

I'm sympathetic that you must be a bit tired of this PR at this point!

If FOSS has taught me one lesson, it's that good things come to those who wait :).

@ZagButNoZig
Copy link
Copy Markdown
Contributor

Awesome! LGTM :)

@patrick-kidger patrick-kidger merged commit 9d0ed8a into patrick-kidger:dev Apr 14, 2025
1 check passed
@patrick-kidger
Copy link
Copy Markdown
Owner

Awesome, this LGTM then, so merged! Thank you for your help both of you :)

@lockwo lockwo deleted the Owen/batch-norm branch April 14, 2025 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants