jax.debug.print() could not print "time.time()" correctly #19102
-
Hi, I want to measure the execution time of a part of a jit-ed function, using
I always got the exactly same output when I run |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi - the issue is that You could probably do something like what you have in mind using jax.debug.callback(lambda: print(f"the time is {time.time()}")) But keep in mind that this kind of host callback can incur a pretty high performance penalty (especially on accelerators) because it requires a sync between the device and the host. If you're interested in runtime profiling of JAX code, you might take a look at the resources here: https://jax.readthedocs.io/en/latest/profiling.html Alternatively, for micro-benchmarks, you can make sure that your benchmarked code is a single unit rather than attempting to benchmark sections within a jit-compiled function. Hope that helps! |
Beta Was this translation helpful? Give feedback.
Hi - the issue is that
time.time()
is a Python operation, not a JAX operation, and so it is only executed at trace-time, and the output becomes a compile-time constant that will be the same value each time you run the function (see How to think in JAX for more background on JAX's execution model).You could probably do something like what you have in mind using
jax.debug.callback
to a function that callstime.time()
; something like this:But keep in mind that this kind of host callback can incur a pretty high performance penalty (especially on accelerators) because it requires a sync between the device and the host.
If you're…