Skip to content
Discussion options

You must be logged in to vote

It sounds like you want a dynamic version of jnp.delete. There's nothing like that included in JAX, but you can write it using a strategy that's something like this:

import jax
import jax.numpy as jnp

@jax.jit
def dynamic_delete(v, i):
  return jnp.where(jnp.arange(len(v) - 1) < i, v[:-1], v[1:])

v = jnp.arange(10)

print(dynamic_delete(v, 3))
# [0 1 2 4 5 6 7 8 9]

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@femtomc
Comment options

@jakevdp
Comment options

@femtomc
Comment options

@femtomc
Comment options

Answer selected by femtomc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants