pmap ValueError #10328
Unanswered
zohimchandani
asked this question in
Q&A
pmap ValueError
#10328
Replies: 1 comment 5 replies
-
a = jnp.ones(4)
a = a.reshape((2, 2))
fun = pmap(vmap(lambda a, b: a + b))
fun(a, a) |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Lets assume I have the following function which I would like to run
pmap
over:Given that
jax.device_count()
returns2
on my machine, I can runpmap
in the following manner for an array ofb
values where the value fora
is broadcasted:My question is, what if I wanted to run this for more
b
values than the number of GPUs available? Moreover, I actually want to run this for all combinations fora
andb
values as shown below:where all possible combinations of
a
andb
are executed.Any ideas on how to approach this problem?
Beta Was this translation helpful? Give feedback.
All reactions