Unable to understand lax.scatter_add and lax.ScatterDimensionNumbers #19717
-
I have a 2D array I can easily implement this in PyTorch
Similar approach in JAX would look like this
The PyTorch implementation works and yields correct output but the JAX implementation throws an error saying I figured the issue might be with |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
My main recommendation here would be to avoid using For example, if I understand your intent correctly, you could write it like this: result = operand.at[indices].add(PHI) If that's not what you have in mind, perhaps you could give a full example of what you're attempting in pytorch (your code above has undefined variables and syntax errors). Thanks! |
Beta Was this translation helpful? Give feedback.
My main recommendation here would be to avoid using
lax.scatter
directly. It is a low-level operation that speaks the language of XLA primitives; generally users will findjnp.ndarray.at
a more amenable interface.For example, if I understand your intent correctly, you could write it like this:
If that's not what you have in mind, perhaps you could give a full example of what you're attempting in pytorch (your code above has undefined variables and syntax errors). Thanks!