Replies: 3 comments 4 replies
-
Fundamentally there are two different problems. The first one is that jax does not support array updates with the x{:index] syntax. This limitation applies in both jitted and non-jitted functions. As you said, jax provides an alternative based on jax.ops.index.update. This is the correct way to handle this issue. The other problem is that jitted code does not support dynamically shaped arrays and x[:index] inside other_fun is dynamically shaped because we do not know the value of index at jitting time. This limitation is fundamental because the underlying optimizations of jax need to know the shapes of intermediates in order to speed up the evaluation. If the function is called many times with the same index value it may make sense to jit a specialized version of import functools
import jax
import jax.numpy as np
def fun(index, x):
y = other_fun(x[:index])
x = jax.ops.index_update(x, index, y)
return x
# Assume that fun will be called many times
# with index=2 in the future
q = jax.jit(functools.partial(fun, 2))
print(q(np.arange(5)) |
Beta Was this translation helpful? Give feedback.
-
x-ref https://stackoverflow.com/q/68419632; I'll repeat the answer here for completeness. For a dynamic index, you can do this using import jax
import jax.numpy as jnp
def other_fun(x):
return x + 1
@jax.jit
def fun(x, index):
mask = jnp.arange(x.shape[0]) < index
return jnp.where(mask, other_fun(x), x)
x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4] |
Beta Was this translation helpful? Give feedback.
-
Is it possible to extend the solution to the case where we want to return the output of |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I want to perform an operation like
This cannot be performed under
jit
. Is there a way of doing this withjax.ops
orjax.lax
?I thought of using
jax.ops.index_update(x, idx, y)
but I cannot find a way of computingy
without incurring in the same problem again.Beta Was this translation helpful? Give feedback.
All reactions