hk.vmap
over a Haiku module with RNG
#11297
-
I am struggling to get Haiku's vmap ( Imports and constants: import haiku as hk
import jax.numpy as jnp
import jax.random as jrand
from jax import jit, vmap I have created a simplified model below, for the sake of clarity. In this case we have a single linear layer with some additive Gaussian noise. INPUT_DIM = 16
OUTPUT_DIM = 3
class SimpleNetwork(hk.Module):
def __call__(self, xx: jnp.ndarray) -> jnp.DeviceArray:
net = hk.Sequential(layers=[hk.Linear(OUTPUT_DIM)])
return net(xx) + jrand.normal(hk.next_rng_key(), (OUTPUT_DIM,)) Using this module for a single input is straightforward and well-documented: key = jrand.PRNGKey(12345)
dummy_input = jnp.ones((1, INPUT_DIM))
network = hk.transform(lambda xx : SimpleNetwork()(xx))
params = network.init(key, dummy_input)
dummy_output = network.apply(params, key, dummy_input)
print(dummy_output) [[ 1.1346992 0.22772181 -1.3999319 ]] Now I want to vectorise this operation over a batch. Doing it naïvely with the vanilla BATCH_SIZE = 8
batched_dummy_input = jnp.ones((BATCH_SIZE, INPUT_DIM))
batched_dummy_output = vmap(network.apply, in_axes=(None,None,0))(params, key, batched_dummy_input)
print(batched_dummy_output) [[ 1.1346992 0.22772181 -1.3999318 ]
[ 1.1346992 0.22772181 -1.3999318 ]
[ 1.1346992 0.22772181 -1.3999318 ]
[ 1.1346992 0.22772181 -1.3999318 ]
[ 1.1346992 0.22772181 -1.3999318 ]
[ 1.1346992 0.22772181 -1.3999318 ]
[ 1.1346992 0.22772181 -1.3999318 ]
[ 1.1346992 0.22772181 -1.3999318 ]] Now, of course this is the result—the same Firstly, I get it working without the RNG splitting: network = hk.transform( hk.vmap(lambda xx : SimpleNetwork()(xx), split_rng=False) )
params = network.init(key, dummy_input) # Works
batched_dummy_output = network.apply(params, key, batched_dummy_input) # Works, prints as above Now, with RNG splitting: network = hk.transform( hk.vmap(lambda xx : SimpleNetwork()(xx), split_rng=True) )
params = network.init(key, dummy_input) Error: ---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [22], in <cell line: 1>()
----> 1 params = network.init(key, dummy_input)
File ~/miniconda3/envs/diss310/lib/python3.10/site-packages/haiku/_src/transform.py:113, in without_state.<locals>.init_fn(*args, **kwargs)
112 def init_fn(*args, **kwargs):
--> 113 params, state = f.init(*args, **kwargs)
114 if state:
115 raise ValueError("If your transformed function uses `hk.{get,set}_state` "
116 "then use `hk.transform_with_state`.")
File ~/miniconda3/envs/diss310/lib/python3.10/site-packages/haiku/_src/transform.py:335, in transform_with_state.<locals>.init_fn(rng, *args, **kwargs)
333 with base.new_context(rng=rng) as ctx:
334 try:
--> 335 f(*args, **kwargs)
336 except jax.errors.UnexpectedTracerError as e:
337 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e
File ~/miniconda3/envs/diss310/lib/python3.10/site-packages/haiku/_src/stateful.py:717, in vmap.<locals>.mapped_fun(*args)
714 saved_rng = state.rng
715 state = InternalState(state.params, state.state, rng)
--> 717 out, state = mapped_pure_fun(args, state)
719 if split_rng:
720 state = InternalState(state.params, state.state, saved_rng)
[... skipping hidden 6 frame]
File ~/miniconda3/envs/diss310/lib/python3.10/site-packages/jax/interpreters/batching.py:674, in matchaxis(axis_name, sz, src, dst, x, sum_match)
671 raise ValueError(f'vmap has mapped output (axis_name={axis_name}) '
672 f'but out_axes is {dst}')
673 else:
--> 674 raise ValueError(f'vmap has mapped output but out_axes is {dst}')
ValueError: vmap has mapped output but out_axes is None I have tried an assortment of Is there something fundamental that I am missing here? I am struggling to find clear documentation on this matter—only that I must be careful. I'd appreciate any help or insights! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi Callum, Thanks for the clear question! The short answer is that you can still use If you rewrite your example of
to
this should do what you want it to (note that we are splitting the key ourselves and changing the in_axes of keys to be 0). The apply key will now be different on every batch. There is a longer answer I can give about that error in If you'd like to apply the vmap inside of the
The error occurs because you're splitting the RNG at initialization time, meaning you're initializing different parameters for every batch. However, Haiku assumes the out_axes of params is always I agree this is not very well documented! I will start by making this error better (tracking here: google-deepmind/dm-haiku#458). (The expressiveness issue I mentioned before is that you might want to split your RNG at init time. There's a few eg. batched parameters:
) |
Beta Was this translation helpful? Give feedback.
Hi Callum,
Thanks for the clear question!
The short answer is that you can still use
jax.vmap
here, if you want tovmap
theapply_fn
. You only really needhk.vmap
if you need thevmap
to be inside of thehk.transform
. Withjax.vmap
you have complete control over whether the RNG state is shared or split across the vmap dimension.If you rewrite your example of
to
this should do what you want it to (note that we are splitting the key …