Skip to content
Discussion options

You must be logged in to vote

You'll find that functions in jax.lax are generally more restrictive than equivalent functions in jax.numpy, particularly when it comes to implicit broadcasting, type promotion, and rank promotion. If you want to use lax functions, you'll generally have to do manual dtype and shape conversions to ensure the inputs match the more restrictive APIs. If you don't want the more restrictive API of lax, then you should use jax.numpy.

Another question, how does one deal with the scenario where your condition is a simple scalar and the cases are arrays of the same shape?

It sounds like you might be looking for lax.cond, which is only valid for scalar conditions:

from jax import lax
x = 1
lax.cond(x

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@LeonEricsson
Comment options

Answer selected by LeonEricsson
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants