jitted-functions slow performance due to copying of arrays #13225
-
Hello, do jitted-functions always copy every array even though it might be a unity mapping? import jax
import jax.numpy as jnp
def gen_time_me(dim1, dim2, jit_outer, jit_inner):
model = (jnp.zeros((dim1, dim1)), jnp.zeros((dim2, dim2)))
def inner(model):
# change only parts of the model, e.g. only some states
return (model[0], model[1] + 1.0)
if jit_inner:
inner = jax.jit(inner)
# run once
inner(model)
def outer(model):
for _ in range(100):
model = inner(model)
return model
if jit_outer:
outer = jax.jit(outer)
# run once
outer(model)
def time_me():
outer(model)
return time_me
time_me = gen_time_me(1000, 1, False, False)
%timeit time_me()
# 449 µs ± 11.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
time_me = gen_time_me(1000, 1, False, True)
%timeit time_me()
# Why is this so slow? compared to unjitted version?
# Is it because at every inner-step we have to make
# an actual copy of the 1000x1000 array?
# 178 ms ± 6.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
time_me = gen_time_me(1000, 1, True, True)
%timeit time_me()
# 638 µs ± 22.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
time_me = gen_time_me(1, 1000, False, False)
%timeit time_me()
# 152 ms ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
time_me = gen_time_me(1, 1000, False, True)
%timeit time_me()
# 175 ms ± 29.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
time_me = gen_time_me(1, 1000, True, True)
%timeit time_me()
# 2.33 ms ± 184 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
To directly answer your question: yes, jit-compiled functions will generally allocate new memory for function outputs, even if the function happens to be an identity. This is due to the fact that in JAX, arrays are immutable and cannot share memory with other arrays, and the jit-of-identity case is not really important enough to implement an exception to the normal XLA computation path. There are generally two ways around this: first, you could use an outer jit around your function that repeatedly calls the identity, and XLA will optimize-away the repeated identity calls and the copies they'll generate (you see this in your example). Second, on GPU or TPU (not CPU) you could use buffer donation to tell XLA that you want the output of the function to share its buffer with the input, and this should avoid the generation of intermediate copies. A side note: you might look over Benchmarking in JAX for some tips on running more robust micro-benchmarks in JAX. In particular, you should use |
Beta Was this translation helpful? Give feedback.
To directly answer your question: yes, jit-compiled functions will generally allocate new memory for function outputs, even if the function happens to be an identity. This is due to the fact that in JAX, arrays are immutable and cannot share memory with other arrays, and the jit-of-identity case is not really important enough to implement an exception to the normal XLA computation path.
There are generally two ways around this: first, you could use an outer jit around your function that repeatedly calls the identity, and XLA will optimize-away the repeated identity calls and the copies they'll generate (you see this in your example). Second, on GPU or TPU (not CPU) you could use buffer do…