Efficient updates of large array #19125
-
I have a large array as a toy example: x = jnp.zeros([4000,4000]), some update_indices = jr.choice(key, 4000,[n],replace=False) and some update_values = jr.normal(key, [n,4000]) where n is of variable size. I'm aware that I can do x = x.at[update_indices].set(update_values) but that creates a full copy which is not efficient. I know that I can jit the update to do in-place updates, but this does not work with variable sized update_indices. What is the recommended way to update x without doing a full copy of the array? Should I compile jitted in-place batch_sizes of 1,2,4,8,16,32 and then iterate through the updates with the largest compiled batch_size available until all updates are completed? For example, if in one case n= 51, I would update 32 +16+2+1 at a time. I've noticed that when I use non-jitted x = x.at[update_indices].set(update_values) it seems to be compiling it under the hood, where the first few iterations take 0.1,0.2 s but as soon as it encounters an n that it's seen before it takes 0.001 s. Even though it's compiling automatically, is it not a problem that it is doing a full copy of the array during the first iterations? If it was a very large array and I can only fit one instance of it on my gpu, will it not lead to OOM? When I jit the update function is it necessary to use donate_argnums=[0] or will it do the update in place even if I don't donate the argument? Thank you for your time. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
You could do this by JIT-compiling the single update operation, along with import jax
import jax.numpy as jnp
from functools import partial
@partial(jax.jit, donate_argnums=(0,))
def update_inplace(x, indices, values):
return x.at[indices].set(values)
x = jnp.arange(100.0)
indices = jnp.arange(0, 100, 10)
values = jnp.ones(10)
original_pointer = x.unsafe_buffer_pointer()
x = update_inplace(x, indices, values)
assert x.unsafe_buffer_pointer() == original_pointer |
Beta Was this translation helpful? Give feedback.
-
This is not a useful answer, but when I face this issue, I tend to pad |
Beta Was this translation helpful? Give feedback.
Do it outside of JIT.