-
Consider a top-level function import jax
import jax.numpy as jnp
def wrapper(x: jax.Array, start: int):
return jnp.triu(x, k=start + 1)
x = jax.random.normal(jax.random.PRNGKey(0), (3, 3))
wrapper(x, 1) # OK
Now I want to JIT-compile wrapper too. However, calling
As far I understand, JIT-compiling the wrapper cancels JIT-compilation of the inner
But in more realistic examples the call stack is much deeper and specifying Is my understanding correct? If so, are there any workarounds to specify |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
The issue is: you cannot pass a dynamic argument to a function which expects a static argument. In your first example, Does that make sense? |
Beta Was this translation helpful? Give feedback.
-
Correct me if I'm wrong, but static arguments actually can be passed to a JIT-wrapped function, yet they will lead to re-compilation with the new value, right? For example, if I invoke In [3]: x = jax.random.normal(jax.random.PRNGKey(0), (3, 3))
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [4]: jnp.triu._cache_size()
Out[4]: 0
In [5]: jnp.triu(x, 1)
Out[5]:
Array([[ 0. , 0.26423115, -0.18252768],
[ 0. , 0. , -0.1521442 ],
[ 0. , 0. , 0. ]], dtype=float32)
In [6]: jnp.triu._cache_size()
Out[6]: 1
In [7]: jnp.triu(x, 2)
Out[7]:
Array([[ 0. , 0. , -0.18252768],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]], dtype=float32)
In [8]: jnp.triu._cache_size()
Out[8]: 2
In [9]: x2 = jax.random.normal(jax.random.PRNGKey(0), (3, 3))
In [10]: jnp.triu(x2, 2) # x2 is not static and has the same size as x, so it doesn't trigger re-compilation
Out[10]:
Array([[ 0. , 0. , -0.18252768],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]], dtype=float32)
In [11]: jnp.triu._cache_size()
Out[11]: 2 Then the question is whether a JIT-compiled function can access this cache & compile missing sets of arguments. From your answer I understand it cannot. Probably, because it would require XLA to call back to the compiler, which is impossible. Is this analysis correct? If so, could you please clarify one more detail? If I lift static arguments to the wrapper (i.e. |
Beta Was this translation helpful? Give feedback.
The issue is: you cannot pass a dynamic argument to a function which expects a static argument.
In your first example,
jit(wrapper)(x, 1)
results instart
being treated as a dynamic variable, and so it cannot be passed to a function that requires a static input. There is no way for the innerjit
to communicate to the outerjit
which arguments should be static and which should not, and in my opinion it would not be desirable for the API to support that kind of "action at a distance". As they say, explicit is better than implicit, and so you must explicitly specify which arguments are static for each function you wrap injit
.Does that make sense?