Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions mlir/docs/PatternRewriter.md
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,33 @@ conversion target, via a set of pattern-based operation rewriting patterns. This
framework also provides support for type conversions. More information on this
driver can be found [here](DialectConversion.md).

### Walk Pattern Rewrite Driver

This is a fast and simple driver that walks the given op and applies patterns
that locally have the most benefit. The benefit of a pattern is decided solely
by the benefit specified on the pattern, and the relative order of the pattern
within the pattern list (when two patterns have the same local benefit).

This driver does not (re)visit modified or newly replaced ops, and does not
allow for progressive rewrites of the same op. Op erasure is only supported for
the currently matched op. If your pattern-set requires these, consider using the
Greedy Pattern Rewrite Driver instead, at the expense of extra overhead.

This driver is exposed using the `walkAndApplyPatterns` function.

#### Debugging

You can debug the Walk Pattern Rewrite Driver by passing the
`--debug-only=walk-rewriter` CLI flag. This will print the visited and matched
ops.

### Greedy Pattern Rewrite Driver

This driver processes ops in a worklist-driven fashion and greedily applies the
patterns that locally have the most benefit. The benefit of a pattern is decided
solely by the benefit specified on the pattern, and the relative order of the
pattern within the pattern list (when two patterns have the same local benefit).
Patterns are iteratively applied to operations until a fixed point is reached or
until the configurable maximum number of iterations exhausted, at which point
the driver finishes.
patterns that locally have the most benefit (same as the Walk Pattern Rewrite
Driver). Patterns are iteratively applied to operations until a fixed point is
reached or until the configurable maximum number of iterations exhausted, at
which point the driver finishes.

This driver comes in two fashions:

Expand Down Expand Up @@ -368,7 +386,7 @@ rewriter and do not bypass the rewriter API by modifying ops directly.
Note: This driver is the one used by the [canonicalization](Canonicalization.md)
[pass](Passes.md/#-canonicalize) in MLIR.

### Debugging
#### Debugging

To debug the execution of the greedy pattern rewrite driver,
`-debug-only=greedy-rewriter` may be used. This command line flag activates
Expand Down
37 changes: 37 additions & 0 deletions mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- WALKPATTERNREWRITEDRIVER.h - Walk Pattern Rewrite Driver -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Declares a helper function to walk the given op and apply rewrite patterns.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
#define MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_

#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"

namespace mlir {

/// A fast walk-based pattern rewrite driver. Rewrites ops nested under the
/// given operation by walking it and applying the highest benefit patterns.
/// This rewriter *does not* wait until a fixpoint is reached and *does not*
/// visit modified or newly replaced ops. Also *does not* perform folding or
/// dead-code elimination.
///
/// This is intended as the simplest and most lightweight pattern rewriter in
/// cases when a simple walk gets the job done.
///
/// Note: Does not apply patterns to the given operation itself.
void walkAndApplyPatterns(Operation *op,
const FrozenRewritePatternSet &patterns,
RewriterBase::Listener *listener = nullptr);

} // namespace mlir

#endif // MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"

namespace mlir {
namespace arith {
Expand Down Expand Up @@ -157,11 +161,7 @@ struct ArithUnsignedWhenEquivalentPass
RewritePatternSet patterns(ctx);
populateUnsignedWhenEquivalentPatterns(patterns, solver);

GreedyRewriteConfig config;
config.listener = &listener;

if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
signalPassFailure();
walkAndApplyPatterns(op, std::move(patterns), &listener);
}
};
} // end anonymous namespace
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Transforms/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_mlir_library(MLIRTransformUtils
LoopInvariantCodeMotionUtils.cpp
OneToNTypeConversion.cpp
RegionUtils.cpp
WalkPatternRewriteDriver.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
Expand Down
86 changes: 86 additions & 0 deletions mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implements mlir::walkAndApplyPatterns.
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/WalkPatternRewriteDriver.h"

#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"

#define DEBUG_TYPE "walk-rewriter"

namespace mlir {

namespace {
// Forwarding listener to guard against unsupported erasures. Because we use
// walk-based pattern application, erasing the op from the *next* iteration
// (e.g., a user of the visited op) is not valid.
struct ErasedOpsListener final : RewriterBase::ForwardingListener {
using RewriterBase::ForwardingListener::ForwardingListener;

void notifyOperationErased(Operation *op) override {
if (op != visitedOp)
llvm::report_fatal_error("unsupported op erased in WalkPatternRewriter; "
"erasure is only supported for matched ops");

ForwardingListener::notifyOperationErased(op);
}

Operation *visitedOp = nullptr;
};
} // namespace

void walkAndApplyPatterns(Operation *op,
const FrozenRewritePatternSet &patterns,
RewriterBase::Listener *listener) {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(op)))
llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

PatternRewriter rewriter(op->getContext());
ErasedOpsListener erasedListener(listener);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
rewriter.setListener(&erasedListener);
#else
(void)erasedListener;
rewriter.setListener(listener);
#endif

PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();

op->walk([&](Operation *visitedOp) {
if (visitedOp == op)
return;

LLVM_DEBUG(llvm::dbgs() << "Visiting op: ";
visitedOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
llvm::dbgs() << "\n";);
erasedListener.visitedOp = visitedOp;
if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
}
});

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(op)))
llvm::report_fatal_error(
"walk pattern rewriter result IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}

} // namespace mlir
2 changes: 1 addition & 1 deletion mlir/test/IR/enum-attr-roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s
// RUN: mlir-opt %s | mlir-opt -test-greedy-patterns | FileCheck %s

