-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir] Allow blacklist ops for reduction linalg elementwise fusion #144176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Allow blacklist ops for reduction linalg elementwise fusion #144176
Conversation
|
@llvm/pr-subscribers-mlir-linalg Author: Evan Liu (Evanyl) ChangesWe should allow the user to blacklist certain linalg elementwise fusion patterns as they may be expensive on their hardware. Full diff: https://github.com/llvm/llvm-project/pull/144176.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..5db234770ef5c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -72,6 +72,9 @@ def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
let summary = "Fuse elementwise operations on tensors";
+ let options = [ListOption<"reductionFusionOpBlacklist",
+ "reduction-fusion-blacklist", "std::string",
+ "List of ops to blacklist for reduction fusion.">];
let dependentDialects = [
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..e4968930ce554 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -450,7 +450,8 @@ using ControlSplitReductionFn =
/// Return true if two `linalg.generic` operations with producer/consumer
/// relationship through `fusedOperand` can be fused using elementwise op
/// fusion.
-bool areElementwiseOpsFusable(OpOperand *fusedOperand);
+bool areElementwiseOpsFusable(OpOperand *fusedOperand,
+ llvm::StringSet<> &blacklistedReductionFusionOps);
/// Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult promoteSubviewsPrecondition(Operation *op,
@@ -505,7 +506,8 @@ struct ElementwiseOpFusionResult {
llvm::DenseMap<Value, Value> replacements;
};
FailureOr<ElementwiseOpFusionResult>
-fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
+fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand,
+ llvm::StringSet<> &blacklistedReductionFusionOps);
/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
@@ -1783,7 +1785,8 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
/// when both operations are fusable elementwise operations.
void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
- const ControlFusionFn &controlElementwiseOpFusion);
+ const ControlFusionFn &controlElementwiseOpFusion,
+ llvm::StringSet<> *blacklistedReductionFusionOps = nullptr);
/// Function type which is used to control propagation of linalg.pack/unpack
/// ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f97ed3d6d5111..e10bba04951b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -104,6 +104,20 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
indexingMaps, producer.getContext())) != AffineMap();
}
+static bool
+shouldFuseIntoReduction(linalg::GenericOp op,
+ llvm::StringSet<> &blacklistedReductionFusionOps) {
+ for (Operation &innerOp : op.getRegion().front()) {
+ if (innerOp.hasTrait<OpTrait::IsTerminator>())
+ continue;
+
+ if (blacklistedReductionFusionOps.contains(
+ innerOp.getName().getStringRef()))
+ return false;
+ }
+ return true;
+}
+
/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
/// * There is a chance that the implementation of the transformation does not
@@ -136,7 +150,8 @@ llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
}
/// Conditions for elementwise fusion of generic operations.
-bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
+bool mlir::linalg::areElementwiseOpsFusable(
+ OpOperand *fusedOperand, llvm::StringSet<> &blacklistedReductionFusionOps) {
if (!fusedOperand)
return false;
@@ -159,6 +174,10 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (producer.getNumParallelLoops() != producer.getNumLoops())
return false;
+ if (consumer.getNumReductionLoops() > 0 &&
+ !shouldFuseIntoReduction(producer, blacklistedReductionFusionOps))
+ return false;
+
// Only allow fusing the producer of an input operand for now.
// TODO: allow fusing the producer of an output operand.
if (!consumer.isDpsInput(fusedOperand))
@@ -335,10 +354,12 @@ static void generateFusedElementwiseOpRegion(
}
FailureOr<mlir::linalg::ElementwiseOpFusionResult>
-mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
- OpOperand *fusedOperand) {
- assert(areElementwiseOpsFusable(fusedOperand) &&
- "expected elementwise operation pre-conditions to pass");
+mlir::linalg::fuseElementwiseOps(
+ RewriterBase &rewriter, OpOperand *fusedOperand,
+ llvm::StringSet<> &blacklistedReductionFusionOps) {
+ assert(
+ areElementwiseOpsFusable(fusedOperand, blacklistedReductionFusionOps) &&
+ "expected elementwise operation pre-conditions to pass");
auto producerResult = cast<OpResult>(fusedOperand->get());
auto producer = cast<GenericOp>(producerResult.getOwner());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
@@ -462,16 +483,19 @@ namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
public:
+ llvm::StringSet<> &blacklistedReductionFusionOps;
FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
+ llvm::StringSet<> &blacklistedReductionFusionOps,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
+ blacklistedReductionFusionOps(blacklistedReductionFusionOps),
controlFn(std::move(fun)) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : genericOp->getOpOperands()) {
- if (!areElementwiseOpsFusable(&opOperand))
+ if (!areElementwiseOpsFusable(&opOperand, blacklistedReductionFusionOps))
continue;
if (!controlFn(&opOperand))
continue;
@@ -479,8 +503,8 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
Operation *producer = opOperand.get().getDefiningOp();
// Find the producer of the operand.
- FailureOr<ElementwiseOpFusionResult> fusionResult =
- fuseElementwiseOps(rewriter, &opOperand);
+ FailureOr<ElementwiseOpFusionResult> fusionResult = fuseElementwiseOps(
+ rewriter, &opOperand, blacklistedReductionFusionOps);
if (failed(fusionResult))
return rewriter.notifyMatchFailure(genericOp, "fusion failed");
@@ -2248,9 +2272,17 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
- const ControlFusionFn &controlElementwiseOpsFusion) {
+ const ControlFusionFn &controlElementwiseOpsFusion,
+ llvm::StringSet<> *blacklistedReductionFusionOps) {
auto *context = patterns.getContext();
- patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
+ if (blacklistedReductionFusionOps)
+ patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion,
+ *blacklistedReductionFusionOps);
+ else {
+ llvm::StringSet<> emptyBlacklistedReductionFusionOps;
+ patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion,
+ emptyBlacklistedReductionFusionOps);
+ }
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
RemoveOutsDependency>(context);
// Add the patterns that clean up dead operands and results.
@@ -2282,11 +2314,18 @@ struct LinalgElementwiseOpFusionPass
LinalgElementwiseOpFusionPass> {
using impl::LinalgElementwiseOpFusionPassBase<
LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
+
+ llvm::StringSet<> blacklistedReductionFusionOps;
+
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
+ for (const auto &opName : reductionFusionOpBlacklist) {
+ blacklistedReductionFusionOps.insert(opName);
+ }
+
// Add folding with reshape by expansion patterns.
ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
@@ -2294,7 +2333,8 @@ struct LinalgElementwiseOpFusionPass
};
// Add elementwise op fusion patterns.
- populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
+ populateElementwiseOpsFusionPatterns(patterns, defaultControlFn,
+ &blacklistedReductionFusionOps);
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
tensor::populateBubbleUpExpandShapePatterns(patterns);
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
new file mode 100644
index 0000000000000..222c73b7695ee
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=blacklist-ops-for-reduction -split-input-file | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0) -> (0, d0)>
+#map2 = affine_map<(d0) -> (0)>
+func.func @consumer_with_reduction_blacklist(%arg0: tensor<1x10xf32>,
+ %arg1: tensor<1x10xf32>,
+ %arg2: tensor<1xf32>) -> tensor<1xf32> {
+ %init = tensor.empty() : tensor<1x10xf32>
+ %0 = linalg.generic
+ {indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<1x10xf32>, tensor<1x10xf32>)
+ outs(%init : tensor<1x10xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %2 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %2 : f32
+ } -> tensor<1x10xf32>
+ %1 = linalg.generic
+ {indexing_maps = [#map1, #map2],
+ iterator_types = ["reduction"]}
+ ins(%0 : tensor<1x10xf32>)
+ outs(%arg2 : tensor<1xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32):
+ %2 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %2 : f32
+ } -> tensor<1xf32>
+ return %1 : tensor<1xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (0)>
+// CHECK: func @consumer_with_reduction_blacklist(%[[ARG0:.+]]: tensor<1x10xf32>, %[[ARG1:.+]]: tensor<1x10xf32>, %[[ARG2:.+]]: tensor<1xf32>)
+// CHECK: %[[RES0:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x10xf32>, tensor<1x10xf32>)
+// CHECK: ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32, %[[T2:.+]]: f32)
+// CHECK: %[[T3:.+]] = arith.addf %[[T0]], %[[T1]] : f32
+// CHECK: linalg.yield %[[T3]]
+// CHECK: %[[RES1:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["reduction"]
+// CHECK-SAME: ins(%[[RES0]] : tensor<1x10xf32>)
+// CHECK: ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32)
+// CHECK: %[[T2:.+]] = arith.addf %[[T0]], %[[T1]] : f32
+// CHECK: linalg.yield %[[T2]]
+// CHECK: return %[[RES1]]
+
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index cb215197253bb..801d72c6c9eac 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -60,8 +60,9 @@ struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
OpOperand *fusableOperand = nullptr;
+ llvm::StringSet<> blacklist;
for (OpOperand &operand : genericOp->getOpOperands()) {
- if (linalg::areElementwiseOpsFusable(&operand)) {
+ if (linalg::areElementwiseOpsFusable(&operand, blacklist)) {
fusableOperand = &operand;
break;
}
@@ -70,7 +71,7 @@ struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
}
std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
- linalg::fuseElementwiseOps(rewriter, fusableOperand);
+ linalg::fuseElementwiseOps(rewriter, fusableOperand, blacklist);
if (!fusionResult)
return rewriter.notifyMatchFailure(genericOp, "fusion failed");
for (auto [origValue, replacement] : fusionResult->replacements) {
@@ -143,6 +144,12 @@ struct TestLinalgElementwiseFusion
llvm::cl::desc("Test fusion of producer ops with multiple uses"),
llvm::cl::init(false)};
+ Option<bool> blacklistOpsForReduction{
+ *this, "blacklist-ops-for-reduction",
+ llvm::cl::desc(
+ "Test fusion of generic operations with a control function."),
+ llvm::cl::init(false)};
+
ListOption<int64_t> collapseDimensions{
*this, "collapse-dimensions-control",
llvm::cl::desc("Test controlling dimension collapse pattern")};
@@ -257,6 +264,20 @@ struct TestLinalgElementwiseFusion
return;
}
+ if (blacklistOpsForReduction) {
+ RewritePatternSet fusionPatterns(context);
+ auto controlFn = [](OpOperand *operand) { return true; };
+ llvm::StringSet<> blacklist;
+ blacklist.insert("arith.addf");
+
+ linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn,
+ &blacklist);
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(fusionPatterns))))
+ return signalPassFailure();
+ return;
+ }
+
if (!collapseDimensions.empty()) {
SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
collapseDimensions.end());
|
|
@llvm/pr-subscribers-mlir Author: Evan Liu (Evanyl) ChangesWe should allow the user to blacklist certain linalg elementwise fusion patterns as they may be expensive on their hardware. Full diff: https://github.com/llvm/llvm-project/pull/144176.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..5db234770ef5c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -72,6 +72,9 @@ def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
let summary = "Fuse elementwise operations on tensors";
+ let options = [ListOption<"reductionFusionOpBlacklist",
+ "reduction-fusion-blacklist", "std::string",
+ "List of ops to blacklist for reduction fusion.">];
let dependentDialects = [
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..e4968930ce554 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -450,7 +450,8 @@ using ControlSplitReductionFn =
/// Return true if two `linalg.generic` operations with producer/consumer
/// relationship through `fusedOperand` can be fused using elementwise op
/// fusion.
-bool areElementwiseOpsFusable(OpOperand *fusedOperand);
+bool areElementwiseOpsFusable(OpOperand *fusedOperand,
+ llvm::StringSet<> &blacklistedReductionFusionOps);
/// Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult promoteSubviewsPrecondition(Operation *op,
@@ -505,7 +506,8 @@ struct ElementwiseOpFusionResult {
llvm::DenseMap<Value, Value> replacements;
};
FailureOr<ElementwiseOpFusionResult>
-fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
+fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand,
+ llvm::StringSet<> &blacklistedReductionFusionOps);
/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
@@ -1783,7 +1785,8 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
/// when both operations are fusable elementwise operations.
void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
- const ControlFusionFn &controlElementwiseOpFusion);
+ const ControlFusionFn &controlElementwiseOpFusion,
+ llvm::StringSet<> *blacklistedReductionFusionOps = nullptr);
/// Function type which is used to control propagation of linalg.pack/unpack
/// ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f97ed3d6d5111..e10bba04951b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -104,6 +104,20 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
indexingMaps, producer.getContext())) != AffineMap();
}
+static bool
+shouldFuseIntoReduction(linalg::GenericOp op,
+ llvm::StringSet<> &blacklistedReductionFusionOps) {
+ for (Operation &innerOp : op.getRegion().front()) {
+ if (innerOp.hasTrait<OpTrait::IsTerminator>())
+ continue;
+
+ if (blacklistedReductionFusionOps.contains(
+ innerOp.getName().getStringRef()))
+ return false;
+ }
+ return true;
+}
+
/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
/// * There is a chance that the implementation of the transformation does not
@@ -136,7 +150,8 @@ llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
}
/// Conditions for elementwise fusion of generic operations.
-bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
+bool mlir::linalg::areElementwiseOpsFusable(
+ OpOperand *fusedOperand, llvm::StringSet<> &blacklistedReductionFusionOps) {
if (!fusedOperand)
return false;
@@ -159,6 +174,10 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (producer.getNumParallelLoops() != producer.getNumLoops())
return false;
+ if (consumer.getNumReductionLoops() > 0 &&
+ !shouldFuseIntoReduction(producer, blacklistedReductionFusionOps))
+ return false;
+
// Only allow fusing the producer of an input operand for now.
// TODO: allow fusing the producer of an output operand.
if (!consumer.isDpsInput(fusedOperand))
@@ -335,10 +354,12 @@ static void generateFusedElementwiseOpRegion(
}
FailureOr<mlir::linalg::ElementwiseOpFusionResult>
-mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
- OpOperand *fusedOperand) {
- assert(areElementwiseOpsFusable(fusedOperand) &&
- "expected elementwise operation pre-conditions to pass");
+mlir::linalg::fuseElementwiseOps(
+ RewriterBase &rewriter, OpOperand *fusedOperand,
+ llvm::StringSet<> &blacklistedReductionFusionOps) {
+ assert(
+ areElementwiseOpsFusable(fusedOperand, blacklistedReductionFusionOps) &&
+ "expected elementwise operation pre-conditions to pass");
auto producerResult = cast<OpResult>(fusedOperand->get());
auto producer = cast<GenericOp>(producerResult.getOwner());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
@@ -462,16 +483,19 @@ namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
public:
+ llvm::StringSet<> &blacklistedReductionFusionOps;
FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
+ llvm::StringSet<> &blacklistedReductionFusionOps,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
+ blacklistedReductionFusionOps(blacklistedReductionFusionOps),
controlFn(std::move(fun)) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : genericOp->getOpOperands()) {
- if (!areElementwiseOpsFusable(&opOperand))
+ if (!areElementwiseOpsFusable(&opOperand, blacklistedReductionFusionOps))
continue;
if (!controlFn(&opOperand))
continue;
@@ -479,8 +503,8 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
Operation *producer = opOperand.get().getDefiningOp();
// Find the producer of the operand.
- FailureOr<ElementwiseOpFusionResult> fusionResult =
- fuseElementwiseOps(rewriter, &opOperand);
+ FailureOr<ElementwiseOpFusionResult> fusionResult = fuseElementwiseOps(
+ rewriter, &opOperand, blacklistedReductionFusionOps);
if (failed(fusionResult))
return rewriter.notifyMatchFailure(genericOp, "fusion failed");
@@ -2248,9 +2272,17 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
- const ControlFusionFn &controlElementwiseOpsFusion) {
+ const ControlFusionFn &controlElementwiseOpsFusion,
+ llvm::StringSet<> *blacklistedReductionFusionOps) {
auto *context = patterns.getContext();
- patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
+ if (blacklistedReductionFusionOps)
+ patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion,
+ *blacklistedReductionFusionOps);
+ else {
+ llvm::StringSet<> emptyBlacklistedReductionFusionOps;
+ patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion,
+ emptyBlacklistedReductionFusionOps);
+ }
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
RemoveOutsDependency>(context);
// Add the patterns that clean up dead operands and results.
@@ -2282,11 +2314,18 @@ struct LinalgElementwiseOpFusionPass
LinalgElementwiseOpFusionPass> {
using impl::LinalgElementwiseOpFusionPassBase<
LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
+
+ llvm::StringSet<> blacklistedReductionFusionOps;
+
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
+ for (const auto &opName : reductionFusionOpBlacklist) {
+ blacklistedReductionFusionOps.insert(opName);
+ }
+
// Add folding with reshape by expansion patterns.
ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
@@ -2294,7 +2333,8 @@ struct LinalgElementwiseOpFusionPass
};
// Add elementwise op fusion patterns.
- populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
+ populateElementwiseOpsFusionPatterns(patterns, defaultControlFn,
+ &blacklistedReductionFusionOps);
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
tensor::populateBubbleUpExpandShapePatterns(patterns);
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
new file mode 100644
index 0000000000000..222c73b7695ee
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-blacklist.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=blacklist-ops-for-reduction -split-input-file | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0) -> (0, d0)>
+#map2 = affine_map<(d0) -> (0)>
+func.func @consumer_with_reduction_blacklist(%arg0: tensor<1x10xf32>,
+ %arg1: tensor<1x10xf32>,
+ %arg2: tensor<1xf32>) -> tensor<1xf32> {
+ %init = tensor.empty() : tensor<1x10xf32>
+ %0 = linalg.generic
+ {indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<1x10xf32>, tensor<1x10xf32>)
+ outs(%init : tensor<1x10xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %2 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %2 : f32
+ } -> tensor<1x10xf32>
+ %1 = linalg.generic
+ {indexing_maps = [#map1, #map2],
+ iterator_types = ["reduction"]}
+ ins(%0 : tensor<1x10xf32>)
+ outs(%arg2 : tensor<1xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32):
+ %2 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %2 : f32
+ } -> tensor<1xf32>
+ return %1 : tensor<1xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (0)>
+// CHECK: func @consumer_with_reduction_blacklist(%[[ARG0:.+]]: tensor<1x10xf32>, %[[ARG1:.+]]: tensor<1x10xf32>, %[[ARG2:.+]]: tensor<1xf32>)
+// CHECK: %[[RES0:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x10xf32>, tensor<1x10xf32>)
+// CHECK: ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32, %[[T2:.+]]: f32)
+// CHECK: %[[T3:.+]] = arith.addf %[[T0]], %[[T1]] : f32
+// CHECK: linalg.yield %[[T3]]
+// CHECK: %[[RES1:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["reduction"]
+// CHECK-SAME: ins(%[[RES0]] : tensor<1x10xf32>)
+// CHECK: ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32)
+// CHECK: %[[T2:.+]] = arith.addf %[[T0]], %[[T1]] : f32
+// CHECK: linalg.yield %[[T2]]
+// CHECK: return %[[RES1]]
+
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index cb215197253bb..801d72c6c9eac 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -60,8 +60,9 @@ struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
OpOperand *fusableOperand = nullptr;
+ llvm::StringSet<> blacklist;
for (OpOperand &operand : genericOp->getOpOperands()) {
- if (linalg::areElementwiseOpsFusable(&operand)) {
+ if (linalg::areElementwiseOpsFusable(&operand, blacklist)) {
fusableOperand = &operand;
break;
}
@@ -70,7 +71,7 @@ struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
}
std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
- linalg::fuseElementwiseOps(rewriter, fusableOperand);
+ linalg::fuseElementwiseOps(rewriter, fusableOperand, blacklist);
if (!fusionResult)
return rewriter.notifyMatchFailure(genericOp, "fusion failed");
for (auto [origValue, replacement] : fusionResult->replacements) {
@@ -143,6 +144,12 @@ struct TestLinalgElementwiseFusion
llvm::cl::desc("Test fusion of producer ops with multiple uses"),
llvm::cl::init(false)};
+ Option<bool> blacklistOpsForReduction{
+ *this, "blacklist-ops-for-reduction",
+ llvm::cl::desc(
+ "Test fusion of generic operations with a control function."),
+ llvm::cl::init(false)};
+
ListOption<int64_t> collapseDimensions{
*this, "collapse-dimensions-control",
llvm::cl::desc("Test controlling dimension collapse pattern")};
@@ -257,6 +264,20 @@ struct TestLinalgElementwiseFusion
return;
}
+ if (blacklistOpsForReduction) {
+ RewritePatternSet fusionPatterns(context);
+ auto controlFn = [](OpOperand *operand) { return true; };
+ llvm::StringSet<> blacklist;
+ blacklist.insert("arith.addf");
+
+ linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn,
+ &blacklist);
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(fusionPatterns))))
+ return signalPassFailure();
+ return;
+ }
+
if (!collapseDimensions.empty()) {
SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
collapseDimensions.end());
|
0b258e1 to
afedeae
Compare
|
This looks like a job for a cost model. |
| @@ -0,0 +1,49 @@ | |||
| // RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=blacklist-ops-for-reduction -split-input-file | FileCheck %s | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm - here the blacklist is left empty, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This specific tests adds arith.addf to the blacklist.
|
As @rengolin suggests, it is worth considering making this configurable as a function of the IR. In particular, I can see it being useful to be able to annotate a parent scope with a That also highlights that I am not sure the current approach works with the abstracted |
This is a good idea, thank you! I'll switch over to using the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think we need this. The fusion control function should already allow for callers to blacklist operations.
Not sure about that. I am opposed to having any bespoke way of specifying "blacklist" etc. This is precisely why there is a call back to the fusion function that allows you to disallow any fusion you want based on cost model fro mthe caller without that leaking into the implementation of the function. All of these could live downstream and not try to have a bespoke way upstream. |
|
I understand now, users should create their own version of this pass and use the control function as the cost model here. Thanks everyone for the clarification! |
We should allow the user to blacklist certain linalg elementwise fusion patterns as they may be expensive on their hardware.