Implementing scatter(and gather) via one-hot for multiple indices #21784
-
I am having a hard time to jit compile a function that updates an array at multiple indices. Before describing the solutions I tried, let me elaborate the problem first. I have an array @jax.jit
def update_positions(arr1, arr2, positions):
arr1 = arr1[positions, :, :].set(arr2[positions, :, :])
return arr1
T = 1024
C = 32
D = 256
positions = jnp.arange(10)
arr1 = jnp.zeros((T, C, D))
arr2 = jnp.asarray(np.random.rand(T, C, D))
# Update the array based on positions
arr1 = update_positions(arr1, arr2, positions) The ProblemGiven that Tried solutionIdeally, we want to keep the number of compilations to bare minimum. One solution is to pad the array of positions to the next biggest power of 2, and make the updates using this padded array of positions. The problem with this is that same recompilation will trigger for padding. So, either we pre-cache the padded position array, or just pre-cache the original function using some dummy data ExpectationThe above solution is more of a hack rather than a proper solution. Ideally, we should have an array of zeros(of full length) where we scatter the positions = jnp.arange(10)
ohe_positions = jax.nn.one_hot(positions, T)
zeros_array = jnp.zeros_like(arr1)
# scatter the ohe position in the zeros_array
mask = scatter(zeros_array, ohe_positions)
# make updates
arr1 = arr1 + arr2 * mask But I couldn't find an easy way to do this, and any help would be much appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I would do this by padding the positions with out-of-bound indices, and then use @jax.jit
def update_positions(arr1, arr2, positions):
return arr1.at[positions, :, :].set(arr2[positions, :, :], mode='drop')
size = 16
positions_padded = jnp.pad(positions, (0, size - len(positions)), constant_values=arr1.shape[0])
# Update the array based on positions
result1 = update_positions(arr1, arr2, positions)
result2 = update_positions(arr1, arr2, positions_padded)
np.testing.assert_array_equal(result1, result2) |
Beta Was this translation helpful? Give feedback.
I would do this by padding the positions with out-of-bound indices, and then use
mode='drop'
to ignore them within theset()
operation. Something like this: