24
24
#include " flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h"
25
25
#include " flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
26
26
#include " flashinfer/trtllm/common.h"
27
+ #include " tensorrt_llm/common/cudaUtils.h"
27
28
#include " tensorrt_llm/common/envUtils.h"
28
29
29
30
namespace tensorrt_llm {
@@ -306,6 +307,8 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
306
307
auto const bmm = BatchedGemmInterface ();
307
308
auto const configs = bmm.getBatchedGemmConfigs ();
308
309
310
+ int32_t multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount ();
311
+
309
312
BatchedGemmData gemmData;
310
313
// Dims
311
314
gemmData.mProblemDimensions .mNumBatches = numBatches;
@@ -322,67 +325,57 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
322
325
gemmData.mProblemDimensions .mWorldSize = 1 ;
323
326
gemmData.mProblemDimensions .mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
324
327
325
- // Tier 0: K < tileK, prefer higher efficiency.
326
- auto cmpTier0 = [&configs, &gemmData](int64_t idx0, int64_t idx1) {
328
+ auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1) {
327
329
auto const & optionsA = configs[idx0].mOptions ;
328
330
auto const & optionsB = configs[idx1].mOptions ;
329
331
int32_t sizeK = gemmData.mProblemDimensions .mK ;
330
- // Both waste computation, prefer higher efficiency.
331
- if (sizeK <= optionsA.mTileK && sizeK <= optionsB.mTileK ) {
332
- double eff_a = (double )sizeK / optionsA.mTileK ;
333
- double eff_b = (double )sizeK / optionsB.mTileK ;
334
- return eff_a > eff_b;
335
- }
336
- // If either can be utilized, sort by tileK.
337
- else {
338
- return optionsA.mTileK > optionsB.mTileK ;
332
+
333
+ // Tier 0: K < tileK, prefer higher efficiency.
334
+ if (optionsA.mTileK != optionsB.mTileK ) {
335
+ // Both waste computation, prefer higher efficiency.
336
+ if (sizeK <= optionsA.mTileK && sizeK <= optionsB.mTileK ) {
337
+ double eff_a = (double )sizeK / optionsA.mTileK ;
338
+ double eff_b = (double )sizeK / optionsB.mTileK ;
339
+ return eff_a > eff_b;
340
+ }
341
+ // If either can be utilized, sort by tileK.
342
+ else {
343
+ return optionsA.mTileK > optionsB.mTileK ;
344
+ }
339
345
}
340
- };
341
- // Tier 1: When tileK is the same, prefer unroll loop 2x for mma.
342
- auto cmpTier1 = [&configs](int64_t idx0, int64_t idx1) {
343
- auto const & optionsA = configs[idx0].mOptions ;
344
- auto const & optionsB = configs[idx1].mOptions ;
345
- if (optionsA.mTileK == optionsB.mTileK ) {
346
+
347
+ // Tier 1: When tileK is the same, prefer unroll loop 2x for mma.
348
+ if (optionsA.mUseUnrollLoop2xForMma != optionsB.mUseUnrollLoop2xForMma ) {
346
349
return optionsA.mUseUnrollLoop2xForMma ;
347
350
}
348
- return false ;
349
- };
350
- // Tier 2+: When previous comparators are the same, prefer higher tileM.
351
- auto cmpTier2 = [&configs](int64_t idx0, int64_t idx1) {
352
- auto const & optionsA = configs[idx0].mOptions ;
353
- auto const & optionsB = configs[idx1].mOptions ;
354
- if (optionsA.mTileK == optionsB.mTileK &&
355
- optionsA.mUseUnrollLoop2xForMma == optionsB.mUseUnrollLoop2xForMma ) {
351
+
352
+ // Tier 2+: When previous comparators are the same, prefer higher tileM.
353
+ if (optionsA.mTileM != optionsB.mTileM ) {
356
354
return optionsA.mTileM > optionsB.mTileM ;
357
355
}
358
- return false ;
359
- };
360
- // Tier 2+: When previous comparators are the same, and when number of estimated CTAs is on the
361
- // larger side, prefer persistent tile scheduler. The threshold is hardcoded as >148 CTAs at the
362
- // moment.
363
- auto cmpTier3 = [&configs, &gemmData](int64_t idx0, int64_t idx1) {
364
- int32_t sizeM = gemmData.mProblemDimensions .mM ;
365
- int32_t sizeN = gemmData.mProblemDimensions .mN ;
366
- auto const & optionsA = configs[idx0].mOptions ;
367
- auto const & optionsB = configs[idx1].mOptions ;
368
- if (optionsA.mTileK == optionsB.mTileK &&
369
- optionsA.mUseUnrollLoop2xForMma == optionsB.mUseUnrollLoop2xForMma &&
370
- optionsA.mTileM == optionsB.mTileM ) {
371
- int64_t numTilesM = batchedGemm::gemm::divUp (sizeM, optionsA.mTileM );
372
- int64_t numTilesN = batchedGemm::gemm::divUp (sizeN, optionsA.mTileN );
373
- if (numTilesM * numTilesN > 148 ) {
356
+
357
+ // Tier 2+: When previous comparators are the same, prefer higher tileN.
358
+ if (optionsA.mTileN != optionsB.mTileN ) {
359
+ return optionsA.mTileN > optionsB.mTileN ;
360
+ }
361
+
362
+ // Tier 2+: When previous comparators are the same, and when the number of estimated CTAs is on
363
+ // the larger side, prefer persistent tile scheduler.
364
+ if (optionsA.mTileScheduler != optionsB.mTileScheduler ) {
365
+ auto options = bmm.getOptionsFromConfigAndData (configs[idx0], gemmData);
366
+ auto numCtas = bmm.getNumCtas (options, gemmData.mProblemDimensions .mMaxNumCtasInTokenDim );
367
+ if (numCtas > multiProcessorCount) {
374
368
return optionsA.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
369
+ } else {
370
+ return optionsB.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
375
371
}
376
372
}
373
+
377
374
return false ;
378
375
};
379
-
380
376
// Sort configs by options.
381
377
std::vector<int64_t > sortedIndices = mPassingConfigIndices ;
382
- std::sort (sortedIndices.begin (), sortedIndices.end (), cmpTier0);
383
- std::sort (sortedIndices.begin (), sortedIndices.end (), cmpTier1);
384
- std::sort (sortedIndices.begin (), sortedIndices.end (), cmpTier2);
385
- std::sort (sortedIndices.begin (), sortedIndices.end (), cmpTier3);
378
+ std::sort (sortedIndices.begin (), sortedIndices.end (), cmpFunc);
386
379
387
380
// Special rules for corner cases, if applicable.
388
381
std::vector<int64_t > prioritizedIndices =
0 commit comments