Skip to content

Commit f322e2b

Browse files
authored
[Codegen][GPU] Adding heuristic strategy to reduce tile size to fill workloads to all CUs (#21546)
This PR implements #21506, strategy 3: workgroup-cap. This strategy has been tested on the combination of: - MI300x bf16, 478 convolutions, improving geo mean perf by 16.4%, and 33% of the configs. - MI300x int8, SDXL, improving heuristic perf by 1%. - MI308x int8, SDXL, marginal improvement with less CU available. Note that CU count is only available when user supply the chip. For mi300x with iree-opt use `--iree-gpu-test-target=mi300x@hip`. For iree-compile use `--iree-hip-target=mi300x` flag instead of gfx942 to see heuristic perf impact. --------- Signed-off-by: jerryyin <[email protected]>
1 parent 4492dda commit f322e2b

File tree

5 files changed

+192
-81
lines changed

5 files changed

+192
-81
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp

Lines changed: 86 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
#include "llvm/ADT/APInt.h"
1313
#include "llvm/ADT/Sequence.h"
14-
#include "llvm/Support/Debug.h"
14+
#include "llvm/Support/DebugLog.h"
1515
#include "llvm/Support/InterleavedRange.h"
1616
#include "llvm/Support/MathExtras.h"
1717
#include "llvm/Support/raw_ostream.h"
@@ -176,11 +176,8 @@ static FailureOr<GPUMMASchedule> fitScheduleInSharedMemory(
176176
llvm::function_ref<bool(const GPUMMASchedule &schedule)> isScheduleValid) {
177177

178178
while (!isScheduleValid(schedule)) {
179-
LLVM_DEBUG({
180-
llvm::dbgs() << "Chosen schedule is invalid:\n";
181-
llvm::dbgs() << schedule << "\n";
182-
llvm::dbgs() << "Shrinking schedule...\n";
183-
});
179+
LDBG() << "Chosen schedule is invalid:\n"
180+
<< schedule << "\nShrinking schedule...";
184181

185182
auto decrementIfPossible =
186183
[](SmallVector<int64_t> &sizes) -> LogicalResult {
@@ -218,10 +215,7 @@ static FailureOr<GPUMMASchedule> fitScheduleInSharedMemory(
218215
return failure();
219216
}
220217

221-
LLVM_DEBUG({
222-
llvm::dbgs() << "Chosen schedule is valid:\n";
223-
llvm::dbgs() << schedule << "\n";
224-
});
218+
LDBG() << "Chosen schedule is valid:\n" << schedule;
225219

226220
return schedule;
227221
}
@@ -351,7 +345,12 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem,
351345
nSubgroupCounts(problem.nSizes.size(), 0);
352346
// Start at the innermost nDim and mDim, and try to distribute evenly to M and
353347
// N for each pair of M and N dims. Otherwise, distribute to N and then M.
348+
LDBG() << "Starting MMA schedule distribution";
354349
while (mDim >= 0 || nDim >= 0) {
350+
LDBG() << "Current iteration: mDim: " << mDim << ", nDim: " << nDim
351+
<< ", remainingSubgroups: " << remainingSubgroups
352+
<< ", remainingTiles: " << remainingTiles
353+
<< ", mTileSizes: " << mTileSizes << ", nTileSizes: " << nTileSizes;
355354
int64_t subgroupSqrt =
356355
1ull << (llvm::divideCeil(llvm::Log2_64(remainingSubgroups), 2));
357356
int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2);
@@ -362,6 +361,7 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem,
362361
if (mDim >= 0 && nDim >= 0 &&
363362
mTotalTileCounts[mDim] > (subgroupSqrt * tileSqrt) &&
364363
mTotalTileCounts[mDim] % (subgroupSqrt * tileSqrt) == 0) {
364+
LDBG() << "Distributing evenly to M and N dimensions.";
365365
mSubgroupCounts[mDim] = subgroupSqrt;
366366
mTileSizes[mDim] = tileSqrt;
367367

@@ -380,6 +380,7 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem,
380380
remainingTiles /= nTileSizes[nDim];
381381
} else {
382382
if (nDim >= 0) {
383+
LDBG() << "Distributing to N dimension first.";
383384
APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]),
384385
APInt(64, remainingSubgroups));
385386
nSubgroupCounts[nDim] = nGCD.getSExtValue();
@@ -393,6 +394,7 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem,
393394
}
394395

395396
if (mDim >= 0) {
397+
LDBG() << "Distributing to M dimension next.";
396398
APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]),
397399
APInt(64, remainingSubgroups));
398400
mSubgroupCounts[mDim] = mGCD.getSExtValue();
@@ -478,11 +480,66 @@ sortMMAIntrinsics(GPUMatmulShapeType problem,
478480
return sortedIntrinsics;
479481
}
480482

