You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all, I'm new to profiling, especially within the JAX framework.
TLDR;
What is the best practice to profile and identify computational bottlenecks inside a scan function.
My code contains one big loop over time, implemented as a jax.lax.scan(f, ...). Now I'd like to time computation that is performed at each timestep to identify the bottlenecks. For this I'm looking at jax.profiler.trace with Perfetto (https://docs.jax.dev/en/latest/profiling.html).
When putting the profiler around the scan function, I do not see any information on the subfunctions being executed within the function that is being scanned over (only the element "$loops.py:112 scan"). However, when bringing the profiler inside of the scan function, I do see the seperate functions, but I was wondering what exactly is being profiled, I suppose this is compilation time? In an approach to profile execution time, I added a .block_until_ready(), however this is not very informative because the Perfetto trace just shows a big block "$api.py:2932 block_until_ready".
Is there a way of profiling the different functions inside a jax scan function to identify the bottleneck inside?
So the goal is to identify whether execution of a() or b() is the bottleneck in the below (conceptual) example:
def foo(carry, xs):
res_a = a(xs)
res_b = b(xs)
return carry, (res_a, res_b)
with jax.profiler.trace("traces/jax-trace", create_perfetto_link=True):
scanfunc = jax.lax.scan(foo, carry, xs) # scan over particles
carry, out = zmean_scan
jax.block_until_ready(out)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all, I'm new to profiling, especially within the JAX framework.
TLDR;
What is the best practice to profile and identify computational bottlenecks inside a scan function.
My code contains one big loop over time, implemented as a jax.lax.scan(f, ...). Now I'd like to time computation that is performed at each timestep to identify the bottlenecks. For this I'm looking at jax.profiler.trace with Perfetto (https://docs.jax.dev/en/latest/profiling.html).
When putting the profiler around the scan function, I do not see any information on the subfunctions being executed within the function that is being scanned over (only the element "$loops.py:112 scan"). However, when bringing the profiler inside of the scan function, I do see the seperate functions, but I was wondering what exactly is being profiled, I suppose this is compilation time? In an approach to profile execution time, I added a .block_until_ready(), however this is not very informative because the Perfetto trace just shows a big block "$api.py:2932 block_until_ready".
Is there a way of profiling the different functions inside a jax scan function to identify the bottleneck inside?
So the goal is to identify whether execution of a() or b() is the bottleneck in the below (conceptual) example:
Thanks in advance
Cedric
Beta Was this translation helpful? Give feedback.
All reactions