1111
1212#include " llvm/ADT/APInt.h"
1313#include " llvm/ADT/Sequence.h"
14- #include " llvm/Support/Debug .h"
14+ #include " llvm/Support/DebugLog .h"
1515#include " llvm/Support/InterleavedRange.h"
1616#include " llvm/Support/MathExtras.h"
1717#include " llvm/Support/raw_ostream.h"
@@ -176,11 +176,8 @@ static FailureOr<GPUMMASchedule> fitScheduleInSharedMemory(
176176 llvm::function_ref<bool (const GPUMMASchedule &schedule)> isScheduleValid) {
177177
178178 while (!isScheduleValid (schedule)) {
179- LLVM_DEBUG ({
180- llvm::dbgs () << " Chosen schedule is invalid:\n " ;
181- llvm::dbgs () << schedule << " \n " ;
182- llvm::dbgs () << " Shrinking schedule...\n " ;
183- });
179+ LDBG () << " Chosen schedule is invalid:\n "
180+ << schedule << " \n Shrinking schedule..." ;
184181
185182 auto decrementIfPossible =
186183 [](SmallVector<int64_t > &sizes) -> LogicalResult {
@@ -218,10 +215,7 @@ static FailureOr<GPUMMASchedule> fitScheduleInSharedMemory(
218215 return failure ();
219216 }
220217
221- LLVM_DEBUG ({
222- llvm::dbgs () << " Chosen schedule is valid:\n " ;
223- llvm::dbgs () << schedule << " \n " ;
224- });
218+ LDBG () << " Chosen schedule is valid:\n " << schedule;
225219
226220 return schedule;
227221}
@@ -351,7 +345,12 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem,
351345 nSubgroupCounts (problem.nSizes .size (), 0 );
352346 // Start at the innermost nDim and mDim, and try to distribute evenly to M and
353347 // N for each pair of M and N dims. Otherwise, distribute to N and then M.
348+ LDBG () << " Starting MMA schedule distribution" ;
354349 while (mDim >= 0 || nDim >= 0 ) {
350+ LDBG () << " Current iteration: mDim: " << mDim << " , nDim: " << nDim
351+ << " , remainingSubgroups: " << remainingSubgroups
352+ << " , remainingTiles: " << remainingTiles
353+ << " , mTileSizes: " << mTileSizes << " , nTileSizes: " << nTileSizes;
355354 int64_t subgroupSqrt =
356355 1ull << (llvm::divideCeil (llvm::Log2_64 (remainingSubgroups), 2 ));
357356 int64_t tileSqrt = 1ull << (llvm::Log2_64 (remainingTiles) / 2 );
@@ -362,6 +361,7 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem,
362361 if (mDim >= 0 && nDim >= 0 &&
363362 mTotalTileCounts [mDim ] > (subgroupSqrt * tileSqrt) &&
364363 mTotalTileCounts [mDim ] % (subgroupSqrt * tileSqrt) == 0 ) {
364+ LDBG () << " Distributing evenly to M and N dimensions." ;
365365 mSubgroupCounts [mDim ] = subgroupSqrt;
366366 mTileSizes [mDim ] = tileSqrt;
367367
@@ -380,6 +380,7 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem,
380380 remainingTiles /= nTileSizes[nDim];
381381 } else {
382382 if (nDim >= 0 ) {
383+ LDBG () << " Distributing to N dimension first." ;
383384 APInt nGCD = GreatestCommonDivisor (APInt (64 , nTotalTileCounts[nDim]),
384385 APInt (64 , remainingSubgroups));
385386 nSubgroupCounts[nDim] = nGCD.getSExtValue ();
@@ -393,6 +394,7 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem,
393394 }
394395
395396 if (mDim >= 0 ) {
397+ LDBG () << " Distributing to M dimension next." ;
396398 APInt mGCD = GreatestCommonDivisor (APInt (64 , mTotalTileCounts [mDim ]),
397399 APInt (64 , remainingSubgroups));
398400 mSubgroupCounts [mDim ] = mGCD .getSExtValue ();
@@ -478,11 +480,66 @@ sortMMAIntrinsics(GPUMatmulShapeType problem,
478480 return sortedIntrinsics;
479481}
480482
483+ static int64_t adjustSeedsForWgpCount (const GPUMatmulShapeType &problem,
484+ const GPUIntrinsicType &intrinsic,
485+ std::optional<int64_t > wgpCount,
486+ int64_t bestSubgroupCountPerWorkgroup,
487+ int64_t bestMNTileCountPerSubgroup) {
488+ if (!wgpCount.has_value ()) {
489+ LDBG () << " WGP count is not available,"
490+ << " Skipping adjustment of seeds for workgroup count." ;
491+ return bestMNTileCountPerSubgroup;
492+ }
493+
494+ int64_t mSize = ShapedType::getNumElements (problem.mSizes );
495+ int64_t nSize = ShapedType::getNumElements (problem.nSizes );
496+ int64_t kSize = ShapedType::getNumElements (problem.kSizes );
497+ float arithmeticIntensity =
498+ (2 .0f * mSize * nSize * kSize ) /
499+ static_cast <float >(mSize * nSize + nSize * kSize + mSize * kSize );
500+
501+ // TODO(jerryyin): compute arithmetic intensity bound based on the information
502+ // from the target chip.
503+ if (arithmeticIntensity <= 10 .0f ) {
504+ LDBG () << " Arithmetic intensity is too low, " << arithmeticIntensity
505+ << " , skipping adjustment of seeds for workgroup count." ;
506+ return bestMNTileCountPerSubgroup;
507+ }
508+ auto computeWorkgroupCount = [&] {
509+ // Compute the number of workgroups needed to cover the problem size.
510+ // This number tends to be lower than actual workgroup count, since:
511+ // 1) It assumes tile and subgroup seeds are all allocated.
512+ // 2) It assumes shared memory usage does not exceed hardware limits.
513+ int64_t mnTileSizePerSubgroup =
514+ bestMNTileCountPerSubgroup * intrinsic.mSizes [0 ] * intrinsic.nSizes [0 ];
515+ int64_t workgroupSize =
516+ mnTileSizePerSubgroup * bestSubgroupCountPerWorkgroup;
517+ return mSize * nSize / workgroupSize;
518+ };
519+ int64_t numWorkgroups = computeWorkgroupCount ();
520+ LDBG () << " Estimated number of workgroups: " << numWorkgroups
521+ << " , WGP count: " << wgpCount;
522+
523+ while (numWorkgroups < wgpCount) {
524+ if (bestMNTileCountPerSubgroup <= 1 ) {
525+ LDBG () << " Cannot decrease tile size further, "
526+ " bestMNTileCountPerSubgroup is already 1." ;
527+ break ;
528+ }
529+ bestMNTileCountPerSubgroup /= 2 ;
530+ LDBG () << " Decreasing bestMNTileCountPerSubgroup to "
531+ << bestMNTileCountPerSubgroup;
532+ numWorkgroups = computeWorkgroupCount ();
533+ }
534+ return bestMNTileCountPerSubgroup;
535+ }
536+
481537FailureOr<GPUMMASchedule> deduceMMASchedule (
482538 const GPUMatmulShapeType &problem, ArrayRef<GPUIntrinsicType> intrinsics,
483539 const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes,
484- int64_t subgroupSize, bool transposedLhs, bool transposedRhs,
485- bool canUpcastAcc, bool mustBeAligned, bool doCPromotion) {
540+ int64_t subgroupSize, std::optional<int64_t > wgpCount, bool transposedLhs,
541+ bool transposedRhs, bool canUpcastAcc, bool mustBeAligned,
542+ bool doCPromotion) {
486543
487544 sortMMAIntrinsics (problem, intrinsics);
488545
@@ -492,12 +549,17 @@ FailureOr<GPUMMASchedule> deduceMMASchedule(
492549 continue ;
493550 }
494551
495- GPUMMASchedule schedule = getOptimalMMASchedule (problem, intrinsic, seeds);
552+ // Note: don't amend the original seeds, as deduceMMASchedule can be called
553+ // more than once in a row, and we want to keep the original seeds intact
554+ // for the next call.
555+ GPUMMAHeuristicSeeds localSeeds = seeds;
556+ localSeeds.bestMNTileCountPerSubgroup = adjustSeedsForWgpCount (
557+ problem, intrinsic, wgpCount, seeds.bestSubgroupCountPerWorkgroup ,
558+ seeds.bestMNTileCountPerSubgroup );
559+ GPUMMASchedule schedule =
560+ getOptimalMMASchedule (problem, intrinsic, localSeeds);
496561
497- LLVM_DEBUG ({
498- llvm::dbgs () << " chosen MMA schedule:\n " ;
499- llvm::dbgs () << " " << schedule << " \n " ;
500- });
562+ LDBG () << " Chosen MMA schedule:\n " << schedule;
501563
502564 auto isValidSchedule = [&](const GPUMMASchedule &schedule) -> bool {
503565 int64_t lhsBitwidth = intrinsic.aType .getIntOrFloatBitWidth ();
@@ -513,13 +575,9 @@ FailureOr<GPUMMASchedule> deduceMMASchedule(
513575 calculateResultSharedMemoryUsedInBytes (schedule, resultBitwidth);
514576 }
515577
516- LLVM_DEBUG ({
517- llvm::dbgs () << " Available Shared Memory: " ;
518- llvm::dbgs () << sharedMemLimitInBytes << " bytes\n " ;
519- llvm::dbgs () << " Predicted Shared Memory Used by Schedule: " ;
520- llvm::dbgs () << sharedMemoryUsed << " bytes\n " ;
521- });
522-
578+ LDBG () << " Available Shared Memory: " << sharedMemLimitInBytes << " bytes"
579+ << " Predicted Shared Memory Used by Schedule: " << sharedMemoryUsed
580+ << " bytes" ;
523581 return isAligned && sharedMemoryUsed <= sharedMemLimitInBytes;
524582 };
525583 return fitScheduleInSharedMemory (schedule, isValidSchedule);
@@ -680,10 +738,7 @@ FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
680738 GPUMMASchedule schedule =
681739 getOptimalAttentionPVSchedule (pvMatmul, intrinsicB, pvMatmulSeeds);
682740
683- LLVM_DEBUG ({
684- llvm::dbgs () << " chosen MMA schedule:\n " ;
685- llvm::dbgs () << " " << schedule << " \n " ;
686- });
741+ LDBG () << " Chosen MMA schedule:\n " << schedule;
687742 int64_t intrinsicAK = intrinsicA.kSizes [0 ];
688743 auto isValidSchedule = [&](const GPUMMASchedule &schedule) -> bool {
689744 // Create a mma schedule for qkMatmul in attention.
@@ -722,12 +777,9 @@ FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
722777 calculateOperandsSharedMemoryUsedInBytes (
723778 schedule, lhsBBitwidth, rhsBBitwidth);
724779
725- LLVM_DEBUG ({
726- llvm::dbgs () << " Available Shared Memory: " ;
727- llvm::dbgs () << sharedMemLimitInBytes << " bytes\n " ;
728- llvm::dbgs () << " Predicted Shared Memory Used by Schedule: " ;
729- llvm::dbgs () << sharedMemoryUsed << " bytes\n " ;
730- });
780+ LDBG () << " Available Shared Memory: " << sharedMemLimitInBytes << " bytes"
781+ << " Predicted Shared Memory Used by Schedule: " << sharedMemoryUsed
782+ << " bytes" ;
731783
732784 return isQKAligned && isPVAligned &&
733785 sharedMemoryUsed <= sharedMemLimitInBytes;
0 commit comments