483+
static int64_t adjustSeedsForWgpCount(const GPUMatmulShapeType &problem,
484+
const GPUIntrinsicType &intrinsic,
485+
std::optional<int64_t> wgpCount,
486+
int64_t bestSubgroupCountPerWorkgroup,
487+
int64_t bestMNTileCountPerSubgroup) {
488+
if (!wgpCount.has_value()) {
489+
LDBG() << "WGP count is not available,"
490+
<< "Skipping adjustment of seeds for workgroup count.";
491+
return bestMNTileCountPerSubgroup;
492+
}
493+
494+
int64_t mSize = ShapedType::getNumElements(problem.mSizes);
495+
int64_t nSize = ShapedType::getNumElements(problem.nSizes);
496+
int64_t kSize = ShapedType::getNumElements(problem.kSizes);
497+
float arithmeticIntensity =
498+
(2.0f * mSize * nSize * kSize) /
499+
static_cast<float>(mSize * nSize + nSize * kSize + mSize * kSize);
500+
501+
// TODO(jerryyin): compute arithmetic intensity bound based on the information
502+
// from the target chip.
503+
if (arithmeticIntensity <= 10.0f) {
504+
LDBG() << "Arithmetic intensity is too low, " << arithmeticIntensity
505+
<< ", skipping adjustment of seeds for workgroup count.";
506+
return bestMNTileCountPerSubgroup;
507+
}
508+
auto computeWorkgroupCount = [&] {
509+
// Compute the number of workgroups needed to cover the problem size.
510+
// This number tends to be lower than actual workgroup count, since:
511+
// 1) It assumes tile and subgroup seeds are all allocated.
512+
// 2) It assumes shared memory usage does not exceed hardware limits.
513+
int64_t mnTileSizePerSubgroup =
514+
bestMNTileCountPerSubgroup * intrinsic.mSizes[0] * intrinsic.nSizes[0];
515+
int64_t workgroupSize =
516+
mnTileSizePerSubgroup * bestSubgroupCountPerWorkgroup;
517+
return mSize * nSize / workgroupSize;
518+
};
519+
int64_t numWorkgroups = computeWorkgroupCount();
520+
LDBG() << "Estimated number of workgroups: " << numWorkgroups
521+
<< ", WGP count: " << wgpCount;
522+
523+
while (numWorkgroups < wgpCount) {
524+
if (bestMNTileCountPerSubgroup <= 1) {
525+
LDBG() << "Cannot decrease tile size further, "
526+
"bestMNTileCountPerSubgroup is already 1.";
527+
break;
528+
}
529+
bestMNTileCountPerSubgroup /= 2;
530+
LDBG() << "Decreasing bestMNTileCountPerSubgroup to "
531+
<< bestMNTileCountPerSubgroup;
532+
numWorkgroups = computeWorkgroupCount();
533+
}
534+
return bestMNTileCountPerSubgroup;
535+
}
536+
481537
FailureOr<GPUMMASchedule> deduceMMASchedule(
482538
const GPUMatmulShapeType &problem, ArrayRef<GPUIntrinsicType> intrinsics,
483539
const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes,
484-
int64_t subgroupSize, bool transposedLhs, bool transposedRhs,
485-
bool canUpcastAcc, bool mustBeAligned, bool doCPromotion) {
540+
int64_t subgroupSize, std::optional<int64_t> wgpCount, bool transposedLhs,
541+
bool transposedRhs, bool canUpcastAcc, bool mustBeAligned,
542+
bool doCPromotion) {
486543

487544
sortMMAIntrinsics(problem, intrinsics);
488545

@@ -492,12 +549,17 @@ FailureOr<GPUMMASchedule> deduceMMASchedule(
492549
continue;
493550
}
494551

495-
GPUMMASchedule schedule = getOptimalMMASchedule(problem, intrinsic, seeds);
552+
// Note: don't amend the original seeds, as deduceMMASchedule can be called
553+
// more than once in a row, and we want to keep the original seeds intact
554+
// for the next call.
555+
GPUMMAHeuristicSeeds localSeeds = seeds;
556+
localSeeds.bestMNTileCountPerSubgroup = adjustSeedsForWgpCount(
557+
problem, intrinsic, wgpCount, seeds.bestSubgroupCountPerWorkgroup,
558+
seeds.bestMNTileCountPerSubgroup);
559+
GPUMMASchedule schedule =
560+
getOptimalMMASchedule(problem, intrinsic, localSeeds);
496561

497-
LLVM_DEBUG({
498-
llvm::dbgs() << "chosen MMA schedule:\n";
499-
llvm::dbgs() << " " << schedule << "\n";
500-
});
562+
LDBG() << "Chosen MMA schedule:\n" << schedule;
501563

502564
auto isValidSchedule = [&](const GPUMMASchedule &schedule) -> bool {
503565
int64_t lhsBitwidth = intrinsic.aType.getIntOrFloatBitWidth();
@@ -513,13 +575,9 @@ FailureOr<GPUMMASchedule> deduceMMASchedule(
513575
calculateResultSharedMemoryUsedInBytes(schedule, resultBitwidth);
514576
}
515577

516-
LLVM_DEBUG({
517-
llvm::dbgs() << "Available Shared Memory: ";
518-
llvm::dbgs() << sharedMemLimitInBytes << " bytes\n";
519-
llvm::dbgs() << "Predicted Shared Memory Used by Schedule: ";
520-
llvm::dbgs() << sharedMemoryUsed << " bytes\n";
521-
});
522-
578+
LDBG() << "Available Shared Memory: " << sharedMemLimitInBytes << " bytes"
579+
<< "Predicted Shared Memory Used by Schedule: " << sharedMemoryUsed
580+
<< " bytes";
523581
return isAligned && sharedMemoryUsed <= sharedMemLimitInBytes;
524582
};
525583
return fitScheduleInSharedMemory(schedule, isValidSchedule);
@@ -680,10 +738,7 @@ FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
680738
GPUMMASchedule schedule =
681739
getOptimalAttentionPVSchedule(pvMatmul, intrinsicB, pvMatmulSeeds);
682740

683-
LLVM_DEBUG({
684-
llvm::dbgs() << "chosen MMA schedule:\n";
685-
llvm::dbgs() << " " << schedule << "\n";
686-
});
741+
LDBG() << "Chosen MMA schedule:\n" << schedule;
687742
int64_t intrinsicAK = intrinsicA.kSizes[0];
688743
auto isValidSchedule = [&](const GPUMMASchedule &schedule) -> bool {
689744
// Create a mma schedule for qkMatmul in attention.
@@ -722,12 +777,9 @@ FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
722777
calculateOperandsSharedMemoryUsedInBytes(
723778
schedule, lhsBBitwidth, rhsBBitwidth);
724779

725-
LLVM_DEBUG({
726-
llvm::dbgs() << "Available Shared Memory: ";
727-
llvm::dbgs() << sharedMemLimitInBytes << " bytes\n";
728-
llvm::dbgs() << "Predicted Shared Memory Used by Schedule: ";
729-
llvm::dbgs() << sharedMemoryUsed << " bytes\n";
730-
});
780+
LDBG() << "Available Shared Memory: " << sharedMemLimitInBytes << " bytes"
781+
<< "Predicted Shared Memory Used by Schedule: " << sharedMemoryUsed
782+
<< " bytes";
731783

732784
return isQKAligned && isPVAligned &&
733785
sharedMemoryUsed <= sharedMemLimitInBytes;

compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,14 @@ struct GPUMMASchedule {
100100

101101
/// Returns a schedule for using one of the given MMA |intrinsics| to target the
102102
/// input |problem|. Returns std::nullopt if we cannot find such a schedule.
103-
FailureOr<GPUMMASchedule> deduceMMASchedule(
104-
const GPUMatmulShapeType &problem, ArrayRef<GPUIntrinsicType> intrinsics,
105-
const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes,
106-
int64_t subgroupSize, bool transposedLhs = false,
107-
bool transposedRhs = false, bool canUpcastAcc = false,
108-
bool mustBeAligned = true, bool doCPromotion = false);
103+
FailureOr<GPUMMASchedule>
104+
deduceMMASchedule(const GPUMatmulShapeType &problem,
105+
ArrayRef<GPUIntrinsicType> intrinsics,
106+
const GPUMMAHeuristicSeeds &seeds,
107+
int64_t sharedMemLimitInBytes, int64_t subgroupSize,
108+
std::optional<int64_t> cuCount, bool transposedLhs = false,
109+
bool transposedRhs = false, bool canUpcastAcc = false,
110+
bool mustBeAligned = true, bool doCPromotion = false);
109111

110112
/// Returns a schedule for the pvMatmul in attention using one of the given MMA
111113
/// |intrinsics| to target the given attention matmul problems, |qkMatmul|

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,15 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
191191
}
192192
int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes();
193193

194+
std::optional<int64_t> wgpCount = std::nullopt;
195+
if (TargetChipAttr chip = target.getChip()) {
196+
wgpCount = chip.getWgpCount();
197+
}
198+
194199
// First try to find a schedule with an exactly matching intrinsic.
195200
std::optional<GPUMMASchedule> schedule = deduceMMASchedule(
196201
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
197-
transposedLhs, transposedRhs, /*canUpcastAcc=*/false,
202+
wgpCount, transposedLhs, transposedRhs, /*canUpcastAcc=*/false,
198203
/*mustBeAligned*/ mustBeAligned, doCPromotion);
199204
return schedule;
200205
}

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -995,15 +995,21 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target,
995995

996996
int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes();
997997

998+
std::optional<int64_t> wgpCount = std::nullopt;
999+
if (IREE::GPU::TargetChipAttr chip = target.getChip()) {
1000+
wgpCount = chip.getWgpCount();
1001+
}
9981002
// First try to find a schedule with an exactly matching intrinsic.
999-
FailureOr<GPUMMASchedule> schedule = deduceMMASchedule(
1000-
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize);
1003+
FailureOr<GPUMMASchedule> schedule =
1004+
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes,
1005+
targetSubgroupSize, wgpCount);
10011006
if (failed(schedule)) {
10021007
// Then try again by allowing upcasting accumulator.
1003-
schedule = deduceMMASchedule(
1004-
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
1005-
/*transposedLhs*/ false, /*transposedRhs*/ false,
1006-
/*canUpcastAcc=*/true);
1008+
schedule =
1009+
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes,
1010+
targetSubgroupSize, wgpCount,
1011+
/*transposedLhs*/ false, /*transposedRhs*/ false,
1012+
/*canUpcastAcc=*/true);
10071013
}
10081014
if (failed(schedule)) {
10091015
return failure();
@@ -1240,15 +1246,21 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
12401246
nDim !=
12411247
llvm::cast<AffineDimExpr>(maps[1].getResults().back()).getPosition();
12421248

1249+
std::optional<int64_t> wgpCount = std::nullopt;
1250+
if (IREE::GPU::TargetChipAttr chip = target.getChip()) {
1251+
wgpCount = chip.getWgpCount();
1252+
}
1253+
12431254
// First try to find a schedule with an exactly matching intrinsic.
1244-
std::optional<GPUMMASchedule> schedule = deduceMMASchedule(
1245-
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize);
1255+
std::optional<GPUMMASchedule> schedule =
1256+
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes,
1257+
targetSubgroupSize, wgpCount);
12461258
if (!schedule) {
12471259
// Then try again by allowing upcasting accumulator.
1248-
schedule =
1249-
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes,
1250-
targetSubgroupSize, transposedLhs, transposedRhs,
1251-
/*canUpcastAcc=*/true);
1260+
schedule = deduceMMASchedule(problem, intrinsics, seeds,
1261+
maxSharedMemoryBytes, targetSubgroupSize,
1262+
wgpCount, transposedLhs, transposedRhs,
1263+
/*canUpcastAcc=*/true);
12521264
}
12531265

12541266
if (!schedule) {

0 commit comments

Comments
 (0)