@@ -275,9 +275,10 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
275275 return type;
276276}
277277
278- void mlir::mesh::maybeInsertTargetShardingAnnotation (MeshSharding sharding,
279- OpOperand &operand,
280- OpBuilder &builder) {
278+ ShardOp mlir::mesh::maybeInsertTargetShardingAnnotation (MeshSharding sharding,
279+ OpOperand &operand,
280+ OpBuilder &builder,
281+ ShardOp newShardOp) {
281282 OpBuilder::InsertionGuard insertionGuard (builder);
282283 Value operandValue = operand.get ();
283284 Operation *operandOp = operand.getOwner ();
@@ -286,34 +287,40 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
286287 if (shardOp && sharding == shardOp.getSharding () &&
287288 !shardOp.getAnnotateForUsers ()) {
288289 // No need for anything the correct sharding is already set.
289- return ;
290+ return newShardOp ? newShardOp : shardOp ;
290291 }
291292
292- auto shardingOp = builder.create <ShardingOp>(operandValue.getLoc (), sharding);
293- auto newShardOp =
294- builder.create <ShardOp>(operandValue.getLoc (), operandValue, shardingOp,
295- /* annotate_for_users*/ false );
293+ if (!newShardOp) {
294+ auto shardingOp =
295+ builder.create <ShardingOp>(operandValue.getLoc (), sharding);
296+ newShardOp =
297+ builder.create <ShardOp>(operandValue.getLoc (), operandValue, shardingOp,
298+ /* annotate_for_users*/ false );
299+ }
296300 IRRewriter rewriter (builder);
297301 rewriter.replaceUsesWithIf (
298302 operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
299303 return use.getOwner () == operandOp && use.get () == operandValue;
300304 });
301305
302306 if (!shardOp || shardOp.getAnnotateForUsers ()) {
303- return ;
307+ return newShardOp ;
304308 }
305309
306- auto newShardOp2 =
307- builder. create <ShardOp>(operandValue. getLoc (), newShardOp, shardingOp ,
308- /* annotate_for_users*/ true );
310+ auto newShardOp2 = builder. create <ShardOp>(operandValue. getLoc (), newShardOp,
311+ newShardOp. getSharding () ,
312+ /* annotate_for_users*/ true );
309313 rewriter.replaceAllUsesExcept (newShardOp, newShardOp2, newShardOp2);
314+ return newShardOp;
310315}
311316
312317void mlir::mesh::maybeInsertTargetShardingAnnotation (MeshSharding sharding,
313318 OpResult result,
314319 OpBuilder &builder) {
320+ ShardOp newShardOp;
315321 for (auto &use : llvm::make_early_inc_range (result.getUses ())) {
316- maybeInsertTargetShardingAnnotation (sharding, use, builder);
322+ newShardOp =
323+ maybeInsertTargetShardingAnnotation (sharding, use, builder, newShardOp);
317324 }
318325}
319326
0 commit comments