Skip to content

Commit d635b86

Browse files
authored
[mlir][mesh] Insert resharding during sharding propagation (llvm#84514)
If there are conflicts between the sharding annotations of some op, insert resharding. Make the Spmdization pass more forgiving to allow for more than 2 chained `mesh.shard` ops. Implement `getReductionLoopIteratorKinds` in ShardingInterface for linalg ops.
1 parent bd3f5a4 commit d635b86

File tree

12 files changed

+540
-73
lines changed

12 files changed

+540
-73
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
151151

152152
let extraClassDeclaration = [{
153153
bool operator==(::mlir::Attribute rhs) const;
154+
bool operator!=(::mlir::Attribute rhs) const;
154155
bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
156+
bool operator!=(::mlir::mesh::MeshShardingAttr rhs) const;
155157
}];
156158

157159
let genVerifyDecl = 1;

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,26 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
5151

5252
// Is the same tensor replicated on all processes.
5353
inline bool isFullReplication(MeshShardingAttr attr) {
54-
return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
54+
return attr.getPartialAxes().empty() &&
55+
llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
56+
return axes.asArrayRef().empty();
57+
});
5558
}
5659

57-
inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
58-
SymbolTableCollection &symbolTableCollection) {
60+
inline mesh::MeshOp
61+
getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
62+
SymbolTableCollection &symbolTableCollection) {
5963
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
6064
op, meshSymbol);
6165
}
6266

67+
inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
68+
SymbolTableCollection &symbolTableCollection) {
69+
mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection);
70+
assert(meshOp);
71+
return meshOp;
72+
}
73+
6374
// Get the corresponding mesh op using the standard attribute nomenclature.
6475
template <typename Op>
6576
mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
@@ -128,6 +139,17 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
128139
// `sharding` in that case must be null.
129140
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
130141

142+
// Insert shard op if there is not one that already has the same sharding.
143+
// May insert resharding if required.
144+
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
145+
OpOperand &operand,
146+
OpBuilder &builder);
147+
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
148+
OpResult result, OpBuilder &builder);
149+
void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
150+
OpOperand &operand,
151+
OpBuilder &builder);
152+
131153
} // namespace mesh
132154
} // namespace mlir
133155

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ struct ShardingOption {
3737
ShardingOption() = default;
3838
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
3939
: shardingArray(std::move(shardingArray)), mesh(mesh) {}
40+
static ShardingOption makeEmpty() {
41+
auto res = ShardingOption();
42+
res.empty = true;
43+
return res;
44+
}
4045
};
4146

4247
// This method retrieves the 'MeshShardingAttr' attribute from a given operation
@@ -56,6 +61,10 @@ defaultGetShardingOption(Operation *op,
5661
ArrayRef<MeshShardingAttr> operandShardings,
5762
ArrayRef<MeshShardingAttr> resultShardings);
5863

64+
FailureOr<SmallVector<MeshShardingAttr>>
65+
defaultGetShardingAnnotations(Operation *op,
66+
const ShardingOption &shardingOption);
67+
5968
LogicalResult
6069
defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
6170
const ShardingOption &shardingOption);

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,11 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
7575
InterfaceMethod<
7676
/*desc=*/[{
7777
Given that certain operands or results of the operation may have
78-
sharding annotations, this method leverages this information to deduce
79-
how the operation should be sharded.
78+
sharding annotations, this method leverages this information to
79+
deduce how the operation should be sharded.
80+
The passed sharding may be incomplete, this gives freedom for the
81+
op to select the most appropriate shardings for all the operands
82+
and results and the op itself.
8083
}],
8184
/*retTy=*/"FailureOr<ShardingOption>",
8285
/*methodName=*/"getShardingOption",
@@ -90,6 +93,24 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
9093
$_op.getOperation(), operandShardings, resultShardings);
9194
}]
9295
>,
96+
InterfaceMethod<
97+
/*desc=*/[{
98+
Based on a given ShardingOption, get the operand and result
99+
operations for the operands and results sharding annotations.
100+
This is what shardings the operands and results need to have in order
101+
to shard the op according to shardingOption.
102+
}],
103+
/*retTy=*/"FailureOr<SmallVector<MeshShardingAttr>>",
104+
/*methodName=*/"getShardingAnnotations",
105+
/*args=*/(ins
106+
"const ShardingOption &":$shardingOption
107+
),
108+
/*methodBody=*/"",
109+
/*defaultImplementation=*/[{
110+
return detail::defaultGetShardingAnnotations(
111+
$_op.getOperation(), shardingOption);
112+
}]
113+
>,
93114
InterfaceMethod<
94115
/*desc=*/[{
95116
Based on a given ShardingOption, this method adds `mesh.shard`

mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "llvm/ADT/SmallVector.h"
3737
#include "llvm/ADT/TypeSwitch.h"
3838
#include <iterator>
39+
#include <numeric>
3940
#include <optional>
4041
#include <utility>
4142

@@ -278,6 +279,20 @@ struct StructuredOpShardingInterface
278279
return res;
279280
}
280281

282+
SmallVector<ReductionKind>
283+
getReductionLoopIteratorKinds(Operation *op) const {
284+
LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
285+
SmallVector<utils::IteratorType> iteratorTypes =
286+
linalgOp.getIteratorTypesArray();
287+
unsigned reductionItersCount = std::accumulate(
288+
iteratorTypes.begin(), iteratorTypes.end(), 0,
289+
[](unsigned count, utils::IteratorType iter) {
290+
return count + (iter == utils::IteratorType::reduction);
291+
});
292+
mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
293+
return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
294+
}
295+
281296
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
282297
ArrayRef<MeshShardingAttr> operandShardings,
283298
ArrayRef<MeshShardingAttr> resultShardings,

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/Location.h"
2121
#include "mlir/IR/PatternMatch.h"
2222
#include "mlir/IR/TypeUtilities.h"
23+
#include "mlir/IR/Value.h"
2324
#include "mlir/Interfaces/ViewLikeInterface.h"
2425
#include "mlir/Support/LLVM.h"
2526
#include "mlir/Support/LogicalResult.h"
@@ -28,6 +29,7 @@
2829
#include "llvm/ADT/SmallSet.h"
2930
#include "llvm/ADT/SmallVector.h"
3031
#include "llvm/ADT/TypeSwitch.h"
32+
#include "llvm/Support/Casting.h"
3133
#include <algorithm>
3234
#include <functional>
3335
#include <iterator>
@@ -99,7 +101,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
99101
static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
100102
FlatSymbolRefAttr meshSymbol,
101103
SymbolTableCollection &symbolTable) {
102-
mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
104+
mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
103105
if (!mesh) {
104106
return op->emitError() << "Undefined required mesh symbol \""
105107
<< meshSymbol.getValue() << "\".";
@@ -178,6 +180,88 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
178180
return type;
179181
}
180182

183+
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
184+
OpOperand &operand,
185+
OpBuilder &builder) {
186+
OpBuilder::InsertionGuard insertionGuard(builder);
187+
Value operandValue = operand.get();
188+
Operation *operandOp = operand.getOwner();
189+
builder.setInsertionPointAfterValue(operandValue);
190+
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
191+
if (shardOp && shardOp.getShard() == sharding &&
192+
!shardOp.getAnnotateForUsers()) {
193+
// No need for anything the correct sharding is already set.
194+
return;
195+
}
196+
197+
auto newShardOp =
198+
builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
199+
/*annotate_for_users*/ false);
200+
IRRewriter rewriter(builder);
201+
rewriter.replaceUsesWithIf(
202+
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
203+
return use.getOwner() == operandOp && use.get() == operandValue;
204+
});
205+
206+
if (!shardOp || shardOp.getAnnotateForUsers()) {
207+
return;
208+
}
209+
210+
auto newShardOp2 = builder.create<ShardOp>(
211+
operandValue.getLoc(), newShardOp, sharding, /*annotate_for_users*/ true);
212+
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
213+
}
214+
215+
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
216+
OpResult result,
217+
OpBuilder &builder) {
218+
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
219+
maybeInsertTargetShardingAnnotation(sharding, use, builder);
220+
}
221+
}
222+
223+
void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
224+
OpOperand &operand,
225+
OpBuilder &builder) {
226+
OpBuilder::InsertionGuard insertionGuard(builder);
227+
Value operandValue = operand.get();
228+
Operation *operandOp = operand.getOwner();
229+
Operation *operandSrcOp = operandValue.getDefiningOp();
230+
bool isBlockArg = !operandSrcOp;
231+
ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
232+
233+
if (shardOp && shardOp.getShard() == sharding &&
234+
shardOp.getAnnotateForUsers()) {
235+
// No need for anything the correct sharding is already set.
236+
return;
237+
}
238+
239+
builder.setInsertionPoint(operandOp);
240+
auto newShardOp =
241+
builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
242+
/*annotate_for_users*/ true);
243+
IRRewriter rewriter(builder);
244+
rewriter.replaceUsesWithIf(
245+
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
246+
return use.getOwner() == operandOp && use.get() == operandValue;
247+
});
248+
249+
if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
250+
// No need for resharding.
251+
return;
252+
}
253+
254+
builder.setInsertionPoint(newShardOp);
255+
auto newPreceedingShardOp =
256+
builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
257+
/*annotate_for_users*/ false);
258+
rewriter.replaceUsesWithIf(newShardOp.getOperand(), newPreceedingShardOp,
259+
[&newShardOp](OpOperand &use) {
260+
return use.getOwner() ==
261+
newShardOp.getOperation();
262+
});
263+
}
264+
181265
//===----------------------------------------------------------------------===//
182266
// mesh.mesh op
183267
//===----------------------------------------------------------------------===//
@@ -286,6 +370,10 @@ bool MeshShardingAttr::operator==(Attribute rhs) const {
286370
return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
287371
}
288372

373+
bool MeshShardingAttr::operator!=(Attribute rhs) const {
374+
return !(*this == rhs);
375+
}
376+
289377
bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
290378
if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
291379
return false;
@@ -311,6 +399,10 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
311399
std::mem_fn(&MeshAxesAttr::empty));
312400
}
313401

402+
bool MeshShardingAttr::operator!=(MeshShardingAttr rhs) const {
403+
return !(*this == rhs);
404+
}
405+
314406
//===----------------------------------------------------------------------===//
315407
// mesh.shard op
316408
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)