Skip to content
Discussion options

You must be logged in to vote

My main recommendation here would be to avoid using lax.scatter directly. It is a low-level operation that speaks the language of XLA primitives; generally users will find jnp.ndarray.at a more amenable interface.

For example, if I understand your intent correctly, you could write it like this:

result = operand.at[indices].add(PHI)

If that's not what you have in mind, perhaps you could give a full example of what you're attempting in pytorch (your code above has undefined variables and syntax errors). Thanks!

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@ajithmoola
Comment options

Answer selected by ajithmoola
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants