Skip to content
Discussion options

You must be logged in to vote

You need to use jax.lax.with_sharding_constraint if you want to constraint shardings of intermediates inside jit.

For inputs/outputs, you can use in_shardings or out_shardings or shard the inputs before you pass them to jitted function (I would recommend doing this).

we should make device_put with sharding specifications an error when run under JIT

Yeah this is in the cards. it's just a semantics breaking change so we need to do it carefully. There have been lots of discussions to merge device_put and with_sharding_constraint but we had decided to keep them separate. But we will revist that soon with the memories work.

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@pharringtonp19
Comment options

Comment options

You must be logged in to vote
1 reply
@pharringtonp19
Comment options

Answer selected by pharringtonp19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants