pmap: data size non-divisible by device_count on data parallel cases #10749
Unanswered
luweizheng
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
I am using
pmap
and it is very powerful to parallelize data onto multiple devices.To use
pmap
, we should dividedata_size
byjax.local_device_count()
. In this case, data is split onto different devices.However, it is very common that
data_size
may not be divisible bylocal_device_count
.I want to keep the remaining rather than drop them. So the common practice is to fill some dummy elements into the data and make
data_size
divisible. Afterpmap
, we should mask out the dummy elements.My question is: Are there any examples to do this fill and masked out thing? As the result from
pmap
is aShardedDeviceArray
, can I just mask out the dummy elements on the last device?Thanks.
Beta Was this translation helpful? Give feedback.
All reactions