@@ -275,10 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
275275 return type;
276276}
277277
278- ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation (MeshSharding sharding,
279- OpOperand &operand,
280- OpBuilder &builder,
281- ShardOp newShardOp) {
278+ void mlir::mesh::maybeInsertTargetShardingAnnotation (MeshSharding sharding,
279+ OpOperand &operand,
280+ OpBuilder &builder,
281+ ShardOp & newShardOp) {
282282 OpBuilder::InsertionGuard insertionGuard (builder);
283283 Value operandValue = operand.get ();
284284 Operation *operandOp = operand.getOwner ();
@@ -287,7 +287,10 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
287287 if (shardOp && sharding == shardOp.getSharding () &&
288288 !shardOp.getAnnotateForUsers ()) {
289289 // No need for anything if the correct sharding is already set.
290- return newShardOp ? newShardOp : shardOp;
290+ if (!newShardOp) {
291+ newShardOp = shardOp;
292+ }
293+ return ;
291294 }
292295
293296 if (!newShardOp) {
@@ -304,23 +307,22 @@ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
304307 });
305308
306309 if (!shardOp || shardOp.getAnnotateForUsers ()) {
307- return newShardOp ;
310+ return ;
308311 }
309312
310313 auto newShardOp2 = builder.create <ShardOp>(operandValue.getLoc (), newShardOp,
311314 newShardOp.getSharding (),
312315 /* annotate_for_users*/ true );
313316 rewriter.replaceAllUsesExcept (newShardOp, newShardOp2, newShardOp2);
314- return newShardOp ;
317+ return ;
315318}
316319
317320void mlir::mesh::maybeInsertTargetShardingAnnotation (MeshSharding sharding,
318321 OpResult result,
319322 OpBuilder &builder) {
320323 ShardOp newShardOp;
321324 for (auto &use : llvm::make_early_inc_range (result.getUses ())) {
322- newShardOp =
323- maybeInsertTargetShardingAnnotation (sharding, use, builder, newShardOp);
325+ maybeInsertTargetShardingAnnotation (sharding, use, builder, newShardOp);
324326 }
325327}
326328
0 commit comments