Skip to content
Discussion options

You must be logged in to vote

Use jax.lax.all_gather_invariant if you want out_specs to be P()

If you want to use all_gather, then your out_specs should be P(None, feats)

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@AakashKumarNain
Comment options

@yashk2810
Comment options

Answer selected by AakashKumarNain
@AakashKumarNain
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants