Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
let dependentDialects = [
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
];
let options = [
Option<"removeOutsDependency", "remove-outs-dependency", "bool",
/*default=*/"true",
"Replace out by tensor.empty">,
];
}

def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> {
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1701,7 +1701,8 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
/// when both operations are fusable elementwise operations.
void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpFusion);
const ControlFusionFn &controlElementwiseOpFusion,
bool replaceOutsDependency = true);

/// Function type which is used to control propagation of tensor.pack/unpack
/// ops.
Expand Down
11 changes: 7 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2134,11 +2134,13 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(

void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpsFusion) {
const ControlFusionFn &controlElementwiseOpsFusion,
bool removeOutsDependency) {
auto *context = patterns.getContext();
patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
RemoveOutsDependency>(context);
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant>(context);
if (removeOutsDependency)
patterns.add<RemoveOutsDependency>(context);
// Add the patterns that clean up dead operands and results.
populateEraseUnusedOperandsAndResultsPatterns(patterns);
}
Expand Down Expand Up @@ -2180,7 +2182,8 @@ struct LinalgElementwiseOpFusionPass
};

// Add elementwise op fusion patterns.
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn,
removeOutsDependency);
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
tensor::populateBubbleUpExpandShapePatterns(patterns);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: mlir-opt %s -p 'builtin.module(func.func(linalg-fuse-elementwise-ops{remove-outs-dependency=0}))' -split-input-file | FileCheck %s

#identity = affine_map<(d0) -> (d0)>

func.func @redudant_copy_with_target_burst_size_two(%arg: tensor<4xf32>) -> tensor<4xf32> attributes {plhw.toplevel} {
// CHECK-NOT: tensor.empty
%1 = linalg.generic {indexing_maps = [#identity, #identity], iterator_types = ["parallel"] } ins(%arg: tensor<4xf32>) outs(%arg: tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%exp = arith.negf %in: f32
linalg.yield %exp : f32
} -> tensor<4xf32>
%2 = linalg.generic {indexing_maps = [#identity, #identity], iterator_types = ["parallel"] } ins(%1: tensor<4xf32>) outs(%arg: tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%exp = arith.mulf %in,%in: f32
linalg.yield %exp : f32
} -> tensor<4xf32>
return %2 : tensor<4xf32>
}
Loading