Skip to content

Commit 44b03ca

Browse files
Improve im2col conversion to handle permuted filter and output maps (#20150)
The orginal generic im2col implementation assumed that the filter and output affine map would not have permutations. While this is true in most named ops, it is not necessarily required. See for example llvm/llvm-project#129547 where it was not the case for `conv_3d_ncdhw_fcdhw` . While in this case we just "normalized" the map upstream, this PR adds support for such maps directly. Fixes : #20139 --------- Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent b07e692 commit 44b03ca

File tree

10 files changed

+90
-28
lines changed

10 files changed

+90
-28
lines changed

compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ convertToIGEMMAndSetConfig(FunctionOpInterface funcOp,
101101
MLIRContext *context = funcOp->getContext();
102102
{
103103
RewritePatternSet patterns(context);
104-
iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(patterns,
105-
controlFn);
104+
iree_compiler::IREE::LinalgExt::populateConvToIm2colOpPatterns(patterns,
105+
controlFn);
106106
if (configFn.has_value()) {
107107
patterns.add<SetIGEMMConfiguration>(context, configFn.value());
108108
}

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ iree_gentbl_cc_library(
3030
iree_compiler_cc_library(
3131
name = "Transforms",
3232
srcs = [
33-
"ConvertConv2DToIm2ColOp.cpp",
3433
"ConvertConv2DToWinograd.cpp",
34+
"ConvertConvToIm2ColOp.cpp",
3535
"ConvertToLoops.cpp",
3636
"DecomposeAttention.cpp",
3737
"DecomposeIm2col.cpp",

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ iree_cc_library(
2727
"Passes.h.inc"
2828
"Transforms.h"
2929
SRCS
30-
"ConvertConv2DToIm2ColOp.cpp"
3130
"ConvertConv2DToWinograd.cpp"
31+
"ConvertConvToIm2ColOp.cpp"
3232
"ConvertToLoops.cpp"
3333
"DecomposeAttention.cpp"
3434
"DecomposeIm2col.cpp"

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp renamed to compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConvToIm2ColOp.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
namespace mlir::iree_compiler::IREE::LinalgExt {
1616

17-
#define GEN_PASS_DEF_CONVERTCONV2DTOIM2COLOPPASS
17+
#define GEN_PASS_DEF_CONVERTCONVTOIM2COLOPPASS
1818
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
1919

2020
static bool hasAllOneValues(ArrayRef<int64_t> attr) {
@@ -254,14 +254,14 @@ class ConvertConvGeneric final
254254
std::optional<ControlFnTy> controlFn;
255255
};
256256

257-
struct ConvertConv2DToIm2ColOpPass final
258-
: impl::ConvertConv2DToIm2ColOpPassBase<ConvertConv2DToIm2ColOpPass> {
257+
struct ConvertConvToIm2ColOpPass final
258+
: impl::ConvertConvToIm2ColOpPassBase<ConvertConvToIm2ColOpPass> {
259259
void getDependentDialects(DialectRegistry &registry) const override {
260260
registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
261261
}
262262
void runOnOperation() override {
263263
RewritePatternSet patterns(&getContext());
264-
populateConv2DToIm2colOpPatterns(patterns);
264+
populateConvToIm2colOpPatterns(patterns);
265265
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
266266
return signalPassFailure();
267267
}
@@ -270,8 +270,8 @@ struct ConvertConv2DToIm2ColOpPass final
270270

271271
} // namespace
272272

273-
void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns,
274-
std::optional<ControlFnTy> controlFn) {
273+
void populateConvToIm2colOpPatterns(RewritePatternSet &patterns,
274+
std::optional<ControlFnTy> controlFn) {
275275
patterns.insert<ConvertConvGeneric>(patterns.getContext(),
276276
std::move(controlFn));
277277
}

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ splitReduction(RewriterBase &rewriter, LinalgExt::TopkOp topkOp,
3232
/// op and reshapes on the inputs.
3333
/// TODO(Max191): Maybe move to transforms and use a funcOp walk instead of a
3434
/// rewrite pattern for this.
35-
void populateConv2DToIm2colOpPatterns(
35+
void populateConvToIm2colOpPatterns(
3636
RewritePatternSet &patterns,
3737
std::optional<std::function<bool(Operation *)>> controlFn = std::nullopt);
3838

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def DecomposeWinogradTransformPass :
6666
"Decomposes winograd transform ops into linalg ops";
6767
}
6868

69-
def ConvertConv2DToIm2ColOpPass :
70-
InterfacePass<"iree-linalg-ext-convert-conv2d-to-im2col-op", "mlir::FunctionOpInterface"> {
69+
def ConvertConvToIm2ColOpPass :
70+
InterfacePass<"iree-linalg-ext-convert-conv-to-im2col-op", "mlir::FunctionOpInterface"> {
7171
let summary = "Convert linalg convolution ops to im2col gemm based implementation.";
7272
}
7373

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ iree_lit_test_suite(
1616
name = "lit",
1717
srcs = enforce_glob(
1818
[
19-
"conv2d_to_im2col.mlir",
2019
"conv2d_to_winograd.mlir",
20+
"conv_to_im2col.mlir",
2121
"convert_to_loops.mlir",
2222
"convert_to_online_attention.mlir",
2323
"decompose_im2col.mlir",

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ iree_lit_test_suite(
1414
NAME
1515
lit
1616
SRCS
17-
"conv2d_to_im2col.mlir"
1817
"conv2d_to_winograd.mlir"
18+
"conv_to_im2col.mlir"
1919
"convert_to_loops.mlir"
2020
"convert_to_online_attention.mlir"
2121
"decompose_im2col.mlir"

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir renamed to compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv_to_im2col.mlir

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-linalg-ext-convert-conv2d-to-im2col-op))" %s | FileCheck %s
1+
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-linalg-ext-convert-conv-to-im2col-op))" %s | FileCheck %s
22

33
util.func public @conv_2d_nhwc_hwcf(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
44
%0 = linalg.conv_2d_nhwc_hwcf
@@ -139,9 +139,9 @@ util.func public @conv_strided(%arg0: tensor<1x16x16x4xf16>, %arg1: tensor<3x3x4
139139
// CHECK: util.return %[[MATMUL]] : tensor<1x7x7x16xf32>
140140

141141
// -----
142-
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
143-
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d3, d6)>
144-
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
142+
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2 + d5, d3 + d6, d4)>
143+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d1, d4)>
144+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d1)>
145145
util.func public @conv_nhwc_hwfc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x16x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
146146
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x16x16x4xf32>, tensor<3x3x16x4xf32>) outs(%arg2 : tensor<1x14x14x16xf32>) {
147147
^bb0(%in: f32, %in_0: f32, %out: f32):
@@ -195,3 +195,53 @@ util.func public @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<
195195
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
196196
// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : tensor<1x14x14x36xf32>, tensor<16x36xf32>)
197197
// CHECK: util.return %[[MATMUL]] : tensor<1x14x14x16xf32>
198+
199+
// -----
200+
201+
util.func public @conv_1d_ncw_fcw_transpose_maps(%arg0: tensor<1x8x130xf32>, %arg1: tensor<16x8x3xf32>) -> tensor<1x16x128xf32> {
202+
%cst = arith.constant 0.000000e+00 : f32
203+
%empty = tensor.empty() : tensor<1x16x128xf32>
204+
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x16x128xf32>) -> tensor<1x16x128xf32>
205+
%0 = linalg.generic {
206+
indexing_maps = [
207+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d1 + d3)>,
208+
affine_map<(d0, d1, d2, d3, d4) -> (d2, d4, d3)>,
209+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1)>],
210+
iterator_types =
211+
["parallel", "parallel", "parallel", "reduction", "reduction"]}
212+
ins(%arg0, %arg1 : tensor<1x8x130xf32>,tensor<16x8x3xf32>)
213+
outs(%fill : tensor<1x16x128xf32>) {
214+
^bb0(%in: f32, %in_0: f32, %out: f32):
215+
%8 = arith.mulf %in, %in_0 : f32
216+
%9 = arith.addf %out, %8 : f32
217+
linalg.yield %9 : f32
218+
} -> tensor<1x16x128xf32>
219+
util.return %0 : tensor<1x16x128xf32>
220+
}
221+
222+
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
223+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
224+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
225+
// CHECK: util.func public @conv_1d_ncw_fcw_transpose_maps(
226+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x8x130xf32>
227+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<16x8x3xf32>
228+
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
229+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x16x128xf32>
230+
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY]] : tensor<1x16x128xf32>) -> tensor<1x16x128xf32>
231+
// CHECK: %[[EMPTY2:.+]] = tensor.empty() : tensor<1x128x24xf32>
232+
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
233+
// CHECK-SAME: strides = [1] dilations = [1] kernel_size = [3]
234+
// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
235+
// CHECK-SAME: batch_pos = [0] m_pos = [2] k_pos = [1]
236+
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x8x130xf32>)
237+
// CHECK-SAME: outs(%[[EMPTY2]] : tensor<1x128x24xf32>) -> tensor<1x128x24xf32>
238+
// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]] : tensor<16x8x3xf32> into tensor<16x24xf32>
239+
// CHECK: %[[MATMUL:.+]] = linalg.generic
240+
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
241+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
242+
// CHECK-SAME: ins(%[[COLLAPSED]], %[[IM2COL]] : tensor<16x24xf32>, tensor<1x128x24xf32>)
243+
// CHECK-SAME: outs(%[[FILL]] : tensor<1x16x128xf32>) {
244+
// CHECK: arith.mulf
245+
// CHECK: arith.addf
246+
// CHECK: } -> tensor<1x16x128xf32>
247+
// CHECK: util.return %[[MATMUL]] : tensor<1x16x128xf32>

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -445,20 +445,40 @@ getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) {
445445
auto outputShape = outputType.getShape();
446446
auto indexingMaps = linalgOp.getIndexingMapsArray();
447447
auto filterMap = indexingMaps[1];
448+
auto outputMap = indexingMaps[2];
448449

449450
SmallVector<int64_t> reductionDims;
450451
for (auto iter : llvm::enumerate(linalgOp.getIteratorTypesArray())) {
451452
if (linalg::isReductionIterator(iter.value())) {
452453
reductionDims.push_back(iter.index());
453454
}
454455
}
456+
457+
bool isOutputChannelFirst = false;
458+
auto outputChannelPos = convDims.outputChannel;
459+
auto outputImagePos = convDims.outputImage;
460+
461+
std::optional<int64_t> outputChannelLastDim = outputMap.getResultPosition(
462+
getAffineDimExpr(outputChannelPos.back(), outputMap.getContext()));
463+
std::optional<int64_t> outputImageFirstDim = outputMap.getResultPosition(
464+
getAffineDimExpr(outputImagePos[0], outputMap.getContext()));
465+
if (!outputImageFirstDim || !outputChannelLastDim) {
466+
LDBG("output image or output channel dim not found in output.");
467+
return failure();
468+
}
469+
if (outputChannelLastDim.value() < outputImageFirstDim.value())
470+
isOutputChannelFirst = true;
471+
455472
SmallVector<int64_t> filterkPos;
456473
for (auto reductionDim : reductionDims) {
457474
std::optional<int64_t> maybeDim = filterMap.getResultPosition(
458475
getAffineDimExpr(reductionDim, filterMap.getContext()));
459476
filterkPos.push_back(maybeDim.value());
460477
}
461-
// group together adjacent reduction dimensions in the filter
478+
// group together adjacent reduction dimensions in the filter.
479+
// First we want to sort the dims as the look up from the filterMap
480+
// can place the dims in arbitarty order.
481+
std::sort(filterkPos.begin(), filterkPos.end());
462482
SmallVector<ReassociationIndices> collapsedFilterReductionDim;
463483
int64_t prevFilterIndex = filterkPos[0];
464484
int64_t currCollapsedIndex = 0;
@@ -526,12 +546,6 @@ getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) {
526546
SmallVector<AffineExpr>(dims.begin(), dims.begin() + numParallelDims),
527547
ctx);
528548

529-
bool isOutputChannelFirst = false;
530-
auto outputChannelPos = convDims.outputChannel;
531-
auto outputImagePos = convDims.outputImage;
532-
if (outputChannelPos.back() < outputImagePos[0])
533-
isOutputChannelFirst = true;
534-
535549
// prepare the input map.
536550
SmallVector<AffineExpr> inputDims;
537551
// Add the batch dimensions.
@@ -579,7 +593,6 @@ getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) {
579593
SmallVector<int64_t> igemmLoopBounds;
580594
igemmLoopBounds.insert(igemmLoopBounds.end(), outputShape.begin(),
581595
outputShape.begin() + numParallelDims);
582-
583596
SmallVector<utils::IteratorType> igemmLoopIterators(outputShape.size(),
584597
parallel);
585598

@@ -594,7 +607,6 @@ getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) {
594607
igemmDetails.isOutputChannelFirst = isOutputChannelFirst;
595608
igemmDetails.convDims = convDims;
596609
igemmDetails.igemmLoopIterators = igemmLoopIterators;
597-
598610
return igemmDetails;
599611
}
600612

0 commit comments

Comments
 (0)