Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
let dependentDialects = [
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
];
let options = [
Option<"introduceTensorEmpty", "introduce-empty", "bool",
/*default=*/"true",
"Replace out by tensor.empty">,
];
}

def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> {
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1701,7 +1701,8 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
/// when both operations are fusable elementwise operations.
void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpFusion);
const ControlFusionFn &controlElementwiseOpFusion,
bool introduceTensorEmpty = true);

/// Function type which is used to control propagation of tensor.pack/unpack
/// ops.
Expand Down
11 changes: 7 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2134,11 +2134,13 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(

void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpsFusion) {
const ControlFusionFn &controlElementwiseOpsFusion,
bool introduceTensorEmpty) {
auto *context = patterns.getContext();
patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
RemoveOutsDependency>(context);
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant>(context);
if (introduceTensorEmpty)
patterns.add<RemoveOutsDependency>(context);
// Add the patterns that clean up dead operands and results.
populateEraseUnusedOperandsAndResultsPatterns(patterns);
}
Expand Down Expand Up @@ -2180,7 +2182,8 @@ struct LinalgElementwiseOpFusionPass
};

// Add elementwise op fusion patterns.
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn,
introduceTensorEmpty);
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
tensor::populateBubbleUpExpandShapePatterns(patterns);

Expand Down
Loading