@@ -704,9 +704,9 @@ HloSharding TransposeSharding(const HloSharding& sharding,
704704
705705std::optional<HloSharding> ReshapeSharding (const Shape& source_shape,
706706 const Shape& target_shape,
707- const HloSharding& sharding ) {
708- if (sharding .IsTileMaximal () || sharding .IsManual ()) {
709- return sharding ;
707+ const HloSharding& source_sharding ) {
708+ if (source_sharding .IsTileMaximal () || source_sharding .IsManual ()) {
709+ return source_sharding ;
710710 }
711711
712712 // In case of a tiled sharding, the reshaped sharding will be valid if the
@@ -732,10 +732,24 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
732732 DimensionVector target_dims_stack (target_shape.dimensions ().rbegin (),
733733 target_shape.dimensions ().rend ());
734734 DimensionVector sharding_tile_dims_stack (
735- sharding.tile_assignment ().dimensions ().begin (),
736- sharding.tile_assignment ().dimensions ().begin () + source_shape.rank ());
735+ source_sharding.tile_assignment ().dimensions ().begin (),
736+ source_sharding.tile_assignment ().dimensions ().begin () +
737+ source_shape.rank ());
737738 std::reverse (sharding_tile_dims_stack.begin (),
738739 sharding_tile_dims_stack.end ());
740+ int64_t source_dims_index = -1 ;
741+ std::vector<int64_t > dims_to_replicate;
742+
743+ auto source_dims_push = [&](int64_t shape_size, int64_t partitions) {
744+ source_dims_stack.push_back (shape_size);
745+ sharding_tile_dims_stack.push_back (partitions);
746+ source_dims_index--;
747+ };
748+ auto source_dims_pop = [&]() {
749+ source_dims_stack.pop_back ();
750+ sharding_tile_dims_stack.pop_back ();
751+ source_dims_index++;
752+ };
739753
740754 bool inplace_add_sharding_dim = false ;
741755 auto append_sharding_dim = [&](int64_t size) {
@@ -753,22 +767,20 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
753767 break ;
754768 }
755769
756- int64_t source_dim_product = 1 ;
770+ int64_t source_dims_product = 1 ;
757771 while (!sharding_tile_dims_stack.empty () &&
758772 sharding_tile_dims_stack.back () == 1 ) {
759- sharding_tile_dims_stack.pop_back ();
760- source_dim_product *= source_dims_stack.back ();
761- source_dims_stack.pop_back ();
773+ source_dims_product *= source_dims_stack.back ();
774+ source_dims_pop ();
762775 }
763776 while (!target_dims_stack.empty () && target_dims_stack.back () > 1 &&
764- source_dim_product % target_dims_stack.back () == 0 ) {
765- source_dim_product /= target_dims_stack.back ();
777+ source_dims_product % target_dims_stack.back () == 0 ) {
778+ source_dims_product /= target_dims_stack.back ();
766779 target_dims_stack.pop_back ();
767780 append_sharding_dim (1 );
768781 }
769- if (source_dim_product != 1 ) {
770- source_dims_stack.push_back (source_dim_product);
771- sharding_tile_dims_stack.push_back (1 );
782+ if (source_dims_product != 1 ) {
783+ source_dims_push (source_dims_product, 1 );
772784 }
773785
774786 if (target_dims_stack.empty ()) {
@@ -781,9 +793,8 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
781793 int64_t s_partitions = 1 ;
782794 if (!source_dims_stack.empty ()) {
783795 s_size = source_dims_stack.back ();
784- source_dims_stack.pop_back ();
785796 s_partitions = sharding_tile_dims_stack.back ();
786- sharding_tile_dims_stack. pop_back ();
797+ source_dims_pop ();
787798 }
788799
789800 if (s_size == t_size) {
@@ -793,19 +804,20 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
793804 t_size % s_partitions == 0 ) {
794805 // If s_partitions evenly divides both s_size and t_size, we can add this
795806 // sharding dim and work on shard sized shapes in the next iteration.
796- source_dims_stack. push_back (s_size / s_partitions);
807+ source_dims_push (s_size / s_partitions, 1 );
797808 target_dims_stack.push_back (t_size / s_partitions);
798- sharding_tile_dims_stack.push_back (1 );
799809 append_sharding_dim (s_partitions);
800810 inplace_add_sharding_dim = true ;
801811 } else if (t_size == 1 ) {
802812 // Trivial dimension added.
803813 append_sharding_dim (1 );
804- source_dims_stack.push_back (s_size);
805- sharding_tile_dims_stack.push_back (s_partitions);
814+ source_dims_push (s_size, s_partitions);
806815 } else if (s_size == 1 ) {
807816 // Trivial dimension removed.
808817 target_dims_stack.push_back (t_size);
818+ if (s_partitions > 1 ) {
819+ dims_to_replicate.push_back (source_dims_index);
820+ }
809821 } else if (s_size > t_size) {
810822 // Dimension split.
811823 if (s_size % s_partitions != 0 ) {
@@ -819,13 +831,11 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
819831 if (t_size % s_partitions == 0 ) {
820832 append_sharding_dim (s_partitions);
821833 // We have part of the s_size unprocessed, so put it back to stack.
822- source_dims_stack.push_back (s_size / t_size);
823- sharding_tile_dims_stack.push_back (1 );
834+ source_dims_push (s_size / t_size, 1 );
824835 } else if (s_partitions % t_size == 0 ) {
825836 append_sharding_dim (t_size);
826837 // We have part of the s_size unprocessed, so put it back to stack.
827- source_dims_stack.push_back (s_size / t_size);
828- sharding_tile_dims_stack.push_back (s_partitions / t_size);
838+ source_dims_push (s_size / t_size, s_partitions / t_size);
829839 } else {
830840 append_sharding_dim (std::gcd (t_size, s_partitions));
831841 break ;
@@ -860,6 +870,16 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
860870 while (target_tile_assignment_dimensions.size () < target_shape.rank ()) {
861871 target_tile_assignment_dimensions.push_back (1 );
862872 }
873+
874+ // If there is a source dimension satisfying (1) size is 1, (2) partition > 1,
875+ // and (3) there is no corresponding target dimension, we replicate the source
876+ // sharding along this dimension since the source sharding cannot be
877+ // propagated along this dimension.
878+ const HloSharding sharding = !dims_to_replicate.empty ()
879+ ? PartiallyReplicateTiledShardingOnDims (
880+ source_sharding, dims_to_replicate)
881+ : source_sharding;
882+
863883 for (int64_t i = sharding.TiledDataRank ();
864884 i < sharding.tile_assignment ().num_dimensions (); ++i) {
865885 target_tile_assignment_dimensions.push_back (
0 commit comments