Skip to content
Discussion options

You must be logged in to vote

I would do this by padding the positions with out-of-bound indices, and then use mode='drop' to ignore them within the set() operation. Something like this:

@jax.jit
def update_positions(arr1, arr2, positions):
    return arr1.at[positions, :, :].set(arr2[positions, :, :], mode='drop')

size = 16
positions_padded = jnp.pad(positions, (0, size - len(positions)), constant_values=arr1.shape[0])

# Update the array based on positions
result1 = update_positions(arr1, arr2, positions)
result2 = update_positions(arr1, arr2, positions_padded)

np.testing.assert_array_equal(result1, result2)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@AakashKumarNain
Comment options

Answer selected by AakashKumarNain
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants