Skip to content
Discussion options

You must be logged in to vote

The reason you're getting this error is because the three-argument jnp.where function requires that the arguments be broadcast-compatible, and your arguments are not broadcast-compatible (they have differing shapes).

It sounds like what you're after is to have a JIT-compatible version of something like this:

import jax
import jax.numpy as jnp

x = jnp.arange(10)
y = jnp.ones(5)

def f1(x, y):
  return x.at[x >= 5].set(y)

print(f1(x, y))
# [0 1 2 3 4 1 1 1 1 1]

When you try to JIT-compile this, it errors:

# print(jax.jit(f1)(x, y))
# NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])

The reason for this is that there's no way to know at compile…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@alisheikholeslam
Comment options

Answer selected by alisheikholeslam
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants