|
9 | 9 | #include "mlir/IR/PatternMatch.h" |
10 | 10 | #include "mlir/IR/Iterators.h" |
11 | 11 | #include "mlir/IR/RegionKindInterface.h" |
| 12 | +#include "llvm/ADT/ScopeExit.h" |
12 | 13 | #include "llvm/ADT/SmallPtrSet.h" |
13 | 14 |
|
14 | 15 | using namespace mlir; |
@@ -348,14 +349,29 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest, |
348 | 349 | /// Split the operations starting at "before" (inclusive) out of the given |
349 | 350 | /// block into a new block, and return it. |
350 | 351 | Block *RewriterBase::splitBlock(Block *block, Block::iterator before) { |
| 352 | + Block *newBlock; |
| 353 | + |
| 354 | + // If the current insertion point is at or after the split point, adjust the |
| 355 | + // insertion point to the new block. |
| 356 | + bool moveIpToNewBlock = getBlock() == block && |
| 357 | + !block->isBeforeInBlock(getInsertionPoint(), before); |
| 358 | + auto adjustInsertionPoint = llvm::make_scope_exit([&]() { |
| 359 | + if (getInsertionPoint() == block->end()) { |
| 360 | + // If the insertion point is at the end of the block, move it to the end |
| 361 | + // of the new block. |
| 362 | + setInsertionPointToEnd(newBlock); |
| 363 | + } else if (moveIpToNewBlock) { |
| 364 | + setInsertionPoint(newBlock, getInsertionPoint()); |
| 365 | + } |
| 366 | + }); |
| 367 | + |
351 | 368 | // Fast path: If no listener is attached, split the block directly. |
352 | 369 | if (!listener) |
353 | | - return block->splitBlock(before); |
| 370 | + return newBlock = block->splitBlock(before); |
354 | 371 |
|
355 | 372 | // `createBlock` sets the insertion point at the beginning of the new block. |
356 | 373 | InsertionGuard g(*this); |
357 | | - Block *newBlock = |
358 | | - createBlock(block->getParent(), std::next(block->getIterator())); |
| 374 | + newBlock = createBlock(block->getParent(), std::next(block->getIterator())); |
359 | 375 |
|
360 | 376 | // If `before` points to end of the block, no ops should be moved. |
361 | 377 | if (before == block->end()) |
@@ -413,6 +429,12 @@ void RewriterBase::moveOpBefore(Operation *op, Block *block, |
413 | 429 | Block *currentBlock = op->getBlock(); |
414 | 430 | Block::iterator nextIterator = std::next(op->getIterator()); |
415 | 431 | op->moveBefore(block, iterator); |
| 432 | + |
| 433 | + // If the current insertion point is before the moved operation, we may have |
| 434 | + // to adjust the insertion block. |
| 435 | + if (getInsertionPoint() == op->getIterator()) |
| 436 | + setInsertionPoint(block, op->getIterator()); |
| 437 | + |
416 | 438 | if (listener) |
417 | 439 | listener->notifyOperationInserted( |
418 | 440 | op, /*previous=*/InsertPoint(currentBlock, nextIterator)); |
|
0 commit comments