-
To include runtime in the objective, I want to jit a function which applies a set of functions to the same input and returns all the outputs and the elapsed runtime for each function. Like a jit'd, batch'd jax.value_and_runtime Here's some fiddling... does perf_counter only run once at compile time? If so, how could we benchmark function runtimes inside jax.jit? from time import perf_counter
from jax import jit
@jit
def jax_perfcounter_experiment(x):
start = perf_counter()
y = x ** 3 + 2 * x ** 2 + 3 * x + 4
elapsed = perf_counter() - start
return y, elapsed y, t = jax_perfcounter_experiment(1000000000)
print(y,t)
# -375890428 0.0020622709 w, t2 = jax_perfcounter_experiment(1)
print(w, t2)
# 10 0.0020622709 t2 == t
# DeviceArray(True, dtype=bool, weak_type=True) for i in range(10000):
z, t3 = jax_perfcounter_experiment(i)
assert t3 == t
print("all timings equal")
# all timings equal |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! Short answer to your first question: within JIT, Python functions will only run once at compile/trace time, so your The only operations that run at runtime are the ones that JAX lowers to XLA. I don't think there is any way to do exactly what you have in mind: XLA does not offer any function that returns a timestamp. Even if it did, it probably would not be that useful, because XLA's compiler will generally re-order operations within JIT so long as the output is unchanged. You can use any Python construct you want outside the JIT boundary, so perhaps you can accomplish your goal that way? Regarding profiling: the current best way to profile JIT-compiled code in JAX is to do so via tensorboard; there is some information on this in the documentation at https://jax.readthedocs.io/en/latest/profiling.html |
Beta Was this translation helpful? Give feedback.
Thanks for the question! Short answer to your first question: within JIT, Python functions will only run once at compile/trace time, so your
perfcounter
will not do what you want it to do within JIT. You might read through How to Think in JAX to help understand why this is.The only operations that run at runtime are the ones that JAX lowers to XLA. I don't think there is any way to do exactly what you have in mind: XLA does not offer any function that returns a timestamp. Even if it did, it probably would not be that useful, because XLA's compiler will generally re-order operations within JIT so long as the output is unchanged.
You can use any Python construct you want outside the JIT bo…