Do I need to introduce "dummy" batches when using PMAP? #7303
Unanswered
shahbuland
asked this question in
Q&A
Replies: 1 comment
-
Per the documentation of pmap, when the batching dimension is smaller than the device count then pmap automatically uses a subset of the machines. No need to pad for pmap to work. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to do gradient accumulation over some microbatches (pmap'd over the number of microbatches). I want each device to get gradient for one microbatch. Let's say there's 8 microbatches of size 4, and I have 5 devices. Spreading the first 5 microbatches across the 5 devices is simple, but after I've calculated those gradients, I now need to calculate for the 3 microbatches left over. The most obvious solution that comes to mind is to pad with 2 dummy microbatches so I can give something to all devices. Then just ignore the gradients returned by the devices with dummy data. Is that the move? Or is there something in jax for cases like this that I should use.
Beta Was this translation helpful? Give feedback.
All reactions