Lazy evaluation for jnp.select
?
#11153
Replies: 1 comment 5 replies
-
The general answer to your question is that conditionals based on Now there is another issue in your example, which is that it has an infinite recursion at trace-time, so the computation never has a chance to be lowered to XLA. With that in mind, here's an example of a conditional that will only be evaluated for a single branch at runtime: from jax import lax
import jax.numpy as jnp
def truefunc(x):
return x
def falsefunc(x):
return 100 * x
x = jnp.arange(5)
lax.cond(False, truefunc, falsefunc, x)
# DeviceArray([ 0, 100, 200, 300, 400], dtype=int32) However, keep in mind that both functions will be traced, so you won't be able to check this by inserting a Python-side infinite loop. Another detail to keep in mind is that if you vmap over |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I have a question on the usage of
jnp.select
.Consider I do the following:
where
f(x)
will always cause an infinite loop. Of course this isn't my use case, but it's representative of the problem that I'm having.Now, since the condition is true, I would expect the infinite loop to never be entered. However, it seems to be evaluated regardless. Is there any way to have lazy evaluation of the default condition? In other words, I would like
f(x)
to only be evaluated when all conditions are false (of course I am using somef
that won't enter an infinite loop under these conditions).Otherwise, can anyone recommend a good way to implement an
if ... elif ... else
control flow that has this property but isn't just a bunch of nestedlax.cond
s? A bunch of nestedlax.cond
s is just a bit messy...Thank you!
Beta Was this translation helpful? Give feedback.
All reactions