Skip to content

Conversation

@matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 10, 2024

This commit adds two new pass options that give users more fine-grained control over which ops are CSE'd / DCE'd.

  • barrier-op-filter specifies ops that should act as CSE'ing barriers. I.e., ops that are nested inside such ops should not be CSE'd with ops that are outside of such ops. (Until now, the only CSE'ing barrier used to be IsolatedFromAbove ops.)
  • eliminate-op-filter specifies ops that are subject to elimination. All non-matching ops are ignored by the CSE pass and remain in place. (If the filter is empty, all ops are subject to elimination.)

This commit adds two new pass options that gives users more fine-grained control over which ops are CSE'd / DCE'd.

* `barrier-op-filter` specifies ops that should act as CSE'ing barriers. I.e., ops that are nested inside such ops should not be CSE'd with ops that are outside of such ops. (Until now, the only CSE'ing barrier used to be IsolatedFromAbove ops.)
* `eliminate-op-filter` specifies ops that are subject to elimination. All non-matching ops are ignored by the CSE pass and remain in place. (If the filter is empty, all ops are subject to elimination.)
@llvmbot
Copy link
Member

llvmbot commented Nov 10, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit adds two new pass options that give users more fine-grained control over which ops are CSE'd / DCE'd.

  • barrier-op-filter specifies ops that should act as CSE'ing barriers. I.e., ops that are nested inside such ops should not be CSE'd with ops that are outside of such ops. (Until now, the only CSE'ing barrier used to be IsolatedFromAbove ops.)
  • eliminate-op-filter specifies ops that are subject to elimination. All non-matching ops are ignored by the CSE pass and remain in place. (If the filter is empty, all ops are subject to elimination.)

Full diff: https://github.com/llvm/llvm-project/pull/115639.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/CSE.h (+26-1)
  • (modified) mlir/include/mlir/Transforms/Passes.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/Passes.td (+15-2)
  • (modified) mlir/lib/Transforms/CSE.cpp (+39-8)
  • (modified) mlir/test/Transforms/cse.mlir (+72-5)
diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h
index 3d01ece0780509..4edca3e3369f24 100644
--- a/mlir/include/mlir/Transforms/CSE.h
+++ b/mlir/include/mlir/Transforms/CSE.h
@@ -13,19 +13,44 @@
 #ifndef MLIR_TRANSFORMS_CSE_H_
 #define MLIR_TRANSFORMS_CSE_H_
 
+#include <functional>
+
 namespace mlir {
 
 class DominanceInfo;
 class Operation;
 class RewriterBase;
 
+/// Configuration for CSE.
+struct CSEConfig {
+  /// If set, matching ops act as a CSE'ing barrier: ops are not CSE'd across
+  /// matching ops.
+  ///
+  /// Note: IsolatedFromAbove ops are always a CSE'ing barrier, regardless of
+  /// this filter.
+  ///
+  /// Example:
+  /// %0 = arith.constant 0 : index
+  /// scf.for ... {
+  ///   %1 = arith.constant 0 : index
+  ///   ...
+  /// }
+  /// If "scf.for" is marked as a CSE'ing barrier, %0 and %1 are *not* CSE'd.
+  std::function<bool(Operation *)> barrierOpFilter = nullptr;
+
+  /// If set, matching ops are not eliminated (neither CSE'd nor DCE'd). All
+  /// non-matching ops are subject to elimination.
+  std::function<bool(Operation *)> eliminateOpFilter = nullptr;
+};
+
 /// Eliminate common subexpressions within the given operation. This transform
 /// looks for and deduplicates equivalent operations.
 ///
 /// `changed` indicates whether the IR was modified or not.
 void eliminateCommonSubExpressions(RewriterBase &rewriter,
                                    DominanceInfo &domInfo, Operation *op,
-                                   bool *changed = nullptr);
+                                   bool *changed = nullptr,
+                                   CSEConfig config = CSEConfig());
 
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 5c977055e95dc8..41f208216374fe 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -33,7 +33,7 @@ class GreedyRewriteConfig;
 
 #define GEN_PASS_DECL_CANONICALIZER
 #define GEN_PASS_DECL_CONTROLFLOWSINK
-#define GEN_PASS_DECL_CSEPASS
+#define GEN_PASS_DECL_CSE
 #define GEN_PASS_DECL_INLINER
 #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
 #define GEN_PASS_DECL_MEM2REG
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 000d9f697618e6..429029f21eb307 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -81,12 +81,25 @@ def CSE : Pass<"cse"> {
   let summary = "Eliminate common sub-expressions";
   let description = [{
     This pass implements a generalized algorithm for common sub-expression
-    elimination. This pass relies on information provided by the
-    `Memory SideEffect` interface to identify when it is safe to eliminate
+    elimination. The pass also eliminates dead operation (DCE). The pass
+    relies on information provided by the `MemoryEffectOpInterface`
+    interface and on `DominanceInfo` to identify when it is safe to eliminate
     operations. See [Common subexpression elimination](https://en.wikipedia.org/wiki/Common_subexpression_elimination)
     for more general details on this optimization.
+
+    The types of ops that are subject to elimination can be configured with
+    `eliminate-op-filter`. If set, only those ops are CSE'd or DCE'd.
+
+    Ops are never CSE'd across IsolatedFromAbove ops. Additional CSE'ing
+    barrier ops can be specified with `barrier-op-filter`.
   }];
   let constructor = "mlir::createCSEPass()";
+  let options = [
+    ListOption<"barrierOpFilter", "barrier-op-filter", "std::string",
+               "Names of ops that act as CSE'ing barriers">,
+    ListOption<"eliminateOpFilter", "eliminate-op-filter", "std::string",
+               "If non-empty, list of ops that are subject to elimination">,
+  ];
   let statistics = [
     Statistic<"numCSE", "num-cse'd", "Number of operations CSE'd">,
     Statistic<"numDCE", "num-dce'd", "Number of operations DCE'd">
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 3affd88d158de5..93ac35db276da0 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -23,8 +23,9 @@
 #include "llvm/ADT/ScopedHashTable.h"
 #include "llvm/Support/Allocator.h"
 #include "llvm/Support/RecyclingAllocator.h"
-#include <deque>
 
+#include <deque>
+#include <unordered_set>
 namespace mlir {
 #define GEN_PASS_DEF_CSE
 #include "mlir/Transforms/Passes.h.inc"
@@ -60,8 +61,9 @@ namespace {
 /// Simple common sub-expression elimination.
 class CSEDriver {
 public:
-  CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
-      : rewriter(rewriter), domInfo(domInfo) {}
+  CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo,
+            const CSEConfig &config)
+      : rewriter(rewriter), domInfo(domInfo), config(config) {}
 
   /// Simplify all operations within the given op.
   void simplify(Operation *op, bool *changed = nullptr);
@@ -125,6 +127,9 @@ class CSEDriver {
   // Various statistics.
   int64_t numCSE = 0;
   int64_t numDCE = 0;
+
+  /// CSE configuration.
+  CSEConfig config;
 };
 } // namespace
 
@@ -226,6 +231,10 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
 LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
                                            Operation *op,
                                            bool hasSSADominance) {
+  // Don't simplify operations that are filtered out.
+  if (config.eliminateOpFilter && !config.eliminateOpFilter(op))
+    return failure();
+
   // Don't simplify terminator operations.
   if (op->hasTrait<OpTrait::IsTerminator>())
     return failure();
@@ -288,8 +297,11 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
     if (op.getNumRegions() != 0) {
       // If this operation is isolated above, we can't process nested regions
       // with the given 'knownValues' map. This would cause the insertion of
-      // implicit captures in explicit capture only regions.
-      if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
+      // implicit captures in explicit capture only regions. Additional barrier
+      // ops can be specified by the user.
+      bool isBarrier = op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
+                       (config.barrierOpFilter && config.barrierOpFilter(&op));
+      if (isBarrier) {
         ScopedMapTy nestedKnownValues;
         for (auto &region : op.getRegions())
           simplifyRegion(nestedKnownValues, region);
@@ -381,8 +393,8 @@ void CSEDriver::simplify(Operation *op, bool *changed) {
 
 void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter,
                                          DominanceInfo &domInfo, Operation *op,
-                                         bool *changed) {
-  CSEDriver driver(rewriter, &domInfo);
+                                         bool *changed, CSEConfig config) {
+  CSEDriver driver(rewriter, &domInfo, config);
   driver.simplify(op, changed);
 }
 
@@ -394,9 +406,28 @@ struct CSE : public impl::CSEBase<CSE> {
 } // namespace
 
 void CSE::runOnOperation() {
+  // Set up CSE configuration from pass options.
+  CSEConfig config;
+  std::unordered_set<std::string> barrierOpNames;
+  for (std::string opName : barrierOpFilter)
+    barrierOpNames.insert(opName);
+  std::unordered_set<std::string> eliminateOpNames;
+  for (std::string opName : eliminateOpFilter)
+    eliminateOpNames.insert(opName);
+  if (!barrierOpNames.empty()) {
+    config.barrierOpFilter = [&](Operation *op) -> bool {
+      return barrierOpNames.count(op->getName().getStringRef().str());
+    };
+  }
+  if (!eliminateOpNames.empty()) {
+    config.eliminateOpFilter = [&](Operation *op) -> bool {
+      return eliminateOpNames.count(op->getName().getStringRef().str());
+    };
+  }
+
   // Simplify the IR.
   IRRewriter rewriter(&getContext());
-  CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
+  CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>(), config);
   bool changed = false;
   driver.simplify(getOperation(), &changed);
 
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index 11a33102684733..5d2da75db6ce2f 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -1,32 +1,47 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
-
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
-#map0 = affine_map<(d0) -> (d0 mod 2)>
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="eliminate-op-filter=arith.constant"))' -split-input-file | FileCheck %s --check-prefix=CHECK-ELIMINATE-FILTER
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="barrier-op-filter=affine.for"))' -split-input-file | FileCheck %s --check-prefix=CHECK-BARRIER-FILTER
 
 // CHECK-LABEL: @simple_constant
+// CHECK-ELIMINATE-FILTER-LABEL: @simple_constant
 func.func @simple_constant() -> (i32, i32) {
   // CHECK-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32
+  // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32
   %0 = arith.constant 1 : i32
 
   // CHECK-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32
+  // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32
   %1 = arith.constant 1 : i32
   return %0, %1 : i32, i32
 }
 
+// -----
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
+// CHECK-ELIMINATE-FILTER-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
+#map0 = affine_map<(d0) -> (d0 mod 2)>
+
 // CHECK-LABEL: @basic
+// CHECK-ELIMINATE-FILTER-LABEL: @basic
 func.func @basic() -> (index, index) {
   // CHECK: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index
+  // CHECK-ELIMINATE-FILTER: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 0 : index
 
   // CHECK-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
+  // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
+  // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
   %0 = affine.apply #map0(%c0)
   %1 = affine.apply #map0(%c1)
 
   // CHECK-NEXT: return %[[VAR_0]], %[[VAR_0]] : index, index
+  // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_0]], %[[VAR_1]] : index, index
   return %0, %1 : index, index
 }
 
+// -----
+
 // CHECK-LABEL: @many
 func.func @many(f32, f32) -> (f32) {
 ^bb0(%a : f32, %b : f32):
@@ -52,6 +67,8 @@ func.func @many(f32, f32) -> (f32) {
   return %l : f32
 }
 
+// -----
+
 /// Check that operations are not eliminated if they have different operands.
 // CHECK-LABEL: @different_ops
 func.func @different_ops() -> (i32, i32) {
@@ -64,6 +81,8 @@ func.func @different_ops() -> (i32, i32) {
   return %0, %1 : i32, i32
 }
 
+// -----
+
 /// Check that operations are not eliminated if they have different result
 /// types.
 // CHECK-LABEL: @different_results
@@ -77,6 +96,8 @@ func.func @different_results(%arg0: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<4
   return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
 }
 
+// -----
+
 /// Check that operations are not eliminated if they have different attributes.
 // CHECK-LABEL: @different_attributes
 func.func @different_attributes(index, index) -> (i1, i1, i1) {
@@ -93,6 +114,8 @@ func.func @different_attributes(index, index) -> (i1, i1, i1) {
   return %0, %1, %2 : i1, i1, i1
 }
 
+// -----
+
 /// Check that operations with side effects are not eliminated.
 // CHECK-LABEL: @side_effect
 func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) {
@@ -106,22 +129,32 @@ func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) {
   return %0, %1 : memref<2x1xf32>, memref<2x1xf32>
 }
 
+// -----
+
 /// Check that operation definitions are properly propagated down the dominance
 /// tree.
 // CHECK-LABEL: @down_propagate_for
+// CHECK-BARRIER-FILTER-LABEL: @down_propagate_for
 func.func @down_propagate_for() {
   // CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
+  // CHECK-BARRIER-FILTER: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
   %0 = arith.constant 1 : i32
 
   // CHECK-NEXT: affine.for {{.*}} = 0 to 4 {
+  // CHECK-BARRIER-FILTER-NEXT: affine.for {{.*}} = 0 to 4 {
   affine.for %i = 0 to 4 {
-    // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> ()
+    // CHECK-BARRIER-FILTER-NEXT: %[[VAR2_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
     %1 = arith.constant 1 : i32
+
+    // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> ()
+    // CHECK-BARRIER-FILTER-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR2_c1_i32]]) : (i32, i32) -> ()
     "foo"(%0, %1) : (i32, i32) -> ()
   }
   return
 }
 
+// -----
+
 // CHECK-LABEL: @down_propagate
 func.func @down_propagate() -> i32 {
   // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
@@ -142,6 +175,8 @@ func.func @down_propagate() -> i32 {
   return %arg : i32
 }
 
+// -----
+
 /// Check that operation definitions are NOT propagated up the dominance tree.
 // CHECK-LABEL: @up_propagate_for
 func.func @up_propagate_for() -> i32 {
@@ -159,6 +194,8 @@ func.func @up_propagate_for() -> i32 {
   return %1 : i32
 }
 
+// -----
+
 // CHECK-LABEL: func @up_propagate
 func.func @up_propagate() -> i32 {
   // CHECK-NEXT:  %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
@@ -188,6 +225,8 @@ func.func @up_propagate() -> i32 {
   return %add : i32
 }
 
+// -----
+
 /// The same test as above except that we are testing on a cfg embedded within
 /// an operation region.
 // CHECK-LABEL: func @up_propagate_region
@@ -221,6 +260,8 @@ func.func @up_propagate_region() -> i32 {
   return %0 : i32
 }
 
+// -----
+
 /// This test checks that nested regions that are isolated from above are
 /// properly handled.
 // CHECK-LABEL: @nested_isolated
@@ -248,6 +289,8 @@ func.func @nested_isolated() -> i32 {
   return %0 : i32
 }
 
+// -----
+
 /// This test is checking that CSE gracefully handles values in graph regions
 /// where the use occurs before the def, and one of the defs could be CSE'd with
 /// the other.
@@ -269,6 +312,8 @@ func.func @use_before_def() {
   return
 }
 
+// -----
+
 /// This test is checking that CSE is removing duplicated read op that follow
 /// other.
 // CHECK-LABEL: @remove_direct_duplicated_read_op
@@ -281,6 +326,8 @@ func.func @remove_direct_duplicated_read_op() -> i32 {
   return %2 : i32
 }
 
+// -----
+
 /// This test is checking that CSE is removing duplicated read op that follow
 /// other.
 // CHECK-LABEL: @remove_multiple_duplicated_read_op
@@ -300,6 +347,8 @@ func.func @remove_multiple_duplicated_read_op() -> i64 {
   return %6 : i64
 }
 
+// -----
+
 /// This test is checking that CSE is not removing duplicated read op that
 /// have write op in between.
 // CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting
@@ -314,6 +363,8 @@ func.func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
   return %2 : i32
 }
 
+// -----
+
 // Check that an operation with a single region can CSE.
 func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -332,6 +383,8 @@ func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
 //   CHECK-NOT:   test.cse_of_single_block_op
 //       CHECK:   return %[[OP]], %[[OP]]
 
+// -----
+
 // Operations with different number of bbArgs dont CSE.
 func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -350,6 +403,8 @@ func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
 //       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
 //       CHECK:   return %[[OP0]], %[[OP1]]
 
+// -----
+
 // Operations with different regions dont CSE
 func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -368,6 +423,8 @@ func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?
 //       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
 //       CHECK:   return %[[OP0]], %[[OP1]]
 
+// -----
+
 // Operation with identical region with multiple statements CSE.
 func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -392,6 +449,8 @@ func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tens
 //   CHECK-NOT:   test.cse_of_single_block_op
 //       CHECK:   return %[[OP]], %[[OP]]
 
+// -----
+
 // Operation with non-identical regions dont CSE.
 func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -416,6 +475,8 @@ func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : t
 //       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
 //       CHECK:   return %[[OP0]], %[[OP1]]
 
+// -----
+
 func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor<2xi1>) -> (tensor<2xi1>, tensor<2xi1>) {
   %false_2 = arith.constant false
   %true_5 = arith.constant true
@@ -438,6 +499,8 @@ func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor
 //       CHECK:     test.region_yield %[[TRUE]]
 //       CHECK:   return %[[OP]], %[[OP]]
 
+// -----
+
 func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
   %r1 = scf.if %c -> (tensor<5xf32>) {
     %0 = tensor.empty() : tensor<5xf32>
@@ -463,6 +526,8 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te
 //   CHECK-NOT:   scf.if
 //       CHECK:   return %[[if]], %[[if]]
 
+// -----
+
 // CHECK-LABEL: @cse_recursive_effects_success
 func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
   // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
@@ -492,6 +557,8 @@ func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
   return %0, %2, %1 : i32, i32, i32
 }
 
+// -----
+
 // CHECK-LABEL: @cse_recursive_effects_failure
 func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
   // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32

@llvmbot
Copy link
Member

llvmbot commented Nov 10, 2024

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds two new pass options that give users more fine-grained control over which ops are CSE'd / DCE'd.

  • barrier-op-filter specifies ops that should act as CSE'ing barriers. I.e., ops that are nested inside such ops should not be CSE'd with ops that are outside of such ops. (Until now, the only CSE'ing barrier used to be IsolatedFromAbove ops.)
  • eliminate-op-filter specifies ops that are subject to elimination. All non-matching ops are ignored by the CSE pass and remain in place. (If the filter is empty, all ops are subject to elimination.)

Full diff: https://github.com/llvm/llvm-project/pull/115639.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/CSE.h (+26-1)
  • (modified) mlir/include/mlir/Transforms/Passes.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/Passes.td (+15-2)
  • (modified) mlir/lib/Transforms/CSE.cpp (+39-8)
  • (modified) mlir/test/Transforms/cse.mlir (+72-5)
diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h
index 3d01ece0780509..4edca3e3369f24 100644
--- a/mlir/include/mlir/Transforms/CSE.h
+++ b/mlir/include/mlir/Transforms/CSE.h
@@ -13,19 +13,44 @@
 #ifndef MLIR_TRANSFORMS_CSE_H_
 #define MLIR_TRANSFORMS_CSE_H_
 
+#include <functional>
+
 namespace mlir {
 
 class DominanceInfo;
 class Operation;
 class RewriterBase;
 
+/// Configuration for CSE.
+struct CSEConfig {
+  /// If set, matching ops act as a CSE'ing barrier: ops are not CSE'd across
+  /// matching ops.
+  ///
+  /// Note: IsolatedFromAbove ops are always a CSE'ing barrier, regardless of
+  /// this filter.
+  ///
+  /// Example:
+  /// %0 = arith.constant 0 : index
+  /// scf.for ... {
+  ///   %1 = arith.constant 0 : index
+  ///   ...
+  /// }
+  /// If "scf.for" is marked as a CSE'ing barrier, %0 and %1 are *not* CSE'd.
+  std::function<bool(Operation *)> barrierOpFilter = nullptr;
+
+  /// If set, matching ops are not eliminated (neither CSE'd nor DCE'd). All
+  /// non-matching ops are subject to elimination.
+  std::function<bool(Operation *)> eliminateOpFilter = nullptr;
+};
+
 /// Eliminate common subexpressions within the given operation. This transform
 /// looks for and deduplicates equivalent operations.
 ///
 /// `changed` indicates whether the IR was modified or not.
 void eliminateCommonSubExpressions(RewriterBase &rewriter,
                                    DominanceInfo &domInfo, Operation *op,
-                                   bool *changed = nullptr);
+                                   bool *changed = nullptr,
+                                   CSEConfig config = CSEConfig());
 
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 5c977055e95dc8..41f208216374fe 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -33,7 +33,7 @@ class GreedyRewriteConfig;
 
 #define GEN_PASS_DECL_CANONICALIZER
 #define GEN_PASS_DECL_CONTROLFLOWSINK
-#define GEN_PASS_DECL_CSEPASS
+#define GEN_PASS_DECL_CSE
 #define GEN_PASS_DECL_INLINER
 #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
 #define GEN_PASS_DECL_MEM2REG
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 000d9f697618e6..429029f21eb307 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -81,12 +81,25 @@ def CSE : Pass<"cse"> {
   let summary = "Eliminate common sub-expressions";
   let description = [{
     This pass implements a generalized algorithm for common sub-expression
-    elimination. This pass relies on information provided by the
-    `Memory SideEffect` interface to identify when it is safe to eliminate
+    elimination. The pass also eliminates dead operation (DCE). The pass
+    relies on information provided by the `MemoryEffectOpInterface`
+    interface and on `DominanceInfo` to identify when it is safe to eliminate
     operations. See [Common subexpression elimination](https://en.wikipedia.org/wiki/Common_subexpression_elimination)
     for more general details on this optimization.
+
+    The types of ops that are subject to elimination can be configured with
+    `eliminate-op-filter`. If set, only those ops are CSE'd or DCE'd.
+
+    Ops are never CSE'd across IsolatedFromAbove ops. Additional CSE'ing
+    barrier ops can be specified with `barrier-op-filter`.
   }];
   let constructor = "mlir::createCSEPass()";
+  let options = [
+    ListOption<"barrierOpFilter", "barrier-op-filter", "std::string",
+               "Names of ops that act as CSE'ing barriers">,
+    ListOption<"eliminateOpFilter", "eliminate-op-filter", "std::string",
+               "If non-empty, list of ops that are subject to elimination">,
+  ];
   let statistics = [
     Statistic<"numCSE", "num-cse'd", "Number of operations CSE'd">,
     Statistic<"numDCE", "num-dce'd", "Number of operations DCE'd">
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 3affd88d158de5..93ac35db276da0 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -23,8 +23,9 @@
 #include "llvm/ADT/ScopedHashTable.h"
 #include "llvm/Support/Allocator.h"
 #include "llvm/Support/RecyclingAllocator.h"
-#include <deque>
 
+#include <deque>
+#include <unordered_set>
 namespace mlir {
 #define GEN_PASS_DEF_CSE
 #include "mlir/Transforms/Passes.h.inc"
@@ -60,8 +61,9 @@ namespace {
 /// Simple common sub-expression elimination.
 class CSEDriver {
 public:
-  CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
-      : rewriter(rewriter), domInfo(domInfo) {}
+  CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo,
+            const CSEConfig &config)
+      : rewriter(rewriter), domInfo(domInfo), config(config) {}
 
   /// Simplify all operations within the given op.
   void simplify(Operation *op, bool *changed = nullptr);
@@ -125,6 +127,9 @@ class CSEDriver {
   // Various statistics.
   int64_t numCSE = 0;
   int64_t numDCE = 0;
+
+  /// CSE configuration.
+  CSEConfig config;
 };
 } // namespace
 
@@ -226,6 +231,10 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
 LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
                                            Operation *op,
                                            bool hasSSADominance) {
+  // Don't simplify operations that are filtered out.
+  if (config.eliminateOpFilter && !config.eliminateOpFilter(op))
+    return failure();
+
   // Don't simplify terminator operations.
   if (op->hasTrait<OpTrait::IsTerminator>())
     return failure();
@@ -288,8 +297,11 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
     if (op.getNumRegions() != 0) {
       // If this operation is isolated above, we can't process nested regions
       // with the given 'knownValues' map. This would cause the insertion of
-      // implicit captures in explicit capture only regions.
-      if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
+      // implicit captures in explicit capture only regions. Additional barrier
+      // ops can be specified by the user.
+      bool isBarrier = op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
+                       (config.barrierOpFilter && config.barrierOpFilter(&op));
+      if (isBarrier) {
         ScopedMapTy nestedKnownValues;
         for (auto &region : op.getRegions())
           simplifyRegion(nestedKnownValues, region);
@@ -381,8 +393,8 @@ void CSEDriver::simplify(Operation *op, bool *changed) {
 
 void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter,
                                          DominanceInfo &domInfo, Operation *op,
-                                         bool *changed) {
-  CSEDriver driver(rewriter, &domInfo);
+                                         bool *changed, CSEConfig config) {
+  CSEDriver driver(rewriter, &domInfo, config);
   driver.simplify(op, changed);
 }
 
@@ -394,9 +406,28 @@ struct CSE : public impl::CSEBase<CSE> {
 } // namespace
 
 void CSE::runOnOperation() {
+  // Set up CSE configuration from pass options.
+  CSEConfig config;
+  std::unordered_set<std::string> barrierOpNames;
+  for (std::string opName : barrierOpFilter)
+    barrierOpNames.insert(opName);
+  std::unordered_set<std::string> eliminateOpNames;
+  for (std::string opName : eliminateOpFilter)
+    eliminateOpNames.insert(opName);
+  if (!barrierOpNames.empty()) {
+    config.barrierOpFilter = [&](Operation *op) -> bool {
+      return barrierOpNames.count(op->getName().getStringRef().str());
+    };
+  }
+  if (!eliminateOpNames.empty()) {
+    config.eliminateOpFilter = [&](Operation *op) -> bool {
+      return eliminateOpNames.count(op->getName().getStringRef().str());
+    };
+  }
+
   // Simplify the IR.
   IRRewriter rewriter(&getContext());
-  CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
+  CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>(), config);
   bool changed = false;
   driver.simplify(getOperation(), &changed);
 
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index 11a33102684733..5d2da75db6ce2f 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -1,32 +1,47 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
-
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
-#map0 = affine_map<(d0) -> (d0 mod 2)>
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="eliminate-op-filter=arith.constant"))' -split-input-file | FileCheck %s --check-prefix=CHECK-ELIMINATE-FILTER
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="barrier-op-filter=affine.for"))' -split-input-file | FileCheck %s --check-prefix=CHECK-BARRIER-FILTER
 
 // CHECK-LABEL: @simple_constant
+// CHECK-ELIMINATE-FILTER-LABEL: @simple_constant
 func.func @simple_constant() -> (i32, i32) {
   // CHECK-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32
+  // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32
   %0 = arith.constant 1 : i32
 
   // CHECK-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32
+  // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32
   %1 = arith.constant 1 : i32
   return %0, %1 : i32, i32
 }
 
+// -----
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
+// CHECK-ELIMINATE-FILTER-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
+#map0 = affine_map<(d0) -> (d0 mod 2)>
+
 // CHECK-LABEL: @basic
+// CHECK-ELIMINATE-FILTER-LABEL: @basic
 func.func @basic() -> (index, index) {
   // CHECK: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index
+  // CHECK-ELIMINATE-FILTER: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 0 : index
 
   // CHECK-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
+  // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
+  // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
   %0 = affine.apply #map0(%c0)
   %1 = affine.apply #map0(%c1)
 
   // CHECK-NEXT: return %[[VAR_0]], %[[VAR_0]] : index, index
+  // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_0]], %[[VAR_1]] : index, index
   return %0, %1 : index, index
 }
 
+// -----
+
 // CHECK-LABEL: @many
 func.func @many(f32, f32) -> (f32) {
 ^bb0(%a : f32, %b : f32):
@@ -52,6 +67,8 @@ func.func @many(f32, f32) -> (f32) {
   return %l : f32
 }
 
+// -----
+
 /// Check that operations are not eliminated if they have different operands.
 // CHECK-LABEL: @different_ops
 func.func @different_ops() -> (i32, i32) {
@@ -64,6 +81,8 @@ func.func @different_ops() -> (i32, i32) {
   return %0, %1 : i32, i32
 }
 
+// -----
+
 /// Check that operations are not eliminated if they have different result
 /// types.
 // CHECK-LABEL: @different_results
@@ -77,6 +96,8 @@ func.func @different_results(%arg0: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<4
   return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
 }
 
+// -----
+
 /// Check that operations are not eliminated if they have different attributes.
 // CHECK-LABEL: @different_attributes
 func.func @different_attributes(index, index) -> (i1, i1, i1) {
@@ -93,6 +114,8 @@ func.func @different_attributes(index, index) -> (i1, i1, i1) {
   return %0, %1, %2 : i1, i1, i1
 }
 
+// -----
+
 /// Check that operations with side effects are not eliminated.
 // CHECK-LABEL: @side_effect
 func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) {
@@ -106,22 +129,32 @@ func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) {
   return %0, %1 : memref<2x1xf32>, memref<2x1xf32>
 }
 
+// -----
+
 /// Check that operation definitions are properly propagated down the dominance
 /// tree.
 // CHECK-LABEL: @down_propagate_for
+// CHECK-BARRIER-FILTER-LABEL: @down_propagate_for
 func.func @down_propagate_for() {
   // CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
+  // CHECK-BARRIER-FILTER: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
   %0 = arith.constant 1 : i32
 
   // CHECK-NEXT: affine.for {{.*}} = 0 to 4 {
+  // CHECK-BARRIER-FILTER-NEXT: affine.for {{.*}} = 0 to 4 {
   affine.for %i = 0 to 4 {
-    // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> ()
+    // CHECK-BARRIER-FILTER-NEXT: %[[VAR2_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
     %1 = arith.constant 1 : i32
+
+    // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> ()
+    // CHECK-BARRIER-FILTER-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR2_c1_i32]]) : (i32, i32) -> ()
     "foo"(%0, %1) : (i32, i32) -> ()
   }
   return
 }
 
+// -----
+
 // CHECK-LABEL: @down_propagate
 func.func @down_propagate() -> i32 {
   // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
@@ -142,6 +175,8 @@ func.func @down_propagate() -> i32 {
   return %arg : i32
 }
 
+// -----
+
 /// Check that operation definitions are NOT propagated up the dominance tree.
 // CHECK-LABEL: @up_propagate_for
 func.func @up_propagate_for() -> i32 {
@@ -159,6 +194,8 @@ func.func @up_propagate_for() -> i32 {
   return %1 : i32
 }
 
+// -----
+
 // CHECK-LABEL: func @up_propagate
 func.func @up_propagate() -> i32 {
   // CHECK-NEXT:  %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
@@ -188,6 +225,8 @@ func.func @up_propagate() -> i32 {
   return %add : i32
 }
 
+// -----
+
 /// The same test as above except that we are testing on a cfg embedded within
 /// an operation region.
 // CHECK-LABEL: func @up_propagate_region
@@ -221,6 +260,8 @@ func.func @up_propagate_region() -> i32 {
   return %0 : i32
 }
 
+// -----
+
 /// This test checks that nested regions that are isolated from above are
 /// properly handled.
 // CHECK-LABEL: @nested_isolated
@@ -248,6 +289,8 @@ func.func @nested_isolated() -> i32 {
   return %0 : i32
 }
 
+// -----
+
 /// This test is checking that CSE gracefully handles values in graph regions
 /// where the use occurs before the def, and one of the defs could be CSE'd with
 /// the other.
@@ -269,6 +312,8 @@ func.func @use_before_def() {
   return
 }
 
+// -----
+
 /// This test is checking that CSE is removing duplicated read op that follow
 /// other.
 // CHECK-LABEL: @remove_direct_duplicated_read_op
@@ -281,6 +326,8 @@ func.func @remove_direct_duplicated_read_op() -> i32 {
   return %2 : i32
 }
 
+// -----
+
 /// This test is checking that CSE is removing duplicated read op that follow
 /// other.
 // CHECK-LABEL: @remove_multiple_duplicated_read_op
@@ -300,6 +347,8 @@ func.func @remove_multiple_duplicated_read_op() -> i64 {
   return %6 : i64
 }
 
+// -----
+
 /// This test is checking that CSE is not removing duplicated read op that
 /// have write op in between.
 // CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting
@@ -314,6 +363,8 @@ func.func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
   return %2 : i32
 }
 
+// -----
+
 // Check that an operation with a single region can CSE.
 func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -332,6 +383,8 @@ func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
 //   CHECK-NOT:   test.cse_of_single_block_op
 //       CHECK:   return %[[OP]], %[[OP]]
 
+// -----
+
 // Operations with different number of bbArgs dont CSE.
 func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -350,6 +403,8 @@ func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
 //       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
 //       CHECK:   return %[[OP0]], %[[OP1]]
 
+// -----
+
 // Operations with different regions dont CSE
 func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -368,6 +423,8 @@ func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?
 //       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
 //       CHECK:   return %[[OP0]], %[[OP1]]
 
+// -----
+
 // Operation with identical region with multiple statements CSE.
 func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -392,6 +449,8 @@ func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tens
 //   CHECK-NOT:   test.cse_of_single_block_op
 //       CHECK:   return %[[OP]], %[[OP]]
 
+// -----
+
 // Operation with non-identical regions dont CSE.
 func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -416,6 +475,8 @@ func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : t
 //       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
 //       CHECK:   return %[[OP0]], %[[OP1]]
 
+// -----
+
 func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor<2xi1>) -> (tensor<2xi1>, tensor<2xi1>) {
   %false_2 = arith.constant false
   %true_5 = arith.constant true
@@ -438,6 +499,8 @@ func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor
 //       CHECK:     test.region_yield %[[TRUE]]
 //       CHECK:   return %[[OP]], %[[OP]]
 
+// -----
+
 func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
   %r1 = scf.if %c -> (tensor<5xf32>) {
     %0 = tensor.empty() : tensor<5xf32>
@@ -463,6 +526,8 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te
 //   CHECK-NOT:   scf.if
 //       CHECK:   return %[[if]], %[[if]]
 
+// -----
+
 // CHECK-LABEL: @cse_recursive_effects_success
 func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
   // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
@@ -492,6 +557,8 @@ func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
   return %0, %2, %1 : i32, i32, i32
 }
 
+// -----
+
 // CHECK-LABEL: @cse_recursive_effects_failure
 func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
   // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32

@River707
Copy link
Contributor

River707 commented Nov 10, 2024

I don't think this is quite the right approach. The information needed here is near exactly the same situation as the materialization region logic in the folder interface. Adding options to the pass feels like trying to side load semantic information that should ideally be provided by the dialect (this is why we have interfaces). I would hope for this that we'd provide interface logic that dialects can hook into for this, im not sure if we could actually just reuse the materialization region hook (if that actually covers all of the same places), or if we need a cse interface. I haven't thought about it in a while, but this problem has come up before.

@joker-eph
Copy link
Collaborator

I agree with River: that seems a bit ad-hoc to me right now.

@matthias-springer
Copy link
Member Author

matthias-springer commented Nov 11, 2024

Some context: In my use case, certain ops with regions should be a hoisting barrier for certain nested ops. CSE is one example of hoisting (by de-deduplication). Another one is the GreedyPatternRewriteDriver, which hoists constants.

I ran into this issue in the past when I was working on a transform dialect-based experiment in IREE: Ops without tensor results were allowed to be hoisted from a "dispatch region" op, but not ops with tensor results.

I just took a look at the DialectFoldInterface interface. It seems related but does not quite fit. The place where a constant is materialized is defined by the folded op (or its dialect to be precise). In my use case, it looks like I'd need the property (of being a hoisting barrier) on the region op, not the op that is going to be hoisted.

What could work is a HoistingBarrierOpInterface with an interface method bool isBarrierFor(Operation *op) with a default implementation of return false. In my use case, the region op would implement the interface and return "true" if op has a tensor result.

CSE and the greedy pattern driver could query that interface. Any thoughts?

@River707
Copy link
Contributor

Some context: In my use case, certain ops with regions should be a hoisting barrier for certain nested ops. CSE is one example of hoisting (by de-deduplication). Another one is the GreedyPatternRewriteDriver, which hoists constants.

I ran into this issue in the past when I was working on a transform dialect-based experiment in IREE: Ops without tensor results were allowed to be hoisted from a "dispatch region" op, but not ops with tensor results.

I just took a look at the DialectFoldInterface interface. It seems related but does not quite fit. The place where a constant is materialized is defined by the folded op (or its dialect to be precise). In my use case, it looks like I'd need the property (of being a hoisting barrier) on the region op, not the op that is going to be hoisted.

I think there is confusion here, that is not what the materialization hook does. That hook is used exactly as you describe your need: it's defined by dialects with region operations that want to define hoist barriers, or more directly region operations that constants should or shouldn't be materialized above. If you look at the logic where this is used, the insertion region is determined by walking upwards from the insertion block to find a suitable region location (using the parent operations as the hooks for the interfaces):

if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))

What could work is a HoistingBarrierOpInterface with an interface method bool isBarrierFor(Operation *op) with a default implementation of return false. In my use case, the region op would implement the interface and return "true" if op has a tensor result.

CSE and the greedy pattern driver could query that interface. Any thoughts?

Yes, I think something like this could be a good idea. I would hope though, that this would replace the shouldMaterializeInto hook (I don't know of any cases that don't map to hoisting more generally). I could be wrong though, but I haven't seen such a case in practice.

nit on the default: I would expect a barrier op interface to default to being a barrier for everything (not the other way around).

@mfrancio
Copy link
Contributor

Is a CSE Interface out of the question here? It seems some work was done in the past towards that, and we would be very interested in seeing it reaching maturity:
#73455
#73520

@joker-eph
Copy link
Collaborator

Some context: In my use case, certain ops with regions should be a hoisting barrier for certain nested ops. CSE is one example of hoisting (by de-deduplication). Another one is the GreedyPatternRewriteDriver, which hoists constants.

In general, every request I saw about this were instead better addressed with a reverse transformation: this is something that needs profitability and would be suitable even if the input program would have the constant (or the expressions) already outside the target region.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants