Skip to content
Discussion options

You must be logged in to vote

Thanks for the question! Short answer to your first question: within JIT, Python functions will only run once at compile/trace time, so your perfcounter will not do what you want it to do within JIT. You might read through How to Think in JAX to help understand why this is.

The only operations that run at runtime are the ones that JAX lowers to XLA. I don't think there is any way to do exactly what you have in mind: XLA does not offer any function that returns a timestamp. Even if it did, it probably would not be that useful, because XLA's compiler will generally re-order operations within JIT so long as the output is unchanged.

You can use any Python construct you want outside the JIT bo…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@bionicles
Comment options

Answer selected by bionicles
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants