The moment when JIT compilation is done #10430
-
Hello JAX Community; While studying just-in-time compilation in JAX, I am confused about one thing, and I am here to ask it. In official JAX documentation, import jax
import jax.numpy as jnp
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()
selu_jit = jax.jit(selu)
# Warm up
selu_jit(x).block_until_ready()
%timeit selu_jit(x).block_until_ready()
What I want to ask is actually first call. When I think that first call should be done over unoptimized (without operation fusion of XLA) version, since during the first call, tracing and instantiation of jaxpr are performed, so before these ones, XLA cannot compile and optimize |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
The first time As a side note, very early versions of JAX did this differently, and the first function execution was unoptimized, with only subsequent executions being compiled. But JAX has not worked that way for a long time. |
Beta Was this translation helpful? Give feedback.
The first time
selu_jit()
is called, it is traced & compiled using abstract values representing the input arrays, and then once compilation is finished, the actual arrays are passed to the compiled function for execution. The tracing and compilation is done with abstract values representing the input array, not with the input array itself.As a side note, very early versions of JAX did this differently, and the first function execution was unoptimized, with only subsequent executions being compiled. But JAX has not worked that way for a long time.