diff --git a/mlir/docs/ActionTracing.md b/mlir/docs/ActionTracing.md index 978fdbbe54d81..984516d5c5e7e 100644 --- a/mlir/docs/ActionTracing.md +++ b/mlir/docs/ActionTracing.md @@ -86,7 +86,7 @@ An action can also carry arbitrary payload, for example we can extend the ```c++ /// A custom Action can be defined minimally by deriving from -/// `tracing::ActionImpl`. It can has any members! +/// `tracing::ActionImpl`. It can have any members! class MyCustomAction : public tracing::ActionImpl { public: using Base = tracing::ActionImpl; diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md index 0ba76199874cc..c61ceaf81681e 100644 --- a/mlir/docs/PatternRewriter.md +++ b/mlir/docs/PatternRewriter.md @@ -320,15 +320,41 @@ conversion target, via a set of pattern-based operation rewriting patterns. This framework also provides support for type conversions. More information on this driver can be found [here](DialectConversion.md). +### Walk Pattern Rewrite Driver + +This is a fast and simple driver that walks the given op and applies patterns +that locally have the most benefit. The benefit of a pattern is decided solely +by the benefit specified on the pattern, and the relative order of the pattern +within the pattern list (when two patterns have the same local benefit). + +The driver performs a post-order traversal. Note that it walks regions of the +given op but does not visit the op. + +This driver does not (re)visit modified or newly replaced ops, and does not +allow for progressive rewrites of the same op. Op and block erasure is only +supported for the currently matched op and its descendant. If your pattern +set requires these, consider using the Greedy Pattern Rewrite Driver instead, +at the expense of extra overhead. + +This driver is exposed using the `walkAndApplyPatterns` function. + +Note: This driver listens for IR changes via the callbacks provided by +`RewriterBase`. It is important that patterns announce all IR changes to the +rewriter and do not bypass the rewriter API by modifying ops directly. + +#### Debugging + +You can debug the Walk Pattern Rewrite Driver by passing the +`--debug-only=walk-rewriter` CLI flag. This will print the visited and matched +ops. + ### Greedy Pattern Rewrite Driver This driver processes ops in a worklist-driven fashion and greedily applies the -patterns that locally have the most benefit. The benefit of a pattern is decided -solely by the benefit specified on the pattern, and the relative order of the -pattern within the pattern list (when two patterns have the same local benefit). -Patterns are iteratively applied to operations until a fixed point is reached or -until the configurable maximum number of iterations exhausted, at which point -the driver finishes. +patterns that locally have the most benefit (same as the Walk Pattern Rewrite +Driver). Patterns are iteratively applied to operations until a fixed point is +reached or until the configurable maximum number of iterations exhausted, at +which point the driver finishes. This driver comes in two fashions: @@ -368,7 +394,7 @@ rewriter and do not bypass the rewriter API by modifying ops directly. Note: This driver is the one used by the [canonicalization](Canonicalization.md) [pass](Passes.md/#-canonicalize) in MLIR. -### Debugging +#### Debugging To debug the execution of the greedy pattern rewrite driver, `-debug-only=greedy-rewriter` may be used. This command line flag activates diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 896fdf1c899e3..2ab0405043a54 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -461,54 +461,60 @@ class RewriterBase : public OpBuilder { /// struct can be used as a base to create listener chains, so that multiple /// listeners can be notified of IR changes. struct ForwardingListener : public RewriterBase::Listener { - ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {} + ForwardingListener(OpBuilder::Listener *listener) + : listener(listener), + rewriteListener( + dyn_cast_if_present(listener)) {} void notifyOperationInserted(Operation *op, InsertPoint previous) override { - listener->notifyOperationInserted(op, previous); + if (listener) + listener->notifyOperationInserted(op, previous); } void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override { - listener->notifyBlockInserted(block, previous, previousIt); + if (listener) + listener->notifyBlockInserted(block, previous, previousIt); } void notifyBlockErased(Block *block) override { - if (auto *rewriteListener = dyn_cast(listener)) + if (rewriteListener) rewriteListener->notifyBlockErased(block); } void notifyOperationModified(Operation *op) override { - if (auto *rewriteListener = dyn_cast(listener)) + if (rewriteListener) rewriteListener->notifyOperationModified(op); } void notifyOperationReplaced(Operation *op, Operation *newOp) override { - if (auto *rewriteListener = dyn_cast(listener)) + if (rewriteListener) rewriteListener->notifyOperationReplaced(op, newOp); } void notifyOperationReplaced(Operation *op, ValueRange replacement) override { - if (auto *rewriteListener = dyn_cast(listener)) + if (rewriteListener) rewriteListener->notifyOperationReplaced(op, replacement); } void notifyOperationErased(Operation *op) override { - if (auto *rewriteListener = dyn_cast(listener)) + if (rewriteListener) rewriteListener->notifyOperationErased(op); } void notifyPatternBegin(const Pattern &pattern, Operation *op) override { - if (auto *rewriteListener = dyn_cast(listener)) + if (rewriteListener) rewriteListener->notifyPatternBegin(pattern, op); } void notifyPatternEnd(const Pattern &pattern, LogicalResult status) override { - if (auto *rewriteListener = dyn_cast(listener)) + if (rewriteListener) rewriteListener->notifyPatternEnd(pattern, status); } void notifyMatchFailure( Location loc, function_ref reasonCallback) override { - if (auto *rewriteListener = dyn_cast(listener)) + if (rewriteListener) rewriteListener->notifyMatchFailure(loc, reasonCallback); } private: OpBuilder::Listener *listener; + RewriterBase::Listener *rewriteListener; }; /// Move the blocks that belong to "region" before the given position in diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h new file mode 100644 index 0000000000000..6d62ae3dd43dc --- /dev/null +++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h @@ -0,0 +1,37 @@ +//===- WALKPATTERNREWRITEDRIVER.h - Walk Pattern Rewrite Driver -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Declares a helper function to walk the given op and apply rewrite patterns. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_ +#define MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_ + +#include "mlir/IR/Visitors.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" + +namespace mlir { + +/// A fast walk-based pattern rewrite driver. Rewrites ops nested under the +/// given operation by walking it and applying the highest benefit patterns. +/// This rewriter *does not* wait until a fixpoint is reached and *does not* +/// visit modified or newly replaced ops. Also *does not* perform folding or +/// dead-code elimination. +/// +/// This is intended as the simplest and most lightweight pattern rewriter in +/// cases when a simple walk gets the job done. +/// +/// Note: Does not apply patterns to the given operation itself. +void walkAndApplyPatterns(Operation *op, + const FrozenRewritePatternSet &patterns, + RewriterBase::Listener *listener = nullptr); + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_ diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp index bebe0b5a7c0b6..8922e93e399f9 100644 --- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp @@ -14,7 +14,7 @@ #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" namespace mlir { namespace arith { @@ -157,11 +157,7 @@ struct ArithUnsignedWhenEquivalentPass RewritePatternSet patterns(ctx); populateUnsignedWhenEquivalentPatterns(patterns, solver); - GreedyRewriteConfig config; - config.listener = &listener; - - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) - signalPassFailure(); + walkAndApplyPatterns(op, std::move(patterns), &listener); } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt index eb588640dbf83..72eb34f36cf5f 100644 --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_library(MLIRTransformUtils LoopInvariantCodeMotionUtils.cpp OneToNTypeConversion.cpp RegionUtils.cpp + WalkPatternRewriteDriver.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp new file mode 100644 index 0000000000000..ee5c642c943c4 --- /dev/null +++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp @@ -0,0 +1,116 @@ +//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements mlir::walkAndApplyPatterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "walk-rewriter" + +namespace mlir { + +namespace { +struct WalkAndApplyPatternsAction final + : tracing::ActionImpl { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction) + using ActionImpl::ActionImpl; + static constexpr StringLiteral tag = "walk-and-apply-patterns"; + void print(raw_ostream &os) const override { os << tag; } +}; + +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS +// Forwarding listener to guard against unsupported erasures of non-descendant +// ops/blocks. Because we use walk-based pattern application, erasing the +// op/block from the *next* iteration (e.g., a user of the visited op) is not +// valid. Note that this is only used with expensive pattern API checks. +struct ErasedOpsListener final : RewriterBase::ForwardingListener { + using RewriterBase::ForwardingListener::ForwardingListener; + + void notifyOperationErased(Operation *op) override { + checkErasure(op); + ForwardingListener::notifyOperationErased(op); + } + + void notifyBlockErased(Block *block) override { + checkErasure(block->getParentOp()); + ForwardingListener::notifyBlockErased(block); + } + + void checkErasure(Operation *op) const { + Operation *ancestorOp = op; + while (ancestorOp && ancestorOp != visitedOp) + ancestorOp = ancestorOp->getParentOp(); + + if (ancestorOp != visitedOp) + llvm::report_fatal_error( + "unsupported erasure in WalkPatternRewriter; " + "erasure is only supported for matched ops and their descendants"); + } + + Operation *visitedOp = nullptr; +}; +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS +} // namespace + +void walkAndApplyPatterns(Operation *op, + const FrozenRewritePatternSet &patterns, + RewriterBase::Listener *listener) { +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (failed(verify(op))) + llvm::report_fatal_error("walk pattern rewriter input IR failed to verify"); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + + MLIRContext *ctx = op->getContext(); + PatternRewriter rewriter(ctx); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + ErasedOpsListener erasedListener(listener); + rewriter.setListener(&erasedListener); +#else + rewriter.setListener(listener); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + + PatternApplicator applicator(patterns); + applicator.applyDefaultCostModel(); + + ctx->executeAction( + [&] { + for (Region ®ion : op->getRegions()) { + region.walk([&](Operation *visitedOp) { + LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print( + llvm::dbgs(), OpPrintingFlags().skipRegions()); + llvm::dbgs() << "\n";); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + erasedListener.visitedOp = visitedOp; +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) { + LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";); + } + }); + } + }, + {op}); + +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (failed(verify(op))) + llvm::report_fatal_error( + "walk pattern rewriter result IR failed to verify"); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS +} + +} // namespace mlir diff --git a/mlir/test/IR/enum-attr-roundtrip.mlir b/mlir/test/IR/enum-attr-roundtrip.mlir index 0b4d379cfb7d5..36e605bdbff4d 100644 --- a/mlir/test/IR/enum-attr-roundtrip.mlir +++ b/mlir/test/IR/enum-attr-roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s +// RUN: mlir-opt %s | mlir-opt -test-greedy-patterns | FileCheck %s // CHECK-LABEL: @test_enum_attr_roundtrip func.func @test_enum_attr_roundtrip() -> () { diff --git a/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir b/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir index f3da9a147fcb9..d619eefd72102 100644 --- a/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir +++ b/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-patterns="max-iterations=1" \ +// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1" \ // RUN: -allow-unregistered-dialect --split-input-file | FileCheck %s // CHECK-LABEL: func @add_to_worklist_after_inplace_update() diff --git a/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir b/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir index a362d6f99b947..9f4a7924b725a 100644 --- a/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir +++ b/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-patterns="max-iterations=1 top-down=true" \ +// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1 top-down=true" \ // RUN: --split-input-file | FileCheck %s // Tests for https://github.com/llvm/llvm-project/issues/86765. Ensure diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir new file mode 100644 index 0000000000000..02f7e60671c9b --- /dev/null +++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir @@ -0,0 +1,121 @@ +// RUN: mlir-opt %s --test-walk-pattern-rewrite-driver="dump-notifications=true" \ +// RUN: --allow-unregistered-dialect --split-input-file | FileCheck %s + +// The following op is updated in-place and will not be added back to the worklist. +// CHECK-LABEL: func.func @inplace_update() +// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> () +// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> () +func.func @inplace_update() { + "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> () + "test.any_attr_of_i32_str"() {attr = 1 : i32} : () -> () + return +} + +// Check that the driver does not fold visited ops. +// CHECK-LABEL: func.func @add_no_fold() +// CHECK: arith.constant +// CHECK: arith.constant +// CHECK: %[[RES:.+]] = arith.addi +// CHECK: return %[[RES]] +func.func @add_no_fold() -> i32 { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %res = arith.addi %c0, %c1 : i32 + return %res : i32 +} + +// Check that the driver handles rewriter.moveBefore. +// CHECK-LABEL: func.func @move_before( +// CHECK: "test.move_before_parent_op" +// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> () +// CHECK: scf.if +// CHECK: return +func.func @move_before(%cond : i1) { + scf.if %cond { + "test.move_before_parent_op"() ({ + "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> () + }) : () -> () + } + return +} + +// Check that the driver handles rewriter.moveAfter. In this case, we expect +// the moved op to be visited only once since walk uses `make_early_inc_range`. +// CHECK-LABEL: func.func @move_after( +// CHECK: scf.if +// CHECK: } +// CHECK: "test.move_after_parent_op" +// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> () +// CHECK: return +func.func @move_after(%cond : i1) { + scf.if %cond { + "test.move_after_parent_op"() ({ + "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> () + }) : () -> () + } + return +} + +// Check that the driver handles rewriter.moveAfter. In this case, we expect +// the moved op to be visited twice since we advance its position to the next +// node after the parent. +// CHECK-LABEL: func.func @move_forward_and_revisit( +// CHECK: scf.if +// CHECK: } +// CHECK: arith.addi +// CHECK: "test.move_after_parent_op" +// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> () +// CHECK: arith.addi +// CHECK: return +func.func @move_forward_and_revisit(%cond : i1) { + scf.if %cond { + "test.move_after_parent_op"() ({ + "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> () + }) {advance = 1 : i32} : () -> () + } + %a = arith.addi %cond, %cond : i1 + %b = arith.addi %a, %cond : i1 + return +} + +// Operation inserted just after the currently visited one won't be visited. +// CHECK-LABEL: func.func @insert_just_after +// CHECK: "test.clone_me"() ({ +// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> () +// CHECK: }) {was_cloned} : () -> () +// CHECK: "test.clone_me"() ({ +// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> () +// CHECK: }) : () -> () +// CHECK: return +func.func @insert_just_after(%cond : i1) { + "test.clone_me"() ({ + "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> () + }) : () -> () + return +} + +// Check that we can replace the current operation with a new one. +// Note that the new op won't be visited. +// CHECK-LABEL: func.func @replace_with_new_op +// CHECK: %[[NEW:.+]] = "test.new_op" +// CHECK: %[[RES:.+]] = arith.addi %[[NEW]], %[[NEW]] +// CHECK: return %[[RES]] +func.func @replace_with_new_op() -> i32 { + %a = "test.replace_with_new_op"() : () -> (i32) + %res = arith.addi %a, %a : i32 + return %res : i32 +} + +// Check that we can erase nested blocks. +// CHECK-LABEL: func.func @erase_nested_block +// CHECK: %[[RES:.+]] = "test.erase_first_block" +// CHECK-NEXT: foo.bar +// CHECK: return %[[RES]] +func.func @erase_nested_block() -> i32 { + %a = "test.erase_first_block"() ({ + "foo.foo"() : () -> () + ^bb1: + "foo.bar"() : () -> () + }): () -> (i32) + return %a : i32 +} diff --git a/mlir/test/Transforms/test-operation-folder-commutative.mlir b/mlir/test/Transforms/test-operation-folder-commutative.mlir index 8ffdeb54f399d..55556c1ec5844 100644 --- a/mlir/test/Transforms/test-operation-folder-commutative.mlir +++ b/mlir/test/Transforms/test-operation-folder-commutative.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --pass-pipeline="builtin.module(test-patterns)" %s | FileCheck %s +// RUN: mlir-opt --pass-pipeline="builtin.module(test-greedy-patterns)" %s | FileCheck %s // CHECK-LABEL: func @test_reorder_constants_and_match func.func @test_reorder_constants_and_match(%arg0 : i32) -> (i32) { diff --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir index 46ee07af993cc..3c0cd15dc6c51 100644 --- a/mlir/test/Transforms/test-operation-folder.mlir +++ b/mlir/test/Transforms/test-operation-folder.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s -// RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s +// RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s +// RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s func.func @foo() -> i32 { %c42 = arith.constant 42 : i32 diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 3eade0369f765..d97f3b41f2ef2 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -13,12 +13,16 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/ScopeExit.h" +#include using namespace mlir; using namespace test; @@ -214,6 +218,30 @@ struct MoveBeforeParentOp : public RewritePattern { } }; +/// This pattern moves "test.move_after_parent_op" after the parent op. +struct MoveAfterParentOp : public RewritePattern { + MoveAfterParentOp(MLIRContext *context) + : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // Do not hoist past functions. + if (isa(op->getParentOp())) + return failure(); + + int64_t moveForwardBy = 0; + if (auto advanceBy = op->getAttrOfType("advance")) + moveForwardBy = advanceBy.getInt(); + + Operation *moveAfter = op->getParentOp(); + for (int64_t i = 0; i < moveForwardBy; ++i) + moveAfter = moveAfter->getNextNode(); + + rewriter.moveOpAfter(op, moveAfter); + return success(); + } +}; + /// This pattern inlines blocks that are nested in /// "test.inline_blocks_into_parent" into the parent block. struct InlineBlocksIntoParent : public RewritePattern { @@ -286,14 +314,65 @@ struct CloneRegionBeforeOp : public RewritePattern { } }; -struct TestPatternDriver - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) +/// Replace an operation may introduce the re-visiting of its users. +class ReplaceWithNewOp : public RewritePattern { +public: + ReplaceWithNewOp(MLIRContext *context) + : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + Operation *newOp; + if (op->hasAttr("create_erase_op")) { + newOp = rewriter.create( + op->getLoc(), + OperationName("test.erase_op", op->getContext()).getIdentifier(), + ValueRange(), TypeRange()); + } else { + newOp = rewriter.create( + op->getLoc(), + OperationName("test.new_op", op->getContext()).getIdentifier(), + op->getOperands(), op->getResultTypes()); + } + // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". + // A "notifyOperationReplaced" callback is triggered in either case. + rewriter.replaceAllOpUsesWith(op, newOp->getResults()); + rewriter.eraseOp(op); + return success(); + } +}; + +/// Erases the first child block of the matched "test.erase_first_block" +/// operation. +class EraseFirstBlock : public RewritePattern { +public: + EraseFirstBlock(MLIRContext *context) + : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + llvm::errs() << "Num regions: " << op->getNumRegions() << "\n"; + for (Region &r : op->getRegions()) { + for (Block &b : r.getBlocks()) { + rewriter.eraseBlock(&b); + llvm::errs() << "Erasing block: " << b << "\n"; + return success(); + } + } + + return failure(); + } +}; + +struct TestGreedyPatternDriver + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver) - TestPatternDriver() = default; - TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {} + TestGreedyPatternDriver() = default; + TestGreedyPatternDriver(const TestGreedyPatternDriver &other) + : PassWrapper(other) {} - StringRef getArgument() const final { return "test-patterns"; } + StringRef getArgument() const final { return "test-greedy-patterns"; } StringRef getDescription() const final { return "Run test dialect patterns"; } void runOnOperation() override { mlir::RewritePatternSet patterns(&getContext()); @@ -470,34 +549,6 @@ struct TestStrictPatternDriver } }; - // Replace an operation may introduce the re-visiting of its users. - class ReplaceWithNewOp : public RewritePattern { - public: - ReplaceWithNewOp(MLIRContext *context) - : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - Operation *newOp; - if (op->hasAttr("create_erase_op")) { - newOp = rewriter.create( - op->getLoc(), - OperationName("test.erase_op", op->getContext()).getIdentifier(), - ValueRange(), TypeRange()); - } else { - newOp = rewriter.create( - op->getLoc(), - OperationName("test.new_op", op->getContext()).getIdentifier(), - op->getOperands(), op->getResultTypes()); - } - // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". - // A "notifyOperationReplaced" callback is triggered in either case. - rewriter.replaceAllOpUsesWith(op, newOp->getResults()); - rewriter.eraseOp(op); - return success(); - } - }; - // Remove an operation may introduce the re-visiting of its operands. class EraseOp : public RewritePattern { public: @@ -560,6 +611,39 @@ struct TestStrictPatternDriver }; }; +struct TestWalkPatternDriver final + : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver) + + TestWalkPatternDriver() = default; + TestWalkPatternDriver(const TestWalkPatternDriver &other) + : PassWrapper(other) {} + + StringRef getArgument() const override { + return "test-walk-pattern-rewrite-driver"; + } + StringRef getDescription() const override { + return "Run test walk pattern rewrite driver"; + } + void runOnOperation() override { + mlir::RewritePatternSet patterns(&getContext()); + + // Patterns for testing the WalkPatternRewriteDriver. + patterns.add, MoveBeforeParentOp, + MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>( + &getContext()); + + DumpNotifications dumpListener; + walkAndApplyPatterns(getOperation(), std::move(patterns), + dumpNotifications ? &dumpListener : nullptr); + } + + Option dumpNotifications{ + *this, "dump-notifications", + llvm::cl::desc("Print rewrite listener notifications"), + llvm::cl::init(false)}; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -1978,8 +2062,9 @@ void registerPatternsTestPass() { PassRegistration(); - PassRegistration(); + PassRegistration(); PassRegistration(); + PassRegistration(); PassRegistration([] { return std::make_unique(legalizerConversionMode); diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index 5ff8710b93770..60d46e676d2a3 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s +// RUN: mlir-opt -test-greedy-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s // CHECK-LABEL: verifyFusedLocs func.func @verifyFusedLocs(%arg0 : i32) -> i32 {