Skip to content

Commit 9303360

Browse files
authored
[Codegen][GPU] Use arithmetic intensity to guide gemm size categorization - step 3 (#21826)
This PR updates the heuristic for tile and fuse pipeline such that GEMMs and Convolutions have different seeds. This PR is backed by experimental data on mi300x: - 478 convolution performance stays flat before and after: 103.7807 (before) -> 103.1491 (after) - 259 gemms performance improved by 6.3403% using geomean over individual perf uplift (from 192.2177(before) -> 186.197(after)) llama3 perf regression results will following in comments while working on code reviews. Signed-off-by: jerryyin <[email protected]>
1 parent 6633605 commit 9303360

File tree

3 files changed

+59
-35
lines changed

3 files changed

+59
-35
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ static GemmCutoff computeGemmCutoffsForAI(IREE::GPU::TargetAttr target,
240240
/// problem based on the available mma intrinsics.
241241
static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
242242
IREE::GPU::TargetAttr target, GPUMatmulShapeType problem,
243-
bool transposedLhs, bool transposedRhs, bool mustBeAligned = true,
244-
bool doCPromotion = false, bool scaled = false) {
243+
bool transposedLhs, bool transposedRhs, bool isGemm,
244+
bool mustBeAligned = true, bool doCPromotion = false, bool scaled = false) {
245245
const int64_t targetSubgroupSize = target.getPreferredSubgroupSize();
246246
SmallVector<GPUIntrinsicType> intrinsics;
247247
if (scaled) {
@@ -307,19 +307,40 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
307307
// bestMNTileCountPerSubgroup and small bestKTileCountPerSubgroup to
308308
// amortize launch/memory costs and maximize throughput.
309309
problem.gemmSize = GemmSize::LargeGemm;
310-
seeds = {/*bestSubgroupCountPerWorkgroup=*/8,
311-
/*bestMNTileCountPerSubgroup=*/8,
312-
/*bestKTileCountPerSubgroup=*/2,
313-
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 /
314-
inBitWidth};
310+
if (isGemm) {
311+
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
312+
/*bestMNTileCountPerSubgroup=*/16,
313+
/*bestKTileCountPerSubgroup=*/2,
314+
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 /
315+
inBitWidth};
316+
} else {
317+
// Favor more subgroups for convolution to help latency hiding from global
318+
// loads.
319+
seeds = {/*bestSubgroupCountPerWorkgroup=*/8,
320+
/*bestMNTileCountPerSubgroup=*/8,
321+
/*bestKTileCountPerSubgroup=*/2,
322+
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 /
323+
inBitWidth};
324+
}
315325
} else {
316326
// Choose balanced tile shapes. Empirically, medium-AI workloads can favor
317327
// either small or large tiles depending on kernel details.
318328
problem.gemmSize = GemmSize::MediumGemm;
319-
seeds = {/*bestSubgroupCountPerWorkgroup=*/8,
320-
/*bestMNTileCountPerSubgroup=*/4,
321-
/*bestKTileCountPerSubgroup=*/4,
322-
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth};
329+
if (isGemm) {
330+
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
331+
/*bestMNTileCountPerSubgroup=*/8,
332+
/*bestKTileCountPerSubgroup=*/4,
333+
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits /
334+
inBitWidth};
335+
} else {
336+
// Favor more subgroups for convolution to help latency hiding from global
337+
// loads.
338+
seeds = {/*bestSubgroupCountPerWorkgroup=*/8,
339+
/*bestMNTileCountPerSubgroup=*/4,
340+
/*bestKTileCountPerSubgroup=*/4,
341+
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits /
342+
inBitWidth};
343+
}
323344
}
324345
int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes();
325346

@@ -407,7 +428,7 @@ static FailureOr<std::pair<LoweringConfigAttr, int64_t>>
407428
getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
408429
SmallVector<int64_t> bounds, ArrayRef<AffineMap> maps,
409430
ArrayRef<Value> operands, IREE::GPU::TargetAttr target, bool useDirectLoad,
410-
bool scaled,
431+
bool isGemm, bool scaled,
411432
std::optional<DenseMap<int64_t, AffineExpr>> convToIgemmDimMap =
412433
std::nullopt,
413434
std::optional<linalg::ConvolutionDimensions> convDims = std::nullopt) {
@@ -537,7 +558,8 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
537558
bool mustBeAligned = true;
538559
bool doCPromotion = false;
539560
std::optional<GPUMMASchedule> schedule = getMmaScheduleFromProblemAndTarget(
540-
target, problem, transposedLhs, transposedRhs, /*mustBeAligned*/ true,
561+
target, problem, transposedLhs, transposedRhs, isGemm,
562+
/*mustBeAligned*/ true,
541563
/*doCPromotion*/ false, scaled);
542564

543565
// TODO (nirvedhmeshram, qedawkins): The performance with this will be bad if
@@ -549,7 +571,7 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
549571
mustBeAligned = false;
550572
doCPromotion = true;
551573
schedule = getMmaScheduleFromProblemAndTarget(
552-
target, problem, transposedLhs, transposedRhs, mustBeAligned,
574+
target, problem, transposedLhs, transposedRhs, isGemm, mustBeAligned,
553575
doCPromotion, scaled);
554576
}
555577

@@ -710,7 +732,7 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
710732
FailureOr<std::pair<LoweringConfigAttr, int64_t>> configAndWgSize =
711733
getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
712734
bounds, igemmContractionMaps, igemmOperands, target, useDirectLoad,
713-
/*scaled*/ false, convToIgemmDimMap, convDims);
735+
/*isGemm=*/false, /*scaled*/ false, convToIgemmDimMap, convDims);
714736
if (failed(configAndWgSize)) {
715737
return failure();
716738
}
@@ -756,7 +778,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
756778

757779
FailureOr<std::pair<LoweringConfigAttr, int64_t>> configAndWgSize =
758780
getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
759-
bounds, maps, operands, target, useDirectLoad, /*scaled*/ false);
781+
bounds, maps, operands, target, useDirectLoad, /*isGemm=*/true,
782+
/*scaled*/ false);
760783

761784
// TODO (muzasyed) : add generalization for scaled and nonscaled versions of
762785
// matmul lowering.
@@ -765,7 +788,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
765788
// conflicts when dealing with scaled matmuls. For now it is disabled.
766789
useDirectLoad = true;
767790
configAndWgSize = getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
768-
bounds, maps, operands, target, useDirectLoad, /*scaled*/ true);
791+
bounds, maps, operands, target, useDirectLoad, /*isGemm=*/true,
792+
/*scaled*/ true);
769793
}
770794

771795
if (failed(configAndWgSize)) {

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor
3737
}
3838

3939
// CHECK-LABEL: func.func @expanded_matmul_transpose_b
40-
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
40+
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
4141
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>
4242

4343
// Verify that the fill does not have the lowering config propagated to it.
@@ -47,7 +47,7 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor
4747
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
4848
// CHECK-SAME: promote_operands = [0, 1]
4949
// CHECK-SAME: reduction = [0, 0, 0, 0, 4]
50-
// CHECK-SAME: subgroup = [1, 2, 2, 1, 0]
50+
// CHECK-SAME: subgroup = [1, 2, 2, 2, 0]
5151
// CHECK-SAME: workgroup = [1, 2, 64, 64, 0]
5252

5353
// LATE: LLVMGPUVectorDistribute
@@ -77,14 +77,14 @@ func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4
7777
}
7878

7979
// CHECK-LABEL: func.func @multi_dim_mma_schedule
80-
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
80+
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
8181
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>
8282

8383
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
8484
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
8585
// CHECK-SAME: promote_operands = [0, 1]
8686
// CHECK-SAME: reduction = [0, 0, 0, 0, 4, 1]
87-
// CHECK-SAME: subgroup = [2, 1, 2, 1, 0, 0]
87+
// CHECK-SAME: subgroup = [2, 4, 1, 1, 0, 0]
8888
// CHECK-SAME: workgroup = [2, 4, 32, 32, 0, 0]
8989

9090
// LATE: LLVMGPUVectorDistribute
@@ -140,7 +140,7 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor<
140140
}
141141

142142
// CHECK-LABEL: func.func @mfma_matmul_1024x1024x1024
143-
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
143+
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
144144
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>
145145

146146
// Verify that the fill does not have the lowering config propagated to it.
@@ -150,8 +150,8 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor<
150150
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
151151
// CHECK-SAME: promote_operands = [0, 1]
152152
// CHECK-SAME: reduction = [0, 0, 4]
153-
// CHECK-SAME: subgroup = [2, 2, 0]
154-
// CHECK-SAME: workgroup = [128, 64, 0]
153+
// CHECK-SAME: subgroup = [2, 4, 0]
154+
// CHECK-SAME: workgroup = [64, 128, 0]
155155

156156
// LATE: LLVMGPUVectorDistribute
157157

@@ -380,12 +380,12 @@ func.func @aligned_dynamic_matmul_with_two_reduce_dim(%arg0: tensor<192x?x16xf32
380380
}
381381

382382
// CHECK-LABEL: func.func @aligned_dynamic_matmul_with_two_reduce_dim
383-
// CHECK-SAME: {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
383+
// CHECK-SAME: {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [128, 1, 1] subgroup_size = 64
384384
// CHECK: linalg.generic
385385
// CHECK-SAME: {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
386386
// CHECK-SAME: promote_operands = [0, 1]
387387
// CHECK-SAME: reduction = [0, 1, 0, 4],
388-
// CHECK-SAME: subgroup = [1, 0, 1, 0],
388+
// CHECK-SAME: subgroup = [2, 0, 1, 0],
389389
// CHECK-SAME: workgroup = [64, 0, 16, 0]}
390390

391391
// -----
@@ -433,14 +433,14 @@ func.func @unaligned_to_intrinsic_batched_matmul_tiling_check(%lhs : tensor<12x5
433433
// schedule with nTileSize of 16 while in reality it should be 8.
434434

435435
// LATE-LABEL: func.func @unaligned_to_intrinsic_batched_matmul_tiling_check
436-
// LATE-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
436+
// LATE-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [128, 1, 1] subgroup_size = 64
437437
// LATE-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
438438
// LATE: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config
439-
// LATE-SAME: padding = [1, 16, 128, 4]
439+
// LATE-SAME: padding = [1, 16, 64, 4]
440440
// LATE-SAME: promote_operands = [0, 1, 2]
441441
// LATE-SAME: reduction = [0, 0, 0, 1]
442442
// LATE-SAME: subgroup = [0, 1, 2, 0]
443-
// LATE-SAME: workgroup = [1, 16, 128, 0]
443+
// LATE-SAME: workgroup = [1, 16, 64, 0]
444444

445445
// -----
446446

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ func.func @scaled_matmul(
2525
}
2626

2727
// CHECK-LABEL: func.func @scaled_matmul
28-
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
28+
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
2929
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = true, use_igemm_convolution = false>
3030
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
3131
// CHECK-SAME: mma_kind = #iree_gpu.scaled_mma_layout<intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>
3232
// CHECK-SAME: promote_operands = [0, 1]
3333
// CHECK-SAME: reduction = [0, 0, 8, 1]
34-
// CHECK-SAME: subgroup = [2, 2, 0, 0]
35-
// CHECK-SAME: workgroup = [128, 64, 0, 0]
34+
// CHECK-SAME: subgroup = [2, 4, 0, 0]
35+
// CHECK-SAME: workgroup = [64, 128, 0, 0]
3636

3737
// -----
3838

@@ -58,14 +58,14 @@ func.func @scaled_matmul_with_batch(
5858
}
5959

6060
// CHECK-LABEL: func.func @scaled_matmul_with_batch
61-
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
61+
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
6262
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = true, use_igemm_convolution = false>
6363
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
6464
// CHECK-SAME: mma_kind = #iree_gpu.scaled_mma_layout<intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>
6565
// CHECK-SAME: promote_operands = [0, 1]
6666
// CHECK-SAME: reduction = [0, 0, 0, 8, 1]
67-
// CHECK-SAME: subgroup = [0, 2, 2, 0, 0]
68-
// CHECK-SAME: workgroup = [1, 128, 64, 0, 0]
67+
// CHECK-SAME: subgroup = [0, 2, 4, 0, 0]
68+
// CHECK-SAME: workgroup = [1, 64, 128, 0, 0]
6969

7070
// -----
7171

0 commit comments

Comments
 (0)