Skip to content

Commit 6dba8b5

Browse files
authored
[Im2col] Add input-filter permutation info to im2col metadata (#20531)
This PR extends the `im2col` operation to support a new attribute `input_filter_perm`, which encodes the permutation required to align the layout of input and filter reduction dimensions. This is useful when the logical layout of the filter (e.g., `CHW`) differs from that of the input (e.g., `HWC`). Then in im2col decomposition, using permutation metadata to handle layout mismatches and reorder index computations. This fixes #20473. --------- Signed-off-by: yzhang93 <[email protected]>
1 parent ff5b150 commit 6dba8b5

File tree

10 files changed

+315
-19
lines changed

10 files changed

+315
-19
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_derived_thread_config.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ module {
179179
strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
180180
m_offset = [0] * [1] k_offset = [0] * [1]
181181
batch_pos = [0] m_pos = [2, 3] k_pos = [1]
182+
input_k_perm = [0, 1, 2]
182183
ins(%2 : tensor<2x34x34x128xf16>)
183184
outs(%3 : tensor<2x128x8xf16>) -> tensor<2x128x8xf16>
184185
return %4 : tensor<2x128x8xf16>

compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,12 @@ FailureOr<SmallVector<Value>> Im2colOp::decomposeOperation(OpBuilder &b) {
777777
}
778778
kBasis.push_back(size);
779779
}
780+
781+
// Transpose the order of (P, Q, C) according to `inputKPerm` encoded in
782+
// im2col metadata.
783+
ArrayRef<int64_t> inputKPerm = getInputKPerm();
784+
applyPermutationToVector(kBasis, inputKPerm);
785+
780786
OpFoldResult kIndex = kOffset;
781787
for (auto [i, ivIdx, stride] :
782788
llvm::enumerate(getKOutputDims(), getMixedKStrides())) {
@@ -792,17 +798,18 @@ FailureOr<SmallVector<Value>> Im2colOp::decomposeOperation(OpBuilder &b) {
792798
/*hasOuterBound=*/true)
793799
.getResults();
794800
// Split the delinearized offsets into the window offsets (for M offsets)
795-
// and the K offsets for the input tensor.
801+
// and the K offsets for the input tensor based on the layout.
796802
SmallVector<Value> windowOffset, inputKOffset;
797803
int delinKIdx = 0;
804+
SmallVector<int64_t> invInputKPerm = invertPermutationVector(inputKPerm);
798805
for (int i = 0; i < getInputRank(); ++i) {
799806
if (batchPosSet.contains(i))
800807
continue;
801808
if (mPosSet.contains(i)) {
802-
windowOffset.push_back(delinKOffset[delinKIdx++]);
809+
windowOffset.push_back(delinKOffset[invInputKPerm[delinKIdx++]]);
803810
continue;
804811
}
805-
inputKOffset.push_back(delinKOffset[delinKIdx++]);
812+
inputKOffset.push_back(delinKOffset[invInputKPerm[delinKIdx++]]);
806813
}
807814

808815
// Compute offsets for extract. The linearized im2col result M offset is

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,13 +1646,16 @@ SmallVector<int64_t> Im2colOp::getKOutputDims() {
16461646
}
16471647

16481648
/// Custom builder methods for im2col op.
1649-
void Im2colOp::build(
1650-
OpBuilder &builder, OperationState &state, Value input, Value output,
1651-
ArrayRef<int64_t> strides, ArrayRef<int64_t> dilations,
1652-
ArrayRef<OpFoldResult> kernelSize, ArrayRef<OpFoldResult> mOffset,
1653-
ArrayRef<OpFoldResult> mStrides, ArrayRef<OpFoldResult> kOffset,
1654-
ArrayRef<OpFoldResult> kStrides, ArrayRef<int64_t> batchPos,
1655-
ArrayRef<int64_t> mPos, ArrayRef<int64_t> kPos) {
1649+
void Im2colOp::build(OpBuilder &builder, OperationState &state, Value input,
1650+
Value output, ArrayRef<int64_t> strides,
1651+
ArrayRef<int64_t> dilations,
1652+
ArrayRef<OpFoldResult> kernelSize,
1653+
ArrayRef<OpFoldResult> mOffset,
1654+
ArrayRef<OpFoldResult> mStrides,
1655+
ArrayRef<OpFoldResult> kOffset,
1656+
ArrayRef<OpFoldResult> kStrides,
1657+
ArrayRef<int64_t> batchPos, ArrayRef<int64_t> mPos,
1658+
ArrayRef<int64_t> kPos, ArrayRef<int64_t> inputKPerm) {
16561659
assert(strides.size() == kernelSize.size() &&
16571660
dilations.size() == kernelSize.size() &&
16581661
mPos.size() == kernelSize.size() &&
@@ -1680,7 +1683,8 @@ void Im2colOp::build(
16801683
builder.getDenseI64ArrayAttr(staticKOffset), dynamicKStrides,
16811684
builder.getDenseI64ArrayAttr(staticKStrides),
16821685
builder.getDenseI64ArrayAttr(batchPos),
1683-
builder.getDenseI64ArrayAttr(mPos), builder.getDenseI64ArrayAttr(kPos));
1686+
builder.getDenseI64ArrayAttr(mPos), builder.getDenseI64ArrayAttr(kPos),
1687+
builder.getDenseI64ArrayAttr(inputKPerm));
16841688
}
16851689

16861690
LogicalResult Im2colOp::verify() {
@@ -1743,6 +1747,7 @@ LogicalResult Im2colOp::verify() {
17431747
ArrayRef<int64_t> strides = getStrides();
17441748
ArrayRef<int64_t> dilations = getDilations();
17451749
SmallVector<OpFoldResult> kernelSize = getMixedKernelSize();
1750+
ArrayRef<int64_t> inputKPerm = getInputKPerm();
17461751
if (kernelSize.size() != mPos.size()) {
17471752
return op->emitOpError(
17481753
"expected kernel rank to be equal to the m_pos rank");
@@ -1756,6 +1761,23 @@ LogicalResult Im2colOp::verify() {
17561761
"expected dilations rank to be equal to the kernel rank");
17571762
}
17581763

1764+
size_t sharedRank = mPos.size() + kPos.size();
1765+
if (inputKPerm.size() != sharedRank) {
1766+
return op->emitOpError("expected input_k_perm size (")
1767+
<< inputKPerm.size()
1768+
<< ") to match the number of shared dimensions (m_Pos + k_pos = "
1769+
<< sharedRank << ")";
1770+
}
1771+
SmallVector<int64_t> permVec(inputKPerm.begin(), inputKPerm.end());
1772+
llvm::sort(permVec);
1773+
for (int64_t i = 0; i < static_cast<int64_t>(sharedRank); ++i) {
1774+
if (permVec[i] != i) {
1775+
return op->emitOpError(
1776+
"expected input_k_perm to be a permutation of [0, ")
1777+
<< sharedRank << ")";
1778+
}
1779+
}
1780+
17591781
// Verify input and output shapes.
17601782
ArrayRef<int64_t> inputShape = inputType.getShape();
17611783
ArrayRef<int64_t> outputShape = outputType.getShape();

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,17 @@ def IREELinalgExt_Im2colOp : IREELinalgExt_Op<"im2col",
889889
would be 4 for `K0`, and 1 for `K1`, meaning as `K0` increases by 1, the
890890
index into the flat `K` increases by 4. The strides in M from `m_strides`
891891
are orthogonal to the strides in `K` from `k_strides`.
892+
893+
The `input_k_perm` attribute defines the permutation needed to align the
894+
reduction dimensions of the input layout with those of the filter layout
895+
when computing the K dimension of the im2col output. This is useful when the
896+
layout of the filter (e.g., `CHW`) differs from that of the input (e.g., `HWC`).
897+
For instance, an `input_k_perm = [2, 0, 1]` indicates the input indices needs
898+
to be transposed from `HWC` to `CHW` layout before extracting slices during
899+
decomposition. The identity permutation (e.g., input_k_perm = [0, 1, 2])
900+
indicates that the input layout is already aligned with the filter layout
901+
in terms of reduction dimensions, so no transposition of indices is necessary
902+
before slice extraction.
892903
}];
893904

894905
let arguments = (ins AnyShaped:$input, AnyShaped:$output,
@@ -906,7 +917,8 @@ def IREELinalgExt_Im2colOp : IREELinalgExt_Op<"im2col",
906917
DenseI64ArrayAttr:$static_k_strides,
907918
DenseI64ArrayAttr:$batch_pos,
908919
DenseI64ArrayAttr:$m_pos,
909-
DenseI64ArrayAttr:$k_pos);
920+
DenseI64ArrayAttr:$k_pos,
921+
DenseI64ArrayAttr:$input_k_perm);
910922

911923
let results = (outs Variadic<AnyShaped>:$results);
912924
let hasFolder = 1;
@@ -925,6 +937,7 @@ def IREELinalgExt_Im2colOp : IREELinalgExt_Op<"im2col",
925937
`batch_pos` `=` $batch_pos
926938
`m_pos` `=` $m_pos
927939
`k_pos` `=` $k_pos
940+
`input_k_perm` `=` $input_k_perm
928941
`ins` `(` $input `:` type($input) `)`
929942
`outs` `(` $output `:` type($output) `)`
930943
(`->` type($results)^)?
@@ -941,7 +954,8 @@ def IREELinalgExt_Im2colOp : IREELinalgExt_Op<"im2col",
941954
"ArrayRef<OpFoldResult>":$k_strides,
942955
"ArrayRef<int64_t>":$batch_dimensions,
943956
"ArrayRef<int64_t>":$m_dimensions,
944-
"ArrayRef<int64_t>":$k_dimensions)>
957+
"ArrayRef<int64_t>":$k_dimensions,
958+
"ArrayRef<int64_t>":$input_k_perm)>
945959
];
946960

947961
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{

compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,7 @@ func.func @illegal_im2col_strides(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x10
699699
%1 = iree_linalg_ext.im2col strides = [1] dilations = [1, 1] kernel_size = [3, 3]
700700
m_offset = [0] * [1] k_offset = [0] * [1]
701701
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
702+
input_k_perm = [0, 1, 2]
702703
ins(%arg0 : tensor<2x34x34x640xf32>)
703704
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
704705
return %1 : tensor<2x1024x5760xf32>
@@ -712,6 +713,7 @@ func.func @illegal_im2col_dilations(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x
712713
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1, 1] kernel_size = [3, 3]
713714
m_offset = [0] * [1] k_offset = [0] * [1]
714715
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
716+
input_k_perm = [0, 1, 2]
715717
ins(%arg0 : tensor<2x34x34x640xf32>)
716718
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
717719
return %1 : tensor<2x1024x5760xf32>
@@ -725,6 +727,7 @@ func.func @illegal_im2col_kernel_size(%arg0: tensor<2x34x34x640xf32>) -> tensor<
725727
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3]
726728
m_offset = [0] * [1] k_offset = [0] * [1]
727729
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
730+
input_k_perm = [0, 1, 2]
728731
ins(%arg0 : tensor<2x34x34x640xf32>)
729732
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
730733
return %1 : tensor<2x1024x5760xf32>
@@ -738,6 +741,7 @@ func.func @illegal_im2col_m_offset(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1
738741
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
739742
m_offset = [0, 0] * [1] k_offset = [0] * [1]
740743
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
744+
input_k_perm = [0, 1, 2]
741745
ins(%arg0 : tensor<2x34x34x640xf32>)
742746
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
743747
return %1 : tensor<2x1024x5760xf32>
@@ -751,6 +755,7 @@ func.func @illegal_im2col_k_offset(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1
751755
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
752756
m_offset = [0] * [1] k_offset = [0, 0] * [1]
753757
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
758+
input_k_perm = [0, 1, 2]
754759
ins(%arg0 : tensor<2x34x34x640xf32>)
755760
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
756761
return %1 : tensor<2x1024x5760xf32>
@@ -764,6 +769,7 @@ func.func @illegal_im2col_m_strides(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x
764769
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
765770
m_offset = [0] * [0] k_offset = [0] * [1]
766771
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
772+
input_k_perm = [0, 1, 2]
767773
ins(%arg0 : tensor<2x34x34x640xf32>)
768774
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
769775
return %1 : tensor<2x1024x5760xf32>
@@ -777,6 +783,7 @@ func.func @illegal_im2col_k_strides(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x
777783
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
778784
m_offset = [0] * [1] k_offset = [0] * [2]
779785
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
786+
input_k_perm = [0, 1, 2]
780787
ins(%arg0 : tensor<2x34x34x640xf32>)
781788
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
782789
return %1 : tensor<2x1024x5760xf32>
@@ -790,6 +797,7 @@ func.func @illegal_im2col_input_rank(%arg0: tensor<1x2x34x34x640xf32>) -> tensor
790797
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
791798
m_offset = [0] * [1] k_offset = [0] * [1]
792799
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
800+
input_k_perm = [0, 1, 2]
793801
ins(%arg0 : tensor<1x2x34x34x640xf32>)
794802
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
795803
return %1 : tensor<2x1024x5760xf32>
@@ -803,13 +811,42 @@ func.func @illegal_im2col_output_rank(%arg0: tensor<2x34x34x640xf32>) -> tensor<
803811
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
804812
m_offset = [0] * [1] k_offset = [0] * [1]
805813
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
814+
input_k_perm = [0, 1, 2]
806815
ins(%arg0 : tensor<2x34x34x640xf32>)
807816
outs(%0 : tensor<2x1024x9x640xf32>) -> tensor<2x1024x9x640xf32>
808817
return %1 : tensor<2x1024x9x640xf32>
809818
}
810819

811820
// -----
812821

822+
func.func @illegal_im2col_perm_num(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
823+
%0 = tensor.empty() : tensor<2x1024x5760xf32>
824+
// expected-error @+1 {{expected input_k_perm size (2) to match the number of shared dimensions (m_Pos + k_pos = 3)}}
825+
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
826+
m_offset = [0] * [1] k_offset = [0] * [1]
827+
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
828+
input_k_perm = [0, 1]
829+
ins(%arg0 : tensor<2x34x34x640xf32>)
830+
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
831+
return %1 : tensor<2x1024x5760xf32>
832+
}
833+
834+
// -----
835+
836+
func.func @illegal_im2col_perm_value(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
837+
%0 = tensor.empty() : tensor<2x1024x5760xf32>
838+
// expected-error @+1 {{expected input_k_perm to be a permutation of [0, 3)}}
839+
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
840+
m_offset = [0] * [1] k_offset = [0] * [1]
841+
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
842+
input_k_perm = [1, 2, 3]
843+
ins(%arg0 : tensor<2x34x34x640xf32>)
844+
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
845+
return %1 : tensor<2x1024x5760xf32>
846+
}
847+
848+
// -----
849+
813850
func.func @illegal_winograd_input_shape(%arg0: tensor<1x10x10x32xf32>) -> tensor<8x8x1x6x6x32xf32> {
814851
%0 = tensor.empty() : tensor<8x8x1x6x6x32xf32>
815852
// expected-error @+1 {{incompatible output shape}}

0 commit comments

Comments
 (0)