How do I profile layer-by-layer Jax performance? #11180
Replies: 2 comments 3 replies
-
Unfortunately, in general JAX might do cross layer fusion, exact performance is hard to measure. However, you can use |
Beta Was this translation helpful? Give feedback.
-
The main tool we tend to use is the tracing/profiling infrastructure described here: https://jax.readthedocs.io/en/latest/profiling.html As YouJiacheng says, we don't compile layers in isolation, so it's not always easy to point to a single thing on a trace which exactly corresponds to a layer. One thing that will certainly work would be to |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I am looking to break down Jax performance, layer-by-layer. Is there a way to place a breakpoint during forward propagation?
Beta Was this translation helpful? Give feedback.
All reactions