Replies: 2 comments 2 replies
-
Those docs could definitely be improved, but hopefully that will get you on the right track! |
Beta Was this translation helpful? Give feedback.
-
I think samuela need a light weight method to log inexact performance stats. Profiler will incur more overhead than from jax.experimental.host_callback import id_tap
def do_sth_after_result_ready(_, _):
print("time")
def train_step(state, data):
# do training
new_state = id_tap(do_sth_after_result_ready, None, result=new_state)
return new_state https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.host_callback.id_tap.html
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
How does one observe the runtime of various operations, eg. the time that it takes to process a batch, within a training loop? I don't want to decimate my performance with
block_until_ready()
, but I would like to get some timing stats. I just want to be able to peek into the execution of asynchronous dispatch a bit.In a similar vein, is there a (non-blocking) method on
jnp.arrays
likeis_ready()
? Having this would be useful to work around issues with external logging libraries like wandb/wandb#3690.Beta Was this translation helpful? Give feedback.
All reactions