auto parallelize #19670
-
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Hi - the issue is that |
Beta Was this translation helpful? Give feedback.
-
You need to use For inputs/outputs, you can use
Yeah this is in the cards. it's just a semantics breaking change so we need to do it carefully. There have been lots of discussions to merge device_put and with_sharding_constraint but we had decided to keep them separate. But we will revist that soon with the memories work. |
Beta Was this translation helpful? Give feedback.
You need to use
jax.lax.with_sharding_constraint
if you want to constraint shardings of intermediates inside jit.For inputs/outputs, you can use
in_shardings
orout_shardings
or shard the inputs before you pass them to jitted function (I would recommend doing this).Yeah this is in the cards. it's just a semantics breaking change so we need to do it carefully. There have been lots of discussions to merge device_put and with_sharding_constraint but we had decided to keep them separate. But we will revist that soon with the memories work.