// CHECK-LABEL: @test_enum_attr_roundtrip
func.func @test_enum_attr_roundtrip() -> () {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-patterns="max-iterations=1" \
// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1" \
// RUN: -allow-unregistered-dialect --split-input-file | FileCheck %s

// CHECK-LABEL: func @add_to_worklist_after_inplace_update()
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-patterns="max-iterations=1 top-down=true" \
// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1 top-down=true" \
// RUN: --split-input-file | FileCheck %s

// Tests for https://github.com/llvm/llvm-project/issues/86765. Ensure
Expand Down
107 changes: 107 additions & 0 deletions mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// RUN: mlir-opt %s --test-walk-pattern-rewrite-driver="dump-notifications=true" \
// RUN: --allow-unregistered-dialect --split-input-file | FileCheck %s

// The following op is updated in-place and will not be added back to the worklist.
// CHECK-LABEL: func.func @inplace_update()
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
func.func @inplace_update() {
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
"test.any_attr_of_i32_str"() {attr = 1 : i32} : () -> ()
return
}

// Check that the driver does not fold visited ops.
// CHECK-LABEL: func.func @add_no_fold()
// CHECK: arith.constant
// CHECK: arith.constant
// CHECK: %[[RES:.+]] = arith.addi
// CHECK: return %[[RES]]
func.func @add_no_fold() -> i32 {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%res = arith.addi %c0, %c1 : i32
return %res : i32
}

// Check that the driver handles rewriter.moveBefore.
// CHECK-LABEL: func.func @move_before(
// CHECK: "test.move_before_parent_op"
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
// CHECK: scf.if
// CHECK: return
func.func @move_before(%cond : i1) {
scf.if %cond {
"test.move_before_parent_op"() ({
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
}) : () -> ()
}
return
}

// Check that the driver handles rewriter.moveAfter. In this case, we expect
// the moved op to be visited only once since walk uses `make_early_inc_range`.
// CHECK-LABEL: func.func @move_after(
// CHECK: scf.if
// CHECK: }
// CHECK: "test.move_after_parent_op"
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
// CHECK: return
func.func @move_after(%cond : i1) {
scf.if %cond {
"test.move_after_parent_op"() ({
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
}) : () -> ()
}
return
}

// Check that the driver handles rewriter.moveAfter. In this case, we expect
// the moved op to be visited twice since we advance its position to the next
// node after the parent.
// CHECK-LABEL: func.func @move_forward_and_revisit(
// CHECK: scf.if
// CHECK: }
// CHECK: arith.addi
// CHECK: "test.move_after_parent_op"
// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
// CHECK: arith.addi
// CHECK: return
func.func @move_forward_and_revisit(%cond : i1) {
scf.if %cond {
"test.move_after_parent_op"() ({
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
}) {advance = 1 : i32} : () -> ()
}
%a = arith.addi %cond, %cond : i1
%b = arith.addi %a, %cond : i1
return
}

// Operation inserted just after the currently visited one won't be visited.
// CHECK-LABEL: func.func @insert_just_after
// CHECK: "test.clone_me"() ({
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
// CHECK: }) {was_cloned} : () -> ()
// CHECK: "test.clone_me"() ({
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
// CHECK: }) : () -> ()
// CHECK: return
func.func @insert_just_after(%cond : i1) {
"test.clone_me"() ({
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
}) : () -> ()
return
}

// Check that we can replace the current operation with a new one.
// Note that the new op won't be visited.
// CHECK-LABEL: func.func @replace_with_new_op
// CHECK: %[[NEW:.+]] = "test.new_op"
// CHECK: %[[RES:.+]] = arith.addi %[[NEW]], %[[NEW]]
// CHECK: return %[[RES]]
func.func @replace_with_new_op() -> i32 {
%a = "test.replace_with_new_op"() : () -> (i32)
%res = arith.addi %a, %a : i32
return %res : i32
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt --pass-pipeline="builtin.module(test-patterns)" %s | FileCheck %s
// RUN: mlir-opt --pass-pipeline="builtin.module(test-greedy-patterns)" %s | FileCheck %s

// CHECK-LABEL: func @test_reorder_constants_and_match
func.func @test_reorder_constants_and_match(%arg0 : i32) -> (i32) {
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Transforms/test-operation-folder.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s
// RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s
// RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s
// RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s

func.func @foo() -> i32 {
%c42 = arith.constant 42 : i32
Expand Down
Loading
Loading