Avoiding an infinite loop in a divide-and-conquer algorithm with lax.cond
#13454
Unanswered
s-a-barnett
asked this question in
Q&A
Replies: 1 comment 1 reply
-
You cannot recurse on dynamic conditions in JAX due to this issue (both branches being traced). But fortunately in this case the condition is static: i.e. you're branching on the length of the array. So you can replace this: return lax.cond(pred, true_fun, false_fun, B) with this: return true_fun(B) if pred else false_fun(B) And your code will execute under JIT. Note that with this approach, the entire sequence of operations will be inlined before being sent to the compiler, so it may or may not be a suitable solution to your problem. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have implemented the divide-and-conquer algorithm for matrix inversion in JAX as follows:
Disabling JIT, this works as predicted. However, as written above, we exceed the max recursion depth due as
lax.cond
tracestrue_fun
even whenpred
isFalse
. Is there a workaround for this that still useslax.cond
?Beta Was this translation helpful? Give feedback.
All reactions