Skip to content

Commit 55c5562

Browse files
authored
[LLVMGPU][NFC] Create LLVMGPU pass for IGEMM (#18871)
This PR refactors the ConvolutionToIGEMM pass to a shared transform function, and creates a new pass for LLVMGPU. This keeps the lowering config details in LLVMGPU separate from the common pass, and removes the need for passing a control function or config function in the pass constructor. This is also a precursor to adding some more complex logic in the control function for LLVMGPU, which will be added in a later PR. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent c6b3592 commit 55c5562

File tree

15 files changed

+230
-136
lines changed

15 files changed

+230
-136
lines changed

compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp

Lines changed: 82 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Common/Transforms.h"
89
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
910
#include "iree/compiler/Codegen/Transforms/Transforms.h"
1011
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
1112
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
1213
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1314
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1415
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16+
#include "mlir/IR/MLIRContext.h"
1517
#include "mlir/Interfaces/FunctionInterfaces.h"
1618
#include "mlir/Pass/Pass.h"
1719
#include "mlir/Pass/PassRegistry.h"
@@ -26,10 +28,14 @@ namespace {
2628

2729
using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect;
2830

31+
/// Pattern to set a lowering configuration on an IGEMM convolution. Searches
32+
/// for a contraction with a linalg_ext.im2col producer, and calls the configFn
33+
/// to set the configuration.
34+
/// TODO(Max191): Use a funcOp walk instead of a pattern for this.
2935
struct SetIGEMMConfiguration final : OpRewritePattern<linalg::GenericOp> {
3036
using OpRewritePattern::OpRewritePattern;
3137

32-
SetIGEMMConfiguration(MLIRContext *context, ConfigFn configFn)
38+
SetIGEMMConfiguration(MLIRContext *context, IGEMMConfigFn configFn)
3339
: OpRewritePattern(context), configFn(configFn) {}
3440

3541
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
@@ -67,99 +73,95 @@ struct SetIGEMMConfiguration final : OpRewritePattern<linalg::GenericOp> {
6773
}
6874

6975
private:
70-
ConfigFn configFn;
76+
IGEMMConfigFn configFn;
7177
};
7278

7379
class ConvolutionToIGEMMPass final
7480
: public impl::ConvolutionToIGEMMPassBase<ConvolutionToIGEMMPass> {
7581
public:
7682
using ConvolutionToIGEMMPassBase::ConvolutionToIGEMMPassBase;
7783

78-
explicit ConvolutionToIGEMMPass(ConfigFn configFn) : configFn(configFn) {}
84+
ConvolutionToIGEMMPass(std::optional<IGEMMConfigFn> configFn,
85+
std::optional<IGEMMControlFn> controlFn)
86+
: configFn(configFn), controlFn(controlFn) {}
7987

80-
void getDependentDialects(DialectRegistry &registry) const override {
81-
registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
82-
}
83-
void runOnOperation() override {
84-
MLIRContext *context = &getContext();
85-
86-
// Rewrite convolutions into a im2col and GEMM.
87-
{
88-
auto conv2dToIm2colControlFn = [](Operation *conv) {
89-
// Don't transform convolutions that have a preset lowering config.
90-
if (getLoweringConfig(conv)) {
91-
return false;
92-
}
93-
return true;
94-
};
95-
MLIRContext *context = &getContext();
96-
RewritePatternSet patterns(context);
97-
iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(
98-
patterns, conv2dToIm2colControlFn);
99-
patterns.add<SetIGEMMConfiguration>(context, configFn);
100-
if (failed(applyPatternsAndFoldGreedily(getOperation(),
101-
std::move(patterns)))) {
102-
return signalPassFailure();
103-
}
104-
}
105-
106-
// The im2col transformation collapses some of the dimensions of the
107-
// convolution operands. Try to push the reshape ops towards the boundaries
108-
// of the function and fold with interface tensor ops.
109-
//
110-
// TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
111-
// generate a multi-M dim contraction instead of collapsing and
112-
// propagating reshapes. It should ultimately become a pass option to
113-
// decide whether to collapse the contraction dimensions into a single
114-
// M/N/K dimension.
115-
{
116-
RewritePatternSet bubbleCollapseShapePatterns(context);
117-
linalg::ControlFusionFn bubbleUpExpansionControlFn =
118-
[](OpOperand *fusedOperand) {
119-
Operation *producer = fusedOperand->get().getDefiningOp();
120-
Operation *consumer = fusedOperand->getOwner();
121-
122-
// Block only if one of the operations has a lowering configuration
123-
// which means it likely expects tiling specific to its original
124-
// shape.
125-
if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
126-
return false;
127-
}
128-
return true;
129-
};
130-
linalg::populateFoldReshapeOpsByCollapsingPatterns(
131-
bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
132-
// Add patterns to do some additional cleanup (on top of canonicalizations
133-
// that can be done later) of reshape ops.
134-
tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns);
135-
linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
136-
context);
137-
tensor::CollapseShapeOp::getCanonicalizationPatterns(
138-
bubbleCollapseShapePatterns, context);
139-
tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
140-
context);
141-
tensor::ExpandShapeOp::getCanonicalizationPatterns(
142-
bubbleCollapseShapePatterns, context);
143-
populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
144-
if (failed(applyPatternsAndFoldGreedily(
145-
getOperation(), std::move(bubbleCollapseShapePatterns)))) {
146-
return signalPassFailure();
147-
}
148-
}
149-
}
88+
void runOnOperation() override;
15089

