*different* MCMC samplers in parallel #8231
-
Hello! I'm aware of the SPMD tools in Jax that allow you to run several MCMC chains in parallel (ie: different realisations of MCMC targeting the same distribution). However I'm wondering if it's possible to run several MCMC in parallel, with each targettng a different posterior. Is this something that is possible in Jax, or does the code running on each device have to be the same? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
All of JAX's parallelism APIs assume SPMD execution, which means that the program dispatched onto a single device during a single parallel launch has to be the same. If targeting a different posterior is doable with the same program and you only need to change the data then you should be all set for that. If you truly need to run different code on different devices, you can compile multiple computations with |
Beta Was this translation helpful? Give feedback.
All of JAX's parallelism APIs assume SPMD execution, which means that the program dispatched onto a single device during a single parallel launch has to be the same. If targeting a different posterior is doable with the same program and you only need to change the data then you should be all set for that. If you truly need to run different code on different devices, you can compile multiple computations with
jax.jit
, assign them to specific devices and dispatch them separately from Python code. JAX uses an asynchronous runtime, which should queue the computations on independent sets of devices concurrently.