lax.select
same shape constraint
#19779
-
I switched over some of my
this will fail because
is this case simply more suited for EDIT: Another question, how does one deal with the scenario where your condition is a simple scalar and the cases are arrays of the same shape? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You'll find that functions in
It sounds like you might be looking for from jax import lax
x = 1
lax.cond(x > 0, lambda: True, lambda: False)
# Array(True, dtype=bool) |
Beta Was this translation helpful? Give feedback.
You'll find that functions in
jax.lax
are generally more restrictive than equivalent functions injax.numpy
, particularly when it comes to implicit broadcasting, type promotion, and rank promotion. If you want to uselax
functions, you'll generally have to do manual dtype and shape conversions to ensure the inputs match the more restrictive APIs. If you don't want the more restrictive API oflax
, then you should usejax.numpy
.It sounds like you might be looking for
lax.cond
, which is only valid for scalar conditions: