Skip to content
Discussion options

You must be logged in to vote

For running multiple NNs in parallel, you could use pmap as you noted. I believe it will jit under the hood automatically (see #5681 (comment)). More info on how to use pmap: https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb

If you want to parallelize a single matmul over multiple cores, then use xmap or pjit https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html.
https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html

Replies: 1 comment 7 replies

Comment options

You must be logged in to vote
7 replies
@mattiasmar
Comment options

@sudhakarsingh27
Comment options

@mattiasmar
Comment options

@sudhakarsingh27
Comment options

@sudhakarsingh27
Comment options

Answer selected by ivandustin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants