How could I change this function to be applicable by jax jit? #10777
-
let say that I have:
I want to apply jax jit on the following function:
I tried:
which will stuck by an error:
How must I use |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I think following code will work. @partial(jax.jit, static_argnums=1)
def new(a, N):
out = jnp.zeros( (N+1)*np.array(a.shape)-N,dtype=a.dtype) # note: use np for shape, not jnp
out = out.at[::N+1,::N+1].set(a)
return out |
Beta Was this translation helpful? Give feedback.
I think following code will work.