Return type of jnp.where
#19672
-
Even though the document states that
the return type of |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi, thanks for the question! >>> x = jnp.array([[True, False],
... [False, True]])
>>> jnp.where(x)
(Array([0, 1], dtype=int32), Array([0, 1], dtype=int32)) With three arguments, it returns the type in the documentation excerpt you quoted: >>> jnp.where(x, 100, 200)
Array([[100, 200],
[200, 100]], dtype=int32) I think this behavior is covered in the Note at the top of the documentation you linked to. |
Beta Was this translation helpful? Give feedback.
Hi, thanks for the question!
jax.numpy.where
follows the API ofnumpy.where
, and the output type depends on the number of arguments. With one argument,jnp.where
is equivalent tojnp.nonzero
, and returns a tuple of indices:With three arguments, it returns the type in the documentation excerpt you quoted:
I think this behavior is covered in the Note at the top of the documentation you linked to.