What's the difference between xmap, shmap, pmap (pjit?) and which should I use? #20312
-
Hello everyone, I'm a bit confused about the many possibilities about using JAX on multiple devices. Can someone please explain the differences between Thanks everyone :) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
Hi - this is a good question, and things are admittedly a bit confusing now because they're in flux and still not well documented. The TL;DR is you should use Some details:
The documentation for implicit parallelism via |
Beta Was this translation helpful? Give feedback.
Hi - this is a good question, and things are admittedly a bit confusing now because they're in flux and still not well documented.
The TL;DR is you should use
shard_map
for explicit parallelism, andjit
for automatic parallelism;xmap
andpmap
are deprecated, andpjit
is now part ofjit
.Some details:
pmap
is the oldest and least flexible parallelizing transformation. It is severely limited: for example, you can only map over axes with the same shape as the number of devices, and it can't be easily nested. It's mostly replaced byshard_map
, though it may live on as a convenient wrapper ofshard_map
.xmap
is a slightly-less-old attempt to generalizevmap
andpmap
, but it has mostly been s…