Skip to content

Commit 6af6d65

Browse files
committed
fixing invalid modification fo use-range while iterating
1 parent 9bdb2fc commit 6af6d65

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,6 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
206206
// Use newShardOp if it is not null. Otherwise create a new one.
207207
// May insert resharding if required.
208208
// Potentially updates newShardOp.
209-
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
210-
OpOperand &operand, OpBuilder &builder,
211-
ShardOp &newShardOp);
212209
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
213210
OpBuilder &builder);
214211
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -298,13 +298,12 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
298298
return type;
299299
}
300300

301-
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
302-
OpOperand &operand,
303-
OpBuilder &builder,
304-
ShardOp &newShardOp) {
301+
static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
302+
Value &operandValue,
303+
Operation *operandOp,
304+
OpBuilder &builder,
305+
ShardOp &newShardOp) {
305306
OpBuilder::InsertionGuard insertionGuard(builder);
306-
Value operandValue = operand.get();
307-
Operation *operandOp = operand.getOwner();
308307
builder.setInsertionPointAfterValue(operandValue);
309308
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
310309
if (shardOp && sharding == shardOp.getSharding() &&
@@ -323,9 +322,8 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
323322
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
324323
/*annotate_for_users*/ false);
325324
}
326-
IRRewriter rewriter(builder);
327-
rewriter.replaceUsesWithIf(
328-
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
325+
operandValue.replaceUsesWithIf(
326+
newShardOp, [operandOp, operandValue](OpOperand &use) {
329327
return use.getOwner() == operandOp && use.get() == operandValue;
330328
});
331329

@@ -336,15 +334,20 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
336334
auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
337335
newShardOp.getSharding(),
338336
/*annotate_for_users*/ true);
339-
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
337+
newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
340338
}
341339

342340
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
343341
OpResult result,
344342
OpBuilder &builder) {
345343
ShardOp newShardOp;
346-
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
347-
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
344+
SmallVector<std::pair<Value, Operation *>> uses;
345+
for (auto &use : result.getUses()) {
346+
uses.emplace_back(use.get(), use.getOwner());
347+
}
348+
for (auto &[operandValue, operandOp] : uses) {
349+
maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
350+
builder, newShardOp);
348351
}
349352
}
350353

0 commit comments

Comments
 (0)