Skip to content

Commit da937d7

Browse files
perf: Fix the tactic sorting in TrtllmGenBatchedGemmRunner::getValidConfigIndices (#1615)
1 parent beebc60 commit da937d7

File tree

2 files changed

+43
-49
lines changed

2 files changed

+43
-49
lines changed

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 39 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h"
2525
#include "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
2626
#include "flashinfer/trtllm/common.h"
27+
#include "tensorrt_llm/common/cudaUtils.h"
2728
#include "tensorrt_llm/common/envUtils.h"
2829

2930
namespace tensorrt_llm {
@@ -306,6 +307,8 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
306307
auto const bmm = BatchedGemmInterface();
307308
auto const configs = bmm.getBatchedGemmConfigs();
308309

310+
int32_t multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
311+
309312
BatchedGemmData gemmData;
310313
// Dims
311314
gemmData.mProblemDimensions.mNumBatches = numBatches;
@@ -322,67 +325,57 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
322325
gemmData.mProblemDimensions.mWorldSize = 1;
323326
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
324327

325-
// Tier 0: K < tileK, prefer higher efficiency.
326-
auto cmpTier0 = [&configs, &gemmData](int64_t idx0, int64_t idx1) {
328+
auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1) {
327329
auto const& optionsA = configs[idx0].mOptions;
328330
auto const& optionsB = configs[idx1].mOptions;
329331
int32_t sizeK = gemmData.mProblemDimensions.mK;
330-
// Both waste computation, prefer higher efficiency.
331-
if (sizeK <= optionsA.mTileK && sizeK <= optionsB.mTileK) {
332-
double eff_a = (double)sizeK / optionsA.mTileK;
333-
double eff_b = (double)sizeK / optionsB.mTileK;
334-
return eff_a > eff_b;
335-
}
336-
// If either can be utilized, sort by tileK.
337-
else {
338-
return optionsA.mTileK > optionsB.mTileK;
332+
333+
// Tier 0: K < tileK, prefer higher efficiency.
334+
if (optionsA.mTileK != optionsB.mTileK) {
335+
// Both waste computation, prefer higher efficiency.
336+
if (sizeK <= optionsA.mTileK && sizeK <= optionsB.mTileK) {
337+
double eff_a = (double)sizeK / optionsA.mTileK;
338+
double eff_b = (double)sizeK / optionsB.mTileK;
339+
return eff_a > eff_b;
340+
}
341+
// If either can be utilized, sort by tileK.
342+
else {
343+
return optionsA.mTileK > optionsB.mTileK;
344+
}
339345
}
340-
};
341-
// Tier 1: When tileK is the same, prefer unroll loop 2x for mma.
342-
auto cmpTier1 = [&configs](int64_t idx0, int64_t idx1) {
343-
auto const& optionsA = configs[idx0].mOptions;
344-
auto const& optionsB = configs[idx1].mOptions;
345-
if (optionsA.mTileK == optionsB.mTileK) {
346+
347+
// Tier 1: When tileK is the same, prefer unroll loop 2x for mma.
348+
if (optionsA.mUseUnrollLoop2xForMma != optionsB.mUseUnrollLoop2xForMma) {
346349
return optionsA.mUseUnrollLoop2xForMma;
347350
}
348-
return false;
349-
};
350-
// Tier 2+: When previous comparators are the same, prefer higher tileM.
351-
auto cmpTier2 = [&configs](int64_t idx0, int64_t idx1) {
352-
auto const& optionsA = configs[idx0].mOptions;
353-
auto const& optionsB = configs[idx1].mOptions;
354-
if (optionsA.mTileK == optionsB.mTileK &&
355-
optionsA.mUseUnrollLoop2xForMma == optionsB.mUseUnrollLoop2xForMma) {
351+
352+
// Tier 2+: When previous comparators are the same, prefer higher tileM.
353+
if (optionsA.mTileM != optionsB.mTileM) {
356354
return optionsA.mTileM > optionsB.mTileM;
357355
}
358-
return false;
359-
};
360-
// Tier 2+: When previous comparators are the same, and when number of estimated CTAs is on the
361-
// larger side, prefer persistent tile scheduler. The threshold is hardcoded as >148 CTAs at the
362-
// moment.
363-
auto cmpTier3 = [&configs, &gemmData](int64_t idx0, int64_t idx1) {
364-
int32_t sizeM = gemmData.mProblemDimensions.mM;
365-
int32_t sizeN = gemmData.mProblemDimensions.mN;
366-
auto const& optionsA = configs[idx0].mOptions;
367-
auto const& optionsB = configs[idx1].mOptions;
368-
if (optionsA.mTileK == optionsB.mTileK &&
369-
optionsA.mUseUnrollLoop2xForMma == optionsB.mUseUnrollLoop2xForMma &&
370-
optionsA.mTileM == optionsB.mTileM) {
371-
int64_t numTilesM = batchedGemm::gemm::divUp(sizeM, optionsA.mTileM);
372-
int64_t numTilesN = batchedGemm::gemm::divUp(sizeN, optionsA.mTileN);
373-
if (numTilesM * numTilesN > 148) {
356+
357+
// Tier 2+: When previous comparators are the same, prefer higher tileN.
358+
if (optionsA.mTileN != optionsB.mTileN) {
359+
return optionsA.mTileN > optionsB.mTileN;
360+
}
361+
362+
// Tier 2+: When previous comparators are the same, and when the number of estimated CTAs is on
363+
// the larger side, prefer persistent tile scheduler.
364+
if (optionsA.mTileScheduler != optionsB.mTileScheduler) {
365+
auto options = bmm.getOptionsFromConfigAndData(configs[idx0], gemmData);
366+
auto numCtas = bmm.getNumCtas(options, gemmData.mProblemDimensions.mMaxNumCtasInTokenDim);
367+
if (numCtas > multiProcessorCount) {
374368
return optionsA.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
369+
} else {
370+
return optionsB.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
375371
}
376372
}
373+
377374
return false;
378375
};
379-
380376
// Sort configs by options.
381377
std::vector<int64_t> sortedIndices = mPassingConfigIndices;
382-
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpTier0);
383-
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpTier1);
384-
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpTier2);
385-
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpTier3);
378+
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpFunc);
386379

387380
// Special rules for corner cases, if applicable.
388381
std::vector<int64_t> prioritizedIndices =

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -522,13 +522,14 @@ class BatchedGemmInterface {
522522
// Returns true if the configuration of the cubin can be executed for the given params.
523523
bool isValidConfig(BatchedGemmConfig const& config, BatchedGemmData const& data) const;
524524

525+
// Creates GemmOptions from kernel and data.
526+
BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config,
527+
BatchedGemmData const& data) const;
528+
525529
private:
526530
// Aligns the pointer to the alignment
527531
template <typename Dtype>
528532
inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const;
529-
// Creates GemmOptions from kernel and data.
530-
BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config,
531-
BatchedGemmData const& data) const;
532533

533534
// Returns the size of the workspace buffers in bytes
534535
std::vector<size_t> getWorkspaceSizesInBytes(BatchedGemmConfig const& config,

0 commit comments

Comments
 (0)