@@ -1080,6 +1080,16 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
10801080 // / to modify/access them is invalid rewriter API usage.
10811081 SetVector<Operation *> replacedOps;
10821082
1083+ // / A set of operations that were created by the current pattern.
1084+ SetVector<Operation *> patternNewOps;
1085+
1086+ // / A set of operations that were modified by the current pattern.
1087+ SetVector<Operation *> patternModifiedOps;
1088+
1089+ // / A set of blocks that were inserted (newly-created blocks or moved blocks)
1090+ // / by the current pattern.
1091+ SetVector<Block *> patternInsertedBlocks;
1092+
10831093 // / A mapping of all unresolved materializations (UnrealizedConversionCastOp)
10841094 // / to the corresponding rewrite objects.
10851095 DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
@@ -1571,6 +1581,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
15711581 if (!previous.isSet ()) {
15721582 // This is a newly created op.
15731583 appendRewrite<CreateOperationRewrite>(op);
1584+ patternNewOps.insert (op);
15741585 return ;
15751586 }
15761587 Operation *prevOp = previous.getPoint () == previous.getBlock ()->end ()
@@ -1655,6 +1666,8 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
16551666 }
16561667 });
16571668
1669+ patternInsertedBlocks.insert (block);
1670+
16581671 if (!previous) {
16591672 // This is a newly created block.
16601673 appendRewrite<CreateBlockRewrite>(block);
@@ -1852,6 +1865,8 @@ void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
18521865 assert (!impl->wasOpReplaced (op) &&
18531866 " attempting to modify a replaced/erased op" );
18541867 PatternRewriter::finalizeOpModification (op);
1868+ impl->patternModifiedOps .insert (op);
1869+
18551870 // There is nothing to do here, we only need to track the operation at the
18561871 // start of the update.
18571872#ifndef NDEBUG
@@ -1964,21 +1979,25 @@ class OperationLegalizer {
19641979 // / Legalize the resultant IR after successfully applying the given pattern.
19651980 LogicalResult legalizePatternResult (Operation *op, const Pattern &pattern,
19661981 ConversionPatternRewriter &rewriter,
1967- RewriterState &curState);
1982+ const SetVector<Operation *> &newOps,
1983+ const SetVector<Operation *> &modifiedOps,
1984+ const SetVector<Block *> &insertedBlocks);
19681985
19691986 // / Legalizes the actions registered during the execution of a pattern.
19701987 LogicalResult
19711988 legalizePatternBlockRewrites (Operation *op,
19721989 ConversionPatternRewriter &rewriter,
19731990 ConversionPatternRewriterImpl &impl,
1974- RewriterState &state, RewriterState &newState);
1975- LogicalResult legalizePatternCreatedOperations (
1976- ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1977- RewriterState &state, RewriterState &newState);
1978- LogicalResult legalizePatternRootUpdates (ConversionPatternRewriter &rewriter,
1979- ConversionPatternRewriterImpl &impl,
1980- RewriterState &state,
1981- RewriterState &newState);
1991+ const SetVector<Block *> &insertedBlocks,
1992+ const SetVector<Operation *> &newOps);
1993+ LogicalResult
1994+ legalizePatternCreatedOperations (ConversionPatternRewriter &rewriter,
1995+ ConversionPatternRewriterImpl &impl,
1996+ const SetVector<Operation *> &newOps);
1997+ LogicalResult
1998+ legalizePatternRootUpdates (ConversionPatternRewriter &rewriter,
1999+ ConversionPatternRewriterImpl &impl,
2000+ const SetVector<Operation *> &modifiedOps);
19822001
19832002 // ===--------------------------------------------------------------------===//
19842003 // Cost Model
@@ -2131,6 +2150,15 @@ OperationLegalizer::legalize(Operation *op,
21312150 return failure ();
21322151}
21332152
2153+ // / Helper function that moves and returns the given object. Also resets the
2154+ // / original object, so that it is in a valid, empty state again.
2155+ template <typename T>
2156+ static T moveAndReset (T &obj) {
2157+ T result = std::move (obj);
2158+ obj = T ();
2159+ return result;
2160+ }
2161+
21342162LogicalResult
21352163OperationLegalizer::legalizeWithFold (Operation *op,
21362164 ConversionPatternRewriter &rewriter) {
@@ -2192,6 +2220,9 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
21922220 RewriterState curState = rewriterImpl.getCurrentState ();
21932221 auto onFailure = [&](const Pattern &pattern) {
21942222 assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
2223+ rewriterImpl.patternNewOps .clear ();
2224+ rewriterImpl.patternModifiedOps .clear ();
2225+ rewriterImpl.patternInsertedBlocks .clear ();
21952226 LLVM_DEBUG ({
21962227 logFailure (rewriterImpl.logger , " pattern failed to match" );
21972228 if (rewriterImpl.config .notifyCallback ) {
@@ -2212,7 +2243,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
22122243 // successfully applied.
22132244 auto onSuccess = [&](const Pattern &pattern) {
22142245 assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
2215- auto result = legalizePatternResult (op, pattern, rewriter, curState);
2246+ SetVector<Operation *> newOps = moveAndReset (rewriterImpl.patternNewOps );
2247+ SetVector<Operation *> modifiedOps =
2248+ moveAndReset (rewriterImpl.patternModifiedOps );
2249+ SetVector<Block *> insertedBlocks =
2250+ moveAndReset (rewriterImpl.patternInsertedBlocks );
2251+ auto result = legalizePatternResult (op, pattern, rewriter, newOps,
2252+ modifiedOps, insertedBlocks);
22162253 appliedPatterns.erase (&pattern);
22172254 if (failed (result)) {
22182255 if (!rewriterImpl.config .allowPatternRollback )
@@ -2253,10 +2290,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
22532290 return true ;
22542291}
22552292
2256- LogicalResult
2257- OperationLegalizer::legalizePatternResult (Operation *op, const Pattern &pattern,
2258- ConversionPatternRewriter &rewriter,
2259- RewriterState &curState) {
2293+ LogicalResult OperationLegalizer::legalizePatternResult (
2294+ Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
2295+ const SetVector<Operation *> &newOps,
2296+ const SetVector<Operation *> &modifiedOps,
2297+ const SetVector<Block *> &insertedBlocks) {
22602298 auto &impl = rewriter.getImpl ();
22612299 assert (impl.pendingRootUpdates .empty () && " dangling root updates" );
22622300
@@ -2274,12 +2312,10 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
22742312#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
22752313
22762314 // Legalize each of the actions registered during application.
2277- RewriterState newState = impl.getCurrentState ();
2278- if (failed (legalizePatternBlockRewrites (op, rewriter, impl, curState,
2279- newState)) ||
2280- failed (legalizePatternRootUpdates (rewriter, impl, curState, newState)) ||
2281- failed (legalizePatternCreatedOperations (rewriter, impl, curState,
2282- newState))) {
2315+ if (failed (legalizePatternBlockRewrites (op, rewriter, impl, insertedBlocks,
2316+ newOps)) ||
2317+ failed (legalizePatternRootUpdates (rewriter, impl, modifiedOps)) ||
2318+ failed (legalizePatternCreatedOperations (rewriter, impl, newOps))) {
22832319 return failure ();
22842320 }
22852321
@@ -2289,20 +2325,14 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
22892325
22902326LogicalResult OperationLegalizer::legalizePatternBlockRewrites (
22912327 Operation *op, ConversionPatternRewriter &rewriter,
2292- ConversionPatternRewriterImpl &impl, RewriterState &state,
2293- RewriterState &newState) {
2294- SmallPtrSet<Operation *, 16 > operationsToIgnore;
2328+ ConversionPatternRewriterImpl &impl,
2329+ const SetVector<Block *> &insertedBlocks,
2330+ const SetVector<Operation *> &newOps) {
2331+ SmallPtrSet<Operation *, 16 > alreadyLegalized;
22952332
22962333 // If the pattern moved or created any blocks, make sure the types of block
22972334 // arguments get legalized.
2298- for (int i = state.numRewrites , e = newState.numRewrites ; i != e; ++i) {
2299- BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites [i].get ());
2300- if (!rewrite)
2301- continue ;
2302- Block *block = rewrite->getBlock ();
2303- if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2304- ReplaceBlockArgRewrite, InlineBlockRewrite>(rewrite))
2305- continue ;
2335+ for (Block *block : insertedBlocks) {
23062336 // Only check blocks outside of the current operation.
23072337 Operation *parentOp = block->getParentOp ();
23082338 if (!parentOp || parentOp == op || block->getNumArguments () == 0 )
@@ -2322,41 +2352,26 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
23222352 continue ;
23232353 }
23242354
2325- // Otherwise, check that this operation isn't one generated by this pattern.
2326- // This is because we will attempt to legalize the parent operation, and
2327- // blocks in regions created by this pattern will already be legalized later
2328- // on. If we haven't built the set yet, build it now.
2329- if (operationsToIgnore.empty ()) {
2330- for (unsigned i = state.numRewrites , e = impl.rewrites .size (); i != e;
2331- ++i) {
2332- auto *createOp =
2333- dyn_cast<CreateOperationRewrite>(impl.rewrites [i].get ());
2334- if (!createOp)
2335- continue ;
2336- operationsToIgnore.insert (createOp->getOperation ());
2355+ // Otherwise, try to legalize the parent operation if it was not generated
2356+ // by this pattern. This is because we will attempt to legalize the parent
2357+ // operation, and blocks in regions created by this pattern will already be
2358+ // legalized later on.
2359+ if (!newOps.count (parentOp) && alreadyLegalized.insert (parentOp).second ) {
2360+ if (failed (legalize (parentOp, rewriter))) {
2361+ LLVM_DEBUG (logFailure (
2362+ impl.logger , " operation '{0}'({1}) became illegal after rewrite" ,
2363+ parentOp->getName (), parentOp));
2364+ return failure ();
23372365 }
23382366 }
2339-
2340- // If this operation should be considered for re-legalization, try it.
2341- if (operationsToIgnore.insert (parentOp).second &&
2342- failed (legalize (parentOp, rewriter))) {
2343- LLVM_DEBUG (logFailure (impl.logger ,
2344- " operation '{0}'({1}) became illegal after rewrite" ,
2345- parentOp->getName (), parentOp));
2346- return failure ();
2347- }
23482367 }
23492368 return success ();
23502369}
23512370
23522371LogicalResult OperationLegalizer::legalizePatternCreatedOperations (
23532372 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2354- RewriterState &state, RewriterState &newState) {
2355- for (int i = state.numRewrites , e = newState.numRewrites ; i != e; ++i) {
2356- auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites [i].get ());
2357- if (!createOp)
2358- continue ;
2359- Operation *op = createOp->getOperation ();
2373+ const SetVector<Operation *> &newOps) {
2374+ for (Operation *op : newOps) {
23602375 if (failed (legalize (op, rewriter))) {
23612376 LLVM_DEBUG (logFailure (impl.logger ,
23622377 " failed to legalize generated operation '{0}'({1})" ,
@@ -2369,12 +2384,8 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
23692384
23702385LogicalResult OperationLegalizer::legalizePatternRootUpdates (
23712386 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2372- RewriterState &state, RewriterState &newState) {
2373- for (int i = state.numRewrites , e = newState.numRewrites ; i != e; ++i) {
2374- auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites [i].get ());
2375- if (!rewrite)
2376- continue ;
2377- Operation *op = rewrite->getOperation ();
2387+ const SetVector<Operation *> &modifiedOps) {
2388+ for (Operation *op : modifiedOps) {
23782389 if (failed (legalize (op, rewriter))) {
23792390 LLVM_DEBUG (logFailure (
23802391 impl.logger , " failed to legalize operation updated in-place '{0}'" ,
0 commit comments