Skip to content

Commit 63ed28b

Browse files
[GPU] Add support for conv padding when decomposing im2col (#20203)
Adds support for fused pad producer to linalg conv ops going down the IGEMM path. This is done by swapping the pad with the extract slice of the decomposed im2col op. See this issue for details on flags needed to fuse the pad with the conv in a single dispatch. #20200 --------- Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 1193f50 commit 63ed28b

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ void DecomposeIm2colPass::runOnOperation() {
8989

9090
RewritePatternSet patterns(context);
9191
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
92+
// After im2col is decomposed, im2col extract slice can be swapped with input
93+
// padding.
94+
patterns.insert<linalg::ExtractSliceOfPadTensorSwapPattern>(
95+
context, [](tensor::ExtractSliceOp) { return false; });
9296
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
9397
return signalPassFailure();
9498
}

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,27 @@ module {
233233
// CHECK-UNROLL: %[[INSERT3:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT2]][%[[C1]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
234234

235235
// CHECK-UNROLL: return %[[INSERT3]] : tensor<2x2x4xf32>
236+
237+
// -----
238+
239+
module {
240+
func.func @im2col_padding(%input: tensor<1x8x3x3xf32>) -> tensor<1x2x2x12xf32> {
241+
%cst = arith.constant 0.000000e+00 : f32
242+
%empty = tensor.empty() : tensor<1x2x2x12xf32>
243+
%padded = tensor.pad %input low[0, 0, 3, 3] high[0, 0, 3, 3] {
244+
^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index):
245+
tensor.yield %cst : f32
246+
} : tensor<1x8x3x3xf32> to tensor<1x8x9x9xf32>
247+
%im2col = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
248+
m_offset = [0, 0] * [2, 1] k_offset = [0] * [1]
249+
batch_pos = [0] m_pos = [2, 3] k_pos = [1]
250+
ins(%padded : tensor<1x8x9x9xf32>)
251+
outs(%empty : tensor<1x2x2x12xf32>) -> tensor<1x2x2x12xf32>
252+
return %im2col : tensor<1x2x2x12xf32>
253+
}
254+
}
255+
256+
// CHECK-LABEL: func.func @im2col_padding
257+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
258+
// CHECK: %[[T1:.+]] = tensor.extract_slice %[[ARG0]]
259+
// CHECK: %[[T2:.+]] = tensor.pad %[[T1]]

0 commit comments

Comments
 (0)