Skip to content
Discussion options

You must be logged in to vote

I think following code will work.

@partial(jax.jit, static_argnums=1)
def new(a, N):
    out = jnp.zeros( (N+1)*np.array(a.shape)-N,dtype=a.dtype) # note: use np for shape, not jnp
    out = out.at[::N+1,::N+1].set(a)
    return out

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@alisheikholeslam
Comment options

@YouJiacheng
Comment options

Answer selected by alisheikholeslam
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants