Skip to content
Discussion options

You must be logged in to vote

The problem is that lax.cond is not a short-circuiting operation: in order to push the computation down to XLA, it traces both the true_fun and false_fun to construct an intermediate representation, even if the final code will not need to evaluate one of the branches.

But even so, I have trouble following the logic of the first code-block: for example, it seems that if C is not None and A = None, then even with short-circuiting logic positive_branch2 would attempt to execute A = A + B, which would be ill-defined even without respect to JAX.

Another thing you could consider: if you're branching on python identity comparisons (as opposed to value comparisons), instead of using lax.cond you …

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@rog77
Comment options

@jakevdp
Comment options

@rog77
Comment options

Answer selected by rog77
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants