Replies: 3 comments 8 replies
-
I'm not sure what's happening either! What jax + jaxlib versions are you using? I'm trying to repro. |
Beta Was this translation helpful? Give feedback.
-
I think the computations aren't running in parallel because they're so short that the dispatch time is slower than the time it takes them to run. You can see that every GPU kernel begins and finishes execution during its corresponding JaxCompiledFunction(mean) call on the host, which is what dispatches the computation and prepares the result. I believe if you were to jit a larger function, you would begin to see overlapping execution. |
Beta Was this translation helpful? Give feedback.
-
@skye friendly ping :) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi jax team,
I'm expecting a jitted function to be asynchronously dispatched, the calling thread can return immediately and make other function calls. One can leverage the above property to parallelize GPU kernels on different devices.
e.g.
I expect the cuda kernels of the above code on the different devices to overlap,
however, this is what I get from the profiler, the jitted functions look like they are called synchronously sequentially. Actually writing
m = means[idx % len(devices)](s).block_until_ready()
is similar to withoutblock_until_ready
.Seems my understanding on the async dispatch is wrong?
Beta Was this translation helpful? Give feedback.
All reactions