Skip to content
Discussion options

You must be logged in to vote

Hi - the issue is that time.time() is a Python operation, not a JAX operation, and so it is only executed at trace-time, and the output becomes a compile-time constant that will be the same value each time you run the function (see How to think in JAX for more background on JAX's execution model).

You could probably do something like what you have in mind using jax.debug.callback to a function that calls time.time(); something like this:

jax.debug.callback(lambda: print(f"the time is {time.time()}"))

But keep in mind that this kind of host callback can incur a pretty high performance penalty (especially on accelerators) because it requires a sync between the device and the host.

If you're…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Sun-Xiaohui
Comment options

Answer selected by Sun-Xiaohui
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants