-
I have a small array wrangling problem which I think should be solvable in a JIT compat way, but I'm having issues figuring out the right idiom (especially concerning permutations). I have a runtime index array So Is there a simple way to construct a runtime permutation, apply it to put the value at I'm stuck on constructing the permutation. I don't think I can do any stacking (that's not trace time compat). I'm wondering if some clever combination of |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
It sounds like you want a dynamic version of import jax
import jax.numpy as jnp
@jax.jit
def dynamic_delete(v, i):
return jnp.where(jnp.arange(len(v) - 1) < i, v[:-1], v[1:])
v = jnp.arange(10)
print(dynamic_delete(v, 3))
# [0 1 2 4 5 6 7 8 9] |
Beta Was this translation helpful? Give feedback.
It sounds like you want a dynamic version of
jnp.delete
. There's nothing like that included in JAX, but you can write it using a strategy that's something like this: