Skip to content

Conversation

@Evanyl
Copy link
Contributor

@Evanyl Evanyl commented Jun 14, 2025

We should allow the user to blacklist certain linalg elementwise fusion patterns as they may be expensive on their hardware.

@llvmbot
Copy link
Member

llvmbot commented Jun 14, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Evan Liu (Evanyl)

Changes

We should allow the user to blacklist certain linalg elementwise fusion patterns as they may be expensive on their hardware.


Full diff: https://github.com/llvm/llvm-project/pull/144176.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+3)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+6-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+51-11)
  • (added) mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir (+49)
  • (modified) mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp (+23-2)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..5db234770ef5c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -72,6 +72,9 @@ def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
 
 def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
   let summary = "Fuse elementwise operations on tensors";
+  let options = [ListOption<"reductionFusionOpBlacklist",
+                            "reduction-fusion-blacklist", "std::string",
+                            "List of ops to blacklist for reduction fusion.">];
   let dependentDialects = [
     "affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
   ];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..e4968930ce554 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -450,7 +450,8 @@ using ControlSplitReductionFn =
 /// Return true if two `linalg.generic` operations with producer/consumer
 /// relationship through `fusedOperand` can be fused using elementwise op
 /// fusion.
-bool areElementwiseOpsFusable(OpOperand *fusedOperand);
+bool areElementwiseOpsFusable(OpOperand *fusedOperand,
+                              llvm::StringSet<> &blacklistedReductionFusionOps);
 
 /// Promote memref.subviews feeding linalg-on-buffers operations.
 LogicalResult promoteSubviewsPrecondition(Operation *op,
@@ -505,7 +506,8 @@ struct ElementwiseOpFusionResult {
   llvm::DenseMap<Value, Value> replacements;
 };
 FailureOr<ElementwiseOpFusionResult>
-fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
+fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand,
+                   llvm::StringSet<> &blacklistedReductionFusionOps);
 
 /// Returns a set of indices of the producer's results which would
 /// be preserved after the fusion.
@@ -1783,7 +1785,8 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
 /// when both operations are fusable elementwise operations.
 void populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns,
-    const ControlFusionFn &controlElementwiseOpFusion);
+    const ControlFusionFn &controlElementwiseOpFusion,
+    llvm::StringSet<> *blacklistedReductionFusionOps = nullptr);
 
 /// Function type which is used to control propagation of linalg.pack/unpack
 /// ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f97ed3d6d5111..e10bba04951b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -104,6 +104,20 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
              indexingMaps, producer.getContext())) != AffineMap();
 }
 
+static bool
+shouldFuseIntoReduction(linalg::GenericOp op,
+                        llvm::StringSet<> &blacklistedReductionFusionOps) {
+  for (Operation &innerOp : op.getRegion().front()) {
+    if (innerOp.hasTrait<OpTrait::IsTerminator>())
+      continue;
+
+    if (blacklistedReductionFusionOps.contains(
+            innerOp.getName().getStringRef()))
+      return false;
+  }
+  return true;
+}
+
 /// Returns a set of indices of the producer's results which would
 /// be preserved after the fusion.
 /// * There is a chance that the implementation of the transformation does not
@@ -136,7 +150,8 @@ llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
 }
 
 /// Conditions for elementwise fusion of generic operations.
-bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
+bool mlir::linalg::areElementwiseOpsFusable(
+    OpOperand *fusedOperand, llvm::StringSet<> &blacklistedReductionFusionOps) {
   if (!fusedOperand)
     return false;
 
@@ -159,6 +174,10 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
   if (producer.getNumParallelLoops() != producer.getNumLoops())
     return false;
 
+  if (consumer.getNumReductionLoops() > 0 &&
+      !shouldFuseIntoReduction(producer, blacklistedReductionFusionOps))
+    return false;
+
   // Only allow fusing the producer of an input operand for now.
   // TODO: allow fusing the producer of an output operand.
   if (!consumer.isDpsInput(fusedOperand))
@@ -335,10 +354,12 @@ static void generateFusedElementwiseOpRegion(
 }
 
 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
-mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
-                                 OpOperand *fusedOperand) {
-  assert(areElementwiseOpsFusable(fusedOperand) &&
-         "expected elementwise operation pre-conditions to pass");
+mlir::linalg::fuseElementwiseOps(
+    RewriterBase &rewriter, OpOperand *fusedOperand,
+    llvm::StringSet<> &blacklistedReductionFusionOps) {
+  assert(
+      areElementwiseOpsFusable(fusedOperand, blacklistedReductionFusionOps) &&
+      "expected elementwise operation pre-conditions to pass");
   auto producerResult = cast<OpResult>(fusedOperand->get());
   auto producer = cast<GenericOp>(producerResult.getOwner());
   auto consumer = cast<GenericOp>(fusedOperand->getOwner());
@@ -462,16 +483,19 @@ namespace {
 /// Patterns to fuse a generic op, with the producer of its operands.
 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 public:
+  llvm::StringSet<> &blacklistedReductionFusionOps;
   FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
+                     llvm::StringSet<> &blacklistedReductionFusionOps,
                      PatternBenefit benefit = 1)
       : OpRewritePattern<GenericOp>(context, benefit),
+        blacklistedReductionFusionOps(blacklistedReductionFusionOps),
         controlFn(std::move(fun)) {}
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
     // Find the first operand that is defined by another generic op on tensors.
     for (OpOperand &opOperand : genericOp->getOpOperands()) {
-      if (!areElementwiseOpsFusable(&opOperand))
+      if (!areElementwiseOpsFusable(&opOperand, blacklistedReductionFusionOps))
         continue;
       if (!controlFn(&opOperand))
         continue;
@@ -479,8 +503,8 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
       Operation *producer = opOperand.get().getDefiningOp();
 
       // Find the producer of the operand.
-      FailureOr<ElementwiseOpFusionResult> fusionResult =
-          fuseElementwiseOps(rewriter, &opOperand);
+      FailureOr<ElementwiseOpFusionResult> fusionResult = fuseElementwiseOps(
+          rewriter, &opOperand, blacklistedReductionFusionOps);
       if (failed(fusionResult))
         return rewriter.notifyMatchFailure(genericOp, "fusion failed");
 
@@ -2248,9 +2272,17 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
 
 void mlir::linalg::populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns,
-    const ControlFusionFn &controlElementwiseOpsFusion) {
+    const ControlFusionFn &controlElementwiseOpsFusion,
+    llvm::StringSet<> *blacklistedReductionFusionOps) {
   auto *context = patterns.getContext();
-  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
+  if (blacklistedReductionFusionOps)
+    patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion,
+                                     *blacklistedReductionFusionOps);
+  else {
+    llvm::StringSet<> emptyBlacklistedReductionFusionOps;
+    patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion,
+                                     emptyBlacklistedReductionFusionOps);
+  }
   patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
                RemoveOutsDependency>(context);
   // Add the patterns that clean up dead operands and results.
@@ -2282,11 +2314,18 @@ struct LinalgElementwiseOpFusionPass
           LinalgElementwiseOpFusionPass> {
   using impl::LinalgElementwiseOpFusionPassBase<
       LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
+
+  llvm::StringSet<> blacklistedReductionFusionOps;
+
   void runOnOperation() override {
     Operation *op = getOperation();
     MLIRContext *context = op->getContext();
     RewritePatternSet patterns(context);
 
+    for (const auto &opName : reductionFusionOpBlacklist) {
+      blacklistedReductionFusionOps.insert(opName);
+    }
+
     // Add folding with reshape by expansion patterns.
     ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
       Operation *producer = fusedOperand->get().getDefiningOp();
@@ -2294,7 +2333,8 @@ struct LinalgElementwiseOpFusionPass
     };
 
     // Add elementwise op fusion patterns.
-    populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
+    populateElementwiseOpsFusionPatterns(patterns, defaultControlFn,
+                                         &blacklistedReductionFusionOps);
     populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
     tensor::populateBubbleUpExpandShapePatterns(patterns);
 
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
new file mode 100644
index 0000000000000..222c73b7695ee
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=blacklist-ops-for-reduction -split-input-file | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0) -> (0, d0)>
+#map2 = affine_map<(d0) -> (0)>
+func.func @consumer_with_reduction_blacklist(%arg0: tensor<1x10xf32>,
+                              %arg1: tensor<1x10xf32>,
+                              %arg2: tensor<1xf32>) -> tensor<1xf32> {
+  %init = tensor.empty() : tensor<1x10xf32>
+  %0 = linalg.generic
+    {indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel"]}
+    ins(%arg0, %arg1 : tensor<1x10xf32>, tensor<1x10xf32>)
+    outs(%init : tensor<1x10xf32>) {
+  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+    %2 = arith.addf %arg3, %arg4 : f32
+    linalg.yield %2 : f32
+  } -> tensor<1x10xf32>
+  %1 = linalg.generic
+    {indexing_maps = [#map1, #map2],
+     iterator_types = ["reduction"]}
+    ins(%0 : tensor<1x10xf32>)
+    outs(%arg2 : tensor<1xf32>)  {
+  ^bb0(%arg3: f32, %arg4: f32):
+    %2 = arith.addf %arg3, %arg4 : f32
+    linalg.yield %2 : f32
+  } -> tensor<1xf32>
+  return %1 : tensor<1xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (0, d0)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (0)>
+//      CHECK: func @consumer_with_reduction_blacklist(%[[ARG0:.+]]: tensor<1x10xf32>, %[[ARG1:.+]]: tensor<1x10xf32>, %[[ARG2:.+]]: tensor<1xf32>)
+//      CHECK:   %[[RES0:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME:     iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<1x10xf32>, tensor<1x10xf32>)
+//      CHECK:   ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32, %[[T2:.+]]: f32)
+//      CHECK:     %[[T3:.+]] = arith.addf %[[T0]], %[[T1]] : f32
+//      CHECK:     linalg.yield %[[T3]]
+//      CHECK:   %[[RES1:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:     iterator_types = ["reduction"]
+// CHECK-SAME:     ins(%[[RES0]] : tensor<1x10xf32>)
+//      CHECK:   ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32)
+//      CHECK:     %[[T2:.+]] = arith.addf %[[T0]], %[[T1]] : f32
+//      CHECK:     linalg.yield %[[T2]]
+//      CHECK:   return %[[RES1]]
+
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index cb215197253bb..801d72c6c9eac 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -60,8 +60,9 @@ struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
   LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *fusableOperand = nullptr;
+    llvm::StringSet<> blacklist;
     for (OpOperand &operand : genericOp->getOpOperands()) {
-      if (linalg::areElementwiseOpsFusable(&operand)) {
+      if (linalg::areElementwiseOpsFusable(&operand, blacklist)) {
         fusableOperand = &operand;
         break;
       }
@@ -70,7 +71,7 @@ struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
       return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
     }
     std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
-        linalg::fuseElementwiseOps(rewriter, fusableOperand);
+        linalg::fuseElementwiseOps(rewriter, fusableOperand, blacklist);
     if (!fusionResult)
       return rewriter.notifyMatchFailure(genericOp, "fusion failed");
     for (auto [origValue, replacement] : fusionResult->replacements) {
@@ -143,6 +144,12 @@ struct TestLinalgElementwiseFusion
       llvm::cl::desc("Test fusion of producer ops with multiple uses"),
       llvm::cl::init(false)};
 
+  Option<bool> blacklistOpsForReduction{
+      *this, "blacklist-ops-for-reduction",
+      llvm::cl::desc(
+          "Test fusion of generic operations with a control function."),
+      llvm::cl::init(false)};
+
   ListOption<int64_t> collapseDimensions{
       *this, "collapse-dimensions-control",
       llvm::cl::desc("Test controlling dimension collapse pattern")};
@@ -257,6 +264,20 @@ struct TestLinalgElementwiseFusion
       return;
     }
 
+    if (blacklistOpsForReduction) {
+      RewritePatternSet fusionPatterns(context);
+      auto controlFn = [](OpOperand *operand) { return true; };
+      llvm::StringSet<> blacklist;
+      blacklist.insert("arith.addf");
+
+      linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn,
+                                                   &blacklist);
+      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                              std::move(fusionPatterns))))
+        return signalPassFailure();
+      return;
+    }
+
     if (!collapseDimensions.empty()) {
       SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
                                    collapseDimensions.end());

@llvmbot
Copy link
Member

llvmbot commented Jun 14, 2025

@llvm/pr-subscribers-mlir

Author: Evan Liu (Evanyl)

Changes

We should allow the user to blacklist certain linalg elementwise fusion patterns as they may be expensive on their hardware.


Full diff: https://github.com/llvm/llvm-project/pull/144176.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+3)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+6-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+51-11)
  • (added) mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir (+49)
  • (modified) mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp (+23-2)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..5db234770ef5c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -72,6 +72,9 @@ def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
 
 def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
   let summary = "Fuse elementwise operations on tensors";
+  let options = [ListOption<"reductionFusionOpBlacklist",
+                            "reduction-fusion-blacklist", "std::string",
+                            "List of ops to blacklist for reduction fusion.">];
   let dependentDialects = [
     "affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
   ];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..e4968930ce554 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -450,7 +450,8 @@ using ControlSplitReductionFn =
 /// Return true if two `linalg.generic` operations with producer/consumer
 /// relationship through `fusedOperand` can be fused using elementwise op
 /// fusion.
-bool areElementwiseOpsFusable(OpOperand *fusedOperand);
+bool areElementwiseOpsFusable(OpOperand *fusedOperand,
+                              llvm::StringSet<> &blacklistedReductionFusionOps);
 
 /// Promote memref.subviews feeding linalg-on-buffers operations.
 LogicalResult promoteSubviewsPrecondition(Operation *op,
@@ -505,7 +506,8 @@ struct ElementwiseOpFusionResult {
   llvm::DenseMap<Value, Value> replacements;
 };
 FailureOr<ElementwiseOpFusionResult>
-fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
+fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand,
+                   llvm::StringSet<> &blacklistedReductionFusionOps);
 
 /// Returns a set of indices of the producer's results which would
 /// be preserved after the fusion.
@@ -1783,7 +1785,8 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
 /// when both operations are fusable elementwise operations.
 void populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns,
-    const ControlFusionFn &controlElementwiseOpFusion);
+    const ControlFusionFn &controlElementwiseOpFusion,
+    llvm::StringSet<> *blacklistedReductionFusionOps = nullptr);
 
 /// Function type which is used to control propagation of linalg.pack/unpack
 /// ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f97ed3d6d5111..e10bba04951b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -104,6 +104,20 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
              indexingMaps, producer.getContext())) != AffineMap();
 }
 
+static bool
+shouldFuseIntoReduction(linalg::GenericOp op,
+                        llvm::StringSet<> &blacklistedReductionFusionOps) {
+  for (Operation &innerOp : op.getRegion().front()) {
+    if (innerOp.hasTrait<OpTrait::IsTerminator>())
+      continue;
+
+    if (blacklistedReductionFusionOps.contains(
+            innerOp.getName().getStringRef()))
+      return false;
+  }
+  return true;
+}
+
 /// Returns a set of indices of the producer's results which would
 /// be preserved after the fusion.
 /// * There is a chance that the implementation of the transformation does not
@@ -136,7 +150,8 @@ llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
 }
 
 /// Conditions for elementwise fusion of generic operations.
-bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
+bool mlir::linalg::areElementwiseOpsFusable(
+    OpOperand *fusedOperand, llvm::StringSet<> &blacklistedReductionFusionOps) {
   if (!fusedOperand)
     return false;
 
@@ -159,6 +174,10 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
   if (producer.getNumParallelLoops() != producer.getNumLoops())
     return false;
 
+  if (consumer.getNumReductionLoops() > 0 &&
+      !shouldFuseIntoReduction(producer, blacklistedReductionFusionOps))
+    return false;
+
   // Only allow fusing the producer of an input operand for now.
   // TODO: allow fusing the producer of an output operand.
   if (!consumer.isDpsInput(fusedOperand))
@@ -335,10 +354,12 @@ static void generateFusedElementwiseOpRegion(
 }
 
 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
-mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
-                                 OpOperand *fusedOperand) {
-  assert(areElementwiseOpsFusable(fusedOperand) &&
-         "expected elementwise operation pre-conditions to pass");
+mlir::linalg::fuseElementwiseOps(
+    RewriterBase &rewriter, OpOperand *fusedOperand,
+    llvm::StringSet<> &blacklistedReductionFusionOps) {
+  assert(
+      areElementwiseOpsFusable(fusedOperand, blacklistedReductionFusionOps) &&
+      "expected elementwise operation pre-conditions to pass");
   auto producerResult = cast<OpResult>(fusedOperand->get());
   auto producer = cast<GenericOp>(producerResult.getOwner());
   auto consumer = cast<GenericOp>(fusedOperand->getOwner());
@@ -462,16 +483,19 @@ namespace {
 /// Patterns to fuse a generic op, with the producer of its operands.
 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 public:
+  llvm::StringSet<> &blacklistedReductionFusionOps;
   FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
+                     llvm::StringSet<> &blacklistedReductionFusionOps,
                      PatternBenefit benefit = 1)
       : OpRewritePattern<GenericOp>(context, benefit),
+        blacklistedReductionFusionOps(blacklistedReductionFusionOps),
         controlFn(std::move(fun)) {}
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
     // Find the first operand that is defined by another generic op on tensors.
     for (OpOperand &opOperand : genericOp->getOpOperands()) {
-      if (!areElementwiseOpsFusable(&opOperand))
+      if (!areElementwiseOpsFusable(&opOperand, blacklistedReductionFusionOps))
         continue;
       if (!controlFn(&opOperand))
         continue;
@@ -479,8 +503,8 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
       Operation *producer = opOperand.get().getDefiningOp();
 
       // Find the producer of the operand.
-      FailureOr<ElementwiseOpFusionResult> fusionResult =
-          fuseElementwiseOps(rewriter, &opOperand);
+      FailureOr<ElementwiseOpFusionResult> fusionResult = fuseElementwiseOps(
+          rewriter, &opOperand, blacklistedReductionFusionOps);
       if (failed(fusionResult))
         return rewriter.notifyMatchFailure(genericOp, "fusion failed");
 
@@ -2248,9 +2272,17 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
 
 void mlir::linalg::populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns,
-    const ControlFusionFn &controlElementwiseOpsFusion) {
+    const ControlFusionFn &controlElementwiseOpsFusion,
+    llvm::StringSet<> *blacklistedReductionFusionOps) {
   auto *context = patterns.getContext();
-  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
+  if (blacklistedReductionFusionOps)
+    patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion,
+                                     *blacklistedReductionFusionOps);
+  else {
+    llvm::StringSet<> emptyBlacklistedReductionFusionOps;
+    patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion,
+                                     emptyBlacklistedReductionFusionOps);
+  }
   patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
                RemoveOutsDependency>(context);
   // Add the patterns that clean up dead operands and results.
@@ -2282,11 +2314,18 @@ struct LinalgElementwiseOpFusionPass
           LinalgElementwiseOpFusionPass> {
   using impl::LinalgElementwiseOpFusionPassBase<
       LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
+
+  llvm::StringSet<> blacklistedReductionFusionOps;
+
   void runOnOperation() override {
     Operation *op = getOperation();
     MLIRContext *context = op->getContext();
     RewritePatternSet patterns(context);
 
+    for (const auto &opName : reductionFusionOpBlacklist) {
+      blacklistedReductionFusionOps.insert(opName);
+    }
+
     // Add folding with reshape by expansion patterns.
     ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
       Operation *producer = fusedOperand->get().getDefiningOp();
@@ -2294,7 +2333,8 @@ struct LinalgElementwiseOpFusionPass
     };
 
     // Add elementwise op fusion patterns.
-    populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
+    populateElementwiseOpsFusionPatterns(patterns, defaultControlFn,
+                                         &blacklistedReductionFusionOps);
     populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
     tensor::populateBubbleUpExpandShapePatterns(patterns);
 
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
new file mode 100644
index 0000000000000..222c73b7695ee
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=blacklist-ops-for-reduction -split-input-file | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0) -> (0, d0)>
+#map2 = affine_map<(d0) -> (0)>
+func.func @consumer_with_reduction_blacklist(%arg0: tensor<1x10xf32>,
+                              %arg1: tensor<1x10xf32>,
+                              %arg2: tensor<1xf32>) -> tensor<1xf32> {
+  %init = tensor.empty() : tensor<1x10xf32>
+  %0 = linalg.generic
+    {indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel"]}
+    ins(%arg0, %arg1 : tensor<1x10xf32>, tensor<1x10xf32>)
+    outs(%init : tensor<1x10xf32>) {
+  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+    %2 = arith.addf %arg3, %arg4 : f32
+    linalg.yield %2 : f32
+  } -> tensor<1x10xf32>
+  %1 = linalg.generic
+    {indexing_maps = [#map1, #map2],
+     iterator_types = ["reduction"]}
+    ins(%0 : tensor<1x10xf32>)
+    outs(%arg2 : tensor<1xf32>)  {
+  ^bb0(%arg3: f32, %arg4: f32):
+    %2 = arith.addf %arg3, %arg4 : f32
+    linalg.yield %2 : f32
+  } -> tensor<1xf32>
+  return %1 : tensor<1xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (0, d0)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (0)>
+//      CHECK: func @consumer_with_reduction_blacklist(%[[ARG0:.+]]: tensor<1x10xf32>, %[[ARG1:.+]]: tensor<1x10xf32>, %[[ARG2:.+]]: tensor<1xf32>)
+//      CHECK:   %[[RES0:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME:     iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<1x10xf32>, tensor<1x10xf32>)
+//      CHECK:   ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32, %[[T2:.+]]: f32)
+//      CHECK:     %[[T3:.+]] = arith.addf %[[T0]], %[[T1]] : f32
+//      CHECK:     linalg.yield %[[T3]]
+//      CHECK:   %[[RES1:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:     iterator_types = ["reduction"]
+// CHECK-SAME:     ins(%[[RES0]] : tensor<1x10xf32>)
+//      CHECK:   ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32)
+//      CHECK:     %[[T2:.+]] = arith.addf %[[T0]], %[[T1]] : f32
+//      CHECK:     linalg.yield %[[T2]]
+//      CHECK:   return %[[RES1]]
+
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index cb215197253bb..801d72c6c9eac 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -60,8 +60,9 @@ struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
   LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *fusableOperand = nullptr;
+    llvm::StringSet<> blacklist;
     for (OpOperand &operand : genericOp->getOpOperands()) {
-      if (linalg::areElementwiseOpsFusable(&operand)) {
+      if (linalg::areElementwiseOpsFusable(&operand, blacklist)) {
         fusableOperand = &operand;
         break;
       }
@@ -70,7 +71,7 @@ struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
       return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
     }
     std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
-        linalg::fuseElementwiseOps(rewriter, fusableOperand);
+        linalg::fuseElementwiseOps(rewriter, fusableOperand, blacklist);
     if (!fusionResult)
       return rewriter.notifyMatchFailure(genericOp, "fusion failed");
     for (auto [origValue, replacement] : fusionResult->replacements) {
@@ -143,6 +144,12 @@ struct TestLinalgElementwiseFusion
       llvm::cl::desc("Test fusion of producer ops with multiple uses"),
       llvm::cl::init(false)};
 
+  Option<bool> blacklistOpsForReduction{
+      *this, "blacklist-ops-for-reduction",
+      llvm::cl::desc(
+          "Test fusion of generic operations with a control function."),
+      llvm::cl::init(false)};
+
   ListOption<int64_t> collapseDimensions{
       *this, "collapse-dimensions-control",
       llvm::cl::desc("Test controlling dimension collapse pattern")};
@@ -257,6 +264,20 @@ struct TestLinalgElementwiseFusion
       return;
     }
 
+    if (blacklistOpsForReduction) {
+      RewritePatternSet fusionPatterns(context);
+      auto controlFn = [](OpOperand *operand) { return true; };
+      llvm::StringSet<> blacklist;
+      blacklist.insert("arith.addf");
+
+      linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn,
+                                                   &blacklist);
+      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                              std::move(fusionPatterns))))
+        return signalPassFailure();
+      return;
+    }
+
     if (!collapseDimensions.empty()) {
       SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
                                    collapseDimensions.end());

@Evanyl Evanyl force-pushed the user/evanyl/06_13_blacklist_ops_for_linalg_reduction_fusion branch from 0b258e1 to afedeae Compare June 14, 2025 01:55
@rengolin
Copy link
Member

This looks like a job for a cost model.
@rolfmorel

@@ -0,0 +1,49 @@
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=blacklist-ops-for-reduction -split-input-file | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm - here the blacklist is left empty, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This specific tests adds arith.addf to the blacklist.

@rolfmorel
Copy link
Contributor

As @rengolin suggests, it is worth considering making this configurable as a function of the IR. In particular, I can see it being useful to be able to annotate a parent scope with a #dlti.map<"reduction-fusion-blacklist" = "linalg.elementwise<mul>" attribute and this fusion pass respecting this attribute for this scope. This could be achieved by doing a dlti::query(op, ["reduction-fusion-blacklist"]) at any op which is being considered for fusion. (A step further would be to have DLTI attribute that acts like an actual cost model, taking into account, e.g. tensor dim sizes and hardware info available through other DLTI attributes.)

That also highlights that I am not sure the current approach works with the abstracted linalg.elementwise<operator> op. Does it?

@Evanyl
Copy link
Contributor Author

Evanyl commented Jun 16, 2025

As @rengolin suggests, it is worth considering making this configurable as a function of the IR. In particular, I can see it being useful to be able to annotate a parent scope with a #dlti.map<"reduction-fusion-blacklist" = "linalg.elementwise<mul>" attribute and this fusion pass respecting this attribute for this scope. This could be achieved by doing a dlti::query(op, ["reduction-fusion-blacklist"]) at any op which is being considered for fusion. (A step further would be to have DLTI attribute that acts like an actual cost model, taking into account, e.g. tensor dim sizes and hardware info available through other DLTI attributes.)

That also highlights that I am not sure the current approach works with the abstracted linalg.elementwise<operator> op. Does it?

This is a good idea, thank you! I'll switch over to using the dlti.map instead. Also might be a noob question but doesn't this pass only operate linalg.generic? Where is linalg.elementwise, i don't see it in any of the tests either.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we need this. The fusion control function should already allow for callers to blacklist operations.

@MaheshRavishankar
Copy link
Contributor

As @rengolin suggests, it is worth considering making this configurable as a function of the IR. In particular, I can see it being useful to be able to annotate a parent scope with a #dlti.map<"reduction-fusion-blacklist" = "linalg.elementwise<mul>" attribute and this fusion pass respecting this attribute for this scope. This could be achieved by doing a dlti::query(op, ["reduction-fusion-blacklist"]) at any op which is being considered for fusion. (A step further would be to have DLTI attribute that acts like an actual cost model, taking into account, e.g. tensor dim sizes and hardware info available through other DLTI attributes.)

That also highlights that I am not sure the current approach works with the abstracted linalg.elementwise<operator> op. Does it?

Not sure about that. I am opposed to having any bespoke way of specifying "blacklist" etc. This is precisely why there is a call back to the fusion function that allows you to disallow any fusion you want based on cost model fro mthe caller without that leaking into the implementation of the function. All of these could live downstream and not try to have a bespoke way upstream.

@Evanyl
Copy link
Contributor Author

Evanyl commented Jun 17, 2025

I understand now, users should create their own version of this pass and use the control function as the cost model here. Thanks everyone for the clarification!

@Evanyl Evanyl closed this Jun 17, 2025
@Evanyl Evanyl deleted the user/evanyl/06_13_blacklist_ops_for_linalg_reduction_fusion branch June 17, 2025 00:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants