Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,19 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops">,
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> {
def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion">,
Deprecated<"Use 'simplify-depthwise-conv' instead."> {
let summary = "Convert from one named linalg op to another.";
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
}

// ------------------ End of "form" conversions

def SimplifyDepthwiseConvPass: Pass<"simplify-depthwise-conv"> {
let summary = "Simplify depthwise convolution.";
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
}

def ConvertElementwiseToLinalgPass : Pass<"convert-elementwise-to-linalg", ""> {
let summary = "Convert ElementwiseMappable ops to linalg";
let description = [{
Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Compiler.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallSet.h"

Expand Down Expand Up @@ -1962,8 +1963,11 @@ void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
void populateFuseTensorPadWithProducerLinalgOpPatterns(
RewritePatternSet &patterns);

/// Patterns to convert from one named op to another. These can be seen as
/// canonicalizations of named ops into another named op.
/// Patterns to simplify depthwise convolutions.
void populateSimplifyDepthwiseConvPatterns(RewritePatternSet &patterns);

/// Patterns to convert from one named op to another. So far only used on
/// depthwise convolutions, so deprecated. Use the pattern avove.
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);

/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MorphOps.cpp
TransposeMatmul.cpp
ShardingInterfaceImpl.cpp
NamedOpConversions.cpp
SimplifyDepthwiseConv.cpp
NamedToElementwise.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

namespace mlir {
#define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS
#define GEN_PASS_DEF_SIMPLIFYDEPTHWISECONVPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

Expand Down Expand Up @@ -143,6 +144,22 @@ struct SimplifyDepthwiseConvQOp
}
};

struct SimplifyDepthwiseConvPass
: public impl::SimplifyDepthwiseConvPassBase<
SimplifyDepthwiseConvPass> {
using impl::SimplifyDepthwiseConvPassBase<
SimplifyDepthwiseConvPass>::SimplifyDepthwiseConvPassBase;

void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateSimplifyDepthwiseConvPatterns(patterns);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};

// Deprecated, use the one above
struct LinalgNamedOpConversionPass
: public impl::LinalgNamedOpConversionPassBase<
LinalgNamedOpConversionPass> {
Expand All @@ -159,6 +176,13 @@ struct LinalgNamedOpConversionPass
};
} // namespace

void mlir::linalg::populateSimplifyDepthwiseConvPatterns(
RewritePatternSet &patterns) {
patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
patterns.getContext());
}

// Deprecated, use the one above
void mlir::linalg::populateLinalgNamedOpConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
Expand Down
Loading