55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
77#include " iree/compiler/Codegen/Common/Passes.h"
8+ #include " iree/compiler/Codegen/Common/Transforms.h"
89#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
910#include " iree/compiler/Codegen/Transforms/Transforms.h"
1011#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
1112#include " iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
1213#include " mlir/Dialect/Affine/IR/AffineOps.h"
1314#include " mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1415#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
16+ #include " mlir/IR/MLIRContext.h"
1517#include " mlir/Interfaces/FunctionInterfaces.h"
1618#include " mlir/Pass/Pass.h"
1719#include " mlir/Pass/PassRegistry.h"
@@ -26,10 +28,14 @@ namespace {
2628
2729using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect;
2830
31+ // / Pattern to set a lowering configuration on an IGEMM convolution. Searches
32+ // / for a contraction with a linalg_ext.im2col producer, and calls the configFn
33+ // / to set the configuration.
34+ // / TODO(Max191): Use a funcOp walk instead of a pattern for this.
2935struct SetIGEMMConfiguration final : OpRewritePattern<linalg::GenericOp> {
3036 using OpRewritePattern::OpRewritePattern;
3137
32- SetIGEMMConfiguration (MLIRContext *context, ConfigFn configFn)
38+ SetIGEMMConfiguration (MLIRContext *context, IGEMMConfigFn configFn)
3339 : OpRewritePattern(context), configFn(configFn) {}
3440
3541 LogicalResult matchAndRewrite (linalg::GenericOp genericOp,
@@ -67,99 +73,95 @@ struct SetIGEMMConfiguration final : OpRewritePattern<linalg::GenericOp> {
6773 }
6874
6975private:
70- ConfigFn configFn;
76+ IGEMMConfigFn configFn;
7177};
7278
7379class ConvolutionToIGEMMPass final
7480 : public impl::ConvolutionToIGEMMPassBase<ConvolutionToIGEMMPass> {
7581public:
7682 using ConvolutionToIGEMMPassBase::ConvolutionToIGEMMPassBase;
7783
78- explicit ConvolutionToIGEMMPass (ConfigFn configFn) : configFn(configFn) {}
84+ ConvolutionToIGEMMPass (std::optional<IGEMMConfigFn> configFn,
85+ std::optional<IGEMMControlFn> controlFn)
86+ : configFn(configFn), controlFn(controlFn) {}
7987
80- void getDependentDialects (DialectRegistry ®istry) const override {
81- registry.insert <tensor::TensorDialect, IREELinalgExtDialect>();
82- }
83- void runOnOperation () override {
84- MLIRContext *context = &getContext ();
85-
86- // Rewrite convolutions into a im2col and GEMM.
87- {
88- auto conv2dToIm2colControlFn = [](Operation *conv) {
89- // Don't transform convolutions that have a preset lowering config.
90- if (getLoweringConfig (conv)) {
91- return false ;
92- }
93- return true ;
94- };
95- MLIRContext *context = &getContext ();
96- RewritePatternSet patterns (context);
97- iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns (
98- patterns, conv2dToIm2colControlFn);
99- patterns.add <SetIGEMMConfiguration>(context, configFn);
100- if (failed (applyPatternsAndFoldGreedily (getOperation (),
101- std::move (patterns)))) {
102- return signalPassFailure ();
103- }
104- }
105-
106- // The im2col transformation collapses some of the dimensions of the
107- // convolution operands. Try to push the reshape ops towards the boundaries
108- // of the function and fold with interface tensor ops.
109- //
110- // TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
111- // generate a multi-M dim contraction instead of collapsing and
112- // propagating reshapes. It should ultimately become a pass option to
113- // decide whether to collapse the contraction dimensions into a single
114- // M/N/K dimension.
115- {
116- RewritePatternSet bubbleCollapseShapePatterns (context);
117- linalg::ControlFusionFn bubbleUpExpansionControlFn =
118- [](OpOperand *fusedOperand) {
119- Operation *producer = fusedOperand->get ().getDefiningOp ();
120- Operation *consumer = fusedOperand->getOwner ();
121-
122- // Block only if one of the operations has a lowering configuration
123- // which means it likely expects tiling specific to its original
124- // shape.
125- if (getLoweringConfig (producer) || getLoweringConfig (consumer)) {
126- return false ;
127- }
128- return true ;
129- };
130- linalg::populateFoldReshapeOpsByCollapsingPatterns (
131- bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
132- // Add patterns to do some additional cleanup (on top of canonicalizations
133- // that can be done later) of reshape ops.
134- tensor::populateFoldTensorEmptyPatterns (bubbleCollapseShapePatterns);
135- linalg::FillOp::getCanonicalizationPatterns (bubbleCollapseShapePatterns,
136- context);
137- tensor::CollapseShapeOp::getCanonicalizationPatterns (
138- bubbleCollapseShapePatterns, context);
139- tensor::EmptyOp::getCanonicalizationPatterns (bubbleCollapseShapePatterns,
140- context);
141- tensor::ExpandShapeOp::getCanonicalizationPatterns (
142- bubbleCollapseShapePatterns, context);
143- populateReshapeToInterfaceTensorPatterns (bubbleCollapseShapePatterns);
144- if (failed (applyPatternsAndFoldGreedily (
145- getOperation (), std::move (bubbleCollapseShapePatterns)))) {
146- return signalPassFailure ();
147- }
148- }
149- }
88+ void runOnOperation () override ;
15089
15190private:
152- ConfigFn configFn = [](linalg::GenericOp genericOp,
153- IREE::LinalgExt::Im2colOp im2colOp) {
154- return failure ();
155- };
91+ std::optional<IGEMMConfigFn> configFn;
92+ std::optional<IGEMMControlFn> controlFn;
15693};
15794
15895} // namespace
15996
160- std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
161- createConvolutionToIGEMMPass (ConfigFn configFn) {
162- return std::make_unique<ConvolutionToIGEMMPass>(configFn);
97+ LogicalResult
98+ convertToIGEMMAndSetConfig (FunctionOpInterface funcOp,
99+ std::optional<IGEMMConfigFn> configFn,
100+ std::optional<IGEMMControlFn> controlFn) {
101+ // Rewrite convolutions into a im2col and GEMM.
102+ MLIRContext *context = funcOp->getContext ();
103+ {
104+ RewritePatternSet patterns (context);
105+ iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns (patterns,
106+ controlFn);
107+ if (configFn.has_value ()) {
108+ patterns.add <SetIGEMMConfiguration>(context, configFn.value ());
109+ }
110+ if (failed (applyPatternsAndFoldGreedily (funcOp, std::move (patterns)))) {
111+ return failure ();
112+ }
113+ }
114+
115+ // The im2col transformation collapses some of the dimensions of the
116+ // convolution operands. Try to push the reshape ops towards the boundaries
117+ // of the function and fold with interface tensor ops.
118+ //
119+ // TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
120+ // generate a multi-M dim contraction instead of collapsing and
121+ // propagating reshapes. It should ultimately become a pass option to
122+ // decide whether to collapse the contraction dimensions into a single
123+ // M/N/K dimension.
124+ {
125+ RewritePatternSet bubbleCollapseShapePatterns (context);
126+ linalg::ControlFusionFn bubbleUpExpansionControlFn =
127+ [](OpOperand *fusedOperand) {
128+ Operation *producer = fusedOperand->get ().getDefiningOp ();
129+ Operation *consumer = fusedOperand->getOwner ();
130+
131+ // Block only if one of the operations has a lowering configuration
132+ // which means it likely expects tiling specific to its original
133+ // shape.
134+ if (getLoweringConfig (producer) || getLoweringConfig (consumer)) {
135+ return false ;
136+ }
137+ return true ;
138+ };
139+ linalg::populateFoldReshapeOpsByCollapsingPatterns (
140+ bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
141+ // Add patterns to do some additional cleanup (on top of canonicalizations
142+ // that can be done later) of reshape ops.
143+ tensor::populateFoldTensorEmptyPatterns (bubbleCollapseShapePatterns);
144+ linalg::FillOp::getCanonicalizationPatterns (bubbleCollapseShapePatterns,
145+ context);
146+ tensor::CollapseShapeOp::getCanonicalizationPatterns (
147+ bubbleCollapseShapePatterns, context);
148+ tensor::EmptyOp::getCanonicalizationPatterns (bubbleCollapseShapePatterns,
149+ context);
150+ tensor::ExpandShapeOp::getCanonicalizationPatterns (
151+ bubbleCollapseShapePatterns, context);
152+ populateReshapeToInterfaceTensorPatterns (bubbleCollapseShapePatterns);
153+ if (failed (applyPatternsAndFoldGreedily (
154+ funcOp, std::move (bubbleCollapseShapePatterns)))) {
155+ return failure ();
156+ }
157+ }
158+ return success ();
159+ }
160+
161+ void ConvolutionToIGEMMPass::runOnOperation () {
162+ if (failed (convertToIGEMMAndSetConfig (getOperation ()))) {
163+ return signalPassFailure ();
164+ }
163165}
164166
165167} // namespace mlir::iree_compiler
0 commit comments