lax.cond executing both branches? #9543
Replies: 2 comments
-
Both Python callables are always executed, in the sense of being applied to abstract values to stage out their contents. But no FLOPs would happen. I assume you're talking about both sides' FLOPs getting evaluated. One case is that
It should, but it doesn't! It's on our todo list to fix, but to do that we probably need to teach JAX's internals about side effects more holistically. #8699 is the first step in that direction, but currently it's just a draft. |
Beta Was this translation helpful? Give feedback.
-
This is a good foot gun to know about, but for my specific situation I'm only worried about
If at FLOP execution time only one of the branches is actually evaluated, then our above rewrite rule would be fine. However, if both get execute at FLOP time, then our rewrite rule becomes undefined and we'd need to warn the user instead. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
When using
lax.cond
, is it ever the case that both branches get executed on device, regardless of the value of thecond
? Does marking the primitives as having side effects prevent this at all?Beta Was this translation helpful? Give feedback.
All reactions