Skip to content

Commit b86ed92

Browse files
authored
[Im2Col] Support converting group convs to im2col (#20611)
This adds support for converting group convs to im2col, allowing them to go down the IGEMM path. Group dimensions are parallel iterator dims that index into the image, filter, and output. For im2col they are treated as a batch dimension. This also fixes #20498
1 parent 73c7462 commit b86ed92

File tree

4 files changed

+205
-107
lines changed

4 files changed

+205
-107
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConvToIm2ColOp.cpp

Lines changed: 49 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -62,54 +62,38 @@ static SmallVector<int64_t> getBasisFromShape(ArrayRef<int64_t> shape) {
6262
return basis;
6363
}
6464

65-
// Collect all AffineDimExprs from an AffineExpr.
66-
static void collectDimExprs(ArrayRef<AffineExpr> exprs,
67-
DenseSet<AffineExpr> &out) {
68-
for (auto &expr : exprs) {
69-
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
70-
out.insert(dimExpr);
71-
} else if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
72-
collectDimExprs({binOpExpr.getLHS(), binOpExpr.getRHS()}, out);
73-
} else {
74-
LLVM_DEBUG(llvm::dbgs()
75-
<< "Non-dimension expression found: " << expr << "\n");
76-
}
77-
}
78-
}
79-
8065
// Computes `inputKPerm` that maps the input spatial and channel dimension order
8166
// to filter's.
82-
static SmallVector<int64_t> computeInputKPerm(AffineMap inputMap,
83-
AffineMap filterMap) {
84-
DenseSet<AffineExpr> inputDimsSet;
85-
DenseSet<AffineExpr> filterDimsSet;
86-
collectDimExprs(inputMap.getResults(), inputDimsSet);
87-
collectDimExprs(filterMap.getResults(), filterDimsSet);
88-
89-
// Get shared dims from input and filter in order of appearance.
90-
SmallVector<AffineExpr> inputSharedDims;
91-
SmallVector<AffineExpr> filterSharedDims;
92-
for (AffineExpr expr : inputMap.getResults()) {
93-
expr.walk([&](AffineExpr dimExpr) {
94-
if (filterDimsSet.contains(dimExpr)) {
95-
inputSharedDims.push_back(dimExpr);
67+
static SmallVector<int64_t>
68+
computeInputKPerm(AffineMap inputMap, AffineMap filterMap,
69+
const mlir::linalg::ConvolutionDimensions &convDims) {
70+
// Get reduction dims from input and filter in order of appearance.
71+
auto reductionDims =
72+
llvm::concat<const unsigned>(convDims.inputChannel, convDims.filterLoop);
73+
SmallVector<int64_t> inputReductionDims;
74+
for (AffineExpr dimExpr : inputMap.getResults()) {
75+
for (unsigned reductionDim : reductionDims) {
76+
if (dimExpr.isFunctionOfDim(reductionDim)) {
77+
inputReductionDims.push_back(reductionDim);
9678
}
97-
});
79+
}
9880
}
99-
for (AffineExpr expr : filterMap.getResults()) {
100-
expr.walk([&](AffineExpr dimExpr) {
101-
if (inputDimsSet.contains(dimExpr)) {
102-
filterSharedDims.push_back(dimExpr);
81+
SmallVector<int64_t> filterReductionDims;
82+
for (AffineExpr dimExpr : filterMap.getResults()) {
83+
for (unsigned reductionDim : reductionDims) {
84+
if (dimExpr.isFunctionOfDim(reductionDim)) {
85+
filterReductionDims.push_back(reductionDim);
10386
}
104-
});
87+
}
10588
}
89+
10690
// Compute the permutation that maps inputSharedDims to filterSharedDims.
10791
SmallVector<int64_t> inputKPerm;
108-
for (AffineExpr filterExpr : filterSharedDims) {
109-
auto it = llvm::find(inputSharedDims, filterExpr);
110-
assert(it != inputSharedDims.end() &&
92+
for (int64_t dim : filterReductionDims) {
93+
auto it = llvm::find(inputReductionDims, dim);
94+
assert(it != inputReductionDims.end() &&
11195
"Filter dimension not found in input shared dimensions");
112-
inputKPerm.push_back(std::distance(inputSharedDims.begin(), it));
96+
inputKPerm.push_back(std::distance(inputReductionDims.begin(), it));
11397
}
11498
return inputKPerm;
11599
}
@@ -211,18 +195,20 @@ class ConvertConvGeneric final
211195
rewriter.getIndexAttr(filterShape[maybeDim.value()]));
212196
}
213197

214-
// Shape of the resulting tensor from im2col.
215-
SmallVector<int64_t> colTensorShape;
216-
SmallVector<int64_t> batchPos;
217-
for (auto batch : convDims.batch) {
218-
std::optional<int64_t> maybeBatch = inputMap.getResultPosition(
219-
getAffineDimExpr(batch, inputMap.getContext()));
220-
if (!maybeBatch) {
221-
return rewriter.notifyMatchFailure(linalgOp,
222-
"Failed to infer batch shape.");
223-
}
224-
batchPos.push_back(maybeBatch.value());
225-
colTensorShape.push_back(inputShape[maybeBatch.value()]);
198+
// Batch dims for the im2col also include the depth/group dimensions of the
199+
// conv.
200+
auto im2colBatchIterDims =
201+
llvm::to_vector(llvm::concat<unsigned>(convDims.depth, convDims.batch));
202+
SmallVector<int64_t> batchPos(im2colBatchIterDims.size());
203+
for (int64_t convDim : im2colBatchIterDims) {
204+
AffineExpr convDimExpr = getAffineDimExpr(convDim, getContext());
205+
int64_t im2colInputDim = inputMap.getResultPosition(convDimExpr).value();
206+
207+
AffineExpr igemmDimExpr = igemmConvDetails.convToIgemmDimMap.at(convDim);
208+
int64_t igemmInputDim = igemmConvDetails.getIgemmInputImageMap()
209+
.getResultPosition(igemmDimExpr)
210+
.value();
211+
batchPos[igemmInputDim] = im2colInputDim;
226212
}
227213

228214
SmallVector<int64_t> mPos;
@@ -236,7 +222,6 @@ class ConvertConvGeneric final
236222
for (auto [idx, e] : llvm::enumerate(outputMap.getResults())) {
237223
if (e.isFunctionOfDim(outputImage)) {
238224
mShape.push_back(outputShape[idx]);
239-
colTensorShape.push_back(outputShape[idx]);
240225
}
241226
}
242227
}
@@ -251,12 +236,11 @@ class ConvertConvGeneric final
251236
}
252237
// The index at which the reduction dimension bounds starts in
253238
// igemmLoopBounds.
254-
int64_t reductionBoundIndex = convDims.batch.size() +
255-
convDims.outputImage.size() +
256-
convDims.outputChannel.size();
239+
int64_t reductionBoundIndex =
240+
convDims.batch.size() + convDims.depth.size() +
241+
convDims.outputImage.size() + convDims.outputChannel.size();
257242
SmallVector<int64_t> kShape(igemmLoopBounds.begin() + reductionBoundIndex,
258243
igemmLoopBounds.end());
259-
colTensorShape.insert(colTensorShape.end(), kShape.begin(), kShape.end());
260244

261245
SmallVector<OpFoldResult> mBasis =
262246
getAsIndexOpFoldResult(getContext(), getBasisFromShape(mShape));
@@ -266,9 +250,17 @@ class ConvertConvGeneric final
266250
SmallVector<OpFoldResult> kOffset(kBasis.size(), rewriter.getIndexAttr(0));
267251
SmallVector<OpFoldResult> mOffset(mBasis.size(), rewriter.getIndexAttr(0));
268252

269-
SmallVector<int64_t> inputKPerm = computeInputKPerm(inputMap, filterMap);
253+
SmallVector<int64_t> inputKPerm =
254+
computeInputKPerm(inputMap, filterMap, convDims);
270255

271256
auto loc = linalgOp.getLoc();
257+
// Shape of the resulting tensor from im2col.
258+
SmallVector<int64_t> colTensorShape;
259+
for (int64_t dim : batchPos) {
260+
colTensorShape.push_back(inputShape[dim]);
261+
}
262+
colTensorShape.append(mShape);
263+
colTensorShape.append(kShape);
272264
Value colTensor = rewriter.create<tensor::EmptyOp>(
273265
loc, colTensorShape, inputType.getElementType());
274266
Value img2ColTensor =

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv_to_im2col.mlir

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,111 @@ util.func public @conv_1d_nhc_chf(%arg0: tensor<1x3x2xf32>, %arg1: tensor<2x2x2x
453453
// CHECK-SAME: input_k_perm = [1, 0]
454454
// CHECK-SAME: ins({{.*}} : tensor<1x3x2xf32>)
455455
// CHECK-SAME: outs({{.*}} : tensor<1x2x4xf32>) -> tensor<1x2x4xf32>
456+
457+
// -----
458+
459+
util.func public @conv_2d_nhwgc_gfhwc(%arg0: tensor<2x10x10x7x4xf32>, %arg1: tensor<7x16x3x3x4xf32>, %arg2: tensor<2x8x8x7x16xf32>) -> tensor<2x8x8x7x16xf32> {
460+
%0 = linalg.conv_2d_nhwgc_gfhwc
461+
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
462+
ins(%arg0, %arg1: tensor<2x10x10x7x4xf32>, tensor<7x16x3x3x4xf32>)
463+
outs(%arg2: tensor<2x8x8x7x16xf32>) -> tensor<2x8x8x7x16xf32>
464+
util.return %0 : tensor<2x8x8x7x16xf32>
465+
}
466+
// n h w g f c
467+
// CHECK-DAG: #[[LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2, d5)>
468+
// CHECK-DAG: #[[RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
469+
// CHECK-DAG: #[[OUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>
470+
// CHECK: util.func public @conv_2d_nhwgc_gfhwc(
471+
// CHECK-SAME: %[[IMG:.+]]: [[IMG_T:tensor<2x10x10x7x4xf32>]]
472+
// CHECK-SAME: %[[FIL:.+]]: [[FIL_T:tensor<7x16x3x3x4xf32>]]
473+
// CHECK-SAME: %[[OUT:.+]]: [[OUT_T:tensor<2x8x8x7x16xf32>]]
474+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : [[LHS_T:tensor<2x7x8x8x36xf32>]]
475+
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
476+
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
477+
// CHECK-SAME: m_offset = [0, 0] * [8, 1] k_offset = [0] * [1]
478+
// CHECK-SAME: batch_pos = [0, 3] m_pos = [1, 2] k_pos = [4]
479+
// CHECK-SAME: input_k_perm = [0, 1, 2]
480+
// CHECK-SAME: ins(%[[IMG]] : [[IMG_T]])
481+
// CHECK-SAME: outs(%[[EMPTY]] : [[LHS_T]])
482+
// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FIL]] {{\[}}[0], [1], [2, 3, 4]] : [[FIL_T]] into [[RHS_T:tensor<7x16x36xf32>]]
483+
// CHECK: %[[MATMUL:.+]] = linalg.generic
484+
// CHECK-SAME: indexing_maps = [#[[LHS_MAP]], #[[RHS_MAP]], #[[OUT_MAP]]]
485+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]
486+
// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : [[LHS_T]], [[RHS_T]])
487+
// CHECK-SAME: outs(%[[OUT]] : [[OUT_T]]) {
488+
// CHECK: }
489+
// CHECK: util.return %[[MATMUL]]
490+
491+
// -----
492+
493+
util.func public @conv_2d_ngchw_fgchw(%arg0: tensor<2x7x4x10x10xf32>, %arg1: tensor<16x7x4x3x3xf32>, %arg2: tensor<2x7x16x8x8xf32>) -> tensor<2x7x16x8x8xf32> {
494+
%0 = linalg.conv_2d_ngchw_fgchw
495+
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
496+
ins(%arg0, %arg1: tensor<2x7x4x10x10xf32>, tensor<16x7x4x3x3xf32>)
497+
outs(%arg2: tensor<2x7x16x8x8xf32>) -> tensor<2x7x16x8x8xf32>
498+
util.return %0 : tensor<2x7x16x8x8xf32>
499+
}
500+
// n g f h w c
501+
// CHECK-DAG: #[[LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5)>
502+
// CHECK-DAG: #[[RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d5)>
503+
// CHECK-DAG: #[[OUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>
504+
// CHECK: util.func public @conv_2d_ngchw_fgchw(
505+
// CHECK-SAME: %[[IMG:.+]]: [[IMG_T:tensor<2x7x4x10x10xf32>]]
506+
// CHECK-SAME: %[[FIL:.+]]: [[FIL_T:tensor<16x7x4x3x3xf32>]]
507+
// CHECK-SAME: %[[OUT:.+]]: [[OUT_T:tensor<2x7x16x8x8xf32>]]
508+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : [[RHS_T:tensor<2x7x8x8x36xf32>]]
509+
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
510+
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
511+
// CHECK-SAME: m_offset = [0, 0] * [8, 1] k_offset = [0] * [1]
512+
// CHECK-SAME: batch_pos = [0, 1] m_pos = [3, 4] k_pos = [2]
513+
// CHECK-SAME: input_k_perm = [0, 1, 2]
514+
// CHECK-SAME: ins(%[[IMG]] : [[IMG_T]])
515+
// CHECK-SAME: outs(%[[EMPTY]] : [[LHS_T]])
516+
// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FIL]] {{\[}}[0], [1], [2, 3, 4]] : [[FIL_T]] into [[LHS_T:tensor<16x7x36xf32>]]
517+
// CHECK: %[[MATMUL:.+]] = linalg.generic
518+
// CHECK-SAME: indexing_maps = [#[[LHS_MAP]], #[[RHS_MAP]], #[[OUT_MAP]]]
519+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]
520+
// CHECK-SAME: ins(%[[COLLAPSED]], %[[IM2COL]] : [[LHS_T]], [[RHS_T]])
521+
// CHECK-SAME: outs(%[[OUT]] : [[OUT_T]]) {
522+
// CHECK: }
523+
// CHECK: util.return %[[MATMUL]]
524+
525+
// -----
526+
// n g h w f c kh kw
527+
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
528+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
529+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d0, d2, d3, d4)>
530+
// Output has 'n' and 'g' dimensions transposed.
531+
util.func public @conv_2d_ngchw_fgchw_gnfhw(%arg0: tensor<2x7x4x10x10xf32>, %arg1: tensor<16x7x4x3x3xf32>, %arg2: tensor<7x2x16x8x8xf32>) -> tensor<7x2x16x8x8xf32> {
532+
%0 = linalg.generic {
533+
indexing_maps = [#map, #map1, #map2],
534+
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
535+
} ins(%arg0, %arg1 : tensor<2x7x4x10x10xf32>, tensor<16x7x4x3x3xf32>) outs(%arg2 : tensor<7x2x16x8x8xf32>) {
536+
^bb0(%in: f32, %in_0: f32, %out: f32):
537+
%1 = arith.mulf %in, %in_0 : f32
538+
%2 = arith.addf %out, %1 : f32
539+
linalg.yield %2 : f32
540+
} -> tensor<7x2x16x8x8xf32>
541+
util.return %0 : tensor<7x2x16x8x8xf32>
542+
}
543+
// g n f h w c
544+
// CHECK-DAG: #[[LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d5)>
545+
// CHECK-DAG: #[[RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4, d5)>
546+
// CHECK-DAG: #[[OUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>
547+
// CHECK: util.func public @conv_2d_ngchw_fgchw_gnfhw(
548+
// CHECK-SAME: %[[IMG:.+]]: [[IMG_T:tensor<2x7x4x10x10xf32>]]
549+
// CHECK-SAME: %[[FIL:.+]]: [[FIL_T:tensor<16x7x4x3x3xf32>]]
550+
// CHECK-SAME: %[[OUT:.+]]: [[OUT_T:tensor<7x2x16x8x8xf32>]]
551+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : [[RHS_T:tensor<2x7x8x8x36xf32>]]
552+
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
553+
// CHECK-SAME: batch_pos = [0, 1] m_pos = [3, 4] k_pos = [2]
554+
// CHECK-SAME: ins(%[[IMG]] : [[IMG_T]])
555+
// CHECK-SAME: outs(%[[EMPTY]] : [[RHS_T]])
556+
// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FIL]] {{\[}}[0], [1], [2, 3, 4]] : [[FIL_T]] into [[LHS_T:tensor<16x7x36xf32>]]
557+
// CHECK: %[[MATMUL:.+]] = linalg.generic
558+
// CHECK-SAME: indexing_maps = [#[[LHS_MAP]], #[[RHS_MAP]], #[[OUT_MAP]]]
559+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]
560+
// CHECK-SAME: ins(%[[COLLAPSED]], %[[IM2COL]] : [[LHS_T]], [[RHS_T]])
561+
// CHECK-SAME: outs(%[[OUT]] : [[OUT_T]]) {
562+
// CHECK: }
563+
// CHECK: util.return %[[MATMUL]]

0 commit comments

Comments
 (0)