vmap
scales approximately linear in wallclock
#19103
Replies: 2 comments
-
So, after some profiling I found where the slow-down arises, it is unsurprisingly isolated to Learner.step. When doing some offline testing with doing I think I'm misunderstanding the potential performance benefits with |
Beta Was this translation helpful? Give feedback.
-
Hi, thanks for the question! All told, the results you're seeing aren't that surprising: after all, the number of operations in your program increases linearly with the size of the batch dimension, so absent any sort of implicit device parallelism, you'd expect the wall clock time to increase linearly as well. If you're running on an accelerator like GPU or TPU, there are situations in which vmapped operations will scale sub-linearly with batch size, particularly when a single operation is not large enough to take advantage of the intrinsic parallelization in the chip's architecture. For example, a vmapped vector product becomes a matrix product, which will generally be faster than performing Regarding large arrays: when you run a Does that help? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I'm using
vmap
to batch experiments (i.e., batch overjax.random.key
seeds) but when I scale-up the batch-dimension it seems to slow-down approximately linearly on the GPU.Is this expected? I would've expected this to only slightly slow down the full computation but not to such a large degree.
Example
The way that I've designed my code is as follows;
jax.vmap(jax.random.key)(jnp.asarray(seeds))
reset
andstep
function. These functions, internally implement batching over random keys acceleration withjax.jit
.So abstractly, this looks something like,
Now when I scale up
seeds
to e.g.,list(range(20))
I notice that the wallclock time per loop-iteration on my RTX 4080 scales almost linearly (bit below). But I somehow expected better performance...Any thoughts?
Beta Was this translation helpful? Give feedback.
All reactions