Skip to content
Discussion options

You must be logged in to vote

This is working as intended. When using JAX in multicontroller mode, each process should only pass pmap arguments for its own local devices. Here, you have one local device in each process, so your pmap should receive an array with a leading axis of size 1.

See https://jax.readthedocs.io/en/latest/multi_process.html

Hope that helps!

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@JiahaoYao
Comment options

Answer selected by JiahaoYao
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants