"ValueError: Non-hashable static arguments" when vectorizing for loop #13052
-
Hi, I have a function which does something and it's behavior depends on a (static) number. For example we can import jax
import jax.numpy as jnp
def shifted_sum(x, y, s):
return jnp.sum(x[s:] + y[:-s]) Say that I am now interested in the following (sum of shifted sums) tot = 0
S = 4
x = y = jnp.arange(5)
for s in range(1, S):
tot += shifted_sum(x, y, s)
print(tot) Since we are eventually summing the output of shifted_sum_jit = jax.jit(shifted_sum, static_argnums=(2,))
jax.vmap(shifted_sum_jit, in_axes=(None, None, 0))(x, y, jnp.arange(1, S)).sum() However the above does not work and gives me a ValueError: Non-hashable static arguments are not supported. What am I doing wrong? This seems a really trivial example that jax should be able to take care of. I am using jax 0.2.24 |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Thanks for the question! The reason for this error is that a parameter that is vmapped over is not static, so you cannot pass such a parameter to a function that is marked as requiring a static argument. The fix is to not mark the argument static if you want to vmap over it. The problem here, though, is that you're passing this argument ( The solution here, if you want to sum over dynamically-sized slices, is to write your code in a way that does not require generating dynamically-sized arrays. Here is a workaround: import jax
import jax.numpy as jnp
def shifted_sum(x, y, s):
x_sum = jnp.where(jnp.arange(len(x)) < s, x, 0).sum()
y_sum = jnp.where(jnp.arange(len(y)) >= len(y) - s, y, 0).sum()
return x_sum + y_sum
x = y = jnp.arange(5)
S = 4
shifted_sum_jit = jax.jit(shifted_sum)
jax.vmap(shifted_sum_jit, in_axes=(None, None, 0))(x, y, jnp.arange(1, S)).sum()
# DeviceArray(24, dtype=int32) There is some ongoing work toward supporting dynamically-shaped arrays within |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
The reason for this error is that a parameter that is vmapped over is not static, so you cannot pass such a parameter to a function that is marked as requiring a static argument.
The fix is to not mark the argument static if you want to vmap over it.
The problem here, though, is that you're passing this argument (
s
) to another function that requires a static argument: index slicing. In JAX the size of an array must be static, so thes
inx[:s]
cannot be a traced value.The solution here, if you want to sum over dynamically-sized slices, is to write your code in a way that does not require generating dynamically-sized arrays. Here is a workaround: