diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h index 3d01ece078050..4edca3e3369f2 100644 --- a/mlir/include/mlir/Transforms/CSE.h +++ b/mlir/include/mlir/Transforms/CSE.h @@ -13,19 +13,44 @@ #ifndef MLIR_TRANSFORMS_CSE_H_ #define MLIR_TRANSFORMS_CSE_H_ +#include + namespace mlir { class DominanceInfo; class Operation; class RewriterBase; +/// Configuration for CSE. +struct CSEConfig { + /// If set, matching ops act as a CSE'ing barrier: ops are not CSE'd across + /// matching ops. + /// + /// Note: IsolatedFromAbove ops are always a CSE'ing barrier, regardless of + /// this filter. + /// + /// Example: + /// %0 = arith.constant 0 : index + /// scf.for ... { + /// %1 = arith.constant 0 : index + /// ... + /// } + /// If "scf.for" is marked as a CSE'ing barrier, %0 and %1 are *not* CSE'd. + std::function barrierOpFilter = nullptr; + + /// If set, matching ops are not eliminated (neither CSE'd nor DCE'd). All + /// non-matching ops are subject to elimination. + std::function eliminateOpFilter = nullptr; +}; + /// Eliminate common subexpressions within the given operation. This transform /// looks for and deduplicates equivalent operations. /// /// `changed` indicates whether the IR was modified or not. void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, - bool *changed = nullptr); + bool *changed = nullptr, + CSEConfig config = CSEConfig()); } // namespace mlir diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 5c977055e95dc..41f208216374f 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -33,7 +33,7 @@ class GreedyRewriteConfig; #define GEN_PASS_DECL_CANONICALIZER #define GEN_PASS_DECL_CONTROLFLOWSINK -#define GEN_PASS_DECL_CSEPASS +#define GEN_PASS_DECL_CSE #define GEN_PASS_DECL_INLINER #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION #define GEN_PASS_DECL_MEM2REG diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 000d9f697618e..429029f21eb30 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -81,12 +81,25 @@ def CSE : Pass<"cse"> { let summary = "Eliminate common sub-expressions"; let description = [{ This pass implements a generalized algorithm for common sub-expression - elimination. This pass relies on information provided by the - `Memory SideEffect` interface to identify when it is safe to eliminate + elimination. The pass also eliminates dead operation (DCE). The pass + relies on information provided by the `MemoryEffectOpInterface` + interface and on `DominanceInfo` to identify when it is safe to eliminate operations. See [Common subexpression elimination](https://en.wikipedia.org/wiki/Common_subexpression_elimination) for more general details on this optimization. + + The types of ops that are subject to elimination can be configured with + `eliminate-op-filter`. If set, only those ops are CSE'd or DCE'd. + + Ops are never CSE'd across IsolatedFromAbove ops. Additional CSE'ing + barrier ops can be specified with `barrier-op-filter`. }]; let constructor = "mlir::createCSEPass()"; + let options = [ + ListOption<"barrierOpFilter", "barrier-op-filter", "std::string", + "Names of ops that act as CSE'ing barriers">, + ListOption<"eliminateOpFilter", "eliminate-op-filter", "std::string", + "If non-empty, list of ops that are subject to elimination">, + ]; let statistics = [ Statistic<"numCSE", "num-cse'd", "Number of operations CSE'd">, Statistic<"numDCE", "num-dce'd", "Number of operations DCE'd"> diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 3affd88d158de..93ac35db276da 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -23,8 +23,9 @@ #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/RecyclingAllocator.h" -#include +#include +#include namespace mlir { #define GEN_PASS_DEF_CSE #include "mlir/Transforms/Passes.h.inc" @@ -60,8 +61,9 @@ namespace { /// Simple common sub-expression elimination. class CSEDriver { public: - CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo) - : rewriter(rewriter), domInfo(domInfo) {} + CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo, + const CSEConfig &config) + : rewriter(rewriter), domInfo(domInfo), config(config) {} /// Simplify all operations within the given op. void simplify(Operation *op, bool *changed = nullptr); @@ -125,6 +127,9 @@ class CSEDriver { // Various statistics. int64_t numCSE = 0; int64_t numDCE = 0; + + /// CSE configuration. + CSEConfig config; }; } // namespace @@ -226,6 +231,10 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp, LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues, Operation *op, bool hasSSADominance) { + // Don't simplify operations that are filtered out. + if (config.eliminateOpFilter && !config.eliminateOpFilter(op)) + return failure(); + // Don't simplify terminator operations. if (op->hasTrait()) return failure(); @@ -288,8 +297,11 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb, if (op.getNumRegions() != 0) { // If this operation is isolated above, we can't process nested regions // with the given 'knownValues' map. This would cause the insertion of - // implicit captures in explicit capture only regions. - if (op.mightHaveTrait()) { + // implicit captures in explicit capture only regions. Additional barrier + // ops can be specified by the user. + bool isBarrier = op.mightHaveTrait() || + (config.barrierOpFilter && config.barrierOpFilter(&op)); + if (isBarrier) { ScopedMapTy nestedKnownValues; for (auto ®ion : op.getRegions()) simplifyRegion(nestedKnownValues, region); @@ -381,8 +393,8 @@ void CSEDriver::simplify(Operation *op, bool *changed) { void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, - bool *changed) { - CSEDriver driver(rewriter, &domInfo); + bool *changed, CSEConfig config) { + CSEDriver driver(rewriter, &domInfo, config); driver.simplify(op, changed); } @@ -394,9 +406,28 @@ struct CSE : public impl::CSEBase { } // namespace void CSE::runOnOperation() { + // Set up CSE configuration from pass options. + CSEConfig config; + std::unordered_set barrierOpNames; + for (std::string opName : barrierOpFilter) + barrierOpNames.insert(opName); + std::unordered_set eliminateOpNames; + for (std::string opName : eliminateOpFilter) + eliminateOpNames.insert(opName); + if (!barrierOpNames.empty()) { + config.barrierOpFilter = [&](Operation *op) -> bool { + return barrierOpNames.count(op->getName().getStringRef().str()); + }; + } + if (!eliminateOpNames.empty()) { + config.eliminateOpFilter = [&](Operation *op) -> bool { + return eliminateOpNames.count(op->getName().getStringRef().str()); + }; + } + // Simplify the IR. IRRewriter rewriter(&getContext()); - CSEDriver driver(rewriter, &getAnalysis()); + CSEDriver driver(rewriter, &getAnalysis(), config); bool changed = false; driver.simplify(getOperation(), &changed); diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir index 11a3310268473..5d2da75db6ce2 100644 --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -1,32 +1,47 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s - -// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)> -#map0 = affine_map<(d0) -> (d0 mod 2)> +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' -split-input-file | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="eliminate-op-filter=arith.constant"))' -split-input-file | FileCheck %s --check-prefix=CHECK-ELIMINATE-FILTER +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="barrier-op-filter=affine.for"))' -split-input-file | FileCheck %s --check-prefix=CHECK-BARRIER-FILTER // CHECK-LABEL: @simple_constant +// CHECK-ELIMINATE-FILTER-LABEL: @simple_constant func.func @simple_constant() -> (i32, i32) { // CHECK-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32 + // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32 %0 = arith.constant 1 : i32 // CHECK-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32 + // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32 %1 = arith.constant 1 : i32 return %0, %1 : i32, i32 } +// ----- + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)> +// CHECK-ELIMINATE-FILTER-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)> +#map0 = affine_map<(d0) -> (d0 mod 2)> + // CHECK-LABEL: @basic +// CHECK-ELIMINATE-FILTER-LABEL: @basic func.func @basic() -> (index, index) { // CHECK: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index + // CHECK-ELIMINATE-FILTER: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index %c0 = arith.constant 0 : index %c1 = arith.constant 0 : index // CHECK-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]]) + // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]]) + // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]]) %0 = affine.apply #map0(%c0) %1 = affine.apply #map0(%c1) // CHECK-NEXT: return %[[VAR_0]], %[[VAR_0]] : index, index + // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_0]], %[[VAR_1]] : index, index return %0, %1 : index, index } +// ----- + // CHECK-LABEL: @many func.func @many(f32, f32) -> (f32) { ^bb0(%a : f32, %b : f32): @@ -52,6 +67,8 @@ func.func @many(f32, f32) -> (f32) { return %l : f32 } +// ----- + /// Check that operations are not eliminated if they have different operands. // CHECK-LABEL: @different_ops func.func @different_ops() -> (i32, i32) { @@ -64,6 +81,8 @@ func.func @different_ops() -> (i32, i32) { return %0, %1 : i32, i32 } +// ----- + /// Check that operations are not eliminated if they have different result /// types. // CHECK-LABEL: @different_results @@ -77,6 +96,8 @@ func.func @different_results(%arg0: tensor<*xf32>) -> (tensor, tensor<4 return %0, %1 : tensor, tensor<4x?xf32> } +// ----- + /// Check that operations are not eliminated if they have different attributes. // CHECK-LABEL: @different_attributes func.func @different_attributes(index, index) -> (i1, i1, i1) { @@ -93,6 +114,8 @@ func.func @different_attributes(index, index) -> (i1, i1, i1) { return %0, %1, %2 : i1, i1, i1 } +// ----- + /// Check that operations with side effects are not eliminated. // CHECK-LABEL: @side_effect func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) { @@ -106,22 +129,32 @@ func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) { return %0, %1 : memref<2x1xf32>, memref<2x1xf32> } +// ----- + /// Check that operation definitions are properly propagated down the dominance /// tree. // CHECK-LABEL: @down_propagate_for +// CHECK-BARRIER-FILTER-LABEL: @down_propagate_for func.func @down_propagate_for() { // CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 + // CHECK-BARRIER-FILTER: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 %0 = arith.constant 1 : i32 // CHECK-NEXT: affine.for {{.*}} = 0 to 4 { + // CHECK-BARRIER-FILTER-NEXT: affine.for {{.*}} = 0 to 4 { affine.for %i = 0 to 4 { - // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> () + // CHECK-BARRIER-FILTER-NEXT: %[[VAR2_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 %1 = arith.constant 1 : i32 + + // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> () + // CHECK-BARRIER-FILTER-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR2_c1_i32]]) : (i32, i32) -> () "foo"(%0, %1) : (i32, i32) -> () } return } +// ----- + // CHECK-LABEL: @down_propagate func.func @down_propagate() -> i32 { // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 @@ -142,6 +175,8 @@ func.func @down_propagate() -> i32 { return %arg : i32 } +// ----- + /// Check that operation definitions are NOT propagated up the dominance tree. // CHECK-LABEL: @up_propagate_for func.func @up_propagate_for() -> i32 { @@ -159,6 +194,8 @@ func.func @up_propagate_for() -> i32 { return %1 : i32 } +// ----- + // CHECK-LABEL: func @up_propagate func.func @up_propagate() -> i32 { // CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32 @@ -188,6 +225,8 @@ func.func @up_propagate() -> i32 { return %add : i32 } +// ----- + /// The same test as above except that we are testing on a cfg embedded within /// an operation region. // CHECK-LABEL: func @up_propagate_region @@ -221,6 +260,8 @@ func.func @up_propagate_region() -> i32 { return %0 : i32 } +// ----- + /// This test checks that nested regions that are isolated from above are /// properly handled. // CHECK-LABEL: @nested_isolated @@ -248,6 +289,8 @@ func.func @nested_isolated() -> i32 { return %0 : i32 } +// ----- + /// This test is checking that CSE gracefully handles values in graph regions /// where the use occurs before the def, and one of the defs could be CSE'd with /// the other. @@ -269,6 +312,8 @@ func.func @use_before_def() { return } +// ----- + /// This test is checking that CSE is removing duplicated read op that follow /// other. // CHECK-LABEL: @remove_direct_duplicated_read_op @@ -281,6 +326,8 @@ func.func @remove_direct_duplicated_read_op() -> i32 { return %2 : i32 } +// ----- + /// This test is checking that CSE is removing duplicated read op that follow /// other. // CHECK-LABEL: @remove_multiple_duplicated_read_op @@ -300,6 +347,8 @@ func.func @remove_multiple_duplicated_read_op() -> i64 { return %6 : i64 } +// ----- + /// This test is checking that CSE is not removing duplicated read op that /// have write op in between. // CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting @@ -314,6 +363,8 @@ func.func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 { return %2 : i32 } +// ----- + // Check that an operation with a single region can CSE. func.func @cse_single_block_ops(%a : tensor, %b : tensor) -> (tensor, tensor) { @@ -332,6 +383,8 @@ func.func @cse_single_block_ops(%a : tensor, %b : tensor) // CHECK-NOT: test.cse_of_single_block_op // CHECK: return %[[OP]], %[[OP]] +// ----- + // Operations with different number of bbArgs dont CSE. func.func @no_cse_varied_bbargs(%a : tensor, %b : tensor) -> (tensor, tensor) { @@ -350,6 +403,8 @@ func.func @no_cse_varied_bbargs(%a : tensor, %b : tensor) // CHECK: %[[OP1:.+]] = test.cse_of_single_block_op // CHECK: return %[[OP0]], %[[OP1]] +// ----- + // Operations with different regions dont CSE func.func @no_cse_region_difference_simple(%a : tensor, %b : tensor) -> (tensor, tensor) { @@ -368,6 +423,8 @@ func.func @no_cse_region_difference_simple(%a : tensor, %b : tensor, %b : tensor, %c : f32, %d : i1) -> (tensor, tensor) { @@ -392,6 +449,8 @@ func.func @cse_single_block_ops_identical_bodies(%a : tensor, %b : tens // CHECK-NOT: test.cse_of_single_block_op // CHECK: return %[[OP]], %[[OP]] +// ----- + // Operation with non-identical regions dont CSE. func.func @no_cse_single_block_ops_different_bodies(%a : tensor, %b : tensor, %c : f32, %d : i1) -> (tensor, tensor) { @@ -416,6 +475,8 @@ func.func @no_cse_single_block_ops_different_bodies(%a : tensor, %b : t // CHECK: %[[OP1:.+]] = test.cse_of_single_block_op // CHECK: return %[[OP0]], %[[OP1]] +// ----- + func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor<2xi1>) -> (tensor<2xi1>, tensor<2xi1>) { %false_2 = arith.constant false %true_5 = arith.constant true @@ -438,6 +499,8 @@ func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor // CHECK: test.region_yield %[[TRUE]] // CHECK: return %[[OP]], %[[OP]] +// ----- + func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { %r1 = scf.if %c -> (tensor<5xf32>) { %0 = tensor.empty() : tensor<5xf32> @@ -463,6 +526,8 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te // CHECK-NOT: scf.if // CHECK: return %[[if]], %[[if]] +// ----- + // CHECK-LABEL: @cse_recursive_effects_success func.func @cse_recursive_effects_success() -> (i32, i32, i32) { // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32 @@ -492,6 +557,8 @@ func.func @cse_recursive_effects_success() -> (i32, i32, i32) { return %0, %2, %1 : i32, i32, i32 } +// ----- + // CHECK-LABEL: @cse_recursive_effects_failure func.func @cse_recursive_effects_failure() -> (i32, i32, i32) { // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32