-
There are multiple arrays generated by individual process. So, my question is:
In numpy, there are considerable difference in performance. I want to know if that's true in jax. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Great question! The general advice for something like this is that, under JIT, what you do at the high level shouldn't matter: the XLA compiler should be able to find the optimal route to computing what your code expresses. That's the ideal, but in practice, you can sometimes improve things by choosing a different high-level approach. In cases like this micro benchmarks might be revealing, and it looks like a simple import jax.numpy as jnp
from jax import jit, lax
arrays = [jnp.arange(i, 2 * i) for i in range(10, 100)]
@jit
def f1(*arrays):
return lax.concatenate(arrays, 0)
@jit
def f2(*arrays):
size = sum(len(arr) for arr in arrays)
out = jnp.zeros(size, arrays[0].dtype)
start = 0
end = 0
for arr in arrays:
end += len(arr)
out = out.at[start:end].set(arr)
start = end
return out
print(jnp.allclose(f1(*arrays), f2(*arrays)))
# True
%timeit f1(*arrays).block_until_ready()
# 10000 loops, best of 5: 152 µs per loop
%timeit f2(*arrays).block_until_ready()
# 1000 loops, best of 5: 275 µs per loop |
Beta Was this translation helpful? Give feedback.
Great question! The general advice for something like this is that, under JIT, what you do at the high level shouldn't matter: the XLA compiler should be able to find the optimal route to computing what your code expresses.
That's the ideal, but in practice, you can sometimes improve things by choosing a different high-level approach. In cases like this micro benchmarks might be revealing, and it looks like a simple
lax.concatenate
, besides being shorter and easier to read, is 40-50% faster than a loop over index updates: