Replies: 1 comment 1 reply
-
For this basic scenario, you could do something based on import jax.numpy as jnp
import jax.scipy as jsp
m = jnp.array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
vals = jnp.array([2, 6, 12])
masked = jnp.where(jnp.isin(m, vals), m, 0)
convolved = jsp.signal.convolve(masked, jnp.ones((1, 3)), mode='same').astype(m.dtype)
result = jnp.where(convolved != 0, convolved, m)
print(result)
This would have to be modified if |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone,
As the title suggests, I want to assign different values to elements at different positions in an array based on multiple conditions simultaneously. Here’s a simplified example: given an array
m = [ [0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14] ]
I want to modify the elements adjacent to the positions where
m == 2, 6, 12
by setting them to 2, 6, and 12, respectively. The result should be:m = [ [0, 2, 2, 2, 4], [6, 6, 6, 8, 9], [10, 12, 12, 12, 14] ]
I used vmap and where functions to achieve this, with code as follows:
My general approach is to use vmap to locate all positions, modify them individually, and assign a value of 0 to positions that haven’t been modified. Then, I sum everything. Finally, I use the
where
function to keep original values at unmodified positions, while for modified positions, I assign the accumulated values from the previous step.However, I feel this method is overly complicated, and there might be a more efficient solution. Does anyone have any suggestions?
Beta Was this translation helpful? Give feedback.
All reactions