Skip to content

Commit 592e214

Browse files
This CL correctly computes the tensor dim to mesh axis mapping for mixed mesh strategies when computing resharding costs involving such a strategy.
PiperOrigin-RevId: 681513792
1 parent 93be085 commit 592e214

File tree

6 files changed

+142
-27
lines changed

6 files changed

+142
-27
lines changed

xla/hlo/experimental/auto_sharding/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,12 @@ cc_library(
261261
":auto_sharding_strategy",
262262
":auto_sharding_util",
263263
":profiling_result",
264-
"//xla:array",
265264
"//xla:shape_util",
266265
"//xla/hlo/ir:hlo",
267266
"//xla/service/spmd:spmd_partitioner",
267+
"@com_google_absl//absl/container:btree",
268+
"@com_google_absl//absl/container:flat_hash_map",
269+
"@com_google_absl//absl/log:check",
268270
"@com_google_absl//absl/strings",
269271
"@com_google_absl//absl/types:span",
270272
],

xla/hlo/experimental/auto_sharding/auto_sharding_test.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DAllOptions) {
272272
option.device_mesh_ids = {0, 1, 2, 3};
273273
option.device_mesh_alpha = {1.0, 1.0};
274274
option.device_mesh_beta = {0.01, 1.0};
275+
option.allow_mixed_mesh_shape = false;
275276
RunMatMulAutoShardingWithOptions(option, 4, 2);
276277

277278
option.enable = true;
@@ -288,6 +289,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DNoAlphaBeta) {
288289
option.enable = true;
289290
option.device_mesh_shape = {2, 2};
290291
option.device_mesh_ids = {0, 1, 2, 3};
292+
option.allow_mixed_mesh_shape = false;
291293
RunMatMulAutoShardingWithOptions(option, 4, 2);
292294

293295
option.enable = true;
@@ -304,6 +306,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DNoAlphaBetaMeshIds) {
304306
AutoShardingOption option;
305307
option.enable = true;
306308
option.device_mesh_shape = {2, 2};
309+
option.allow_mixed_mesh_shape = false;
307310
RunMatMulAutoShardingWithOptions(option, 4, 2);
308311

309312
option.enable = true;
@@ -322,6 +325,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DNoMeshIds) {
322325
option.device_mesh_shape = {2, 2};
323326
option.device_mesh_alpha = {1.0, 1.0};
324327
option.device_mesh_beta = {0.01, 1.0};
328+
option.allow_mixed_mesh_shape = false;
325329
RunMatMulAutoShardingWithOptions(option, 4, 2);
326330

327331
option.enable = true;
@@ -349,6 +353,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape3DAllOptions) {
349353
TEST_F(AutoShardingTest, Matmul3DMeshShape2DSharding) {
350354
AutoShardingOption option;
351355
option.enable = true;
356+
option.allow_mixed_mesh_shape = false;
352357
option.device_mesh_shape = {1, 2, 2};
353358
RunMatMulAutoShardingWithOptions(option, 4, 2);
354359

@@ -458,7 +463,7 @@ TEST_F(AutoShardingTest, LargeSize) {
458463
option.device_mesh_alpha = {1.0, 1.0, 1.0, 1.0};
459464
option.device_mesh_beta = {1.0, 1.0, 1.0, 1.0};
460465
option.memory_budget_per_device = (8192 + 8192 * 2 + 8192 * 4 / 8);
461-
RunMatMulAutoShardingWithOptions(option, 7, 1);
466+
RunMatMulAutoShardingWithOptions(option, 56, 1);
462467
}
463468

464469
TEST_F(AutoShardingTest, InvalidOptions) {
@@ -716,6 +721,7 @@ ENTRY %elementwise {
716721
.enable = true,
717722
.preserve_shardings =
718723
AutoShardingOption::PreserveShardingsType::kKeepAllShardings,
724+
.allow_mixed_mesh_shape = false,
719725
.only_allow_divisible_input_output = false,
720726
.device_mesh_shape = {16, 16},
721727
.device_mesh_alpha = {1.0, 1.0},

xla/hlo/experimental/auto_sharding/auto_sharding_util.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,55 @@ absl::StatusOr<std::vector<int64_t>> GetTensorDimToMeshDimNoCrash(
11931193
return tensor_dim_to_device_dim;
11941194
}
11951195

1196+
absl::StatusOr<std::vector<absl::btree_set<int64_t>>>
1197+
GetTensorDimToMeshDimMixedMeshSharding(int64_t tensor_shape_rank,
1198+
const HloSharding& sharding,
1199+
const DeviceMesh& device_mesh,
1200+
bool consider_reverse_device_meshes) {
1201+
CHECK(!sharding.IsReplicated());
1202+
// Check the compatibility of tensor_shape_rank and spec
1203+
if (tensor_shape_rank != sharding.TiledDataRank()) {
1204+
return absl::InvalidArgumentError(
1205+
"Tensor shape rank should be equal to the tiled data rank of the input "
1206+
"spec.");
1207+
}
1208+
if (!TileAssignmentMatchesMesh(sharding, device_mesh)) {
1209+
return absl::InvalidArgumentError(
1210+
"Device mesh and tile assignment need to have the same number of "
1211+
"sharded dims.");
1212+
}
1213+
1214+
TF_ASSIGN_OR_RETURN(
1215+
std::vector<int64_t> axes,
1216+
GetMeshDimPermutationOrderInShardingSpec(sharding, device_mesh,
1217+
consider_reverse_device_meshes));
1218+
1219+
std::vector<absl::btree_set<int64_t>> tensor_dim_to_mesh_axis_mapping;
1220+
int mesh_axis_idx = 0;
1221+
for (int i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
1222+
if (sharding.tile_assignment().dim(i) == 1) {
1223+
tensor_dim_to_mesh_axis_mapping.push_back({});
1224+
continue;
1225+
}
1226+
1227+
absl::btree_set<int64_t> mesh_axes_for_this_tensor_dim;
1228+
int product = 1;
1229+
do {
1230+
if (mesh_axis_idx >= device_mesh.num_dimensions()) {
1231+
return absl::InternalError(
1232+
"Mismatched mesh shapes encountered. This can happen when the "
1233+
"sharding does not map well to the mesh shape provided");
1234+
}
1235+
product *= device_mesh.dim(axes[mesh_axis_idx]);
1236+
mesh_axes_for_this_tensor_dim.insert(axes[mesh_axis_idx]);
1237+
mesh_axis_idx++;
1238+
} while (product < sharding.tile_assignment().dim(i));
1239+
CHECK(!mesh_axes_for_this_tensor_dim.empty());
1240+
tensor_dim_to_mesh_axis_mapping.push_back(mesh_axes_for_this_tensor_dim);
1241+
}
1242+
return tensor_dim_to_mesh_axis_mapping;
1243+
}
1244+
11961245
std::vector<int64_t> GetTensorDimToMeshDim(
11971246
int64_t tensor_shape_rank, const HloSharding& spec,
11981247
const DeviceMesh& device_mesh, bool consider_reverse_device_meshes) {

xla/hlo/experimental/auto_sharding/auto_sharding_util.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include <vector>
2525

2626
#include "absl/algorithm/container.h"
27+
#include "absl/container/btree_set.h"
2728
#include "absl/container/flat_hash_map.h"
2829
#include "absl/container/flat_hash_set.h"
2930
#include "absl/functional/function_ref.h"
@@ -472,6 +473,15 @@ absl::StatusOr<int64_t> CheckArithmeticSequence(
472473
// device mesh.
473474
bool TileAssignmentMatchesMesh(const HloSharding& spec, const DeviceMesh& mesh);
474475

476+
absl::StatusOr<std::vector<int64_t>> GetMeshDimPermutationOrderInShardingSpec(
477+
const HloSharding& spec, const Array<int64_t>& device_mesh,
478+
bool consider_reverse_device_meshes);
479+
480+
absl::StatusOr<std::vector<absl::btree_set<int64_t>>>
481+
GetTensorDimToMeshDimMixedMeshSharding(
482+
int64_t tensor_shape_rank, const HloSharding& sharding,
483+
const DeviceMesh& device_mesh, bool consider_reverse_device_meshes = false);
484+
475485
// Get the mapped mesh dimension for every tensor dimension.
476486
// The returned value maps ith tensor dim to one mesh dim. -1 means the tensor
477487
// is replicated on that dimension.

xla/hlo/experimental/auto_sharding/cluster_environment.cc

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,21 @@ limitations under the License.
1616
#include "xla/hlo/experimental/auto_sharding/cluster_environment.h"
1717

1818
#include <algorithm>
19+
#include <cmath>
1920
#include <cstddef>
2021
#include <cstdint>
2122
#include <optional>
2223
#include <string>
2324
#include <utility>
2425
#include <vector>
2526

27+
#include "absl/container/btree_set.h"
28+
#include "absl/container/flat_hash_map.h"
29+
#include "absl/log/check.h"
2630
#include "absl/types/span.h"
2731
#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h"
2832
#include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h"
33+
#include "xla/hlo/ir/hlo_sharding.h"
2934
#include "xla/service/spmd/spmd_partitioner_util.h"
3035
#include "xla/shape.h"
3136

@@ -121,35 +126,79 @@ double ClusterEnvironment::AllToAllCost(double num_bytes, int mesh_dim) const {
121126
return AllToAllCostUtil(num_bytes, mesh_dim, num_devices);
122127
}
123128

129+
template <typename T>
130+
bool IsSubset(absl::btree_set<T> superset, absl::btree_set<T> subset) {
131+
for (const T& element : subset) {
132+
if (!superset.contains(element)) {
133+
return false;
134+
}
135+
}
136+
return true;
137+
}
138+
124139
// Do not consider device id changes yet.
125140
double ClusterEnvironment::ReshardingCostMixedMeshShape(
126-
const Shape& shape, absl::Span<const int64_t> src_tensor_dim_to_mesh_dim,
127-
absl::Span<const int64_t> dst_tensor_dim_to_mesh_dim) const {
141+
const Shape& shape, const HloSharding& src_sharding,
142+
const HloSharding& dst_sharding) const {
143+
absl::StatusOr<std::vector<absl::btree_set<int64_t>>>
144+
src_tensor_dim_to_mesh_axis = GetTensorDimToMeshDimMixedMeshSharding(
145+
shape.rank(), src_sharding, device_mesh_,
146+
/*consider_reverse_device_meshes=*/true);
147+
absl::StatusOr<std::vector<absl::btree_set<int64_t>>>
148+
dst_tensor_dim_to_mesh_axis = GetTensorDimToMeshDimMixedMeshSharding(
149+
shape.rank(), dst_sharding, device_mesh_,
150+
/*consider_reverse_device_meshes=*/true);
151+
if (!src_tensor_dim_to_mesh_axis.ok() || !dst_tensor_dim_to_mesh_axis.ok()) {
152+
return OverestimateReplicationCost(shape, src_sharding, device_mesh_);
153+
}
154+
128155
int64_t num_devices = device_mesh_.num_elements();
129-
double resharding_costs = 0.0;
156+
std::vector<int64_t> collective_mesh_axes;
157+
// Only consider sharded dimensions, do not consider replicate_on_last_dim.
130158
for (size_t i = 0; i < shape.rank(); ++i) {
131-
// Only consider sharded dimensions, do not consider replicate_on_last_dim.
132-
if (src_tensor_dim_to_mesh_dim[i] == dst_tensor_dim_to_mesh_dim[i]) {
159+
if ((*src_tensor_dim_to_mesh_axis)[i] ==
160+
(*dst_tensor_dim_to_mesh_axis)[i]) {
133161
continue;
134162
}
135-
if (dst_tensor_dim_to_mesh_dim[i] == -1 ||
136-
src_tensor_dim_to_mesh_dim[i] == -1) {
137-
// AllToAll cost
138-
int64_t communication_dim;
139-
if (dst_tensor_dim_to_mesh_dim[i] != -1) {
140-
communication_dim = dst_tensor_dim_to_mesh_dim[i];
141-
} else {
142-
communication_dim = src_tensor_dim_to_mesh_dim[i];
143-
}
144-
int64_t communication_bytes = ByteSizeOfShape(shape);
145-
resharding_costs +=
146-
AllToAllCostUtil(communication_bytes, communication_dim, num_devices);
147-
} else {
163+
if (IsSubset((*dst_tensor_dim_to_mesh_axis)[i],
164+
(*src_tensor_dim_to_mesh_axis)[i])) {
165+
// do nothing; the src is sharded more than the dest
166+
continue;
167+
}
168+
if (!IsSubset((*src_tensor_dim_to_mesh_axis)[i],
169+
(*dst_tensor_dim_to_mesh_axis)[i])) {
148170
// Do not support this sharding, assuming it is gonna be very expensive.
149-
return kInfinityCost;
171+
return OverestimateReplicationCost(shape, src_sharding, device_mesh_);
172+
}
173+
for (int64_t mesh_dim : (*src_tensor_dim_to_mesh_axis)[i]) {
174+
if (!(*dst_tensor_dim_to_mesh_axis)[i].contains(mesh_dim)) {
175+
collective_mesh_axes.push_back(mesh_dim);
176+
}
150177
}
151178
}
152-
return resharding_costs;
179+
180+
auto is_mesh_axis_used_for_dst_sharding = [&](int64_t mesh_dim) {
181+
int end = dst_sharding.ReplicateOnLastTileDim()
182+
? dst_tensor_dim_to_mesh_axis->size() - 1
183+
: dst_tensor_dim_to_mesh_axis->size();
184+
for (int i = 0; i < end; ++i) {
185+
if ((*dst_tensor_dim_to_mesh_axis)[i].contains(mesh_dim)) {
186+
return true;
187+
}
188+
}
189+
return false;
190+
};
191+
192+
double resharding_cost = 0.0;
193+
int64_t communication_bytes = ByteSizeOfShape(shape);
194+
for (int mesh_dim : collective_mesh_axes) {
195+
bool used_for_dst_sharding = is_mesh_axis_used_for_dst_sharding(mesh_dim);
196+
resharding_cost +=
197+
used_for_dst_sharding
198+
? AllToAllCostUtil(communication_bytes, mesh_dim, num_devices)
199+
: AllGatherCost(communication_bytes, mesh_dim);
200+
}
201+
return resharding_cost;
153202
}
154203

155204
double ClusterEnvironment::CollectivePermuteCost(
@@ -313,8 +362,7 @@ double ClusterEnvironment::ReshardingCost(const Shape& shape,
313362
dst_tensor_dim_to_mesh_dim_or.value();
314363

315364
if (src_n_dim != dst_n_dim && src_n_dim != -1 && dst_n_dim != -1) {
316-
return ReshardingCostMixedMeshShape(shape, src_tensor_dim_to_mesh_dim,
317-
dst_tensor_dim_to_mesh_dim);
365+
return ReshardingCostMixedMeshShape(shape, src_spec, dst_spec);
318366
}
319367

320368
AdjustTensorMeshDimMapping(src_tensor_dim_to_mesh_dim, src_n_dim);

xla/hlo/experimental/auto_sharding/cluster_environment.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ class ClusterEnvironment {
145145

146146
double AllToAllCost(double num_bytes, int mesh_dim) const;
147147

148-
double ReshardingCostMixedMeshShape(
149-
const Shape& shape, absl::Span<const int64_t> src_tensor_dim_to_mesh_dim,
150-
absl::Span<const int64_t> dst_tensor_dim_to_mesh_dim) const;
148+
double ReshardingCostMixedMeshShape(const Shape& shape,
149+
const HloSharding& src_sharding,
150+
const HloSharding& dst_sharding) const;
151151

152152
double CollectivePermuteCost(
153153
double num_bytes,

0 commit comments

Comments
 (0)