Measuring runtime of functions inside a jitted function #11800
-
I'm trying to build a reinforcement learning algorithm and I have every train step (rollout + parameter updates) all jitted under one function. The rollout and parameter updates are two separate functions and I want to be able to time them and measure how long they take every train step. I did this originally
but it appears that I was wondering if there is some way to measure runtime of code within a jitted function, it would help me determine exactly what might be a bottleneck and what might not be (with RL train loops being one such example). Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Have you looked over the FAQ entry on benchmarking? You may need a |
Beta Was this translation helpful? Give feedback.
-
I think your question is ill-posed: within the computation of an outer jit-compiled function, internal jit-compiled functions no longer exist as units that can be benchmarked. As a simple example of this, let's consider these three functions: from jax import jit
import jax.numpy as jnp
@jit
def func1(x):
return x + 1
@jit
def func2(x):
return x - 1
@jit
def func2_of_func1(x):
return func2(func1(x)) We can look at the compiled HLO for the first two functions like this: x = jnp.arange(5)
print(func1.lower(x).compile().as_text())
# HloModule jit_func1.14, entry_computation_layout={(s32[5]{0})->s32[5]{0}}
# %fused_computation (param_0: s32[5]) -> s32[5] {
# %param_0 = s32[5]{0} parameter(0)
# %constant.0 = s32[] constant(1)
# %broadcast.0 = s32[5]{0} broadcast(s32[] %constant.0), dimensions={}
# ROOT %add.0 = s32[5]{0} add(s32[5]{0} %param_0, s32[5]{0} %broadcast.0), metadata={op_name="jit(func1)/jit(main)/add" source_file="<ipython-input-7-123cc4c05ee1>" source_line=6}
# }
# ENTRY %main.5 (Arg_0.1: s32[5]) -> s32[5] {
# %Arg_0.1 = s32[5]{0} parameter(0)
# ROOT %fusion = s32[5]{0} fusion(s32[5]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(func1)/jit(main)/add" source_file="<ipython-input-7-123cc4c05ee1>" source_line=6}
# }
print(func2.lower(x).compile().as_text())
# HloModule jit_func2.15, entry_computation_layout={(s32[5]{0})->s32[5]{0}}
# %fused_computation (param_0: s32[5]) -> s32[5] {
# %param_0 = s32[5]{0} parameter(0)
# %constant.1 = s32[] constant(-1)
# %broadcast.1 = s32[5]{0} broadcast(s32[] %constant.1), dimensions={}
# ROOT %add.1 = s32[5]{0} add(s32[5]{0} %param_0, s32[5]{0} %broadcast.1), metadata={op_name="jit(func2)/jit(main)/sub" source_file="<ipython-input-7-123cc4c05ee1>" source_line=10}
# }
# ENTRY %main.5 (Arg_0.1: s32[5]) -> s32[5] {
# %Arg_0.1 = s32[5]{0} parameter(0)
# ROOT %fusion = s32[5]{0} fusion(s32[5]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(func2)/jit(main)/sub" source_file="<ipython-input-7-123cc4c05ee1>" source_line=10}
# } It takes a bit to get used to reading this, but note that the first function has a Now let's construct another function that calls these two functions, and look at its HLO: print(func2_of_func1.lower(x).compile().as_text())
# HloModule jit_func2_of_func1.16, entry_computation_layout={(s32[5]{0})->s32[5]{0}}
# ENTRY %main.14 (Arg_0.1: s32[5]) -> s32[5] {
# %Arg_0.1 = s32[5]{0} parameter(0), metadata={op_name="jit(func2_of_func1)/jit(main)/jit(func2)/sub" source_file="<ipython-input-7-123cc4c05ee1>" source_line=10}
# ROOT %copy = s32[5]{0} copy(s32[5]{0} %Arg_0.1)
# } Notice there is no longer any In your case, the functions are more complicated, but I think the answer is the same: in general JIT is going to re-arrange, fuse, and/or elide parts of the computation that go into it, and a question about the runtime of the units that make up the function has no meaningful answer. |
Beta Was this translation helpful? Give feedback.
I think your question is ill-posed: within the computation of an outer jit-compiled function, internal jit-compiled functions no longer exist as units that can be benchmarked. As a simple example of this, let's consider these three functions:
We can look at the compiled HLO for the first two functions like this: