-
Hi, I've got the ConcretizationTypeError when I tried to make my function jit-able. The function takes two integers and returns an array using jax.numpy.arange. Here is a simple version of a function: from jax import numpy as np
from jax import jit
def func(a, b):
c = 3*a + 2*b
return np.arange(c)
print(jit(func)(2,3)) # This causes ConcretizationTypeError When the function takes a single input, jit with from jax import numpy as np
from jax import jit
def func(a):
c = 3*a
return np.arange(c)
print(jit(func, static_argnums=0)(2)) # This works fine Any help or suggestion would be very appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
The issue is that JAX JIT cannot create arrays that are dynamically shaped, and your function creates an array that is dynamically shaped (i.e. the value of To fix this, you should do what you did in the second function, and mark as static any parameters that determine the shape of an array you create within a jit-compiled function. For example, this works: print(jit(func, static_argnums=[0, 1])(2,3)) You'll find more background and explanation of this topic in the How To Think In JAX doc. |
Beta Was this translation helpful? Give feedback.
-
Hi @jakevdp from jax import numpy as jnp
from jax import jit
from functools import partial
@partial(jit, static_argnums=[0])
def func(ts):
times = jnp.arange(ts[0], ts[-1], 1e-2)
ts = jnp.array(list(ts))
s_indices = jnp.searchsorted(times, ts)
ts = (0, 1) # works
ts = [0, 1] # not work
y = func(ts) I use |
Beta Was this translation helpful? Give feedback.
The issue is that JAX JIT cannot create arrays that are dynamically shaped, and your function creates an array that is dynamically shaped (i.e. the value of
3*a + 2*b
is not known at compile time, becausea
andb
are not marked as static).To fix this, you should do what you did in the second function, and mark as static any parameters that determine the shape of an array you create within a jit-compiled function. For example, this works:
You'll find more background and explanation of this topic in the How To Think In JAX doc.