Interesting while-loop use-case : jacrev Ok, JIT not => scan #14732
Unanswered
jecampagne
asked this question in
Q&A
Replies: 1 comment 2 replies
-
Thanks for sharing! This is similar in spirit to the "early exit scan" or "bounded while loop" proposed here: #13062 |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
here I think an interesting simple exo to work to a while-loop transformation (idea adapted from this discussion.
One can for instance code like that using a while-loop:
and
gives
Interestingly the jacobian FORWARD & BACKWARD can be computed as
while
is notjax.lax.while_loop
BUT, the JIT crashes due to
So, I have modified the code as followed, which shows that one should think different using JAX if one wants the full power
Then one finds the same results of the non-Jitted version
Takeaway:
Now, may be my jfunc_bis is not the best way, so if someone has a better solution (not for the given exemple but for a possible schema of generalisation in other use-case) I will be glad too.
PS: In fact at the start of the exo I was trying to set a jax.lax.while_loop code to make the jacrev crash...
then the
jacrev
will then crash as expected.Beta Was this translation helpful? Give feedback.
All reactions