Replies: 1 comment
-
See the answer at #7991 (comment). In the future, there is no need to post your question twice. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
When I run the code below without
config.update("jax_enable_x64", True)
, i.e. using 32-bit number, I see parallel performance -- running a batch of 1000test_fn
using vmap takes the same time as 1 chain.However, when I use 64-bit number, I see an increase in the runtime. The increase is not much, e.g. 1000 chains takes about twice as long as 1 chain. Nevertheless, I do not understand why this is the case. I reproduced the same result in Google Colab, this is the exact printout (this is running the code AFTER the first run, to account for compilation time):
1 parallel batches takes 0.127 seconds.
101 parallel batches takes 0.135 seconds.
201 parallel batches takes 0.147 seconds.
301 parallel batches takes 0.158 seconds.
401 parallel batches takes 0.159 seconds.
501 parallel batches takes 0.171 seconds.
601 parallel batches takes 0.195 seconds.
701 parallel batches takes 0.207 seconds.
801 parallel batches takes 0.212 seconds.
901 parallel batches takes 0.231 seconds.
1001 parallel batches takes 0.242 seconds.
Here is my code.
Sidenote: Using only 32-bit number, I also observe the same 'less-than-parallel' performance when I vmap over functions (MCMC chains via lax.scan) that involve large matrix multiplications, neural networks etc. The exact same MCMC implementation yields parallel performance when sampling a 100-dimensional Gaussian, which requires no more than
jnp.sum
and itsjax.grad
gradient. Could complex operations also ruin parallelism?Beta Was this translation helpful? Give feedback.
All reactions