Skip to content

Commit 3729a86

Browse files
committed
maybeInsertTargetShardingAnnotation accepting reference only
1 parent 99cf24e commit 3729a86

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,10 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
203203
// Insert shard op if there is not one that already has the same sharding.
204204
// Use newShardOp if it is not null. Otherwise create a new one.
205205
// May insert resharding if required.
206-
// Return the target ShardOP (new or existing).
207-
ShardOp maybeInsertTargetShardingAnnotation(MeshSharding sharding,
208-
OpOperand &operand,
209-
OpBuilder &builder,
210-
ShardOp newShardOp);
206+
// Potentially updates newShardOp.
207+
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
208+
OpOperand &operand, OpBuilder &builder,
209+
ShardOp &newShardOp);
211210
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
212211
OpBuilder &builder);
213212
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
275275
return type;
276276
}
277277

278-
ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
279-
OpOperand &operand,
280-
OpBuilder &builder,
281-
ShardOp newShardOp) {
278+
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
279+
OpOperand &operand,
280+
OpBuilder &builder,
281+
ShardOp &newShardOp) {
282282
OpBuilder::InsertionGuard insertionGuard(builder);
283283
Value operandValue = operand.get();
284284
Operation *operandOp = operand.getOwner();
@@ -287,7 +287,10 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
287287
if (shardOp && sharding == shardOp.getSharding() &&
288288
!shardOp.getAnnotateForUsers()) {
289289
// No need for anything if the correct sharding is already set.
290-
return newShardOp ? newShardOp : shardOp;
290+
if (!newShardOp) {
291+
newShardOp = shardOp;
292+
}
293+
return;
291294
}
292295

293296
if (!newShardOp) {
@@ -304,23 +307,22 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
304307
});
305308

306309
if (!shardOp || shardOp.getAnnotateForUsers()) {
307-
return newShardOp;
310+
return;
308311
}
309312

310313
auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
311314
newShardOp.getSharding(),
312315
/*annotate_for_users*/ true);
313316
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
314-
return newShardOp;
317+
return;
315318
}
316319

317320
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
318321
OpResult result,
319322
OpBuilder &builder) {
320323
ShardOp newShardOp;
321324
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
322-
newShardOp =
323-
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
325+
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
324326
}
325327
}
326328

0 commit comments

Comments
 (0)