Problem with Boolean masks indexing by jax jit #10975
-
I have written a code as:
I tried to write this code in to other forms using jax jit, but it stuck when reaching Boolean masks, which, I think, could be handled by
which get such the following error:
What is the problem and how can I pass it? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The reason you're getting this error is because the three-argument It sounds like what you're after is to have a JIT-compatible version of something like this: import jax
import jax.numpy as jnp
x = jnp.arange(10)
y = jnp.ones(5)
def f1(x, y):
return x.at[x >= 5].set(y)
print(f1(x, y))
# [0 1 2 3 4 1 1 1 1 1] When you try to JIT-compile this, it errors: # print(jax.jit(f1)(x, y))
# NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10]) The reason for this is that there's no way to know at compile-time how many elements of The alternative you tried is roughly this: def f2(x, y):
return jnp.where(x >= 5, y, x) This is invalid because Instead you need to do something like this: def f3(x, y):
idx = jnp.where(x >= 5, size=len(y))
return x.at[idx].set(y)
print(f3(x, y))
# [0 1 2 3 4 1 1 1 1 1]
print(jax.jit(f3)(x, y))
# [0 1 2 3 4 1 1 1 1 1] Notice that this involves specifying at compile time that you expect there to be 5 indices where the condition is True. This information allows JAX to correctly JIT-compile the code. |
Beta Was this translation helpful? Give feedback.
The reason you're getting this error is because the three-argument
jnp.where
function requires that the arguments be broadcast-compatible, and your arguments are not broadcast-compatible (they have differing shapes).It sounds like what you're after is to have a JIT-compatible version of something like this:
When you try to JIT-compile this, it errors:
The reason for this is that there's no way to know at compile…