diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 3b669f51a615f..ff48647f43305 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -896,7 +896,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { bool wasOpReplaced(Operation *op) const; //===--------------------------------------------------------------------===// - // Type Conversion + // IR Rewrites / Type Conversion //===--------------------------------------------------------------------===// /// Convert the types of block arguments within the given region. @@ -916,6 +916,22 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion); + /// Replace the results of the given operation with the given values and + /// erase the operation. + /// + /// There can be multiple replacement values for each result (1:N + /// replacement). If the replacement values are empty, the respective result + /// is dropped and a source materialization is built if the result still has + /// uses. + void replaceOp(Operation *op, SmallVector> &&newValues); + + /// Erase the given block and its contents. + void eraseBlock(Block *block); + + /// Inline the source block into the destination block before the given + /// iterator. + void inlineBlockBefore(Block *source, Block *dest, Block::iterator before); + //===--------------------------------------------------------------------===// // Materializations //===--------------------------------------------------------------------===// @@ -952,21 +968,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override; - /// Notifies that an op is about to be replaced with the given values. - void notifyOpReplaced(Operation *op, - SmallVector> &&newValues); - - /// Notifies that a block is about to be erased. - void notifyBlockIsBeingErased(Block *block); - /// Notifies that a block was inserted. void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override; - /// Notifies that a block is being inlined into another block. - void notifyBlockBeingInlined(Block *block, Block *srcBlock, - Block::iterator before); - /// Notifies that a pattern match failed for the given reason. void notifyMatchFailure(Location loc, @@ -1548,7 +1553,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( appendRewrite(op, previous.getBlock(), prevOp); } -void ConversionPatternRewriterImpl::notifyOpReplaced( +void ConversionPatternRewriterImpl::replaceOp( Operation *op, SmallVector> &&newValues) { assert(newValues.size() == op->getNumResults()); assert(!ignoredOps.contains(op) && "operation was already replaced"); @@ -1599,8 +1604,14 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( op->walk([&](Operation *op) { replacedOps.insert(op); }); } -void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { +void ConversionPatternRewriterImpl::eraseBlock(Block *block) { appendRewrite(block); + + // Unlink the block from its parent region. The block is kept in the rewrite + // object and will be actually destroyed when rewrites are applied. This + // allows us to keep the operations in the block live and undo the removal by + // re-inserting the block. + block->getParent()->getBlocks().remove(block); } void ConversionPatternRewriterImpl::notifyBlockInserted( @@ -1628,9 +1639,10 @@ void ConversionPatternRewriterImpl::notifyBlockInserted( appendRewrite(block, previous, prevBlock); } -void ConversionPatternRewriterImpl::notifyBlockBeingInlined( - Block *block, Block *srcBlock, Block::iterator before) { - appendRewrite(block, srcBlock, before); +void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source, + Block *dest, + Block::iterator before) { + appendRewrite(dest, source, before); } void ConversionPatternRewriterImpl::notifyMatchFailure( @@ -1673,7 +1685,7 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { llvm::map_to_vector(newValues, [](Value v) -> SmallVector { return v ? SmallVector{v} : SmallVector(); }); - impl->notifyOpReplaced(op, std::move(newVals)); + impl->replaceOp(op, std::move(newVals)); } void ConversionPatternRewriter::replaceOpWithMultiple( @@ -1684,7 +1696,7 @@ void ConversionPatternRewriter::replaceOpWithMultiple( impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); - impl->notifyOpReplaced(op, std::move(newValues)); + impl->replaceOp(op, std::move(newValues)); } void ConversionPatternRewriter::eraseOp(Operation *op) { @@ -1693,7 +1705,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); SmallVector> nullRepls(op->getNumResults(), {}); - impl->notifyOpReplaced(op, std::move(nullRepls)); + impl->replaceOp(op, std::move(nullRepls)); } void ConversionPatternRewriter::eraseBlock(Block *block) { @@ -1704,12 +1716,7 @@ void ConversionPatternRewriter::eraseBlock(Block *block) { for (Operation &op : *block) eraseOp(&op); - // Unlink the block from its parent region. The block is kept in the rewrite - // object and will be actually destroyed when rewrites are applied. This - // allows us to keep the operations in the block live and undo the removal by - // re-inserting the block. - impl->notifyBlockIsBeingErased(block); - block->getParent()->getBlocks().remove(block); + impl->eraseBlock(block); } Block *ConversionPatternRewriter::applySignatureConversion( @@ -1797,7 +1804,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, bool fastPath = !impl->config.listener; if (fastPath) - impl->notifyBlockBeingInlined(dest, source, before); + impl->inlineBlockBefore(source, dest, before); // Replace all uses of block arguments. for (auto it : llvm::zip(source->getArguments(), argValues))