-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][Transforms] CSE: Add filter options to control CSE'ing #115639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This commit adds two new pass options that gives users more fine-grained control over which ops are CSE'd / DCE'd. * `barrier-op-filter` specifies ops that should act as CSE'ing barriers. I.e., ops that are nested inside such ops should not be CSE'd with ops that are outside of such ops. (Until now, the only CSE'ing barrier used to be IsolatedFromAbove ops.) * `eliminate-op-filter` specifies ops that are subject to elimination. All non-matching ops are ignored by the CSE pass and remain in place. (If the filter is empty, all ops are subject to elimination.)
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit adds two new pass options that give users more fine-grained control over which ops are CSE'd / DCE'd.
Full diff: https://github.com/llvm/llvm-project/pull/115639.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h
index 3d01ece0780509..4edca3e3369f24 100644
--- a/mlir/include/mlir/Transforms/CSE.h
+++ b/mlir/include/mlir/Transforms/CSE.h
@@ -13,19 +13,44 @@
#ifndef MLIR_TRANSFORMS_CSE_H_
#define MLIR_TRANSFORMS_CSE_H_
+#include <functional>
+
namespace mlir {
class DominanceInfo;
class Operation;
class RewriterBase;
+/// Configuration for CSE.
+struct CSEConfig {
+ /// If set, matching ops act as a CSE'ing barrier: ops are not CSE'd across
+ /// matching ops.
+ ///
+ /// Note: IsolatedFromAbove ops are always a CSE'ing barrier, regardless of
+ /// this filter.
+ ///
+ /// Example:
+ /// %0 = arith.constant 0 : index
+ /// scf.for ... {
+ /// %1 = arith.constant 0 : index
+ /// ...
+ /// }
+ /// If "scf.for" is marked as a CSE'ing barrier, %0 and %1 are *not* CSE'd.
+ std::function<bool(Operation *)> barrierOpFilter = nullptr;
+
+ /// If set, matching ops are not eliminated (neither CSE'd nor DCE'd). All
+ /// non-matching ops are subject to elimination.
+ std::function<bool(Operation *)> eliminateOpFilter = nullptr;
+};
+
/// Eliminate common subexpressions within the given operation. This transform
/// looks for and deduplicates equivalent operations.
///
/// `changed` indicates whether the IR was modified or not.
void eliminateCommonSubExpressions(RewriterBase &rewriter,
DominanceInfo &domInfo, Operation *op,
- bool *changed = nullptr);
+ bool *changed = nullptr,
+ CSEConfig config = CSEConfig());
} // namespace mlir
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 5c977055e95dc8..41f208216374fe 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -33,7 +33,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_CANONICALIZER
#define GEN_PASS_DECL_CONTROLFLOWSINK
-#define GEN_PASS_DECL_CSEPASS
+#define GEN_PASS_DECL_CSE
#define GEN_PASS_DECL_INLINER
#define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
#define GEN_PASS_DECL_MEM2REG
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 000d9f697618e6..429029f21eb307 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -81,12 +81,25 @@ def CSE : Pass<"cse"> {
let summary = "Eliminate common sub-expressions";
let description = [{
This pass implements a generalized algorithm for common sub-expression
- elimination. This pass relies on information provided by the
- `Memory SideEffect` interface to identify when it is safe to eliminate
+ elimination. The pass also eliminates dead operation (DCE). The pass
+ relies on information provided by the `MemoryEffectOpInterface`
+ interface and on `DominanceInfo` to identify when it is safe to eliminate
operations. See [Common subexpression elimination](https://en.wikipedia.org/wiki/Common_subexpression_elimination)
for more general details on this optimization.
+
+ The types of ops that are subject to elimination can be configured with
+ `eliminate-op-filter`. If set, only those ops are CSE'd or DCE'd.
+
+ Ops are never CSE'd across IsolatedFromAbove ops. Additional CSE'ing
+ barrier ops can be specified with `barrier-op-filter`.
}];
let constructor = "mlir::createCSEPass()";
+ let options = [
+ ListOption<"barrierOpFilter", "barrier-op-filter", "std::string",
+ "Names of ops that act as CSE'ing barriers">,
+ ListOption<"eliminateOpFilter", "eliminate-op-filter", "std::string",
+ "If non-empty, list of ops that are subject to elimination">,
+ ];
let statistics = [
Statistic<"numCSE", "num-cse'd", "Number of operations CSE'd">,
Statistic<"numDCE", "num-dce'd", "Number of operations DCE'd">
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 3affd88d158de5..93ac35db276da0 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -23,8 +23,9 @@
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/RecyclingAllocator.h"
-#include <deque>
+#include <deque>
+#include <unordered_set>
namespace mlir {
#define GEN_PASS_DEF_CSE
#include "mlir/Transforms/Passes.h.inc"
@@ -60,8 +61,9 @@ namespace {
/// Simple common sub-expression elimination.
class CSEDriver {
public:
- CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
- : rewriter(rewriter), domInfo(domInfo) {}
+ CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo,
+ const CSEConfig &config)
+ : rewriter(rewriter), domInfo(domInfo), config(config) {}
/// Simplify all operations within the given op.
void simplify(Operation *op, bool *changed = nullptr);
@@ -125,6 +127,9 @@ class CSEDriver {
// Various statistics.
int64_t numCSE = 0;
int64_t numDCE = 0;
+
+ /// CSE configuration.
+ CSEConfig config;
};
} // namespace
@@ -226,6 +231,10 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
Operation *op,
bool hasSSADominance) {
+ // Don't simplify operations that are filtered out.
+ if (config.eliminateOpFilter && !config.eliminateOpFilter(op))
+ return failure();
+
// Don't simplify terminator operations.
if (op->hasTrait<OpTrait::IsTerminator>())
return failure();
@@ -288,8 +297,11 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
if (op.getNumRegions() != 0) {
// If this operation is isolated above, we can't process nested regions
// with the given 'knownValues' map. This would cause the insertion of
- // implicit captures in explicit capture only regions.
- if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
+ // implicit captures in explicit capture only regions. Additional barrier
+ // ops can be specified by the user.
+ bool isBarrier = op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
+ (config.barrierOpFilter && config.barrierOpFilter(&op));
+ if (isBarrier) {
ScopedMapTy nestedKnownValues;
for (auto ®ion : op.getRegions())
simplifyRegion(nestedKnownValues, region);
@@ -381,8 +393,8 @@ void CSEDriver::simplify(Operation *op, bool *changed) {
void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter,
DominanceInfo &domInfo, Operation *op,
- bool *changed) {
- CSEDriver driver(rewriter, &domInfo);
+ bool *changed, CSEConfig config) {
+ CSEDriver driver(rewriter, &domInfo, config);
driver.simplify(op, changed);
}
@@ -394,9 +406,28 @@ struct CSE : public impl::CSEBase<CSE> {
} // namespace
void CSE::runOnOperation() {
+ // Set up CSE configuration from pass options.
+ CSEConfig config;
+ std::unordered_set<std::string> barrierOpNames;
+ for (std::string opName : barrierOpFilter)
+ barrierOpNames.insert(opName);
+ std::unordered_set<std::string> eliminateOpNames;
+ for (std::string opName : eliminateOpFilter)
+ eliminateOpNames.insert(opName);
+ if (!barrierOpNames.empty()) {
+ config.barrierOpFilter = [&](Operation *op) -> bool {
+ return barrierOpNames.count(op->getName().getStringRef().str());
+ };
+ }
+ if (!eliminateOpNames.empty()) {
+ config.eliminateOpFilter = [&](Operation *op) -> bool {
+ return eliminateOpNames.count(op->getName().getStringRef().str());
+ };
+ }
+
// Simplify the IR.
IRRewriter rewriter(&getContext());
- CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
+ CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>(), config);
bool changed = false;
driver.simplify(getOperation(), &changed);
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index 11a33102684733..5d2da75db6ce2f 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -1,32 +1,47 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
-
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
-#map0 = affine_map<(d0) -> (d0 mod 2)>
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="eliminate-op-filter=arith.constant"))' -split-input-file | FileCheck %s --check-prefix=CHECK-ELIMINATE-FILTER
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="barrier-op-filter=affine.for"))' -split-input-file | FileCheck %s --check-prefix=CHECK-BARRIER-FILTER
// CHECK-LABEL: @simple_constant
+// CHECK-ELIMINATE-FILTER-LABEL: @simple_constant
func.func @simple_constant() -> (i32, i32) {
// CHECK-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32
+ // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32
%0 = arith.constant 1 : i32
// CHECK-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32
+ // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32
%1 = arith.constant 1 : i32
return %0, %1 : i32, i32
}
+// -----
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
+// CHECK-ELIMINATE-FILTER-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
+#map0 = affine_map<(d0) -> (d0 mod 2)>
+
// CHECK-LABEL: @basic
+// CHECK-ELIMINATE-FILTER-LABEL: @basic
func.func @basic() -> (index, index) {
// CHECK: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index
+ // CHECK-ELIMINATE-FILTER: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 0 : index
// CHECK-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
+ // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
+ // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
%0 = affine.apply #map0(%c0)
%1 = affine.apply #map0(%c1)
// CHECK-NEXT: return %[[VAR_0]], %[[VAR_0]] : index, index
+ // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_0]], %[[VAR_1]] : index, index
return %0, %1 : index, index
}
+// -----
+
// CHECK-LABEL: @many
func.func @many(f32, f32) -> (f32) {
^bb0(%a : f32, %b : f32):
@@ -52,6 +67,8 @@ func.func @many(f32, f32) -> (f32) {
return %l : f32
}
+// -----
+
/// Check that operations are not eliminated if they have different operands.
// CHECK-LABEL: @different_ops
func.func @different_ops() -> (i32, i32) {
@@ -64,6 +81,8 @@ func.func @different_ops() -> (i32, i32) {
return %0, %1 : i32, i32
}
+// -----
+
/// Check that operations are not eliminated if they have different result
/// types.
// CHECK-LABEL: @different_results
@@ -77,6 +96,8 @@ func.func @different_results(%arg0: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<4
return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
}
+// -----
+
/// Check that operations are not eliminated if they have different attributes.
// CHECK-LABEL: @different_attributes
func.func @different_attributes(index, index) -> (i1, i1, i1) {
@@ -93,6 +114,8 @@ func.func @different_attributes(index, index) -> (i1, i1, i1) {
return %0, %1, %2 : i1, i1, i1
}
+// -----
+
/// Check that operations with side effects are not eliminated.
// CHECK-LABEL: @side_effect
func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) {
@@ -106,22 +129,32 @@ func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) {
return %0, %1 : memref<2x1xf32>, memref<2x1xf32>
}
+// -----
+
/// Check that operation definitions are properly propagated down the dominance
/// tree.
// CHECK-LABEL: @down_propagate_for
+// CHECK-BARRIER-FILTER-LABEL: @down_propagate_for
func.func @down_propagate_for() {
// CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
+ // CHECK-BARRIER-FILTER: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
%0 = arith.constant 1 : i32
// CHECK-NEXT: affine.for {{.*}} = 0 to 4 {
+ // CHECK-BARRIER-FILTER-NEXT: affine.for {{.*}} = 0 to 4 {
affine.for %i = 0 to 4 {
- // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> ()
+ // CHECK-BARRIER-FILTER-NEXT: %[[VAR2_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
%1 = arith.constant 1 : i32
+
+ // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> ()
+ // CHECK-BARRIER-FILTER-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR2_c1_i32]]) : (i32, i32) -> ()
"foo"(%0, %1) : (i32, i32) -> ()
}
return
}
+// -----
+
// CHECK-LABEL: @down_propagate
func.func @down_propagate() -> i32 {
// CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
@@ -142,6 +175,8 @@ func.func @down_propagate() -> i32 {
return %arg : i32
}
+// -----
+
/// Check that operation definitions are NOT propagated up the dominance tree.
// CHECK-LABEL: @up_propagate_for
func.func @up_propagate_for() -> i32 {
@@ -159,6 +194,8 @@ func.func @up_propagate_for() -> i32 {
return %1 : i32
}
+// -----
+
// CHECK-LABEL: func @up_propagate
func.func @up_propagate() -> i32 {
// CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
@@ -188,6 +225,8 @@ func.func @up_propagate() -> i32 {
return %add : i32
}
+// -----
+
/// The same test as above except that we are testing on a cfg embedded within
/// an operation region.
// CHECK-LABEL: func @up_propagate_region
@@ -221,6 +260,8 @@ func.func @up_propagate_region() -> i32 {
return %0 : i32
}
+// -----
+
/// This test checks that nested regions that are isolated from above are
/// properly handled.
// CHECK-LABEL: @nested_isolated
@@ -248,6 +289,8 @@ func.func @nested_isolated() -> i32 {
return %0 : i32
}
+// -----
+
/// This test is checking that CSE gracefully handles values in graph regions
/// where the use occurs before the def, and one of the defs could be CSE'd with
/// the other.
@@ -269,6 +312,8 @@ func.func @use_before_def() {
return
}
+// -----
+
/// This test is checking that CSE is removing duplicated read op that follow
/// other.
// CHECK-LABEL: @remove_direct_duplicated_read_op
@@ -281,6 +326,8 @@ func.func @remove_direct_duplicated_read_op() -> i32 {
return %2 : i32
}
+// -----
+
/// This test is checking that CSE is removing duplicated read op that follow
/// other.
// CHECK-LABEL: @remove_multiple_duplicated_read_op
@@ -300,6 +347,8 @@ func.func @remove_multiple_duplicated_read_op() -> i64 {
return %6 : i64
}
+// -----
+
/// This test is checking that CSE is not removing duplicated read op that
/// have write op in between.
// CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting
@@ -314,6 +363,8 @@ func.func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
return %2 : i32
}
+// -----
+
// Check that an operation with a single region can CSE.
func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -332,6 +383,8 @@ func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
// CHECK-NOT: test.cse_of_single_block_op
// CHECK: return %[[OP]], %[[OP]]
+// -----
+
// Operations with different number of bbArgs dont CSE.
func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -350,6 +403,8 @@ func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
// CHECK: return %[[OP0]], %[[OP1]]
+// -----
+
// Operations with different regions dont CSE
func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -368,6 +423,8 @@ func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
// CHECK: return %[[OP0]], %[[OP1]]
+// -----
+
// Operation with identical region with multiple statements CSE.
func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -392,6 +449,8 @@ func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tens
// CHECK-NOT: test.cse_of_single_block_op
// CHECK: return %[[OP]], %[[OP]]
+// -----
+
// Operation with non-identical regions dont CSE.
func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -416,6 +475,8 @@ func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : t
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
// CHECK: return %[[OP0]], %[[OP1]]
+// -----
+
func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor<2xi1>) -> (tensor<2xi1>, tensor<2xi1>) {
%false_2 = arith.constant false
%true_5 = arith.constant true
@@ -438,6 +499,8 @@ func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor
// CHECK: test.region_yield %[[TRUE]]
// CHECK: return %[[OP]], %[[OP]]
+// -----
+
func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
%r1 = scf.if %c -> (tensor<5xf32>) {
%0 = tensor.empty() : tensor<5xf32>
@@ -463,6 +526,8 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te
// CHECK-NOT: scf.if
// CHECK: return %[[if]], %[[if]]
+// -----
+
// CHECK-LABEL: @cse_recursive_effects_success
func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
// CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
@@ -492,6 +557,8 @@ func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
return %0, %2, %1 : i32, i32, i32
}
+// -----
+
// CHECK-LABEL: @cse_recursive_effects_failure
func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
// CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
|
|
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit adds two new pass options that give users more fine-grained control over which ops are CSE'd / DCE'd.
Full diff: https://github.com/llvm/llvm-project/pull/115639.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h
index 3d01ece0780509..4edca3e3369f24 100644
--- a/mlir/include/mlir/Transforms/CSE.h
+++ b/mlir/include/mlir/Transforms/CSE.h
@@ -13,19 +13,44 @@
#ifndef MLIR_TRANSFORMS_CSE_H_
#define MLIR_TRANSFORMS_CSE_H_
+#include <functional>
+
namespace mlir {
class DominanceInfo;
class Operation;
class RewriterBase;
+/// Configuration for CSE.
+struct CSEConfig {
+ /// If set, matching ops act as a CSE'ing barrier: ops are not CSE'd across
+ /// matching ops.
+ ///
+ /// Note: IsolatedFromAbove ops are always a CSE'ing barrier, regardless of
+ /// this filter.
+ ///
+ /// Example:
+ /// %0 = arith.constant 0 : index
+ /// scf.for ... {
+ /// %1 = arith.constant 0 : index
+ /// ...
+ /// }
+ /// If "scf.for" is marked as a CSE'ing barrier, %0 and %1 are *not* CSE'd.
+ std::function<bool(Operation *)> barrierOpFilter = nullptr;
+
+ /// If set, matching ops are not eliminated (neither CSE'd nor DCE'd). All
+ /// non-matching ops are subject to elimination.
+ std::function<bool(Operation *)> eliminateOpFilter = nullptr;
+};
+
/// Eliminate common subexpressions within the given operation. This transform
/// looks for and deduplicates equivalent operations.
///
/// `changed` indicates whether the IR was modified or not.
void eliminateCommonSubExpressions(RewriterBase &rewriter,
DominanceInfo &domInfo, Operation *op,
- bool *changed = nullptr);
+ bool *changed = nullptr,
+ CSEConfig config = CSEConfig());
} // namespace mlir
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 5c977055e95dc8..41f208216374fe 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -33,7 +33,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_CANONICALIZER
#define GEN_PASS_DECL_CONTROLFLOWSINK
-#define GEN_PASS_DECL_CSEPASS
+#define GEN_PASS_DECL_CSE
#define GEN_PASS_DECL_INLINER
#define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
#define GEN_PASS_DECL_MEM2REG
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 000d9f697618e6..429029f21eb307 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -81,12 +81,25 @@ def CSE : Pass<"cse"> {
let summary = "Eliminate common sub-expressions";
let description = [{
This pass implements a generalized algorithm for common sub-expression
- elimination. This pass relies on information provided by the
- `Memory SideEffect` interface to identify when it is safe to eliminate
+ elimination. The pass also eliminates dead operation (DCE). The pass
+ relies on information provided by the `MemoryEffectOpInterface`
+ interface and on `DominanceInfo` to identify when it is safe to eliminate
operations. See [Common subexpression elimination](https://en.wikipedia.org/wiki/Common_subexpression_elimination)
for more general details on this optimization.
+
+ The types of ops that are subject to elimination can be configured with
+ `eliminate-op-filter`. If set, only those ops are CSE'd or DCE'd.
+
+ Ops are never CSE'd across IsolatedFromAbove ops. Additional CSE'ing
+ barrier ops can be specified with `barrier-op-filter`.
}];
let constructor = "mlir::createCSEPass()";
+ let options = [
+ ListOption<"barrierOpFilter", "barrier-op-filter", "std::string",
+ "Names of ops that act as CSE'ing barriers">,
+ ListOption<"eliminateOpFilter", "eliminate-op-filter", "std::string",
+ "If non-empty, list of ops that are subject to elimination">,
+ ];
let statistics = [
Statistic<"numCSE", "num-cse'd", "Number of operations CSE'd">,
Statistic<"numDCE", "num-dce'd", "Number of operations DCE'd">
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 3affd88d158de5..93ac35db276da0 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -23,8 +23,9 @@
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/RecyclingAllocator.h"
-#include <deque>
+#include <deque>
+#include <unordered_set>
namespace mlir {
#define GEN_PASS_DEF_CSE
#include "mlir/Transforms/Passes.h.inc"
@@ -60,8 +61,9 @@ namespace {
/// Simple common sub-expression elimination.
class CSEDriver {
public:
- CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
- : rewriter(rewriter), domInfo(domInfo) {}
+ CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo,
+ const CSEConfig &config)
+ : rewriter(rewriter), domInfo(domInfo), config(config) {}
/// Simplify all operations within the given op.
void simplify(Operation *op, bool *changed = nullptr);
@@ -125,6 +127,9 @@ class CSEDriver {
// Various statistics.
int64_t numCSE = 0;
int64_t numDCE = 0;
+
+ /// CSE configuration.
+ CSEConfig config;
};
} // namespace
@@ -226,6 +231,10 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
Operation *op,
bool hasSSADominance) {
+ // Don't simplify operations that are filtered out.
+ if (config.eliminateOpFilter && !config.eliminateOpFilter(op))
+ return failure();
+
// Don't simplify terminator operations.
if (op->hasTrait<OpTrait::IsTerminator>())
return failure();
@@ -288,8 +297,11 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
if (op.getNumRegions() != 0) {
// If this operation is isolated above, we can't process nested regions
// with the given 'knownValues' map. This would cause the insertion of
- // implicit captures in explicit capture only regions.
- if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
+ // implicit captures in explicit capture only regions. Additional barrier
+ // ops can be specified by the user.
+ bool isBarrier = op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
+ (config.barrierOpFilter && config.barrierOpFilter(&op));
+ if (isBarrier) {
ScopedMapTy nestedKnownValues;
for (auto ®ion : op.getRegions())
simplifyRegion(nestedKnownValues, region);
@@ -381,8 +393,8 @@ void CSEDriver::simplify(Operation *op, bool *changed) {
void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter,
DominanceInfo &domInfo, Operation *op,
- bool *changed) {
- CSEDriver driver(rewriter, &domInfo);
+ bool *changed, CSEConfig config) {
+ CSEDriver driver(rewriter, &domInfo, config);
driver.simplify(op, changed);
}
@@ -394,9 +406,28 @@ struct CSE : public impl::CSEBase<CSE> {
} // namespace
void CSE::runOnOperation() {
+ // Set up CSE configuration from pass options.
+ CSEConfig config;
+ std::unordered_set<std::string> barrierOpNames;
+ for (std::string opName : barrierOpFilter)
+ barrierOpNames.insert(opName);
+ std::unordered_set<std::string> eliminateOpNames;
+ for (std::string opName : eliminateOpFilter)
+ eliminateOpNames.insert(opName);
+ if (!barrierOpNames.empty()) {
+ config.barrierOpFilter = [&](Operation *op) -> bool {
+ return barrierOpNames.count(op->getName().getStringRef().str());
+ };
+ }
+ if (!eliminateOpNames.empty()) {
+ config.eliminateOpFilter = [&](Operation *op) -> bool {
+ return eliminateOpNames.count(op->getName().getStringRef().str());
+ };
+ }
+
// Simplify the IR.
IRRewriter rewriter(&getContext());
- CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
+ CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>(), config);
bool changed = false;
driver.simplify(getOperation(), &changed);
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index 11a33102684733..5d2da75db6ce2f 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -1,32 +1,47 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
-
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
-#map0 = affine_map<(d0) -> (d0 mod 2)>
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="eliminate-op-filter=arith.constant"))' -split-input-file | FileCheck %s --check-prefix=CHECK-ELIMINATE-FILTER
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="barrier-op-filter=affine.for"))' -split-input-file | FileCheck %s --check-prefix=CHECK-BARRIER-FILTER
// CHECK-LABEL: @simple_constant
+// CHECK-ELIMINATE-FILTER-LABEL: @simple_constant
func.func @simple_constant() -> (i32, i32) {
// CHECK-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32
+ // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32
%0 = arith.constant 1 : i32
// CHECK-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32
+ // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32
%1 = arith.constant 1 : i32
return %0, %1 : i32, i32
}
+// -----
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
+// CHECK-ELIMINATE-FILTER-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
+#map0 = affine_map<(d0) -> (d0 mod 2)>
+
// CHECK-LABEL: @basic
+// CHECK-ELIMINATE-FILTER-LABEL: @basic
func.func @basic() -> (index, index) {
// CHECK: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index
+ // CHECK-ELIMINATE-FILTER: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 0 : index
// CHECK-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
+ // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
+ // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
%0 = affine.apply #map0(%c0)
%1 = affine.apply #map0(%c1)
// CHECK-NEXT: return %[[VAR_0]], %[[VAR_0]] : index, index
+ // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_0]], %[[VAR_1]] : index, index
return %0, %1 : index, index
}
+// -----
+
// CHECK-LABEL: @many
func.func @many(f32, f32) -> (f32) {
^bb0(%a : f32, %b : f32):
@@ -52,6 +67,8 @@ func.func @many(f32, f32) -> (f32) {
return %l : f32
}
+// -----
+
/// Check that operations are not eliminated if they have different operands.
// CHECK-LABEL: @different_ops
func.func @different_ops() -> (i32, i32) {
@@ -64,6 +81,8 @@ func.func @different_ops() -> (i32, i32) {
return %0, %1 : i32, i32
}
+// -----
+
/// Check that operations are not eliminated if they have different result
/// types.
// CHECK-LABEL: @different_results
@@ -77,6 +96,8 @@ func.func @different_results(%arg0: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<4
return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
}
+// -----
+
/// Check that operations are not eliminated if they have different attributes.
// CHECK-LABEL: @different_attributes
func.func @different_attributes(index, index) -> (i1, i1, i1) {
@@ -93,6 +114,8 @@ func.func @different_attributes(index, index) -> (i1, i1, i1) {
return %0, %1, %2 : i1, i1, i1
}
+// -----
+
/// Check that operations with side effects are not eliminated.
// CHECK-LABEL: @side_effect
func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) {
@@ -106,22 +129,32 @@ func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) {
return %0, %1 : memref<2x1xf32>, memref<2x1xf32>
}
+// -----
+
/// Check that operation definitions are properly propagated down the dominance
/// tree.
// CHECK-LABEL: @down_propagate_for
+// CHECK-BARRIER-FILTER-LABEL: @down_propagate_for
func.func @down_propagate_for() {
// CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
+ // CHECK-BARRIER-FILTER: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
%0 = arith.constant 1 : i32
// CHECK-NEXT: affine.for {{.*}} = 0 to 4 {
+ // CHECK-BARRIER-FILTER-NEXT: affine.for {{.*}} = 0 to 4 {
affine.for %i = 0 to 4 {
- // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> ()
+ // CHECK-BARRIER-FILTER-NEXT: %[[VAR2_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
%1 = arith.constant 1 : i32
+
+ // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> ()
+ // CHECK-BARRIER-FILTER-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR2_c1_i32]]) : (i32, i32) -> ()
"foo"(%0, %1) : (i32, i32) -> ()
}
return
}
+// -----
+
// CHECK-LABEL: @down_propagate
func.func @down_propagate() -> i32 {
// CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
@@ -142,6 +175,8 @@ func.func @down_propagate() -> i32 {
return %arg : i32
}
+// -----
+
/// Check that operation definitions are NOT propagated up the dominance tree.
// CHECK-LABEL: @up_propagate_for
func.func @up_propagate_for() -> i32 {
@@ -159,6 +194,8 @@ func.func @up_propagate_for() -> i32 {
return %1 : i32
}
+// -----
+
// CHECK-LABEL: func @up_propagate
func.func @up_propagate() -> i32 {
// CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
@@ -188,6 +225,8 @@ func.func @up_propagate() -> i32 {
return %add : i32
}
+// -----
+
/// The same test as above except that we are testing on a cfg embedded within
/// an operation region.
// CHECK-LABEL: func @up_propagate_region
@@ -221,6 +260,8 @@ func.func @up_propagate_region() -> i32 {
return %0 : i32
}
+// -----
+
/// This test checks that nested regions that are isolated from above are
/// properly handled.
// CHECK-LABEL: @nested_isolated
@@ -248,6 +289,8 @@ func.func @nested_isolated() -> i32 {
return %0 : i32
}
+// -----
+
/// This test is checking that CSE gracefully handles values in graph regions
/// where the use occurs before the def, and one of the defs could be CSE'd with
/// the other.
@@ -269,6 +312,8 @@ func.func @use_before_def() {
return
}
+// -----
+
/// This test is checking that CSE is removing duplicated read op that follow
/// other.
// CHECK-LABEL: @remove_direct_duplicated_read_op
@@ -281,6 +326,8 @@ func.func @remove_direct_duplicated_read_op() -> i32 {
return %2 : i32
}
+// -----
+
/// This test is checking that CSE is removing duplicated read op that follow
/// other.
// CHECK-LABEL: @remove_multiple_duplicated_read_op
@@ -300,6 +347,8 @@ func.func @remove_multiple_duplicated_read_op() -> i64 {
return %6 : i64
}
+// -----
+
/// This test is checking that CSE is not removing duplicated read op that
/// have write op in between.
// CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting
@@ -314,6 +363,8 @@ func.func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
return %2 : i32
}
+// -----
+
// Check that an operation with a single region can CSE.
func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -332,6 +383,8 @@ func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
// CHECK-NOT: test.cse_of_single_block_op
// CHECK: return %[[OP]], %[[OP]]
+// -----
+
// Operations with different number of bbArgs dont CSE.
func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -350,6 +403,8 @@ func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
// CHECK: return %[[OP0]], %[[OP1]]
+// -----
+
// Operations with different regions dont CSE
func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -368,6 +423,8 @@ func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
// CHECK: return %[[OP0]], %[[OP1]]
+// -----
+
// Operation with identical region with multiple statements CSE.
func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -392,6 +449,8 @@ func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tens
// CHECK-NOT: test.cse_of_single_block_op
// CHECK: return %[[OP]], %[[OP]]
+// -----
+
// Operation with non-identical regions dont CSE.
func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -416,6 +475,8 @@ func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : t
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
// CHECK: return %[[OP0]], %[[OP1]]
+// -----
+
func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor<2xi1>) -> (tensor<2xi1>, tensor<2xi1>) {
%false_2 = arith.constant false
%true_5 = arith.constant true
@@ -438,6 +499,8 @@ func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor
// CHECK: test.region_yield %[[TRUE]]
// CHECK: return %[[OP]], %[[OP]]
+// -----
+
func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
%r1 = scf.if %c -> (tensor<5xf32>) {
%0 = tensor.empty() : tensor<5xf32>
@@ -463,6 +526,8 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te
// CHECK-NOT: scf.if
// CHECK: return %[[if]], %[[if]]
+// -----
+
// CHECK-LABEL: @cse_recursive_effects_success
func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
// CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
@@ -492,6 +557,8 @@ func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
return %0, %2, %1 : i32, i32, i32
}
+// -----
+
// CHECK-LABEL: @cse_recursive_effects_failure
func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
// CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
|
|
I don't think this is quite the right approach. The information needed here is near exactly the same situation as the materialization region logic in the folder interface. Adding options to the pass feels like trying to side load semantic information that should ideally be provided by the dialect (this is why we have interfaces). I would hope for this that we'd provide interface logic that dialects can hook into for this, im not sure if we could actually just reuse the materialization region hook (if that actually covers all of the same places), or if we need a cse interface. I haven't thought about it in a while, but this problem has come up before. |
|
I agree with River: that seems a bit ad-hoc to me right now. |
|
Some context: In my use case, certain ops with regions should be a hoisting barrier for certain nested ops. CSE is one example of hoisting (by de-deduplication). Another one is the I ran into this issue in the past when I was working on a transform dialect-based experiment in IREE: Ops without tensor results were allowed to be hoisted from a "dispatch region" op, but not ops with tensor results. I just took a look at the What could work is a CSE and the greedy pattern driver could query that interface. Any thoughts? |
I think there is confusion here, that is not what the materialization hook does. That hook is used exactly as you describe your need: it's defined by dialects with region operations that want to define hoist barriers, or more directly region operations that constants should or shouldn't be materialized above. If you look at the logic where this is used, the insertion region is determined by walking upwards from the insertion block to find a suitable region location (using the parent operations as the hooks for the interfaces):
Yes, I think something like this could be a good idea. I would hope though, that this would replace the shouldMaterializeInto hook (I don't know of any cases that don't map to hoisting more generally). I could be wrong though, but I haven't seen such a case in practice. nit on the default: I would expect a barrier op interface to default to being a barrier for everything (not the other way around). |
In general, every request I saw about this were instead better addressed with a reverse transformation: this is something that needs profitability and would be suitable even if the input program would have the constant (or the expressions) already outside the target region. |
This commit adds two new pass options that give users more fine-grained control over which ops are CSE'd / DCE'd.
barrier-op-filterspecifies ops that should act as CSE'ing barriers. I.e., ops that are nested inside such ops should not be CSE'd with ops that are outside of such ops. (Until now, the only CSE'ing barrier used to be IsolatedFromAbove ops.)eliminate-op-filterspecifies ops that are subject to elimination. All non-matching ops are ignored by the CSE pass and remain in place. (If the filter is empty, all ops are subject to elimination.)