@@ -441,8 +441,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
441441
442442 void commit () override ;
443443
444- void cleanup () override ;
445-
446444 void rollback () override ;
447445
448446private:
@@ -791,24 +789,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
791789 // / block is returned containing the new arguments. Returns `block` if it did
792790 // / not require conversion.
793791 FailureOr<Block *> convertBlockSignature (
794- Block *block, const TypeConverter *converter,
792+ ConversionPatternRewriter &rewriter, Block *block,
793+ const TypeConverter *converter,
795794 TypeConverter::SignatureConversion *conversion = nullptr );
796795
797796 // / Convert the types of non-entry block arguments within the given region.
798797 LogicalResult convertNonEntryRegionTypes (
799- Region *region, const TypeConverter &converter,
798+ ConversionPatternRewriter &rewriter, Region *region,
799+ const TypeConverter &converter,
800800 ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
801801
802802 // / Apply a signature conversion on the given region, using `converter` for
803803 // / materializations if not null.
804804 Block *
805- applySignatureConversion (Region *region,
805+ applySignatureConversion (ConversionPatternRewriter &rewriter, Region *region,
806806 TypeConverter::SignatureConversion &conversion,
807807 const TypeConverter *converter);
808808
809809 // / Convert the types of block arguments within the given region.
810810 FailureOr<Block *>
811- convertRegionTypes (Region *region, const TypeConverter &converter,
811+ convertRegionTypes (ConversionPatternRewriter &rewriter, Region *region,
812+ const TypeConverter &converter,
812813 TypeConverter::SignatureConversion *entryConversion);
813814
814815 // / Apply the given signature conversion on the given block. The new block
@@ -818,7 +819,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
818819 // / translate between the origin argument types and those specified in the
819820 // / signature conversion.
820821 Block *applySignatureConversion (
821- Block *block, const TypeConverter *converter,
822+ ConversionPatternRewriter &rewriter, Block *block,
823+ const TypeConverter *converter,
822824 TypeConverter::SignatureConversion &signatureConversion);
823825
824826 // ===--------------------------------------------------------------------===//
@@ -991,24 +993,8 @@ void BlockTypeConversionRewrite::commit() {
991993 }
992994}
993995
994- void BlockTypeConversionRewrite::cleanup () {
995- assert (origBlock->empty () && " expected empty block" );
996- origBlock->dropAllDefinedValueUses ();
997- delete origBlock;
998- origBlock = nullptr ;
999- }
1000-
1001996void BlockTypeConversionRewrite::rollback () {
1002- // Drop all uses of the new block arguments and replace uses of the new block.
1003- for (int i = block->getNumArguments () - 1 ; i >= 0 ; --i)
1004- block->getArgument (i).dropAllUses ();
1005997 block->replaceAllUsesWith (origBlock);
1006-
1007- // Move the operations back the original block, move the original block back
1008- // into its original location and the delete the new block.
1009- origBlock->getOperations ().splice (origBlock->end (), block->getOperations ());
1010- block->getParent ()->getBlocks ().insert (Region::iterator (block), origBlock);
1011- eraseBlock (block);
1012998}
1013999
10141000LogicalResult BlockTypeConversionRewrite::materializeLiveConversions (
@@ -1224,10 +1210,11 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
12241210// Type Conversion
12251211
12261212FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature (
1227- Block *block, const TypeConverter *converter,
1213+ ConversionPatternRewriter &rewriter, Block *block,
1214+ const TypeConverter *converter,
12281215 TypeConverter::SignatureConversion *conversion) {
12291216 if (conversion)
1230- return applySignatureConversion (block, converter, *conversion);
1217+ return applySignatureConversion (rewriter, block, converter, *conversion);
12311218
12321219 // If a converter wasn't provided, and the block wasn't already converted,
12331220 // there is nothing we can do.
@@ -1236,35 +1223,39 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
12361223
12371224 // Try to convert the signature for the block with the provided converter.
12381225 if (auto conversion = converter->convertBlockSignature (block))
1239- return applySignatureConversion (block, converter, *conversion);
1226+ return applySignatureConversion (rewriter, block, converter, *conversion);
12401227 return failure ();
12411228}
12421229
12431230Block *ConversionPatternRewriterImpl::applySignatureConversion (
1244- Region *region, TypeConverter::SignatureConversion &conversion,
1231+ ConversionPatternRewriter &rewriter, Region *region,
1232+ TypeConverter::SignatureConversion &conversion,
12451233 const TypeConverter *converter) {
12461234 if (!region->empty ())
1247- return *convertBlockSignature (®ion->front (), converter, &conversion);
1235+ return *convertBlockSignature (rewriter, ®ion->front (), converter,
1236+ &conversion);
12481237 return nullptr ;
12491238}
12501239
12511240FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes (
1252- Region *region, const TypeConverter &converter,
1241+ ConversionPatternRewriter &rewriter, Region *region,
1242+ const TypeConverter &converter,
12531243 TypeConverter::SignatureConversion *entryConversion) {
12541244 regionToConverter[region] = &converter;
12551245 if (region->empty ())
12561246 return nullptr ;
12571247
1258- if (failed (convertNonEntryRegionTypes (region, converter)))
1248+ if (failed (convertNonEntryRegionTypes (rewriter, region, converter)))
12591249 return failure ();
12601250
1261- FailureOr<Block *> newEntry =
1262- convertBlockSignature ( ®ion->front (), &converter, entryConversion);
1251+ FailureOr<Block *> newEntry = convertBlockSignature (
1252+ rewriter, ®ion->front (), &converter, entryConversion);
12631253 return newEntry;
12641254}
12651255
12661256LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes (
1267- Region *region, const TypeConverter &converter,
1257+ ConversionPatternRewriter &rewriter, Region *region,
1258+ const TypeConverter &converter,
12681259 ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
12691260 regionToConverter[region] = &converter;
12701261 if (region->empty ())
@@ -1285,16 +1276,18 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
12851276 : const_cast <TypeConverter::SignatureConversion *>(
12861277 &blockConversions[blockIdx++]);
12871278
1288- if (failed (convertBlockSignature (&block, &converter, blockConversion)))
1279+ if (failed (convertBlockSignature (rewriter, &block, &converter,
1280+ blockConversion)))
12891281 return failure ();
12901282 }
12911283 return success ();
12921284}
12931285
12941286Block *ConversionPatternRewriterImpl::applySignatureConversion (
1295- Block *block, const TypeConverter *converter,
1287+ ConversionPatternRewriter &rewriter, Block *block,
1288+ const TypeConverter *converter,
12961289 TypeConverter::SignatureConversion &signatureConversion) {
1297- MLIRContext *ctx = eraseRewriter .getContext ();
1290+ MLIRContext *ctx = rewriter .getContext ();
12981291
12991292 // If no arguments are being changed or added, there is nothing to do.
13001293 unsigned origArgCount = block->getNumArguments ();
@@ -1304,11 +1297,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13041297
13051298 // Split the block at the beginning to get a new block to use for the updated
13061299 // signature.
1307- Block *newBlock = block-> splitBlock (block->begin ());
1300+ Block *newBlock = rewriter. splitBlock (block, block->begin ());
13081301 block->replaceAllUsesWith (newBlock);
1309- // Unlink the block, but do not erase it yet, so that the change can be rolled
1310- // back.
1311- block->getParent ()->getBlocks ().remove (block);
13121302
13131303 // Map all new arguments to the location of the argument they originate from.
13141304 SmallVector<Location> newLocs (convertedTypes.size (),
@@ -1384,6 +1374,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13841374
13851375 appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
13861376 converter);
1377+
1378+ // Erase the old block. (It is just unlinked for now and will be erased during
1379+ // cleanup.)
1380+ rewriter.eraseBlock (block);
1381+
13871382 return newBlock;
13881383}
13891384
@@ -1592,7 +1587,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
15921587 assert (!impl->wasOpReplaced (region->getParentOp ()) &&
15931588 " attempting to apply a signature conversion to a block within a "
15941589 " replaced/erased op" );
1595- return impl->applySignatureConversion (region, conversion, converter);
1590+ return impl->applySignatureConversion (* this , region, conversion, converter);
15961591}
15971592
15981593FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes (
@@ -1601,7 +1596,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
16011596 assert (!impl->wasOpReplaced (region->getParentOp ()) &&
16021597 " attempting to apply a signature conversion to a block within a "
16031598 " replaced/erased op" );
1604- return impl->convertRegionTypes (region, converter, entryConversion);
1599+ return impl->convertRegionTypes (* this , region, converter, entryConversion);
16051600}
16061601
16071602LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes (
@@ -1610,7 +1605,8 @@ LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
16101605 assert (!impl->wasOpReplaced (region->getParentOp ()) &&
16111606 " attempting to apply a signature conversion to a block within a "
16121607 " replaced/erased op" );
1613- return impl->convertNonEntryRegionTypes (region, converter, blockConversions);
1608+ return impl->convertNonEntryRegionTypes (*this , region, converter,
1609+ blockConversions);
16141610}
16151611
16161612void ConversionPatternRewriter::replaceUsesOfBlockArgument (BlockArgument from,
@@ -2104,7 +2100,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
21042100 // If the region of the block has a type converter, try to convert the block
21052101 // directly.
21062102 if (auto *converter = impl.regionToConverter .lookup (block->getParent ())) {
2107- if (failed (impl.convertBlockSignature (block, converter))) {
2103+ if (failed (impl.convertBlockSignature (rewriter, block, converter))) {
21082104 LLVM_DEBUG (logFailure (impl.logger , " failed to convert types of moved "
21092105 " block" ));
21102106 return failure ();
0 commit comments