How to add trace annotations within a jitted function? #34148
-
|
I use the jax profiler with TraceAnnotation to see in the trace where things happen in my code. However TraceAnnotation works only at tracing time, and there is no equivalent at runtime to use within jitted functions. I can imagine coding up something myself to work around this by using and abusing pure_callback or debug.callback to time when a certain array is computed. However integrating it into the trace viewer is something I expect to be too much work for me on my own. I currently split up my function into a few individually compiled chunks and trace annotate them, though this is quite brutal, it has side effects, since xla would probably fuse and reorder stuff across my somewhat arbitrary boundaries. Ideally I would like something that starts and end a trace event with semantics similar to pure_callback or debug.callback, i.e., the trace start and end correspond to when one or more arrays are actually computed (with all side effects, e.g., possibly forcing them to be computed when they would be fused, or dropping the trace event in case of dead code elimination). Hypothetical API: @jit
def f(x):
t, x = start_trace_event(x)
y = g(x)
y = t.end_event(y)
return yDoes anyone have ideas? How have you solved this problem? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
|
Ok, maybe I just found the answer: |
Beta Was this translation helpful? Give feedback.
Ok, maybe I just found the answer:
jax.named_call.