-
|
Does JAX use cache when JIT-compiling nested functions? Consider this example: import jax
import jax.numpy as jnp
def outer(a: jax.Array):
def inner(x: jax.Array):
print("compiling")
return x * 2
# print(f"id(inner) = {id(inner)}")
# print(f"hash(inner) = {hash(inner)}")
jitted_inner = jax.jit(inner)
return jitted_inner(a)
a = jax.random.normal(jax.random.key(0), (3, 4))
outer(a)When calling Some sources say that If Python re-defines nested functions on each call, how do I make JIT cache them? Full output after 3 calls: System info:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 11 replies
-
|
As you've identified here, since def inner(x: jax.Array):
print("compiling")
return x * 2
def outer(a: jax.Array):
jitted_inner = jax.jit(inner)
return jitted_inner(a)which would work as you intend. To be completely explicit I would probably also move the @jax.jit
def inner(x: jax.Array):
print("compiling")
return x * 2
def outer(a: jax.Array):
return inner(a)but the former seems to do the trick as well. In this simple example, I don't see any reason why you wouldn't want to refactor like this, but I'm not sure how easily this generalizes to your case. Either way, I hope it helps! |
Beta Was this translation helpful? Give feedback.
That's what I expected! Yeah, like you say, the usual advice here would be to move the
jitas high up the stack as a you can. For example, in the flax examples that you link to, thejitis applied to the training step, e.g.:in which case
inneris only compiled once!But, there are cases where this won't necessarily work (e.g. long compile times, etc.). In that case, maybe you could try converting the closure into a compiled function (at the global level) which takes the relevant parameters as static arguments, which should also lead to a cache hit.