Conversation
|
Also, for anyone who wants to try out the comparisons, here's some example code that shows haiku vs equinox codeimport jax
from jax import numpy as jnp
import equinox as eqx
import haiku as hk
class Net(hk.Module):
def __init__(
self,
name="net",
):
super().__init__(name=name)
def __call__(self, x, is_training):
x = x.astype(jnp.float32)
x = hk.BatchNorm(True, True, 0.9)(x, is_training)
return x
class Neteqx(eqx.Module):
norm: eqx.nn.BatchNorm
def __init__(
self,
output_channels: int = 64,
):
self.norm = eqx.nn.BatchNorm(
output_channels, "batch", momentum=0.9, mode="batch"
)
def __call__(self, x, state):
x = x.astype(jnp.float32)
x = jnp.moveaxis(x, -1, 0)
x, state = self.norm(x, state)
x = jnp.moveaxis(x, 0, -1)
return x, state
def forward_fn(x):
net = Net()
v = net(x, True)
return v
def inf_fn(x):
net = Net()
v = net(x, False)
return v
forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))
inf = hk.without_apply_rng(hk.transform_with_state(inf_fn))
ins = jnp.array([1, 2, 3, 4]).reshape((4, 1, 1, 1)) * jnp.ones((4, 2, 2, 3))
params, state = forward.init(jax.random.key(0), jnp.zeros_like(ins))
# print(params)
# print(hk.data_structures.tree_size(params))
print(state)
jax.tree.map(lambda x: print(x), state)
for i in range(2):
out, state = forward.apply(params, state, ins[2 * i : 2 * (i + 1)])
jax.tree.map(lambda x: print(x), state)
inf.apply(params, state, ins)
eq, s = eqx.nn.make_with_state(Neteqx)(3)
eq = eqx.tree_at(
lambda x: x.norm.weight, eq, params["net/batch_norm"]["scale"].squeeze()
)
eq = eqx.tree_at(
lambda x: x.norm.bias, eq, params["net/batch_norm"]["offset"].squeeze()
)
# print(sum(x.size for x in jax.tree.leaves(eqx.filter(eq, eqx.is_inexact_array))))
jax.tree.map(lambda x: print(x), s)
for i in range(2):
out_eqx, s = eqx.filter_vmap(
eq, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(ins[2 * i : 2 * (i + 1)], s)
jax.tree.map(lambda x: print(x), s)
# print(out_eqx)
eqx.filter_vmap(
eqx.nn.inference_mode(eq), in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(ins, s) |
|
Ok so I don't know how useful this is, but I validated this implementation against the old ema impl and PyTorch's batch norm.
Here is the script if anyone wants to reproduce it: import equinox as eqx
import jax.numpy as jnp
import jax
import numpy as np
import torch
import old_bn
N = 16
m = 0.1
mode = "ema"
bn_t = torch.nn.BatchNorm1d(N, momentum=m)
bn_e, s = eqx.nn.make_with_state(eqx.nn.BatchNorm)(input_size=N, axis_name="b", mode=mode, momentum=(1-m))
bn_eo, s2 = eqx.nn.make_with_state(old_bn.BatchNorm)(input_size=N, axis_name="b", momentum=(1-m))
for _ in range(50):
x = np.random.rand(N*N).reshape((N, N)).astype(dtype=np.float32)
x_t = torch.tensor(x)
x_j = jnp.array(x)
o_t = bn_t(x_t)
o_j, s = eqx.filter_vmap(bn_e, axis_name="b", in_axes=(0, None), out_axes=(0, None))(x_j, s)
o_eo, s2 = eqx.filter_vmap(bn_eo, axis_name="b", in_axes=(0, None), out_axes=(0, None))(x_j, s2)
o_t = np.array(o_t.detach().cpu())
o_j = np.array(o_j)
o_eo = np.array(o_eo)
delta = np.abs(o_t - o_j)
print('Torch vs. New: ', delta.max())
delta = np.abs(o_j - o_eo)
print('New vs. Old: ',delta.max())Where |
There was a problem hiding this comment.
Okay gosh, I know this one has languished for a long time.
I think this basically LGTM! I've left some minor comments but I think I'm happy to merge this.
Ok so I don't know how useful this is, but I validated this implementation against the old ema impl and PyTorch's batch norm.
As for this, this is super useful! So thank you -- it's another proof point that it's doing what we need it to.
equinox/nn/_batch_norm.py
Outdated
| self, | ||
| input_size: int, | ||
| axis_name: Hashable | Sequence[Hashable], | ||
| mode: str = "ema", |
There was a problem hiding this comment.
Should we make the default "legacy" -- which emits a warning and then sets "ema" -- to encourage picking explicitly?
Also the type annotation here could be a Literal.
There was a problem hiding this comment.
Ah sure, I added a warning
:( |
|
For the pre-commit check, this appears to be resolved by bumping it to v3.0.1. |
|
Don't mean to be pushy, but are there any blockers left on this one? Thanks again @lockwo for the amazing work! |
patrick-kidger
left a comment
There was a problem hiding this comment.
Fortunately this is on my list of PRs to review today :)
I have just one question, where I might be missing something?
equinox/nn/_batch_norm.py
Outdated
| _, (mean, var) = state.get(self.batch_state_index) | ||
| else: | ||
| batch_mean, batch_var = jax.vmap(_stats)(x) | ||
| counter = state.get(self.batch_counter) | ||
| (hidden_mean, hidden_var), (running_mean, running_var) = state.get( | ||
| self.batch_state_index | ||
| ) |
There was a problem hiding this comment.
IIUC then running_mean and running_var don't actually need to be in the state -- they're never consumed on the inference=False branch.
I think we could afford to have just the hidden_mean, hidden_var in the state, and then do the zero-debiasing on the inference=True branch? This would save on both memory and compute.
WDYT?
There was a problem hiding this comment.
If your use case is training heavy then yes. But if you mostly run inference then this adds extra compute as we have to recompute the running data every time?
There was a problem hiding this comment.
So for inference then it's already pretty common to want additional optimisations: for example, fusing the weights/biases of batch norm with the parameters of adjacent linear layers.
I think this should probably come under the same banner. Ideally such behaviour should be obtained via constant propagation in the XLA compiler (or similar).
Exactly how this is usually done in practice I've not dug into, as I don't typically run inference-heavy workloads, but ideally we should provide something that can be optimized to do that, rather than the current implementation which unconditionally takes the slower training path?
There was a problem hiding this comment.
Ah right. Makes sense, thank you!
- Now tracking only the running statistics, not the zero-debiased statistics. These are handled at inference time instead. - Standarised bibtex formatting. - Moved `mode` argument to the end for backward compatibility.
bdad129 to
6dd6e69
Compare
|
Okay, so following on the from above discussion, I've gone ahead and adjusted the above. (I'm sympathetic that you must be a bit tired of this PR at this point!) I've:
@lockwo @ZagButNoZig take a look and let me know what you think? I've also tried running an adapted version of @ZagButNoZig's script, to test agreement between PyTorch and our
|
|
Yea I think that all looks good. Moving to inference I think is totally reasonable (and cleaning things up is always good, like deleting key and what not). I just made a minor tweak to update one of the comments to match the language.. I can resquash things if you want too.
If FOSS has taught me one lesson, it's that good things come to those who wait :). |
|
Awesome! LGTM :) |
|
Awesome, this LGTM then, so merged! Thank you for your help both of you :) |
Revives #675, but ideally a smaller/simpler change (that doesn't require any math) that matches the patterns in flax/haiku (see: https://github.com/google-deepmind/dm-haiku/blob/main/haiku/_src/batch_norm.py#L42%23L206, https://github.com/google-deepmind/dm-haiku/blob/main/haiku/_src/moving_averages.py#L41%23L139). Has some hardcoded defaults (e.g. I don't support warmup iterations at all), but should help the users of batch norm. In addition to the stability in the initial example, here is an example on an AlphaZero model playing 9x9 Go:
This screenshot (and code) also comes from this PR: sotetsuk/pgx#1300