How to construct a 'jit'able function to delete an entry of an array #13633
Answered
by
jakevdp
abhiroop513
asked this question in
Q&A
-
I was trying to write the following 'jit'able function to delete some entries of an array: import jax.numpy as jnp
from jax import jit
x = jnp.array([1., 2., 3., 4., 5., 6., 7., 8., 9.])
def zdel(x, indx):
return jnp.delete(x, obj=indx)
zjit = jit(zdel)
indx = 5
z = zjit(x, indx)
print(z) This gives ConcretizationType Error for the variable "indx". How will I write a 'jit'able function for the same |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Dec 13, 2022
Replies: 1 comment 3 replies
-
The index argument to zjit = jit(zdel, static_argnames=['indx']) If you need to use a dynamic index, you can often compute the same result using operations that are compatible with dynamic values. Here's one way you might do it: @jit
def zdel2(x, indx):
return jnp.where(jnp.arange(len(result) - 1) < indx, x[:-1], x[1:])
print(zjit(x, indx))
# [1. 2. 3. 4. 5. 7. 8. 9.] |
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
abhiroop513
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The index argument to
delete
must be static, so the easiest way to jit-compile this function is to mark the argument as static:If you need to use a dynamic index, you can often compute the same result using operations that are compatible with dynamic values. Here's one way you might do it: