Skip to content

Commit 25d8239

Browse files
yzhang93Max191
andauthored
[Codegen][IGEMM] Fix and preserve padding dim order for convs (iree-org#21772)
Fix for the first case in iree-org#21660. The previous logic was trying to determine the position for each convolution dimensions and assign padding values accordingly which is fragile given the fact there could be different kinds of convolutions in generic forms. This PR adds a mapping from dims of a convolution to the corresponding dims in the GEMM space which is more robust for figuring the dims for padding on convolutions. --------- Signed-off-by: yzhang93 <[email protected]> Signed-off-by: Max Dawkins <[email protected]> Co-authored-by: Max Dawkins <[email protected]>
1 parent 8ba9f68 commit 25d8239

File tree

3 files changed

+112
-77
lines changed

3 files changed

+112
-77
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 60 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -206,75 +206,60 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
206206

207207
/// Helper function to get convolution padding sizes if possible.
208208
static std::optional<ArrayAttr> getPaddingConvSizes(
209-
Builder &b, int64_t kSize, int64_t kPaddingSize,
209+
Builder &b, const SmallVector<int64_t> &bounds,
210+
const SmallVector<int64_t> &paddingSizes,
210211
const SmallVector<int64_t> &workgroupTileSizes,
211-
const SmallVector<int64_t> &mDims, const SmallVector<int64_t> &nDims,
212-
const SmallVector<int64_t> &batchDims,
213-
std::optional<mlir::linalg::ConvolutionDimensions> &padConvDims) {
214-
if (!padConvDims.has_value())
212+
const SmallVector<int64_t> &reductionTileSizes,
213+
std::optional<DenseMap<int64_t, AffineExpr>> &convToIgemmDimMap,
214+
std::optional<mlir::linalg::ConvolutionDimensions> &convDims) {
215+
if (!convToIgemmDimMap.has_value() || !convDims.has_value())
215216
return std::nullopt;
216217

217-
SmallVector<unsigned> batchAndImageDims;
218-
mlir::linalg::ConvolutionDimensions convDims = padConvDims.value();
219-
bool isBatchLast = !convDims.batch.empty() &&
220-
convDims.outputImage.back() < convDims.batch.front();
221-
if (isBatchLast) {
222-
batchAndImageDims.append(convDims.outputImage.begin(),
223-
convDims.outputImage.end());
224-
batchAndImageDims.append(convDims.batch.begin(), convDims.batch.end());
225-
} else {
226-
batchAndImageDims.append(convDims.batch.begin(), convDims.batch.end());
227-
batchAndImageDims.append(convDims.outputImage.begin(),
228-
convDims.outputImage.end());
229-
}
230-
231-
SmallVector<unsigned> concatMDims, concatNDims;
232-
bool isOutputChannelFirst =
233-
convDims.outputChannel.back() < convDims.outputImage.front();
234-
if (isOutputChannelFirst) {
235-
concatMDims.append(convDims.outputChannel.begin(),
236-
convDims.outputChannel.end());
237-
concatNDims = batchAndImageDims;
238-
} else {
239-
concatMDims = batchAndImageDims;
240-
concatNDims.append(convDims.outputChannel.begin(),
241-
convDims.outputChannel.end());
242-
}
243-
244-
// Verify that the number of M, N dimensions from IGEMM match the
245-
// corresponding number of convolution dimensions.
246-
if (concatMDims.size() != mDims.size() ||
247-
concatNDims.size() != nDims.size() ||
248-
convDims.depth.size() != batchDims.size()) {
249-
return std::nullopt;
250-
}
251-
218+
DenseMap<int64_t, AffineExpr> convToIgemmMap = convToIgemmDimMap.value();
252219
// Padding sizes for parallel dimensions are the same as workgroup tile
253220
// sizes.
254-
int64_t totalNumDims = convDims.batch.size() + convDims.outputImage.size() +
255-
convDims.outputChannel.size() +
256-
convDims.filterLoop.size() +
257-
convDims.inputChannel.size() + convDims.depth.size();
258-
SmallVector<int64_t> paddingConvSizes(totalNumDims, 0);
259-
if (batchDims.size() != 0) {
260-
for (auto [dim, bDim] : llvm::zip(convDims.depth, batchDims)) {
261-
paddingConvSizes[dim] = workgroupTileSizes[bDim];
221+
DenseSet<int64_t> paddedIGEMMDims;
222+
DenseMap<int64_t, SmallVector<int64_t>> paddedReductionConvDims;
223+
SetVector<int64_t> inputChannelDims(convDims->inputChannel.begin(),
224+
convDims->inputChannel.end());
225+
SmallVector<int64_t> paddingConvSizes(convToIgemmMap.size(), 0);
226+
for (auto [convDim, IGEMMExpr] : convToIgemmMap) {
227+
auto IGEMMDimExpr = cast<AffineDimExpr>(IGEMMExpr);
228+
unsigned IGEMMPos = IGEMMDimExpr.getPosition();
229+
if (reductionTileSizes[IGEMMPos] != 0) {
230+
// For reduction dimensions, avoid setting padding on the convolution
231+
// if the product of the corresponding conv sizes are already divisible
232+
// by the padding size.
233+
if (paddingSizes[IGEMMPos] &&
234+
bounds[IGEMMPos] % paddingSizes[IGEMMPos] == 0) {
235+
paddedIGEMMDims.insert(IGEMMPos);
236+
continue;
237+
}
238+
// Only pad input channel dims. If we need to pad filter dims, then we
239+
// would rather just do padding on the GEMM instead.
240+
if (inputChannelDims.contains(convDim)) {
241+
// Multiple input channel dims for a single IGEMMPos is not supported.
242+
if (paddedIGEMMDims.contains(IGEMMPos)) {
243+
return std::nullopt;
244+
}
245+
paddingConvSizes[convDim] = paddingSizes[IGEMMPos];
246+
paddedIGEMMDims.insert(IGEMMPos);
247+
}
248+
continue;
262249
}
250+
// Multiple padded parallel dims mapping to the same IGEMM dim is not
251+
// supported.
252+
if (workgroupTileSizes[IGEMMPos] != 0 &&
253+
paddedIGEMMDims.contains(IGEMMPos)) {
254+
return std::nullopt;
255+
}
256+
paddingConvSizes[convDim] = paddingSizes[IGEMMPos];
257+
paddedIGEMMDims.insert(IGEMMPos);
263258
}
264-
for (auto [dim, mDim] : llvm::zip(concatMDims, mDims))
265-
paddingConvSizes[dim] = workgroupTileSizes[mDim];
266-
for (auto [dim, nDim] : llvm::zip(concatNDims, nDims))
267-
paddingConvSizes[dim] = workgroupTileSizes[nDim];
268-
269-
// To avoid over-padding, no padding for channel dimensions is needed if
270-
// the product of reduction sizes is already multiples of k padding
271-
// size. Otherwise, pad the innermost channel dimension.
272-
// TODO (vivian): Padding the innermost channel dimension to a multiple
273-
// of vector size may still be needed even if the K-dim is aligned, and
274-
// this should be validated based on performance.
275-
if (kSize % kPaddingSize != 0) {
276-
int64_t innerChannelDim = convDims.inputChannel.back();
277-
paddingConvSizes[innerChannelDim] = kPaddingSize;
259+
260+
// Ensure that all dimensions have been padded.
261+
if (paddedIGEMMDims.size() != paddingSizes.size()) {
262+
return std::nullopt;
278263
}
279264
return b.getI64ArrayAttr(paddingConvSizes);
280265
}
@@ -291,7 +276,9 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
291276
SmallVector<int64_t> bounds, ArrayRef<AffineMap> maps,
292277
ArrayRef<Value> operands, IREE::GPU::TargetAttr target, bool useDirectLoad,
293278
bool scaled,
294-
std::optional<mlir::linalg::ConvolutionDimensions> padConvDims = {}) {
279+
std::optional<DenseMap<int64_t, AffineExpr>> convToIgemmDimMap =
280+
std::nullopt,
281+
std::optional<linalg::ConvolutionDimensions> convDims = std::nullopt) {
295282
if (target.getWgp().getMma().empty()) {
296283
return failure();
297284
}
@@ -537,9 +524,9 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
537524

538525
// Create `padding_conv` attribute when padding convolutions before IGEMM is
539526
// possible, otherwise fallback to pad IGEMM.
540-
if (auto attr = getPaddingConvSizes(
541-
b, bounds[innerKDim], paddingTileSizes[innerKDim],
542-
workgroupTileSizes, mDims, nDims, batchDims, padConvDims)) {
527+
if (auto attr = getPaddingConvSizes(b, bounds, paddingTileSizes,
528+
workgroupTileSizes, reductionTileSizes,
529+
convToIgemmDimMap, convDims)) {
543530
attrs.emplace_back(StringAttr::get(context, "padding_conv"), *attr);
544531
} else {
545532
attrs.emplace_back(StringAttr::get(context, "padding"),
@@ -580,15 +567,18 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
580567
igemmGenericConvDetails->igemmLoopBounds;
581568
SmallVector<Value> igemmOperands = igemmGenericConvDetails->igemmOperands;
582569

583-
std::optional<mlir::linalg::ConvolutionDimensions> padConvDims;
584-
if (padConv)
585-
padConvDims = igemmGenericConvDetails->convDims;
570+
std::optional<DenseMap<int64_t, AffineExpr>> convToIgemmDimMap;
571+
std::optional<linalg::ConvolutionDimensions> convDims;
572+
if (padConv) {
573+
convDims = igemmGenericConvDetails->convDims;
574+
convToIgemmDimMap = igemmGenericConvDetails->convToIgemmDimMap;
575+
}
586576

587577
SmallVector<int64_t> bounds = igemmLoopBounds;
588578
FailureOr<std::pair<LoweringConfigAttr, int64_t>> configAndWgSize =
589579
getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
590580
bounds, igemmContractionMaps, igemmOperands, target, useDirectLoad,
591-
/*scaled*/ false, padConvDims);
581+
/*scaled*/ false, convToIgemmDimMap, convDims);
592582
if (failed(configAndWgSize)) {
593583
return failure();
594584
}

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// RUN: --iree-codegen-llvmgpu-use-igemm=true --iree-codegen-llvmgpu-igemm-pad-convolution=false --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefixes=CHECK,MI300X
66

77
// RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx942 \
8-
// RUN: --iree-codegen-llvmgpu-use-igemm=true --iree-codegen-llvmgpu-igemm-pad-convolution=true --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefix=PAD-CONV
8+
// RUN: --iree-codegen-llvmgpu-use-igemm=true --iree-codegen-llvmgpu-igemm-pad-convolution=true --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefix=PAD-CONV-GFX942
99

1010
func.func @nhwc_conv_mfma() {
1111
%cst = arith.constant 0.000000e+00 : f32
@@ -110,7 +110,7 @@ func.func @nhwc_conv_unaligned_mfma() {
110110
// MI300X-SAME: subgroup = [1, 1, 1, 1, 0]
111111
// MI300X-SAME: workgroup = [1, 1, 16, 64, 0]
112112

113-
// PAD-CONV: padding_conv = [2, 1, 32, 64, 0, 0, 0]
113+
// PAD-CONV-GFX942: padding_conv = [2, 1, 32, 64, 0, 0, 0]
114114

115115
// -----
116116

@@ -149,7 +149,7 @@ func.func @nchw_conv_unaligned_mfma() {
149149
// MI300X-SAME: subgroup = [1, 1, 1, 1, 0]
150150
// MI300X-SAME: workgroup = [1, 32, 1, 32, 0]
151151

152-
// PAD-CONV: padding_conv = [1, 64, 2, 32, 0, 0, 0]
152+
// PAD-CONV-GFX942: padding_conv = [1, 64, 2, 32, 0, 0, 0]
153153

154154
// -----
155155

@@ -188,7 +188,7 @@ func.func @conv_nhwc_fhwc_unaligned_channel(%arg0: tensor<16x26x19x287xf16>, %ar
188188
// MI300X-SAME: subgroup = [1, 4, 1, 1, 0]
189189
// MI300X-SAME: workgroup = [1, 4, 32, 32, 0]
190190

191-
// PAD-CONV: padding_conv = [1, 8, 32, 32, 0, 0, 32]
191+
// PAD-CONV-GFX942: padding_conv = [1, 8, 32, 32, 0, 0, 32]
192192

193193
// -----
194194

@@ -220,7 +220,7 @@ func.func @conv_chwn_chwf_unaligned(%arg0: tensor<16x193x129x40xbf16>, %arg1: te
220220
// CHECK-SAME: subgroup = [1, 1, 1, 1, 0]
221221
// CHECK-SAME: workgroup = [16, 1, 1, 16, 0]
222222

223-
// PAD-CONV: padding_conv = [16, 1, 1, 16, 0, 0, 0]
223+
// PAD-CONV-GFX942: padding_conv = [16, 1, 1, 16, 0, 0, 0]
224224

225225
// -----
226226

@@ -258,4 +258,44 @@ func.func @group_conv_unaligned(%arg0: tensor<61x93x16x56xbf16>, %arg1: tensor<1
258258
// MI300X-SAME: subgroup = [1, 1, 0, 1, 0]
259259
// MI300X-SAME: workgroup = [1, 32, 1, 32, 0]
260260

261-
// PAD-CONV: padding_conv = [1, 32, 1, 64, 0, 0, 32]
261+
// PAD-CONV-GFX942: padding_conv = [1, 32, 1, 64, 0, 0, 32]
262+
263+
// -----
264+
265+
#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2, d5)>
266+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
267+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
268+
module {
269+
func.func @conv_nhwc_filter_5x1_unaligned(%arg0: tensor<16x42x19x64xbf16>, %arg1: tensor<64x5x64xbf16>, %arg2: tensor<16x38x19x64xf32>) -> tensor<16x38x19x64xf32> {
270+
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x42x19x64xbf16>, tensor<64x5x64xbf16>) outs(%arg2 : tensor<16x38x19x64xf32>) {
271+
^bb0(%in: bf16, %in_0: bf16, %out: f32):
272+
%1 = arith.extf %in : bf16 to f32
273+
%2 = arith.extf %in_0 : bf16 to f32
274+
%3 = arith.mulf %1, %2 : f32
275+
%4 = arith.addf %out, %3 : f32
276+
linalg.yield %4 : f32
277+
} -> tensor<16x38x19x64xf32>
278+
return %0 : tensor<16x38x19x64xf32>
279+
}
280+
}
281+
282+
// CHECK-LABEL: func.func @conv_nhwc_filter_5x1_unaligned
283+
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
284+
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
285+
// CHECK-SAME: use_igemm_convolution = true
286+
287+
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
288+
// GFX942-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16>
289+
// GFX942-SAME: padding = [2, 2, 32, 64, 32]
290+
// GFX942-SAME: promote_operands = [0, 1, 2]
291+
// GFX942-SAME: reduction = [0, 0, 0, 0, 2]
292+
// GFX942-SAME: subgroup = [2, 2, 2, 1, 0]
293+
// GFX942-SAME: workgroup = [2, 2, 32, 64, 0]
294+
295+
// MI300X-SAME: padding = [1, 1, 32, 64, 32]
296+
// MI300X-SAME: promote_operands = [0, 1, 2]
297+
// MI300X-SAME: reduction = [0, 0, 0, 0, 2]
298+
// MI300X-SAME: subgroup = [1, 1, 2, 1, 0]
299+
// MI300X-SAME: workgroup = [1, 1, 32, 64, 0]
300+
301+
// PAD-CONV-GFX942: padding_conv = [2, 2, 32, 64, 0, 0]

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,12 +584,17 @@ getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) {
584584
auto inputMapGEMM =
585585
AffineMap::get(numParallelDims + numKDims, 0, inputDims, ctx);
586586

587-
// Prepare filter map.
587+
// Prepare filter map and add mapping for reduction dimensions.
588588
int64_t currKPos = numParallelDims;
589589
SmallVector<AffineExpr> filterDims;
590590
for (const auto &[iter, indices] :
591591
llvm::zip_equal(filterIterators, filterReassocIndices)) {
592592
if (iter == reduction) {
593+
for (int64_t reInd : indices) {
594+
int64_t convDimIdx =
595+
cast<AffineDimExpr>(filterMap.getResult(reInd)).getPosition();
596+
convToIgemmDimMap[convDimIdx] = dims[currKPos];
597+
}
593598
filterDims.push_back(dims[currKPos++]);
594599
} else {
595600
assert(iter == parallel && "expected a parallel dim");

0 commit comments

Comments
 (0)