Skip to content
Discussion options

You must be logged in to vote

lax.scatter exists, but is indeed rather complex. For cases where you would use tf.scatter_nd, we recommend using indexed update functions or the equivalent syntactic sugar using the .at property.

The API for indexed updates is very similar but with a different axis order matching NumPy's advanced indexing. To reproduce scatter_nd in JAX you could use:

import jax.numpy as jnp

def scatter_nd(indices, updates, shape):
    zeros = jnp.zeros(shape, updates.dtype)
    key = tuple(jnp.moveaxis(indices, -1, 0))
    return zeros.at[key].add(updates)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jpuigcerver
Comment options

Answer selected by mattjj
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants