Jax indexing code runs extremely slowly? #4436
-
This code takes more than 15 seconds to run, while running it in standard numpy is instantaneous. I've looked around but can't find any explanation why. Intuitively, it feels like doing a simple indexing operation should not take nearly this long. Would appreciate any insight here, thank you so much.
(128 is arbitrary, just happens to be relevant to my use-case). |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
Code compiled using I tried your code and replaced your last line with
and got on my setup (time is measured in seconds)
|
Beta Was this translation helpful? Give feedback.
-
I meet the same problem at first with my code, so I've check common gotcha's especially here in-place-updates You can see that the first call will probably create new array every time you assign value to array element and will copy it (it was case with my code :P) Also, first JIT call can be a lot longer cause of unrolling of the loops inside for_loops You can easily fix long compilation time running it on smaller array at first and using lax loops and some time just using numpy for looping and variable assignment is just a lot faster. (In my case almost all my code based on creating new arrays and padding just use pure numpy and then convert it to jax arrrays it's a lot faster when you got a lot of dynamic shaped arrays) @jit
def foo(a):
for i in range(128):
for j in range(128):
s = a[0] # here you probably create new array every time assigment happend without jitted function in pure python it'll be a 'pointer'
return 0 |
Beta Was this translation helpful? Give feedback.
-
The reason this takes a long time to run is because you are timing the tracing and compiling of the function. Once the compilation is finished, subsequent runs will be fast: import jax
import jax.numpy as jnp
from jax import jit
@jit
def foo(a):
for i in range(32):
for j in range(32):
s = a[0]
return 0
a = jnp.zeros(128)
%time foo(a).block_until_ready() # trace, compile & run
# CPU times: user 2.38 s, sys: 28.2 ms, total: 2.41 s
# Wall time: 2.4 s
%time foo(a).block_until_ready() # run the compiled function
# CPU times: user 1.03 ms, sys: 0 ns, total: 1.03 ms
# Wall time: 827 µs Why is compilation so slow here? Well, You can read more about how JIT interacts with Python control flow in https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT |
Beta Was this translation helpful? Give feedback.
The reason this takes a long time to run is because you are timing the tracing and compiling of the function. Once the compilation is finished, subsequent runs will be fast:
Why is compilation so slow here? Well,
jax.jit
effectively unrolls all Python loops in the …