Skip to content

Commit 510b024

Browse files
authored
[Codegen] Add padding for convolutions before IGEMM (#21470)
Previously for IGEMM path, we first converted convolutions to IGEMM and then padded GEMM operands using `GPUPadOperandsPass`. However, this doesn't guarantee the im2col to be vectorized after decomposition, because the reduction(K) dimension is padded after collapsing all the reduction dimensions from convolutions. The im2col input tensor on the hand still depends on the original layout of convolution and vector loads can be spanned on multiple different channel dimensions, which are not contiguous. To solve this problem, this PR added a pass to pad convolutions before converting them to IGEMM. We directly pad the input channel dimension to be multiple of vector size, so that all vector loads from the input will be contiguous. In addition, this pass makes the use of upstream `PadTilingInterface` for padding and adapts the `padding_size` in the `lowering_config` as padding options. --------- Signed-off-by: yzhang93 <[email protected]>
1 parent 90729fb commit 510b024

18 files changed

+365
-38
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ iree_compiler_cc_library(
8080
"GPUMultiBuffering.cpp",
8181
"GPUNestedLayoutDistributionPatterns.cpp",
8282
"GPUPackToIntrinsics.cpp",
83+
"GPUPadConvs.cpp",
8384
"GPUPadOperands.cpp",
8485
"GPUPatterns.cpp",
8586
"GPUPipelining.cpp",

compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ iree_cc_library(
7373
"GPUMultiBuffering.cpp"
7474
"GPUNestedLayoutDistributionPatterns.cpp"
7575
"GPUPackToIntrinsics.cpp"
76+
"GPUPadConvs.cpp"
7677
"GPUPadOperands.cpp"
7778
"GPUPatterns.cpp"
7879
"GPUPipelining.cpp"
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
8+
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
9+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
10+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11+
#include "mlir/Transforms/Passes.h"
12+
13+
namespace mlir::iree_compiler {
14+
15+
#define GEN_PASS_DEF_GPUPADCONVSPASS
16+
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
17+
18+
namespace {
19+
20+
static LogicalResult padToStaticSizes(RewriterBase &rewriter,
21+
TilingInterface tilingInterfaceOp,
22+
SmallVector<OpFoldResult> paddingSizes) {
23+
SmallVector<Attribute> paddingValues;
24+
for (Value operand : tilingInterfaceOp.getOperation()->getOperands()) {
25+
paddingValues.push_back(
26+
rewriter.getZeroAttr(getElementTypeOrSelf(operand.getType())));
27+
}
28+
29+
auto options = linalg::PadTilingInterfaceOptions()
30+
.setPaddingSizes(paddingSizes)
31+
.setPaddingValues(paddingValues)
32+
.setPadToMultipleOf(true);
33+
34+
SmallVector<tensor::PadOp> padOps;
35+
FailureOr<TilingInterface> maybePaddedOp =
36+
linalg::rewriteAsPaddedOp(rewriter, tilingInterfaceOp, options, padOps);
37+
if (failed(maybePaddedOp)) {
38+
return tilingInterfaceOp->emitOpError("failed to pad op");
39+
}
40+
41+
return success();
42+
}
43+
44+
struct GPUPadConvsPass final : impl::GPUPadConvsPassBase<GPUPadConvsPass> {
45+
void runOnOperation() override {
46+
FunctionOpInterface funcOp = getOperation();
47+
48+
IRRewriter rewriter(funcOp);
49+
funcOp.walk([&](TilingInterface op) {
50+
auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
51+
if (!linalgOp || !linalg::isaConvolutionOpInterface(linalgOp)) {
52+
return;
53+
}
54+
55+
auto loweringConfig =
56+
getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
57+
if (!loweringConfig) {
58+
return;
59+
}
60+
61+
// Get padding sizes from lowering_config.
62+
std::optional<SmallVector<int64_t>> paddingSizes =
63+
getPaddingList(loweringConfig, /*padConv*/ true);
64+
if (!paddingSizes) {
65+
return;
66+
}
67+
68+
SmallVector<OpFoldResult> padSizes =
69+
getAsIndexOpFoldResult(rewriter.getContext(), paddingSizes.value());
70+
rewriter.setInsertionPoint(op);
71+
if (failed(padToStaticSizes(rewriter, op, padSizes))) {
72+
return signalPassFailure();
73+
}
74+
});
75+
}
76+
};
77+
78+
} // namespace
79+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,16 @@ def GPUPackToIntrinsicsPass :
195195
];
196196
}
197197

198+
def GPUPadConvsPass :
199+
InterfacePass<"iree-codegen-gpu-pad-convs",
200+
"mlir::FunctionOpInterface"> {
201+
let summary = "Pass to pad operands of a convolution with padding configuration provided.";
202+
let dependentDialects = [
203+
"::mlir::linalg::LinalgDialect",
204+
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect"
205+
];
206+
}
207+
198208
def GPUPadOperandsPass :
199209
InterfacePass<"iree-codegen-gpu-pad-operands",
200210
"mlir::FunctionOpInterface"> {

compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ iree_lit_test_suite(
4545
"gpu_nested_layout_vector_distribution_mask.mlir",
4646
"gpu_nested_layout_vector_distribution_multi_reduce.mlir",
4747
"gpu_nested_layout_vector_distribution_step.mlir",
48+
"gpu_pad_convs.mlir",
4849
"gpu_pad_operands.mlir",
4950
"gpu_pipeline.mlir",
5051
"gpu_promote_matmul_operands.mlir",

compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ iree_lit_test_suite(
4141
"gpu_nested_layout_vector_distribution_multi_reduce.mlir"
4242
"gpu_nested_layout_vector_distribution_step.mlir"
4343
"gpu_pack_to_instrinsics.mlir"
44+
"gpu_pad_convs.mlir"
4445
"gpu_pad_operands.mlir"
4546
"gpu_pipeline.mlir"
4647
"gpu_promote_matmul_operands.mlir"
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-pad-convs))" | FileCheck %s
2+
3+
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
4+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
5+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
6+
#lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, padding_conv = [1, 8, 32, 32, 0, 0, 32]}>
7+
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<16x26x19x287xf16>, %arg1: tensor<287x3x3x287xf16>, %arg2: tensor<16x24x17x287xf32>) -> tensor<16x24x17x287xf32> {
8+
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x26x19x287xf16>, tensor<287x3x3x287xf16>) outs(%arg2 : tensor<16x24x17x287xf32>) attrs = {lowering_config = #lowering_config} {
9+
^bb0(%in: f16, %in_0: f16, %out: f32):
10+
%1 = arith.extf %in : f16 to f32
11+
%2 = arith.extf %in_0 : f16 to f32
12+
%3 = arith.mulf %1, %2 : f32
13+
%4 = arith.addf %out, %3 : f32
14+
linalg.yield %4 : f32
15+
} -> tensor<16x24x17x287xf32>
16+
return %0 : tensor<16x24x17x287xf32>
17+
}
18+
19+
// CHECK-LABEL: func.func @conv_2d_nhwc_fhwc
20+
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<16x26x19x287xf16>
21+
// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<287x3x3x287xf16>
22+
// CHECK-SAME: %[[C:[A-Za-z0-9]+]]: tensor<16x24x17x287xf32>
23+
// CHECK: %[[PADDED_LHS:.+]] = tensor.pad %[[A]] low[0, 0, 0, 0] high[0, 0, 15, 1]
24+
// CHECK: %[[PADDED_RHS:.+]] = tensor.pad %[[B]] low[0, 0, 0, 0] high[1, 0, 0, 1]
25+
// CHECK: %[[PADDED_INIT:.+]] = tensor.pad %[[C]] low[0, 0, 0, 0] high[0, 0, 15, 1]
26+
// CHECK: %[[PADDED_RESULT:.+]] = linalg.generic
27+
// CHECK-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] : tensor<16x26x34x288xf16>, tensor<288x3x3x288xf16>)
28+
// CHECK-SAME: outs(%[[PADDED_INIT]] : tensor<16x24x32x288xf32>)
29+
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[PADDED_RESULT]][0, 0, 0, 0] [16, 24, 17, 287] [1, 1, 1, 1]
30+
// CHECK-SAME: : tensor<16x24x32x288xf32> to tensor<16x24x17x287xf32>
31+
// CHECK: return %[[EXTRACT]] : tensor<16x24x17x287xf32>
32+
33+
// -----
34+
35+
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1 + d5 * 2, d2 + d6 * 2, d3)>
36+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d0)>
37+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
38+
#lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16>, padding_conv = [16, 1, 1, 16, 0, 0, 0]}>
39+
func.func @conv_2d_chwn_chwf(%arg0: tensor<16x193x129x40xbf16>, %arg1: tensor<16x96x64x40xbf16>, %arg2: tensor<40x3x3x40xf32>) -> tensor<40x3x3x40xf32> {
40+
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x193x129x40xbf16>, tensor<16x96x64x40xbf16>) outs(%arg2 : tensor<40x3x3x40xf32>) attrs = {lowering_config = #lowering_config} {
41+
^bb0(%in: bf16, %in_0: bf16, %out: f32):
42+
%1 = arith.extf %in : bf16 to f32
43+
%2 = arith.extf %in_0 : bf16 to f32
44+
%3 = arith.mulf %1, %2 : f32
45+
%4 = arith.addf %out, %3 : f32
46+
linalg.yield %4 : f32
47+
} -> tensor<40x3x3x40xf32>
48+
return %0 : tensor<40x3x3x40xf32>
49+
}
50+
51+
// CHECK-LABEL: func.func @conv_2d_chwn_chwf
52+
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<16x193x129x40xbf16>
53+
// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<16x96x64x40xbf16>
54+
// CHECK-SAME: %[[C:[A-Za-z0-9]+]]: tensor<40x3x3x40xf32>
55+
// CHECK: %[[PADDED_LHS:.+]] = tensor.pad %[[A]] low[0, 0, 0, 0] high[0, 0, 0, 8]
56+
// CHECK: %[[PADDED_RHS:.+]] = tensor.pad %[[B]] low[0, 0, 0, 0] high[0, 0, 0, 8]
57+
// CHECK: %[[PADDED_INIT:.+]] = tensor.pad %[[C]] low[0, 0, 0, 0] high[8, 0, 0, 8]
58+
// CHECK: %[[PADDED_RESULT:.+]] = linalg.generic
59+
// CHECK-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] : tensor<16x193x129x48xbf16>, tensor<16x96x64x48xbf16>)
60+
// CHECK-SAME: outs(%[[PADDED_INIT]] : tensor<48x3x3x48xf32>)
61+
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[PADDED_RESULT]][0, 0, 0, 0] [40, 3, 3, 40] [1, 1, 1, 1]
62+
// CHECK-SAME: : tensor<48x3x3x48xf32> to tensor<40x3x3x40xf32>
63+
// CHECK: return %[[EXTRACT]] : tensor<40x3x3x40xf32>

compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
363363
// Cleanup patterns for tile and distribute
364364
{
365365
RewritePatternSet patterns(context);
366+
populateSwapExtractWithCollapsePattern(patterns);
366367
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
367368
tensor::populateFoldTensorEmptyPatterns(patterns);
368369
context->getOrLoadDialect<tensor::TensorDialect>()

compiler/src/iree/compiler/Codegen/Common/Transforms.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,14 @@ swapCollapseShapeWithSlice(RewriterBase &rewriter,
516516
"collapsed size must be static");
517517
}
518518

519+
// Compose all nested affine.apply chains and check if the offset is
520+
// multiple of collapsed size.
521+
SmallVector<Value> operands(applyOp.getOperands());
522+
affine::fullyComposeAffineMapAndOperands(&map, &operands);
523+
map = simplifyAffineMap(map);
519524
if (!map.getResult(0).isMultipleOf(maybeStaticSize.value())) {
520525
return rewriter.notifyMatchFailure(
521-
sliceOp, "collapsed size is not divisible by offset multiplier");
526+
sliceOp, "offset multiplier must be multiple of collapsed size");
522527
}
523528

524529
unsigned lastReassocSize = srcShape[reassocIndices.back()];

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,12 @@ setPromotedOperandsList(MLIRContext *context,
178178
}
179179

180180
constexpr StringLiteral kPaddingName = "padding";
181+
constexpr StringLiteral kPaddingConvName = "padding_conv";
181182

182-
std::optional<SmallVector<int64_t>> getPaddingList(LoweringConfigAttr config) {
183-
auto array = config.getAttributes().getAs<ArrayAttr>(kPaddingName);
183+
std::optional<SmallVector<int64_t>> getPaddingList(LoweringConfigAttr config,
184+
bool paddingConv) {
185+
auto attrName = paddingConv ? kPaddingConvName : kPaddingName;
186+
auto array = config.getAttributes().getAs<ArrayAttr>(attrName);
184187
if (!array) {
185188
return std::nullopt;
186189
}

0 commit comments

Comments
 (0)