You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[sharding_in_types] Add slice_p and squeeze_p sharding rule to make flash attention work in backward pass
For `slice_p`'s sharding rule, I error out if the operand dim is sharded and the output dim is not divisible by that axis size.
I am working on a design to make JAX support uneven sharding at the top level after which slice_p's sharding rule can just `return operand.sharding`. Another option is to add `out_sharding` to `slice` but after uneven sharding support lands, it won't be necessary.
PiperOrigin-RevId: 698522980
0 commit comments