15190
private:
152-
ConfigFn configFn = [](linalg::GenericOp genericOp,
153-
IREE::LinalgExt::Im2colOp im2colOp) {
154-
return failure();
155-
};
91+
std::optional<IGEMMConfigFn> configFn;
92+
std::optional<IGEMMControlFn> controlFn;
15693
};
15794

15895
} // namespace
15996

160-
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
161-
createConvolutionToIGEMMPass(ConfigFn configFn) {
162-
return std::make_unique<ConvolutionToIGEMMPass>(configFn);
97+
LogicalResult
98+
convertToIGEMMAndSetConfig(FunctionOpInterface funcOp,
99+
std::optional<IGEMMConfigFn> configFn,
100+
std::optional<IGEMMControlFn> controlFn) {
101+
// Rewrite convolutions into a im2col and GEMM.
102+
MLIRContext *context = funcOp->getContext();
103+
{
104+
RewritePatternSet patterns(context);
105+
iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(patterns,
106+
controlFn);
107+
if (configFn.has_value()) {
108+
patterns.add<SetIGEMMConfiguration>(context, configFn.value());
109+
}
110+
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
111+
return failure();
112+
}
113+
}
114+
115+
// The im2col transformation collapses some of the dimensions of the
116+
// convolution operands. Try to push the reshape ops towards the boundaries
117+
// of the function and fold with interface tensor ops.
118+
//
119+
// TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
120+
// generate a multi-M dim contraction instead of collapsing and
121+
// propagating reshapes. It should ultimately become a pass option to
122+
// decide whether to collapse the contraction dimensions into a single
123+
// M/N/K dimension.
124+
{
125+
RewritePatternSet bubbleCollapseShapePatterns(context);
126+
linalg::ControlFusionFn bubbleUpExpansionControlFn =
127+
[](OpOperand *fusedOperand) {
128+
Operation *producer = fusedOperand->get().getDefiningOp();
129+
Operation *consumer = fusedOperand->getOwner();
130+
131+
// Block only if one of the operations has a lowering configuration
132+
// which means it likely expects tiling specific to its original
133+
// shape.
134+
if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
135+
return false;
136+
}
137+
return true;
138+
};
139+
linalg::populateFoldReshapeOpsByCollapsingPatterns(
140+
bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
141+
// Add patterns to do some additional cleanup (on top of canonicalizations
142+
// that can be done later) of reshape ops.
143+
tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns);
144+
linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
145+
context);
146+
tensor::CollapseShapeOp::getCanonicalizationPatterns(
147+
bubbleCollapseShapePatterns, context);
148+
tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
149+
context);
150+
tensor::ExpandShapeOp::getCanonicalizationPatterns(
151+
bubbleCollapseShapePatterns, context);
152+
populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
153+
if (failed(applyPatternsAndFoldGreedily(
154+
funcOp, std::move(bubbleCollapseShapePatterns)))) {
155+
return failure();
156+
}
157+
}
158+
return success();
159+
}
160+
161+
void ConvolutionToIGEMMPass::runOnOperation() {
162+
if (failed(convertToIGEMMAndSetConfig(getOperation()))) {
163+
return signalPassFailure();
164+
}
163165
}
164166

165167
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/Passes.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,6 @@ std::unique_ptr<InterfacePass<FunctionOpInterface>>
6060
createConvertToDestinationPassingStylePass(
6161
bool useWARForCooperativeMatrixCodegen);
6262

63-
using ConfigFn =
64-
std::function<LogicalResult(linalg::GenericOp, IREE::LinalgExt::Im2colOp)>;
65-
/// Pass to convert Conv2D ops into IGEMM (Im2colOp + matmul). `configFn` is
66-
/// used to set lowering configurations on the resulting ops, if necessary.
67-
std::unique_ptr<InterfacePass<FunctionOpInterface>>
68-
createConvolutionToIGEMMPass(ConfigFn configFn);
69-
7063
std::unique_ptr<Pass> createDecomposeSoftmaxPass(bool useFusion);
7164

7265
/// Pass to perform linalg on tensor bufferization. The function passed into

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def ConvolutionToIGEMMPass :
8383
InterfacePass<"iree-codegen-convolution-to-igemm", "mlir::FunctionOpInterface"> {
8484
let summary =
8585
"Transforms convolution operations into an implicit GEMM format.";
86+
let dependentDialects = [
87+
"tensor::TensorDialect",
88+
"iree_compiler::IREE::LinalgExt::IREELinalgExtDialect"
89+
];
8690
}
8791

