Skip to content
Discussion options

You must be logged in to vote

jnp.where always concretely evaluate both branches, and the output can be a mixture of two branches if cond is an array.
You can use lax.cond for this case, which only allows scalar condition and is lazy evaluated.
lax.cond do exactly what you want: both branches are traced to jaxpr, while only execute one of them in runtime.

Replies: 2 comments 8 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
8 replies
@jakevdp
Comment options

@YouJiacheng
Comment options

@fbartolic
Comment options

@YouJiacheng
Comment options

@soraros
Comment options

Answer selected by fbartolic
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants