Skip to content

Commit 1d86186

Browse files
committed
sharding propagation: add only one shardop for each result
1 parent c1324d3 commit 1d86186

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,13 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
201201
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
202202

203203
// Insert shard op if there is not one that already has the same sharding.
204+
// Use newShardOp if it is not null. Otherwise create a new one.
204205
// May insert resharding if required.
205-
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
206-
OpOperand &operand,
207-
OpBuilder &builder);
206+
// Return the target ShardOP (new or existing).
207+
ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding,
208+
OpOperand &operand,
209+
OpBuilder &builder,
210+
ShardOp newShardOp);
208211
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
209212
OpBuilder &builder);
210213
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
275275
return type;
276276
}
277277

278-
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
279-
OpOperand &operand,
280-
OpBuilder &builder) {
278+
ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
279+
OpOperand &operand,
280+
OpBuilder &builder,
281+
ShardOp newShardOp) {
281282
OpBuilder::InsertionGuard insertionGuard(builder);
282283
Value operandValue = operand.get();
283284
Operation *operandOp = operand.getOwner();
@@ -286,34 +287,40 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
286287
if (shardOp && sharding == shardOp.getSharding() &&
287288
!shardOp.getAnnotateForUsers()) {
288289
// No need for anything the correct sharding is already set.
289-
return;
290+
return newShardOp ? newShardOp : shardOp;
290291
}
291292

292-
auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
293-
auto newShardOp =
294-
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
295-
/*annotate_for_users*/ false);
293+
if (!newShardOp) {
294+
auto shardingOp =
295+
builder.create<ShardingOp>(operandValue.getLoc(), sharding);
296+
newShardOp =
297+
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
298+
/*annotate_for_users*/ false);
299+
}
296300
IRRewriter rewriter(builder);
297301
rewriter.replaceUsesWithIf(
298302
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
299303
return use.getOwner() == operandOp && use.get() == operandValue;
300304
});
301305

302306
if (!shardOp || shardOp.getAnnotateForUsers()) {
303-
return;
307+
return newShardOp;
304308
}
305309

306-
auto newShardOp2 =
307-
builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
308-
/*annotate_for_users*/ true);
310+
auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
311+
newShardOp.getSharding(),
312+
/*annotate_for_users*/ true);
309313
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
314+
return newShardOp;
310315
}
311316

312317
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
313318
OpResult result,
314319
OpBuilder &builder) {
320+
ShardOp newShardOp;
315321
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
316-
maybeInsertTargetShardingAnnotation(sharding, use, builder);
322+
newShardOp =
323+
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
317324
}
318325
}
319326

0 commit comments

Comments
 (0)