@@ -298,13 +298,12 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
298298 return type;
299299}
300300
301- void mlir::mesh::maybeInsertTargetShardingAnnotation (MeshSharding sharding,
302- OpOperand &operand,
303- OpBuilder &builder,
304- ShardOp &newShardOp) {
301+ static void maybeInsertTargetShardingAnnotationImpl (MeshSharding sharding,
302+ Value &operandValue,
303+ Operation *operandOp,
304+ OpBuilder &builder,
305+ ShardOp &newShardOp) {
305306 OpBuilder::InsertionGuard insertionGuard (builder);
306- Value operandValue = operand.get ();
307- Operation *operandOp = operand.getOwner ();
308307 builder.setInsertionPointAfterValue (operandValue);
309308 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
310309 if (shardOp && sharding == shardOp.getSharding () &&
@@ -323,9 +322,8 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
323322 builder.create <ShardOp>(operandValue.getLoc (), operandValue, shardingOp,
324323 /* annotate_for_users*/ false );
325324 }
326- IRRewriter rewriter (builder);
327- rewriter.replaceUsesWithIf (
328- operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
325+ operandValue.replaceUsesWithIf (
326+ newShardOp, [operandOp, operandValue](OpOperand &use) {
329327 return use.getOwner () == operandOp && use.get () == operandValue;
330328 });
331329
@@ -336,15 +334,20 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
336334 auto newShardOp2 = builder.create <ShardOp>(operandValue.getLoc (), newShardOp,
337335 newShardOp.getSharding (),
338336 /* annotate_for_users*/ true );
339- rewriter. replaceAllUsesExcept (newShardOp, newShardOp2, newShardOp2);
337+ newShardOp. getResult (). replaceAllUsesExcept (newShardOp2, newShardOp2);
340338}
341339
342340void mlir::mesh::maybeInsertTargetShardingAnnotation (MeshSharding sharding,
343341 OpResult result,
344342 OpBuilder &builder) {
345343 ShardOp newShardOp;
346- for (auto &use : llvm::make_early_inc_range (result.getUses ())) {
347- maybeInsertTargetShardingAnnotation (sharding, use, builder, newShardOp);
344+ SmallVector<std::pair<Value, Operation *>> uses;
345+ for (auto &use : result.getUses ()) {
346+ uses.emplace_back (use.get (), use.getOwner ());
347+ }
348+ for (auto &[operandValue, operandOp] : uses) {
349+ maybeInsertTargetShardingAnnotationImpl (sharding, operandValue, operandOp,
350+ builder, newShardOp);
348351 }
349352}
350353
0 commit comments