-
Hello! I would like to flip the sign of an array at some indices which are computed dynamically. I tried some stuff, such as this: import jax
@jax.jit
def flip(x, start, end):
return x.at[start:end].mul(-1)
x = jax.numpy.arange(10)
flipped_x = flip(x, 0, 5) but nothing works because, apparently, this creates a dynamically shaped array. Is it possible to do this in a I tried to take a look at Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
One pattern I've found useful in this situation is to use import jax
import jax.numpy as jnp
@jax.jit
def flip(x, start, end):
indices = jnp.arange(x.shape[0])
return jnp.where((indices >= start) & (indices < end), x * -1, x)
x = jax.numpy.arange(10)
print(flip(x, 0, 5))
# [ 0 -1 -2 -3 -4 5 6 7 8 9] |
Beta Was this translation helpful? Give feedback.
One pattern I've found useful in this situation is to use
where
with a mask: