Cost of evaluating the two branches of jnp.where #10306
-
Consider the following function import jax.numpy as jnp
@jit
def f(x):
cond = g(x)
return jnp.where(cond, A(x), B(x)) where branch B is at least 1000x more expensive to evaluate than A. My understanding is that in a jit-ed function both branches are evaluated if |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 8 replies
-
|
Beta Was this translation helpful? Give feedback.
-
|
Beta Was this translation helpful? Give feedback.
jnp.where
always concretely evaluate both branches, and the output can be a mixture of two branches ifcond
is an array.You can use
lax.cond
for this case, which only allows scalar condition and is lazy evaluated.lax.cond
do exactly what you want: both branches are traced to jaxpr, while only execute one of them in runtime.