Skip to content

Commit 5ee0652

Browse files
authored
[Codegen][IGEMM] Support Conv with no input channel dimension (#23271)
Remove the input channel dim non-empty requirement. This would fix #23270 and #23268. Signed-off-by: yzhang93 <zhyuhang88@gmail.com>
1 parent 1ce2fa2 commit 5ee0652

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -916,8 +916,10 @@ FailureOr<SmallVector<Value>> Im2colOp::decomposeOperation(OpBuilder &b) {
916916
sliceSizes.back() = innerInputTileSize;
917917

918918
// Set the batch and K offsets for the input tensor.
919-
const int64_t kPos = getKPos().front();
920-
sliceOffsets[kPos] = inputKOffset.front();
919+
if (!getKPos().empty()) {
920+
const int64_t kPos = getKPos().front();
921+
sliceOffsets[kPos] = inputKOffset.front();
922+
}
921923
SmallVector<int64_t> inverseOutputPerm =
922924
invertPermutationVector(getOutputPerm());
923925
for (auto [ivIdx, bPos] : llvm::enumerate(getBatchPos())) {

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv_to_im2col.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,32 @@ util.func public @conv_1d_nhc_chf(%arg0: tensor<1x3x2xf32>, %arg1: tensor<2x2x2x
493493

494494
// -----
495495

496+
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2 + d5, d3 + d6, d0, d4)>
497+
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d0, d1)>
498+
#map5 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
499+
util.func public @conv_2d_no_input_channel(%arg0: tensor<61x93x16x64xbf16>, %arg1: tensor<59x91x16x56xbf16>, %arg2: tensor<16x56x3x3x64xf32>) -> tensor<16x56x3x3x64xf32> {
500+
%0 = linalg.generic {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<61x93x16x64xbf16>, tensor<59x91x16x56xbf16>) outs(%arg2 : tensor<16x56x3x3x64xf32>) {
501+
^bb0(%in: bf16, %in_0: bf16, %out: f32):
502+
%1 = arith.extf %in : bf16 to f32
503+
%2 = arith.extf %in_0 : bf16 to f32
504+
%3 = arith.mulf %1, %2 : f32
505+
%4 = arith.addf %out, %3 : f32
506+
linalg.yield %4 : f32
507+
} -> tensor<16x56x3x3x64xf32>
508+
util.return %0 : tensor<16x56x3x3x64xf32>
509+
}
510+
511+
// CHECK: util.func public @conv_2d_no_input_channel(
512+
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
513+
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [59, 91]
514+
// CHECK-SAME: m_offset = [0, 0] * [3, 1] k_offset = [0] * [1]
515+
// CHECK-SAME: batch_pos = [3, 2] m_pos = [0, 1] k_pos = []
516+
// CHECK-SAME: input_k_perm = [0, 1] output_perm = [2, 3, 4, 1, 0]
517+
// CHECK-SAME: ins({{.*}} : tensor<61x93x16x64xbf16>)
518+
// CHECK-SAME: outs({{.*}} : tensor<3x3x5369x16x64xbf16>) -> tensor<3x3x5369x16x64xbf16>
519+
520+
// -----
521+
496522
util.func public @conv_2d_nhwgc_gfhwc(%arg0: tensor<2x10x10x7x4xf32>, %arg1: tensor<7x16x3x3x4xf32>, %arg2: tensor<2x8x8x7x16xf32>) -> tensor<2x8x8x7x16xf32> {
497523
%0 = linalg.conv_2d_nhwgc_gfhwc
498524
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,8 @@ getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) {
494494
return failure();
495495
}
496496

497-
// TODO: Support pooling operations. For pooling ops, the input/output channel
498-
// size will be categorized as the additional batch dimension.
499-
if (convDims.outputChannel.empty() || convDims.inputChannel.empty()) {
497+
// TODO: Support pooling operations.
498+
if (convDims.outputChannel.empty()) {
500499
LDBG() << "[unimplemented] expected no pooling operations.";
501500
return failure();
502501
}

0 commit comments

Comments
 (0)