Slice and repeat #20947
-
The question I am asking here is a part of the workflow that is very common nowadays, especially for the KV cache used in LLMs and MLLMs, and I am unable to think of any good solution to it. I will try to keep the example simple. Let's say you have an array Normal workflow (without jit)key = jax.random.PRNGKey(1234)
bs = 2
seqlen = 5
dim2 = 3
dim3 = 4
pos = 2
num_repeats = 3
x = jax.random.randint(key=key, shape=(bs, seqlen, dim2, dim3), minval=2, maxval=100)
x_slice_v1 = x[:, :pos, :, :]
x_slice_v1 = jnp.repeat(x_slice_v1, num_repeats, axis=2)
JIT-based workflowFor jit based workflow, there are two options:
@partial(jax.jit, static_argnums=(1,))
def get_slice(x, pos):
seqlen = x.shape[1]
idx = (jnp.arange(seqlen) < pos)[None, :, None, None]
x_slice = jnp.where(idx, x, 0)
return x_slice
x_slice_v2 = get_slice(x, pos) The output of the above operation is this:
The slice extracted are correct but If I apply My question is "How can I obtain the value for |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Beta Was this translation helpful? Give feedback.
I'm confused as to why you can't use your
x_slice_v1
code under JIT. It seems like it should work, so long as all variables affecting array size are marked as static.