You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a question about doing vmap over a list of haiku functions. This could be useful for example as an easy way to do multihead attention (though, I'm interested in other parallel computation as well).
Following this, I tried using jax.lax.switch but I found that there's a "leak"?
Here is minimal code:
import jax
import jax.numpy as jnp
import haiku as hk
# create network + initialize parameters
def linear(x):
functions = [hk.Linear(64) for i in range(8)]
index = jnp.arange(len(functions))
vmap_functions = jax.vmap(lambda i, x: jax.lax.switch(i, functions, x))
x = vmap_functions(index, x)
return x
x = jnp.zeros((8, 10, 128))
net = hk.without_apply_rng(hk.transform(linear))
params = net.init(jax.random.PRNGKey(42), x)
y = net.apply(params, x)
and this is the error:
UnexpectedTracerError Traceback (most recent call last)
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/transform.py:364, in transform_with_state.<locals>.init_fn(rng, *args, **kwargs)
363 try:
--> 364 f(*args, **kwargs)
365 except jax.errors.UnexpectedTracerError as e:
Input In [1], in linear(x)
11 vmap_functions = jax.vmap(lambda i, x: jax.lax.switch(i, functions, x))
---> 12 x = vmap_functions(index, x)
14 return x
[... skipping hidden 3 frame]
Input In [1], in linear.<locals>.<lambda>(i, x)
9 index = jnp.arange(len(functions))
---> 11 vmap_functions = jax.vmap(lambda i, x: jax.lax.switch(i, functions, x))
12 x = vmap_functions(index, x)
[... skipping hidden 14 frame]
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/module.py:428, in wrap_method.<locals>.wrapped(self, *args, **kwargs)
426 f = stateful.named_call(f, name=local_name)
--> 428 out = f(*args, **kwargs)
430 # Module names are set in the constructor. If `f` is the constructor then
431 # its name will only be set **after** `f` has run. For methods other
432 # than `__init__` we need the name before running in order to wrap their
433 # execution with `named_call`.
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/module.py:279, in run_interceptors(bound_method, method_name, self, *args, **kwargs)
278 if not interceptor_stack:
--> 279 return bound_method(*args, **kwargs)
281 ctx = MethodContext(module=self,
282 method_name=method_name,
283 orig_method=bound_method)
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/basic.py:178, in Linear.__call__(self, inputs, precision)
177 w_init = hk.initializers.TruncatedNormal(stddev=stddev)
--> 178 w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
180 out = jnp.dot(inputs, w, precision=precision)
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/base.py:311, in get_parameter(name, shape, dtype, init)
306 raise ValueError(
307 "Unable to retrieve parameter {!r} for module {!r}. "
308 "All parameters must be created as part of `init`.".format(
309 name, bundle_name))
--> 311 param = run_creators(param_creator_stack, context, shape, dtype, init)
312 params[name] = param # pytype: disable=unsupported-operands
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/base.py:370, in run_creators(stack, context, shape, dtype, init)
369 if not stack:
--> 370 return init(shape, dtype)
372 stack_copy = stack.clone()
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/initializers.py:104, in TruncatedNormal.__call__(self, shape, dtype)
103 s = jax.lax.convert_element_type(self.stddev, dtype)
--> 104 unscaled = jax.random.truncated_normal(hk.next_rng_key(), -2., 2., shape,
105 dtype)
106 return s * unscaled + m
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/base.py:638, in next_rng_key()
637 assert_context("next_rng_key")
--> 638 return next_rng_key_internal()
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/base.py:644, in next_rng_key_internal()
643 rng_seq = rng_seq_or_fail()
--> 644 return next(rng_seq)
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/base.py:586, in PRNGSequence.__next__(self)
585 if not self._subkeys:
--> 586 self.reserve(DEFAULT_PRNG_RESERVE_SIZE)
587 return self._subkeys.popleft()
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/base.py:560, in PRNGSequence.reserve(self, num)
551 if num > 0:
552 # When storing keys we adopt a pattern of key0 being reserved for future
553 # splitting and all other keys being provided to the user in linear order.
(...)
558 #
559 # Where subkey1->subkey4 are provided to the user in order when requested.
--> 560 new_keys = tuple(jax.random.split(self._key, num + 1))
561 self._key = new_keys[0]
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/jax/_src/random.py:190, in split(key, num)
189 key, wrapped = _check_prng_key(key)
--> 190 return _return_prng_keys(wrapped, _split(key, num))
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/jax/_src/random.py:176, in _split(key, num)
173 def _split(key: KeyArray, num: int = 2) -> KeyArray:
174 # Alternative to split() to use within random samplers.
175 # TODO(frostig): remove and use split() once we always enable_custom_prng
--> 176 return key._split(num)
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/jax/_src/prng.py:191, in PRNGKeyArray._split(self, num)
190 def _split(self, num: int) -> 'PRNGKeyArray':
--> 191 return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/jax/_src/prng.py:439, in threefry_split(key, num)
438 def threefry_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
--> 439 return _threefry_split(key, int(num))
[... skipping hidden 5 frame]
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1174, in DynamicJaxprTracer._assert_live(self)
1173 if not self._trace.main.jaxpr_stack: # type: ignore
-> 1174 raise core.escaped_tracer_error(self, None)
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (2,) and dtype uint32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was <unknown> traced for switch.
------------------------------
The leaked intermediate value was created on line /Users/cogscikid/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/base.py:560 (reserve).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/Users/cogscikid/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/initializers.py:104 (__call__)
/Users/cogscikid/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/base.py:638 (next_rng_key)
/Users/cogscikid/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/base.py:644 (next_rng_key_internal)
/Users/cogscikid/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/base.py:586 (__next__)
/Users/cogscikid/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/base.py:560 (reserve)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
The above exception was the direct cause of the following exception:
UnexpectedTracerError Traceback (most recent call last)
Input In [1], in <module>
16 x = jnp.zeros((8, 10, 128))
17 net = hk.without_apply_rng(hk.transform(linear))
---> 18 params = net.init(jax.random.PRNGKey(42), x)
20 y = net.apply(params, x)
22 print(y.shape, jax.tree_map(lambda x: x.shape, params))
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/transform.py:113, in without_state.<locals>.init_fn(*args, **kwargs)
112 def init_fn(*args, **kwargs):
--> 113 params, state = f.init(*args, **kwargs)
114 if state:
115 raise ValueError("If your transformed function uses `hk.{get,set}_state` "
116 "then use `hk.transform_with_state`.")
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/transform.py:366, in transform_with_state.<locals>.init_fn(rng, *args, **kwargs)
364 f(*args, **kwargs)
365 except jax.errors.UnexpectedTracerError as e:
--> 366 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e
367 return ctx.collect_params(), ctx.collect_initial_state()
UnexpectedTracerError: An UnexpectedTracerError was raised while inside a Haiku transformed function (see error above).
Hint: are you using a JAX transform or JAX control-flow function (jax.vmap/jax.scan/...) inside a Haiku transform? You might want to use the Haiku version of the transform instead (hk.vmap/hk.scan/...).
See https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html on why you can't use JAX transforms inside a Haiku module.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, first, thanks for the great tool.
I have a question about doing vmap over a list of haiku functions. This could be useful for example as an easy way to do multihead attention (though, I'm interested in other parallel computation as well).
Following this, I tried using
jax.lax.switch
but I found that there's a "leak"?Here is minimal code:
and this is the error:
Beta Was this translation helpful? Give feedback.
All reactions