-
I'm trying to analyze some code using @profile
@jax.jit
def step(...):
# your code here it throws an error However, instead, if I use @jax.jit
@profile
def step(...):
# your code here it seems to run. However, I'm not sure if it's only profiling the compilation time. Output ObservedTo provide more details about the specific output I'm seeing : Here's the output of the line profiler I'm seeing in the training loop when I'm calling the function I want to profile: Line # Hits Time Per Hit % Time Line Contents
==============================================================
373 10002 495914672.0 49581.6 97.8 self.params, self.opt_state = self.step(
374 5001 2417.0 0.5 0.0 self.params, self.opt_state, batch
375 ) Here it seems like the time spent in the step function is 495s. And the corresponding output for the Total time: 21.4522 s
File: benchmark_singlenode_singleGPU.py
Function: step at line 312
Line # Hits Time Per Hit % Time Line Contents
==============================================================
312 @partial(jax.jit, static_argnums=(0,))
313 @profile
314 def step(self, params, opt_state, batch):
315 1 21307501.0 21307501.0 99.3 g = grad(self.loss)(params, batch)
316 1 134959.0 134959.0 0.6 updates, opt_state = self.optimizer.update(g, opt_state)
317 1 9756.0 9756.0 0.0 params = optax.apply_updates(params, updates)
318 1 1.0 1.0 0.0 return params, opt_state Here, you can see that it states that total time spent is 21s which seems to suggest that this profiling is just for the compilation stage. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
Under |
Beta Was this translation helpful? Give feedback.
Under
jit
, the code withing a function is being compiled and run as a unit rather than being run by the python interpreter, so typical Python line-profiling approaches will not work. You'll have to use other tools: the documentation has some information on profiling here: https://jax.readthedocs.io/en/latest/profiling.html