Skip to content
Discussion options

You must be logged in to vote

One pattern I've found useful in this situation is to use where with a mask:

import jax
import jax.numpy as jnp

@jax.jit
def flip(x, start, end):
  indices = jnp.arange(x.shape[0])
  return jnp.where((indices >= start) & (indices < end), x * -1, x)

x = jax.numpy.arange(10)
print(flip(x, 0, 5))
# [ 0 -1 -2 -3 -4  5  6  7  8  9]

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@dedeswim
Comment options

Answer selected by dedeswim
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants