Skip to content

Commit 7ad1d5b

Browse files
authored
[mlir][mesh] removing partial/reduction axes from mesh.sharding (#149805)
[mlir][mesh] Removing partial axes from sharding annotations (discourse 87053)
1 parent 5050a15 commit 7ad1d5b

File tree

13 files changed

+207
-520
lines changed

13 files changed

+207
-520
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ class MeshSharding {
4343
private:
4444
::mlir::FlatSymbolRefAttr mesh;
4545
SmallVector<MeshAxesAttr> split_axes;
46-
SmallVector<MeshAxis> partial_axes;
47-
ReductionKind partial_type = ReductionKind::Sum;
4846
SmallVector<int64_t> static_halo_sizes;
4947
SmallVector<int64_t> static_sharded_dims_offsets;
5048
SmallVector<Value> dynamic_halo_sizes;
@@ -55,17 +53,13 @@ class MeshSharding {
5553
MeshSharding(Value rhs);
5654
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
5755
ArrayRef<MeshAxesAttr> split_axes_,
58-
ArrayRef<MeshAxis> partial_axes_ = {},
59-
ReductionKind partial_type_ = ReductionKind::Sum,
6056
ArrayRef<int64_t> static_halo_sizes_ = {},
6157
ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
6258
ArrayRef<Value> dynamic_halo_sizes_ = {},
6359
ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
6460
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
6561
::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
6662
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
67-
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
68-
ReductionKind getPartialType() const { return partial_type; }
6963
ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
7064
ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
7165
return static_sharded_dims_offsets;
@@ -79,7 +73,7 @@ class MeshSharding {
7973
bool operator!=(Value rhs) const;
8074
bool operator==(const MeshSharding &rhs) const;
8175
bool operator!=(const MeshSharding &rhs) const;
82-
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
76+
bool equalSplitAxes(const MeshSharding &rhs) const;
8377
bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
8478
bool equalHaloSizes(const MeshSharding &rhs) const;
8579
bool equalShardSizes(const MeshSharding &rhs) const;
@@ -110,10 +104,9 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
110104

111105
// Is the same tensor replicated on all processes.
112106
inline bool isFullReplication(MeshSharding sharding) {
113-
return sharding.getPartialAxes().empty() &&
114-
llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
115-
return axes.asArrayRef().empty();
116-
});
107+
return llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
108+
return axes.asArrayRef().empty();
109+
});
117110
}
118111

119112
inline mesh::MeshOp

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
204204
let description = [{
205205
The MeshSharding specifies how a tensor is sharded and distributed across the
206206
process mesh. It is typically used in a `mesh.shard` operation.
207-
The operation has the follwing attributes and operands:
207+
The operation has the following attributes and operands:
208208

209209
1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
210210
mesh where the distributed tensor is placed. The symbol must resolve to a
@@ -215,23 +215,15 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
215215
its value is [x, y], it indicates that the tensor's i-th dimension is splitted
216216
along the x and y axes of the device mesh.
217217

218-
3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial
219-
one along the specified mesh axes. An all-reduce should be applied to obtain
220-
the complete tensor, with reduction type being specified by `partial_type`.
221-
222-
4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce
223-
op. It has 4 possible values:
224-
`generic`: is not an allowed value inside a shard attribute.
225-
226-
5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
218+
3. [Optional] Sizes of halos to be added for each sharded tensor dimension.
227219
`halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each
228220
sharded dimension. `halo_sizes = [1, 2]` means that the first sharded dimension
229221
gets an additional halo of size 1 at the start of the first dimension and a halo
230222
size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
231223
sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
232224
seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
233225

234-
6. [Optional] Offsets for each shard and sharded tensor dimension.
226+
4. [Optional] Offsets for each shard and sharded tensor dimension.
235227
`sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
236228
sharded tensor dimension the offsets (starting index) of all shards in that
237229
dimension and an additional value for the end of the last shard are provided.
@@ -260,14 +252,6 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
260252
// The tensor is sharded on the first dimension along axis 0 of @mesh0
261253
%sharding1 = mesh.sharding @mesh0 split_axes = [[0]]
262254

263-
// The tensor is sharded on its first dimension along axis 0 of @mesh0 and
264-
// it is also a partial_sum along mesh axis 1.
265-
%sharding2 = mesh.sharding @mesh0 split_axes = [[0] split_axes = []] partial = sum[1]
266-
267-
// The tensor is sharded on its first dimension along axis 0 of @mesh0 and
268-
// it is also a partial_max along mesh axis 1.
269-
%sharding3 = mesh.sharding @mesh0 split_axes = [[0]] partial = max[1]
270-
271255
// Could be used for a mesh.shard op
272256
%sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
273257

@@ -287,8 +271,6 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
287271
let arguments = (ins
288272
FlatSymbolRefAttr:$mesh,
289273
Mesh_MeshAxesArrayAttr:$split_axes,
290-
OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
291-
OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
292274
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
293275
Variadic<I64>:$dynamic_sharded_dims_offsets,
294276
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
@@ -300,20 +282,15 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
300282
let assemblyFormat = [{
301283
$mesh
302284
`split_axes` `=` $split_axes
303-
(`partial` `=` $partial_type $partial_axes^)?
304285
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
305286
(`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
306287
attr-dict `:` type($result)
307288
}];
308289
let builders = [
309290
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
310291
"ArrayRef<MeshAxesAttr>":$split_axes,
311-
"ArrayRef<MeshAxis>":$partial_axes,
312-
"mesh::ReductionKind":$partial_type,
313292
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
314293
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
315-
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
316-
"ArrayRef<MeshAxesAttr>":$split_axes)>,
317294
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
318295
"ArrayRef<MeshAxesAttr>":$split_axes,
319296
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,

mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -187,37 +187,19 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
187187
return newOperands;
188188
}
189189

190-
static void createAllReduceForResultWithoutPartialSharding(
191-
Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
192-
MeshSharding resultSharding, ReductionKind reductionKind,
193-
IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
194-
SmallVector<MeshAxis> allReduceMeshAxes;
195-
llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
196-
[&resultSharding](MeshAxis axis) {
197-
return !llvm::is_contained(resultSharding.getPartialAxes(),
198-
axis);
199-
});
200-
if (allReduceMeshAxes.empty()) {
201-
return;
202-
}
203-
204-
Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
205-
Value reducedValue = builder.create<mesh::AllReduceOp>(
206-
spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes,
207-
reductionKind);
208-
spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
209-
}
210-
211190
static void createAllReduceForResultsWithoutPartialShardings(
212191
LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
213192
ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
214193
ImplicitLocOpBuilder &builder) {
215194
ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
216195
for (auto [unshardedLinalgOpResult, resultSharding] :
217196
llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
218-
createAllReduceForResultWithoutPartialSharding(
219-
unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
220-
reductionKind, spmdizationMap, builder);
197+
Value spmdizedLinalgOpResult =
198+
spmdizationMap.lookup(unshardedLinalgOpResult);
199+
Value reducedValue = builder.create<mesh::AllReduceOp>(
200+
spmdizedLinalgOpResult, resultSharding.getMesh(), opReductionMeshAxes,
201+
reductionKind);
202+
spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
221203
}
222204
}
223205

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 17 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -479,37 +479,23 @@ void MeshShapeOp::getAsmResultNames(
479479
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
480480
FlatSymbolRefAttr mesh,
481481
ArrayRef<MeshAxesAttr> split_axes,
482-
ArrayRef<MeshAxis> partial_axes,
483-
mesh::ReductionKind partial_type,
484482
ArrayRef<int64_t> static_halos,
485483
ArrayRef<int64_t> static_offsets) {
486484
return build(
487485
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
488-
::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
489-
::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
490486
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
491487
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
492488
}
493489

494-
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
495-
FlatSymbolRefAttr mesh,
496-
ArrayRef<MeshAxesAttr> split_axes) {
497-
return build(
498-
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
499-
::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
500-
{}, {}, {}, {});
501-
}
502-
503490
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
504491
llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
505492
ArrayRef<int64_t> static_halos,
506493
ArrayRef<int64_t> static_offsets) {
507-
return build(
508-
b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
509-
MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
510-
::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
511-
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
512-
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
494+
return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
495+
MeshAxesArrayAttr::get(b.getContext(), split_axes),
496+
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
497+
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
498+
{});
513499
}
514500

515501
void ShardingOp::build(
@@ -522,8 +508,7 @@ void ShardingOp::build(
522508
dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
523509
dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
524510
return build(
525-
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
526-
::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
511+
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
527512
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
528513
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
529514
}
@@ -533,11 +518,6 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
533518

534519
build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
535520
MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
536-
from.getPartialAxes().empty()
537-
? DenseI16ArrayAttr()
538-
: b.getDenseI16ArrayAttr(from.getPartialAxes()),
539-
::mlir::mesh::ReductionKindAttr::get(b.getContext(),
540-
from.getPartialType()),
541521
from.getStaticShardedDimsOffsets().empty()
542522
? DenseI64ArrayAttr()
543523
: b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
@@ -566,9 +546,6 @@ LogicalResult ShardingOp::verify() {
566546
if (failed(checkMeshAxis(subAxesArray)))
567547
return failure();
568548
}
569-
if (getPartialAxes().has_value() &&
570-
failed(checkMeshAxis(getPartialAxes().value())))
571-
return failure();
572549

573550
if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
574551
return emitOpError("halo sizes and shard offsets are mutually exclusive");
@@ -710,17 +687,11 @@ void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
710687
// MeshSharding
711688
//===----------------------------------------------------------------------===//
712689

713-
bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
690+
bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const {
714691
if (getMesh() != rhs.getMesh()) {
715692
return false;
716693
}
717694

718-
if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
719-
(!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
720-
!llvm::equal(getPartialAxes(), rhs.getPartialAxes())) {
721-
return false;
722-
}
723-
724695
auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
725696
if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
726697
getSplitAxes().begin() + minSize),
@@ -768,13 +739,13 @@ bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
768739
}
769740

770741
bool MeshSharding::operator==(Value rhs) const {
771-
return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
742+
return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
772743
}
773744

774745
bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
775746

776747
bool MeshSharding::operator==(const MeshSharding &rhs) const {
777-
return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
748+
return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
778749
}
779750

780751
bool MeshSharding::operator!=(const MeshSharding &rhs) const {
@@ -787,30 +758,26 @@ MeshSharding::MeshSharding(Value rhs) {
787758
auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
788759
assert(shardingOp && "expected sharding op");
789760
auto splitAxes = shardingOp.getSplitAxes().getAxes();
790-
auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
791-
// If splitAxes and partialAxes are empty, use "empty" constructor.
792-
if (splitAxes.empty() && partialAxes.empty()) {
761+
// If splitAxes are empty, use "empty" constructor.
762+
if (splitAxes.empty()) {
793763
*this = MeshSharding(shardingOp.getMeshAttr());
794764
return;
795765
}
796-
*this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
797-
shardingOp.getPartialType().value_or(ReductionKind::Sum),
798-
shardingOp.getStaticHaloSizes(),
799-
shardingOp.getStaticShardedDimsOffsets(),
800-
SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
801-
SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
766+
*this =
767+
get(shardingOp.getMeshAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
768+
shardingOp.getStaticShardedDimsOffsets(),
769+
SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
770+
SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
802771
}
803772

804773
MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
805774
ArrayRef<MeshAxesAttr> split_axes_,
806-
ArrayRef<MeshAxis> partial_axes_,
807-
ReductionKind partial_type_,
808775
ArrayRef<int64_t> static_halo_sizes_,
809776
ArrayRef<int64_t> static_sharded_dims_offsets_,
810777
ArrayRef<Value> dynamic_halo_sizes_,
811778
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
812779
MeshSharding res(mesh_);
813-
if (split_axes_.empty() && partial_axes_.empty()) {
780+
if (split_axes_.empty()) {
814781
return res;
815782
}
816783

@@ -825,8 +792,6 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
825792
llvm::copy(src, dst.begin());
826793
};
827794

828-
clone(partial_axes_, res.partial_axes);
829-
res.partial_type = partial_type_;
830795
clone(static_halo_sizes_, res.static_halo_sizes);
831796
clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
832797
clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);

0 commit comments

Comments
 (0)