SPMD Parallellism on CPU -- Question concerning Best Practice. #15174
Unanswered
danielkelshaw
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hey, I'm currently doing some work where I need to run a function over a variety of initial conditions - I'm looking for the most efficient way to do this.
I have a function of the form:
The function is largely a wrapper around a
jax.lax.while_loop
so the use ofjax.jit
does not do much in this case as the condition / body function are already lowered.Attempt 01 ::
jax.vmap
My first port-of-call was
jax.vmap
to run this as SIMD. However, when I inspect the cpu usage viahtop
it is clear that not all cpus are being used. The load factor does not increase much, and is no-where near optimal loading on my 16 core machine.I understand that
vmap
will take as long as the slowest function call, this is something I'm happy to live with.Attempt 02 ::
jax.pmap
withjax.vmap
I notice that
jax.device_count()
states I am using a single cpu, howeverhtop
shows processing on multiple cores. I gather from other discussions that this is due toBLAS
/LAPACK
calls for certain functions? In order to access all available cpus, I set the environment variableXLA_FLAGS="--xla_force_host_platform_device_count=16"
This allows me to use
pmap
to utilise parallelism in a SPMD fashion, limiting the size of the axis over whichpmap
is applied to16
. In my case, I reshape my input to(16, 6, ...)
in order to run the function for96
initial conditions:This runs considerably faster but brings up a few questions:
Is there any way to use this process for an arbitrary number of initial conditions? In this example I have explicitly chosen a multiple of the number of devices visible to
jax
- is this a restriction, or are there any workarounds?There is nothing stopping me from setting
xla_force_host_platform_device_count
to a number exceeding the total number of cores available on my machine -- it is not clear how this works, or how this affects performance? If I set the number of cores to200
, I can successfullypmap
over these, but how this maps to true distribution over cpus is unclear.Potential use of
jax.experimental.maps.xmap
The idea of 'easy-to-revise parallelism' is quite alluring, but at first glance it appears that I would need to re-write the internals of my function in a pretty major way? Is this something worth looking into?
TL;DR
What is the best way to maximise cpu usage when using
jax
-- when should we usevmap
, and when should we look to something more complicated such as thepmap
ofvmap
mentioned above? How doesxmap
fit in here, and would this require a major re-write of the code in order to work?Thank you for any advice you can offer, it's much appreciated!
Beta Was this translation helpful? Give feedback.
All reactions