Skip to content
Discussion options

You must be logged in to vote

I got @pschuh to help me think on this one.

We recently added an experimental (intentionally not-yet-documented) option to vmap via the keyword argument spmd_axis_name that might be useful here. See #11807. What do you think?

Is there an easy way to see what decisions the partitioner is making for arrays created deep inside the computation, specifically random.bernoulli?

In general, you can't inspect pjit's sharding specs because they're applied only downstream, at compilation time.

should I just be using xmap for the batching and use pjit for the model partitioning? I think this would ensure that the batch axis is always sharded, even if it's a bit messy.

We're fairly certain that xm…

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@dlwh
Comment options

@pschuh
Comment options

pschuh Aug 11, 2022
Collaborator

@dlwh
Comment options

@froystig
Comment options

@dlwh
Comment options

Answer selected by dlwh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants