Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,26 @@ struct RewriterState {
// IR rewrites
//===----------------------------------------------------------------------===//

static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);

/// Notify the listener that the given block and its contents are being erased.
static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
for (Operation &op : b)
notifyIRErased(listener, op);
listener->notifyBlockErased(&b);
}

/// Notify the listener that the given operation and its contents are being
/// erased.
static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
for (Region &r : op.getRegions()) {
for (Block &b : r) {
notifyIRErased(listener, b);
}
}
listener->notifyOperationErased(&op);
}

/// An IR rewrite that can be committed (upon success) or rolled back (upon
/// failure).
///
Expand Down Expand Up @@ -422,17 +442,20 @@ class EraseBlockRewrite : public BlockRewrite {
}

void commit(RewriterBase &rewriter) override {
// Erase the block.
assert(block && "expected block");
assert(block->empty() && "expected empty block");

// Notify the listener that the block is about to be erased.
// Notify the listener that the block and its contents are being erased.
if (auto *listener =
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
listener->notifyBlockErased(block);
notifyIRErased(listener, *block);
}

void cleanup(RewriterBase &rewriter) override {
// Erase the contents of the block.
for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
rewriter.eraseOp(&op);
assert(block->empty() && "expected empty block");

// Erase the block.
block->dropAllDefinedValueUses();
delete block;
Expand Down Expand Up @@ -1147,12 +1170,9 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
if (getConfig().unlegalizedOps)
getConfig().unlegalizedOps->erase(op);

// Notify the listener that the operation (and its nested operations) was
// erased.
if (listener) {
op->walk<WalkOrder::PostOrder>(
[&](Operation *op) { listener->notifyOperationErased(op); });
}
// Notify the listener that the operation and its contents are being erased.
if (listener)
notifyIRErased(listener, *op);

// Do not erase the operation yet. It may still be referenced in `mapping`.
// Just unlink it for now and erase it during cleanup.
Expand Down Expand Up @@ -1605,13 +1625,18 @@ void ConversionPatternRewriterImpl::replaceOp(
}

void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
assert(!wasOpReplaced(block->getParentOp()) &&
"attempting to erase a block within a replaced/erased op");
appendRewrite<EraseBlockRewrite>(block);

// Unlink the block from its parent region. The block is kept in the rewrite
// object and will be actually destroyed when rewrites are applied. This
// allows us to keep the operations in the block live and undo the removal by
// re-inserting the block.
block->getParent()->getBlocks().remove(block);

// Mark all nested ops as erased.
block->walk([&](Operation *op) { replacedOps.insert(op); });
}

void ConversionPatternRewriterImpl::notifyBlockInserted(
Expand Down Expand Up @@ -1709,13 +1734,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
}

void ConversionPatternRewriter::eraseBlock(Block *block) {
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to erase a block within a replaced/erased op");

// Mark all ops for erasure.
for (Operation &op : *block)
eraseOp(&op);

impl->eraseBlock(block);
}

Expand Down
18 changes: 16 additions & 2 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -461,12 +461,26 @@ func.func @convert_detached_signature() {

// -----

// CHECK: notifyOperationReplaced: test.erase_op
// CHECK: notifyOperationErased: test.dummy_op_lvl_2
// CHECK: notifyBlockErased
// CHECK: notifyOperationErased: test.dummy_op_lvl_1
// CHECK: notifyBlockErased
// CHECK: notifyOperationErased: test.erase_op
// CHECK: notifyOperationInserted: test.valid, was unlinked
// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid
// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid

// CHECK-LABEL: func @circular_mapping()
// CHECK-NEXT: "test.valid"() : () -> ()
func.func @circular_mapping() {
// Regression test that used to crash due to circular
// unrealized_conversion_cast ops.
%0 = "test.erase_op"() : () -> (i64)
// unrealized_conversion_cast ops.
%0 = "test.erase_op"() ({
"test.dummy_op_lvl_1"() ({
"test.dummy_op_lvl_2"() : () -> ()
}) : () -> ()
}): () -> (i64)
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
}

Expand Down