@@ -319,8 +319,7 @@ class RandomizedWorklist : public Worklist {
319319// / This abstract class manages the worklist and contains helper methods for
320320// / rewriting ops on the worklist. Derived classes specify how ops are added
321321// / to the worklist in the beginning.
322- class GreedyPatternRewriteDriver : public PatternRewriter ,
323- public RewriterBase::Listener {
322+ class GreedyPatternRewriteDriver : public RewriterBase ::Listener {
324323protected:
325324 explicit GreedyPatternRewriteDriver (MLIRContext *ctx,
326325 const FrozenRewritePatternSet &patterns,
@@ -339,7 +338,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
339338 // / Notify the driver that the specified operation was inserted. Update the
340339 // / worklist as needed: The operation is enqueued depending on scope and
341340 // / strict mode.
342- void notifyOperationInserted (Operation *op, InsertPoint previous) override ;
341+ void notifyOperationInserted (Operation *op,
342+ OpBuilder::InsertPoint previous) override ;
343343
344344 // / Notify the driver that the specified operation was removed. Update the
345345 // / worklist as needed: The operation and its children are removed from the
@@ -354,6 +354,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
354354 // / reached. Return `true` if any IR was changed.
355355 bool processWorklist ();
356356
357+ // / The pattern rewriter that is used for making IR modifications and is
358+ // / passed to rewrite patterns.
359+ PatternRewriter rewriter;
360+
357361 // / The worklist for this transformation keeps track of the operations that
358362 // / need to be (re)visited.
359363#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
@@ -407,7 +411,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
407411GreedyPatternRewriteDriver::GreedyPatternRewriteDriver (
408412 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
409413 const GreedyRewriteConfig &config)
410- : PatternRewriter (ctx), config(config), matcher(patterns)
414+ : rewriter (ctx), config(config), matcher(patterns)
411415#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
412416 // clang-format off
413417 , expensiveChecks(
@@ -423,9 +427,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
423427#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
424428 // Send IR notifications to the debug handler. This handler will then forward
425429 // all notifications to this GreedyPatternRewriteDriver.
426- setListener (&expensiveChecks);
430+ rewriter. setListener (&expensiveChecks);
427431#else
428- setListener (this );
432+ rewriter. setListener (this );
429433#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
430434}
431435
@@ -473,7 +477,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
473477
474478 // If the operation is trivially dead - remove it.
475479 if (isOpTriviallyDead (op)) {
476- eraseOp (op);
480+ rewriter. eraseOp (op);
477481 changed = true ;
478482
479483 LLVM_DEBUG (logResultWithLine (" success" , " operation is trivially dead" ));
@@ -505,8 +509,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
505509 // Op results can be replaced with `foldResults`.
506510 assert (foldResults.size () == op->getNumResults () &&
507511 " folder produced incorrect number of results" );
508- OpBuilder::InsertionGuard g (* this );
509- setInsertionPoint (op);
512+ OpBuilder::InsertionGuard g (rewriter );
513+ rewriter. setInsertionPoint (op);
510514 SmallVector<Value> replacements;
511515 bool materializationSucceeded = true ;
512516 for (auto [ofr, resultType] :
@@ -519,7 +523,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
519523 }
520524 // Materialize Attributes as SSA values.
521525 Operation *constOp = op->getDialect ()->materializeConstant (
522- * this , ofr.get <Attribute>(), resultType, op->getLoc ());
526+ rewriter , ofr.get <Attribute>(), resultType, op->getLoc ());
523527
524528 if (!constOp) {
525529 // If materialization fails, cleanup any operations generated for
@@ -532,7 +536,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
532536 replacementOps.insert (replacement.getDefiningOp ());
533537 }
534538 for (Operation *op : replacementOps) {
535- eraseOp (op);
539+ rewriter. eraseOp (op);
536540 }
537541
538542 materializationSucceeded = false ;
@@ -547,7 +551,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
547551 }
548552
549553 if (materializationSucceeded) {
550- replaceOp (op, replacements);
554+ rewriter. replaceOp (op, replacements);
551555 changed = true ;
552556 LLVM_DEBUG (logSuccessfulFolding (dumpRootOp));
553557#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -608,7 +612,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
608612#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
609613
610614 LogicalResult matchResult =
611- matcher.matchAndRewrite (op, * this , canApply, onFailure, onSuccess);
615+ matcher.matchAndRewrite (op, rewriter , canApply, onFailure, onSuccess);
612616
613617 if (succeeded (matchResult)) {
614618 LLVM_DEBUG (logResultWithLine (" success" , " pattern matched" ));
@@ -664,8 +668,8 @@ void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
664668 config.listener ->notifyBlockErased (block);
665669}
666670
667- void GreedyPatternRewriteDriver::notifyOperationInserted (Operation *op,
668- InsertPoint previous) {
671+ void GreedyPatternRewriteDriver::notifyOperationInserted (
672+ Operation *op, OpBuilder:: InsertPoint previous) {
669673 LLVM_DEBUG ({
670674 logger.startLine () << " ** Insert : '" << op->getName () << " '(" << op
671675 << " )\n " ;
@@ -822,7 +826,7 @@ class GreedyPatternRewriteIteration
822826LogicalResult RegionPatternRewriteDriver::simplify (bool *changed) && {
823827 bool continueRewrites = false ;
824828 int64_t iteration = 0 ;
825- MLIRContext *ctx = getContext ();
829+ MLIRContext *ctx = rewriter. getContext ();
826830 do {
827831 // Check if the iteration limit was reached.
828832 if (++iteration > config.maxIterations &&
@@ -834,7 +838,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
834838
835839 // `OperationFolder` CSE's constant ops (and may move them into parents
836840 // regions to enable more aggressive CSE'ing).
837- OperationFolder folder (getContext () , this );
841+ OperationFolder folder (ctx , this );
838842 auto insertKnownConstant = [&](Operation *op) {
839843 // Check for existing constants when populating the worklist. This avoids
840844 // accidentally reversing the constant order during processing.
@@ -872,7 +876,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
872876 // After applying patterns, make sure that the CFG of each of the
873877 // regions is kept up to date.
874878 if (config.enableRegionSimplification )
875- continueRewrites |= succeeded (simplifyRegions (* this , region));
879+ continueRewrites |= succeeded (simplifyRegions (rewriter , region));
876880 },
877881 {®ion}, iteration);
878882 } while (continueRewrites);
0 commit comments