@@ -861,8 +861,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
861861 // / conversion process succeeds.
862862 void applyRewrites ();
863863
864- // / Reset the state of the rewriter to a previously saved point.
865- void resetState (RewriterState state);
864+ // / Reset the state of the rewriter to a previously saved point. Optionally,
865+ // / the name of the pattern that triggered the rollback can specified for
866+ // / debugging purposes.
867+ void resetState (RewriterState state, StringRef patternName = " " );
866868
867869 // / Append a rewrite. Rewrites are committed upon success and rolled back upon
868870 // / failure.
@@ -873,8 +875,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
873875 }
874876
875877 // / Undo the rewrites (motions, splits) one by one in reverse order until
876- // / "numRewritesToKeep" rewrites remains.
877- void undoRewrites (unsigned numRewritesToKeep = 0 );
878+ // / "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
879+ // / that triggered the rollback can specified for debugging purposes.
880+ void undoRewrites (unsigned numRewritesToKeep = 0 , StringRef patternName = " " );
878881
879882 // / Remap the given values to those with potentially different types. Returns
880883 // / success if the values could be remapped, failure otherwise. `valueDiagTag`
@@ -1204,9 +1207,10 @@ RewriterState ConversionPatternRewriterImpl::getCurrentState() {
12041207 return RewriterState (rewrites.size (), ignoredOps.size (), replacedOps.size ());
12051208}
12061209
1207- void ConversionPatternRewriterImpl::resetState (RewriterState state) {
1210+ void ConversionPatternRewriterImpl::resetState (RewriterState state,
1211+ StringRef patternName) {
12081212 // Undo any rewrites.
1209- undoRewrites (state.numRewrites );
1213+ undoRewrites (state.numRewrites , patternName );
12101214
12111215 // Pop all of the recorded ignored operations that are no longer valid.
12121216 while (ignoredOps.size () != state.numIgnoredOperations )
@@ -1216,10 +1220,19 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
12161220 replacedOps.pop_back ();
12171221}
12181222
1219- void ConversionPatternRewriterImpl::undoRewrites (unsigned numRewritesToKeep) {
1223+ void ConversionPatternRewriterImpl::undoRewrites (unsigned numRewritesToKeep,
1224+ StringRef patternName) {
12201225 for (auto &rewrite :
1221- llvm::reverse (llvm::drop_begin (rewrites, numRewritesToKeep)))
1226+ llvm::reverse (llvm::drop_begin (rewrites, numRewritesToKeep))) {
1227+ if (!config.allowPatternRollback &&
1228+ !isa<UnresolvedMaterializationRewrite>(rewrite)) {
1229+ // Unresolved materializations can always be rolled back (erased).
1230+ std::string errorMessage = " pattern '" + std::string (patternName) +
1231+ " ' rollback of IR modifications requested" ;
1232+ llvm_unreachable (errorMessage.c_str ());
1233+ }
12221234 rewrite->rollback ();
1235+ }
12231236 rewrites.resize (numRewritesToKeep);
12241237}
12251238
@@ -2158,7 +2171,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
21582171 });
21592172 if (config.listener )
21602173 config.listener ->notifyPatternEnd (pattern, failure ());
2161- rewriterImpl.resetState (curState);
2174+ rewriterImpl.resetState (curState, pattern. getDebugName () );
21622175 appliedPatterns.erase (&pattern);
21632176 };
21642177
@@ -2168,8 +2181,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
21682181 assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
21692182 auto result = legalizePatternResult (op, pattern, rewriter, curState);
21702183 appliedPatterns.erase (&pattern);
2171- if (failed (result))
2172- rewriterImpl.resetState (curState);
2184+ if (failed (result)) {
2185+ if (!rewriterImpl.config .allowPatternRollback )
2186+ op->emitError (" pattern '" )
2187+ << pattern.getDebugName ()
2188+ << " ' produced IR that could not be legalized" ;
2189+ rewriterImpl.resetState (curState, pattern.getDebugName ());
2190+ }
21732191 if (config.listener )
21742192 config.listener ->notifyPatternEnd (pattern, result);
21752193 return result;
@@ -2674,9 +2692,20 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
26742692 ConversionPatternRewriter rewriter (ops.front ()->getContext (), config);
26752693 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl ();
26762694
2677- for (auto *op : toConvert)
2678- if (failed (convert (rewriter, op)))
2679- return rewriterImpl.undoRewrites (), failure ();
2695+ for (auto *op : toConvert) {
2696+ if (failed (convert (rewriter, op))) {
2697+ // Dialect conversion failed.
2698+ if (rewriterImpl.config .allowPatternRollback ) {
2699+ // Rollback is allowed: restore the original IR.
2700+ rewriterImpl.undoRewrites ();
2701+ } else {
2702+ // Rollback is not allowed: apply all modifications that have been
2703+ // performed so far.
2704+ rewriterImpl.applyRewrites ();
2705+ }
2706+ return failure ();
2707+ }
2708+ }
26802709
26812710 // After a successful conversion, apply rewrites.
26822711 rewriterImpl.applyRewrites ();
0 commit comments