Skip to content
Discussion options

You must be logged in to vote

Hi - so I think the main misunderstanding here is that vmap doesn't explicitly have anything to do with parallelism: it just converts unbatched instructions to batched instructions. So a vmapped vector product is just a matrix project of the batched input. A vmapped sum is just a sum along one axis of the batched input.

As for why large input sizes slow down your computation... well, the size of the computation in vmap scales linearly with the size of the input. From the times you quote, if I'm understanding correctly it looks like the actual wall-time scaling is very sub-linear in the number of batches, meaning that for your program, the compiler is making very good use of the hardware.

D…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@YannBerthelot
Comment options

@jakevdp
Comment options

Answer selected by YannBerthelot
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants