Skip to content

Commit 6345728

Browse files
ZixuanJiangGoogle-ML-Automation
authored andcommitted
[XLA:SPMD] Fix ReshapeSharding for dimensions of size 1 and >1 partitions.
If there is a source dimension satisfying the following conditions, we replicate the source sharding along this dimension since the source sharding cannot be propagated along this dimension. 1. its size is 1 2. its partitions is greater than 1 3. there is no corresponding target dimension An example is shown below. Please refer to the added examples in hlo_sharding_util_test.cc. ``` input shape: [1,2,16] input sharding: [3,2,2]<=[12] output shape: [2,16] output sharding before this cl: [2,2,3]<=[12] last_tile_dim_replicate output sharding with this cl: [2,2,3]<=[3,2,2]T(1,2,0) last_tile_dim_replicate ``` PiperOrigin-RevId: 681514767
1 parent 592e214 commit 6345728

File tree

3 files changed

+104
-25
lines changed

3 files changed

+104
-25
lines changed

xla/hlo/utils/hlo_sharding_util.cc

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -704,9 +704,9 @@ HloSharding TransposeSharding(const HloSharding& sharding,
704704

705705
std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
706706
const Shape& target_shape,
707-
const HloSharding& sharding) {
708-
if (sharding.IsTileMaximal() || sharding.IsManual()) {
709-
return sharding;
707+
const HloSharding& source_sharding) {
708+
if (source_sharding.IsTileMaximal() || source_sharding.IsManual()) {
709+
return source_sharding;
710710
}
711711

712712
// In case of a tiled sharding, the reshaped sharding will be valid if the
@@ -732,10 +732,24 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
732732
DimensionVector target_dims_stack(target_shape.dimensions().rbegin(),
733733
target_shape.dimensions().rend());
734734
DimensionVector sharding_tile_dims_stack(
735-
sharding.tile_assignment().dimensions().begin(),
736-
sharding.tile_assignment().dimensions().begin() + source_shape.rank());
735+
source_sharding.tile_assignment().dimensions().begin(),
736+
source_sharding.tile_assignment().dimensions().begin() +
737+
source_shape.rank());
737738
std::reverse(sharding_tile_dims_stack.begin(),
738739
sharding_tile_dims_stack.end());
740+
int64_t source_dims_index = -1;
741+
std::vector<int64_t> dims_to_replicate;
742+
743+
auto source_dims_push = [&](int64_t shape_size, int64_t partitions) {
744+
source_dims_stack.push_back(shape_size);
745+
sharding_tile_dims_stack.push_back(partitions);
746+
source_dims_index--;
747+
};
748+
auto source_dims_pop = [&]() {
749+
source_dims_stack.pop_back();
750+
sharding_tile_dims_stack.pop_back();
751+
source_dims_index++;
752+
};
739753

740754
bool inplace_add_sharding_dim = false;
741755
auto append_sharding_dim = [&](int64_t size) {
@@ -753,22 +767,20 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
753767
break;
754768
}
755769

756-
int64_t source_dim_product = 1;
770+
int64_t source_dims_product = 1;
757771
while (!sharding_tile_dims_stack.empty() &&
758772
sharding_tile_dims_stack.back() == 1) {
759-
sharding_tile_dims_stack.pop_back();
760-
source_dim_product *= source_dims_stack.back();
761-
source_dims_stack.pop_back();
773+
source_dims_product *= source_dims_stack.back();
774+
source_dims_pop();
762775
}
763776
while (!target_dims_stack.empty() && target_dims_stack.back() > 1 &&
764-
source_dim_product % target_dims_stack.back() == 0) {
765-
source_dim_product /= target_dims_stack.back();
777+
source_dims_product % target_dims_stack.back() == 0) {
778+
source_dims_product /= target_dims_stack.back();
766779
target_dims_stack.pop_back();
767780
append_sharding_dim(1);
768781
}
769-
if (source_dim_product != 1) {
770-
source_dims_stack.push_back(source_dim_product);
771-
sharding_tile_dims_stack.push_back(1);
782+
if (source_dims_product != 1) {
783+
source_dims_push(source_dims_product, 1);
772784
}
773785

774786
if (target_dims_stack.empty()) {
@@ -781,9 +793,8 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
781793
int64_t s_partitions = 1;
782794
if (!source_dims_stack.empty()) {
783795
s_size = source_dims_stack.back();
784-
source_dims_stack.pop_back();
785796
s_partitions = sharding_tile_dims_stack.back();
786-
sharding_tile_dims_stack.pop_back();
797+
source_dims_pop();
787798
}
788799

789800
if (s_size == t_size) {
@@ -793,19 +804,20 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
793804
t_size % s_partitions == 0) {
794805
// If s_partitions evenly divides both s_size and t_size, we can add this
795806
// sharding dim and work on shard sized shapes in the next iteration.
796-
source_dims_stack.push_back(s_size / s_partitions);
807+
source_dims_push(s_size / s_partitions, 1);
797808
target_dims_stack.push_back(t_size / s_partitions);
798-
sharding_tile_dims_stack.push_back(1);
799809
append_sharding_dim(s_partitions);
800810
inplace_add_sharding_dim = true;
801811
} else if (t_size == 1) {
802812
// Trivial dimension added.
803813
append_sharding_dim(1);
804-
source_dims_stack.push_back(s_size);
805-
sharding_tile_dims_stack.push_back(s_partitions);
814+
source_dims_push(s_size, s_partitions);
806815
} else if (s_size == 1) {
807816
// Trivial dimension removed.
808817
target_dims_stack.push_back(t_size);
818+
if (s_partitions > 1) {
819+
dims_to_replicate.push_back(source_dims_index);
820+
}
809821
} else if (s_size > t_size) {
810822
// Dimension split.
811823
if (s_size % s_partitions != 0) {
@@ -819,13 +831,11 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
819831
if (t_size % s_partitions == 0) {
820832
append_sharding_dim(s_partitions);
821833
// We have part of the s_size unprocessed, so put it back to stack.
822-
source_dims_stack.push_back(s_size / t_size);
823-
sharding_tile_dims_stack.push_back(1);
834+
source_dims_push(s_size / t_size, 1);
824835
} else if (s_partitions % t_size == 0) {
825836
append_sharding_dim(t_size);
826837
// We have part of the s_size unprocessed, so put it back to stack.
827-
source_dims_stack.push_back(s_size / t_size);
828-
sharding_tile_dims_stack.push_back(s_partitions / t_size);
838+
source_dims_push(s_size / t_size, s_partitions / t_size);
829839
} else {
830840
append_sharding_dim(std::gcd(t_size, s_partitions));
831841
break;
@@ -860,6 +870,16 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
860870
while (target_tile_assignment_dimensions.size() < target_shape.rank()) {
861871
target_tile_assignment_dimensions.push_back(1);
862872
}
873+
874+
// If there is a source dimension satisfying (1) size is 1, (2) partition > 1,
875+
// and (3) there is no corresponding target dimension, we replicate the source
876+
// sharding along this dimension since the source sharding cannot be
877+
// propagated along this dimension.
878+
const HloSharding sharding = !dims_to_replicate.empty()
879+
? PartiallyReplicateTiledShardingOnDims(
880+
source_sharding, dims_to_replicate)
881+
: source_sharding;
882+
863883
for (int64_t i = sharding.TiledDataRank();
864884
i < sharding.tile_assignment().num_dimensions(); ++i) {
865885
target_tile_assignment_dimensions.push_back(

xla/hlo/utils/hlo_sharding_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ HloSharding TransposeSharding(const HloSharding& sharding,
121121
// maximal sharding returns the original sharding.
122122
std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
123123
const Shape& target_shape,
124-
const HloSharding& sharding);
124+
const HloSharding& source_sharding);
125125

126126
// Propagates sharding through reshape. It tries to find partial matches on
127127
// subsets of dimensions that could satisfy ReshapeSharding() constraints, then

xla/hlo/utils/hlo_sharding_util_test.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,65 @@ TEST(HloShardingUtilTest, TransposeShardingWithCollapsedDimsSubgroupManual) {
132132
output);
133133
}
134134

135+
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned1) {
136+
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 16});
137+
Shape output_shape = ShapeUtil::MakeShape(F32, {2, 16});
138+
HloSharding input_sharding = HloSharding::IotaTile({3, 2, 2});
139+
HloSharding output_sharding =
140+
HloSharding::PartialTile(TileAssignment({2, 2, 3}, {3, 2, 2}, {1, 2, 0}));
141+
std::optional<HloSharding> result =
142+
ReshapeSharding(input_shape, output_shape, input_sharding);
143+
EXPECT_TRUE(result.has_value());
144+
EXPECT_EQ(result.value(), output_sharding);
145+
}
146+
147+
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned2) {
148+
Shape input_shape = ShapeUtil::MakeShape(F32, {2, 1, 16});
149+
Shape output_shape = ShapeUtil::MakeShape(F32, {2, 16});
150+
HloSharding input_sharding = HloSharding::IotaTile({2, 3, 2});
151+
HloSharding output_sharding =
152+
HloSharding::PartialTile(TileAssignment({2, 2, 3}, {2, 3, 2}, {0, 2, 1}));
153+
std::optional<HloSharding> result =
154+
ReshapeSharding(input_shape, output_shape, input_sharding);
155+
EXPECT_TRUE(result.has_value());
156+
EXPECT_EQ(result.value(), output_sharding);
157+
}
158+
159+
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned3) {
160+
Shape input_shape = ShapeUtil::MakeShape(F32, {2, 1, 16});
161+
Shape output_shape = ShapeUtil::MakeShape(F32, {32});
162+
HloSharding input_sharding = HloSharding::IotaTile({2, 3, 2});
163+
HloSharding output_sharding =
164+
HloSharding::PartialTile(TileAssignment({4, 3}, {2, 3, 2}, {0, 2, 1}));
165+
std::optional<HloSharding> result =
166+
ReshapeSharding(input_shape, output_shape, input_sharding);
167+
EXPECT_TRUE(result.has_value());
168+
EXPECT_EQ(result.value(), output_sharding);
169+
}
170+
171+
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned4) {
172+
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 32});
173+
Shape output_shape = ShapeUtil::MakeShape(F32, {2, 16});
174+
HloSharding input_sharding = HloSharding::IotaTile({3, 4});
175+
HloSharding output_sharding =
176+
HloSharding::PartialTile(TileAssignment({2, 2, 3}, {3, 4}, {1, 0}));
177+
std::optional<HloSharding> result =
178+
ReshapeSharding(input_shape, output_shape, input_sharding);
179+
EXPECT_TRUE(result.has_value());
180+
EXPECT_EQ(result.value(), output_sharding);
181+
}
182+
183+
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned5) {
184+
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 32});
185+
Shape output_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 16});
186+
HloSharding input_sharding = HloSharding::IotaTile({2, 3, 4});
187+
HloSharding output_sharding = HloSharding::IotaTile({2, 3, 2, 2});
188+
std::optional<HloSharding> result =
189+
ReshapeSharding(input_shape, output_shape, input_sharding);
190+
EXPECT_TRUE(result.has_value());
191+
EXPECT_EQ(result.value(), output_sharding);
192+
}
193+
135194
TEST(HloShardingUtilTest, ReshapeShardingMaximal) {
136195
Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5});
137196
Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2});

0 commit comments

Comments
 (0)