Replies: 1 comment
-
Here is the Perfetto trace. I don't understand why there are large gaps between the kernels, the large amount of computation done on the host and the long duration (60 microseconds) of the MemcpyH2D. The input data is 336kB and much lower than the theoretical amount of 1875kB transferable in 60 microseconds at the 32GB/s of PCIe 4.0. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have a program where each JIT CPU single-core run takes 400 microseconds and each JIT GPU run takes 300 microseconds. This is timed after a 5-iteration warmup loop. The GPU run is only ~20% faster than single-core CPU. The program is imported to JAX from an ONNX model using jaxonnxruntime.
I did some rough profiling of one inference using
jax.profiler.trace()
and saw what seemed to be many kernel launches interspersed with host activity. Could it be that the host activity is causing the slowness? How could I debug this and get more things to run on the GPU instead of the host? Is there a better way to import ONNX to JAX than jaxonnxruntime? Is there a way to show the JAX code corresponding to each kernel in the trace? The kernel names don't tell much.Or, are there other avenues to try for understanding and debugging the slowness?
Beta Was this translation helpful? Give feedback.
All reactions