8892
def DecomposeAffineOpsPass: Pass<"iree-codegen-decompose-affine-ops"> {

compiler/src/iree/compiler/Codegen/Common/Transforms.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ struct OneShotBufferizationOptions;
1818

1919
namespace mlir::iree_compiler {
2020

21+
using IGEMMConfigFn =
22+
std::function<LogicalResult(linalg::GenericOp, IREE::LinalgExt::Im2colOp)>;
23+
using IGEMMControlFn = std::function<bool(Operation *)>;
24+
25+
/// Converts conv_2d ops into linalg_ext.im2col + matmul, and sets a lowering
26+
/// configuration on the matmul.
27+
LogicalResult convertToIGEMMAndSetConfig(
28+
FunctionOpInterface funcOp,
29+
std::optional<IGEMMConfigFn> configFn = std::nullopt,
30+
std::optional<IGEMMControlFn> controlFn = std::nullopt);
31+
2132
/// Eliminates tensor.empty ops to avoid buffer allocations.
2233
LogicalResult eliminateEmptyTensors(
2334
RewriterBase &rewriter, Operation *op,

compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,6 @@ module {
6969

7070
// -----
7171

72-
#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)>
73-
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
74-
func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> {
75-
%cst = arith.constant 0.0 : f32
76-
%empty = tensor.empty() : tensor<1x14x14x16xf32>
77-
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
78-
%0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config,
79-
dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
80-
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
81-
outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
82-
return %0 : tensor<1x14x14x16xf32>
83-
}
84-
// CHECK: func.func public @conv_with_lowering_config
85-
// CHECK-NOT: iree_linalg_ext.im2col
86-
// CHECK: linalg.conv_2d_nhwc_hwcf
87-
// CHECK-SAME: lowering_config
88-
89-
// -----
90-
9172
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
9273
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
9374
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>

compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ iree_compiler_cc_library(
9595
"LLVMGPUCastTypeToFitMMA.cpp",
9696
"LLVMGPUConfigureTensorLayouts.cpp",
9797
"LLVMGPUConfigureVectorLayouts.cpp",
98+
"LLVMGPUConvolutionToIGEMM.cpp",
9899
"LLVMGPULowerExecutableTarget.cpp",
99100
"LLVMGPUPackSharedMemoryAlloc.cpp",
100101
"LLVMGPUPrefetching.cpp",

compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ iree_cc_library(
8080
"LLVMGPUCastTypeToFitMMA.cpp"
8181
"LLVMGPUConfigureTensorLayouts.cpp"
8282
"LLVMGPUConfigureVectorLayouts.cpp"
83+
"LLVMGPUConvolutionToIGEMM.cpp"
8384
"LLVMGPULowerExecutableTarget.cpp"
8485
"LLVMGPUPackSharedMemoryAlloc.cpp"
8586
"LLVMGPUPrefetching.cpp"
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright 2024 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Common/Transforms.h"
9+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
10+
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
11+
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
12+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
13+
14+
#define DEBUG_TYPE "iree-llvmgpu-convolution-to-igemm"
15+
16+
namespace mlir::iree_compiler {
17+
18+
#define GEN_PASS_DEF_LLVMGPUCONVOLUTIONTOIGEMMPASS
19+
#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
20+
21+
namespace {
22+
23+
/// Function for setting lowering configurations on contractions resulting from
24+
/// the IGEMM transformation. This currently uses the TileAndFuse pipeline, and
25+
/// tries to target MMA intrinsics.
26+
static LogicalResult llvmgpuConfigFn(linalg::GenericOp genericOp,
27+
IREE::LinalgExt::Im2colOp im2colOp) {
28+
auto funcOp = genericOp->getParentOfType<FunctionOpInterface>();
29+
if (!funcOp) {
30+
return genericOp.emitError("cannot find parent funcOp");
31+
}
32+
IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
33+
if (!target) {
34+
return funcOp.emitError("missing GPU target in parent funcOp");
35+
}
36+
if (failed(IREE::GPU::setMatmulLoweringConfig(target, funcOp, genericOp))) {
37+
return IREE::GPU::setTileAndFuseLoweringConfig(target, funcOp, genericOp);
38+
}
39+
return success();
40+
}
41+
42+
static bool llvmgpuControlFn(Operation *op) {
43+
// Do not convert anything that already has a lowering configuration.
44+
if (getLoweringConfig(op)) {
45+
return false;
46+
}
47+
return true;
48+
}
49+
50+
struct LLVMGPUConvolutionToIGEMMPass final
51+
: impl::LLVMGPUConvolutionToIGEMMPassBase<LLVMGPUConvolutionToIGEMMPass> {
52+
using impl::LLVMGPUConvolutionToIGEMMPassBase<
53+
LLVMGPUConvolutionToIGEMMPass>::LLVMGPUConvolutionToIGEMMPassBase;
54+
55+
void runOnOperation() override;
56+
};
57+
58+
void LLVMGPUConvolutionToIGEMMPass::runOnOperation() {
59+
if (failed(convertToIGEMMAndSetConfig(getOperation(), llvmgpuConfigFn,
60+
llvmgpuControlFn))) {
61+
return signalPassFailure();
62+
}
63+
}
64+
65+
} // namespace
66+
} // namespace mlir::iree_compiler

0 commit comments

Comments
 (0)