Skip to content
Discussion options

You must be logged in to vote

All of JAX's parallelism APIs assume SPMD execution, which means that the program dispatched onto a single device during a single parallel launch has to be the same. If targeting a different posterior is doable with the same program and you only need to change the data then you should be all set for that. If you truly need to run different code on different devices, you can compile multiple computations with jax.jit, assign them to specific devices and dispatch them separately from Python code. JAX uses an asynchronous runtime, which should queue the computations on independent sets of devices concurrently.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jeremiecoullon
Comment options

Answer selected by jeremiecoullon
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants