Detect when a new JAX frame starts #6009
Unanswered
dionhaefner
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
For
mpi4jax
, our API would get a lot cleaner if we could keep track of local state within eachjit
block to inject tokens into our custom calls. But for this I need to detect when a new frame starts to clear the local state.I can make this work with something like this:
Then we can use it like this:
Is there any way to detect that we are now in a different frame as the one that
current_token
originated from that doesn't rely on catchingUnexpectedTracerError
?Beta Was this translation helpful? Give feedback.
All reactions