Skip to content
Discussion options

You must be logged in to vote

Hi, thanks for the question! jax.numpy.where follows the API of numpy.where, and the output type depends on the number of arguments. With one argument, jnp.where is equivalent to jnp.nonzero, and returns a tuple of indices:

>>> x = jnp.array([[True, False],
...                [False, True]])
>>> jnp.where(x)
(Array([0, 1], dtype=int32), Array([0, 1], dtype=int32))

With three arguments, it returns the type in the documentation excerpt you quoted:

>>> jnp.where(x, 100, 200)
Array([[100, 200],
       [200, 100]], dtype=int32)

I think this behavior is covered in the Note at the top of the documentation you linked to.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@bitsandscraps
Comment options

Answer selected by bitsandscraps
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants