Skip to content

Commit 5ce1e73

Browse files
committed
addlow latency kernel
Signed-off-by: jiahanc <[email protected]>
1 parent 2a2fa30 commit 5ce1e73

File tree

9 files changed

+79
-84
lines changed

9 files changed

+79
-84
lines changed

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17+
#include <cstring>
1718
#include <vector>
1819

1920
#include "flashinfer/trtllm/batched_gemm/KernelRunner.h"
@@ -100,17 +101,9 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
100101
options.mTransposeMmaOutput == mOptions.transposeMmaOutput &&
101102
(!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct &&
102103
options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch &&
103-
tileSize == mOptions.tileSize) {
104-
auto sm = configs[i].mSm;
105-
if (sm != SmVersion::Sm100f) {
106-
int smVersion = tensorrt_llm::common::getSMVersion();
107-
if (smVersion == 100 && sm != SmVersion::Sm100a) {
108-
continue;
109-
} else if (smVersion == 103 && sm != SmVersion::Sm103a) {
110-
continue;
111-
}
112-
}
113-
104+
tileSize == mOptions.tileSize &&
105+
options.mUseShuffledMatrixA == mOptions.useShuffledMatrixA &&
106+
options.mLayoutA == mOptions.weightLayout) {
114107
if (options.mFusedAct) {
115108
if (options.mActType != static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType)) {
116109
continue;
@@ -123,7 +116,14 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
123116
}
124117
}
125118

126-
FLASHINFER_CHECK(!mPassingConfigIndices.empty(), "No kernel found for the given options");
119+
FLASHINFER_CHECK(
120+
!mPassingConfigIndices.empty(),
121+
"No kernel found for the given options: mDtypeA: %s, mDtypeB: %s, mDtypeC: %s, "
122+
"mUseDeepSeekFp8: %d, "
123+
"mTransposeMmaOutput: %d, mRouteAct: %d, mFusedAct: %d, mIsStaticBatch: %d, mTileSize: %d",
124+
tg::dtypeToString(mOptions.dtypeA).c_str(), tg::dtypeToString(mOptions.dtypeB).c_str(),
125+
tg::dtypeToString(mOptions.dtypeC).c_str(), mOptions.deepSeekFp8, mOptions.transposeMmaOutput,
126+
mOptions.routeAct, mOptions.fusedAct, mOptions.staticBatch, mOptions.tileSize);
127127
}
128128

129129
size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(
@@ -169,7 +169,6 @@ void TrtllmGenBatchedGemmRunner::run(
169169
auto const configs = bmm.getBatchedGemmConfigs();
170170

171171
auto const& config = configs[configIndex];
172-
std::cout << "config function name: " << config.mFunctionName << std::endl;
173172

174173
FLASHINFER_CHECK(numBatches > 0, "Batched GEMM requires numBatches > 0");
175174
if (!mOptions.staticBatch) {
@@ -391,13 +390,6 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
391390
auto const& config = configs[configIndex];
392391
auto isValidConfig = bmm.isValidConfig(config, gemmData);
393392
if (isValidConfig) {
394-
// if (static_cast<int32_t>(config.mOptions.mLayoutA) == 0 ){
395-
// std::cout << "config.mLayoutA: " << static_cast<int32_t>(config.mOptions.mLayoutA) <<
396-
// std::endl; std::cout << "config.mLayoutB: " <<
397-
// static_cast<int32_t>(config.mOptions.mLayoutB) << std::endl; std::cout <<
398-
// "config.mFunctionName: " << config.mFunctionName << std::endl;
399-
// validConfigIndices.push_back(configIndex);
400-
// }
401393
validConfigIndices.push_back(configIndex);
402394
}
403395
}

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,12 @@ void trtllm_fp8_block_scale_moe_launcher(
393393
int32_t max_num_padded_tokens =
394394
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount(
395395
args.num_tokens, top_k, num_experts, tile_tokens_dim);
396+
int32_t max_num_padded_tokens_gemm1 =
397+
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount(
398+
max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt));
399+
int32_t max_num_padded_tokens_gemm2 =
400+
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount(
401+
max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut));
396402
Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits->device);
397403
Tensor expanded_idx_to_permuted_idx =
398404
alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits->device);
@@ -413,16 +419,16 @@ void trtllm_fp8_block_scale_moe_launcher(
413419
// dl_float8_e4m3fn, hidden_states->device);
414420
// Tensor activation_output = alloc_tensor({max_num_padded_tokens, intermediate_size},
415421
// dl_float8_e4m3fn, hidden_states->device);
416-
Tensor gemm1_output =
417-
alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, dl_uint8, hidden_states->device);
422+
Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8,
423+
hidden_states->device);
418424
Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens},
419425
dl_float32, hidden_states->device);
420-
Tensor activation_output =
421-
alloc_tensor({max_num_padded_tokens, intermediate_size}, dl_uint8, hidden_states->device);
422-
Tensor activation_output_scale = alloc_tensor({intermediate_size / 128, max_num_padded_tokens},
423-
dl_float32, hidden_states->device);
424-
Tensor gemm2_output =
425-
alloc_tensor({max_num_padded_tokens, args.hidden_size}, dl_bfloat16, hidden_states->device);
426+
Tensor activation_output = alloc_tensor({max_num_padded_tokens_gemm1, intermediate_size},
427+
dl_uint8, hidden_states->device);
428+
Tensor activation_output_scale = alloc_tensor(
429+
{intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, hidden_states->device);
430+
Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16,
431+
hidden_states->device);
426432

427433
int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim(
428434
args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim);
@@ -519,7 +525,8 @@ void trtllm_fp8_block_scale_moe_launcher(
519525

520526
// setup workspace
521527
workspace.total_num_padded_tokens = static_cast<int*>(total_num_padded_tokens->data);
522-
workspace.total_max_padded_tokens = max_num_padded_tokens;
528+
workspace.total_max_padded_tokens =
529+
std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2);
523530
workspace.ProjUpTileN = tile_tokens_dim;
524531
workspace.routing_expert_indexes = static_cast<int*>(expert_indexes->data);
525532
workspace.permuted_idx_size = static_cast<int*>(total_num_padded_tokens->data);
@@ -764,6 +771,12 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
764771
int32_t max_num_padded_tokens =
765772
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount(
766773
args.num_tokens, top_k, num_experts, tile_tokens_dim);
774+
int32_t max_num_padded_tokens_gemm1 =
775+
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount(
776+
max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt));
777+
int32_t max_num_padded_tokens_gemm2 =
778+
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount(
779+
max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut));
767780
Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, hidden_states->device);
768781
Tensor expanded_idx_to_permuted_idx =
769782
alloc_tensor({args.num_tokens, args.top_k}, dl_int32, hidden_states->device);
@@ -788,20 +801,20 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
788801
// Tensor gemm1_output = alloc_tensor(
789802
// {max_num_padded_tokens, gemm1_output_hidden},
790803
// dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_float8_e4m3fn, hidden_states->device);
791-
Tensor gemm1_output = alloc_tensor({max_num_padded_tokens, gemm1_output_hidden},
804+
Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, gemm1_output_hidden},
792805
dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8,
793806
hidden_states->device);
794807

795808
Optional<Tensor> gemm1_output_scale = std::nullopt;
796809
if (dtype_act == btg::Dtype::E2m1 || dtype_act == btg::Dtype::MxE4m3) {
797-
int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens,
810+
int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens_gemm1,
798811
intermediate_size / sf_vec_size);
799812
// gemm1_output_scale = alloc_tensor({sf_size}, dl_float8_e4m3fn, hidden_states->device);
800813
gemm1_output_scale = alloc_tensor({sf_size}, dl_uint8, hidden_states->device);
801814
}
802815

803-
Tensor gemm2_output =
804-
alloc_tensor({max_num_padded_tokens, args.hidden_size}, dl_bfloat16, hidden_states->device);
816+
Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16,
817+
hidden_states->device);
805818

806819
int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim(
807820
args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim);
@@ -958,7 +971,8 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
958971

959972
// setup workspace
960973
workspace.total_num_padded_tokens = static_cast<int*>(total_num_padded_tokens->data);
961-
workspace.total_max_padded_tokens = max_num_padded_tokens;
974+
workspace.total_max_padded_tokens =
975+
std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2);
962976
workspace.ProjUpTileN = tile_tokens_dim;
963977
workspace.routing_expert_indexes = static_cast<int*>(expert_indices->data);
964978
workspace.permuted_idx_size = static_cast<int*>(total_num_padded_tokens->data);

flashinfer/artifacts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def get_available_cubin_files(
7676
class ArtifactPath:
7777
TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen"
7878
TRTLLM_GEN_BMM: str = (
79-
"696906bd3985f84662799054f377b4b47a1907d3/batched_gemm-074aec4-3df1e6c"
79+
"0b88a3f2499b29b63fc8140cc23a1aa2945bed1b/batched_gemm-074aec4-cc00b23"
8080
)
8181
TRTLLM_GEN_GEMM: str = (
8282
"037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e"
@@ -90,7 +90,7 @@ class MetaInfoHash:
9090
TRTLLM_GEN_FMHA: str = (
9191
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
9292
)
93-
TRTLLM_GEN_BMM: str = "696906bd3985f84662799054f377b4b47a1907d3"
93+
TRTLLM_GEN_BMM: str = "0b88a3f2499b29b63fc8140cc23a1aa2945bed1b"
9494
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
9595
TRTLLM_GEN_GEMM: str = (
9696
"0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba"

flashinfer/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def calculate_tile_tokens_dim(
125125
tile_tokens_dim = 192
126126
# Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel.
127127
tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim)
128-
129128
return tile_tokens_dim
130129

131130

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,11 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
107107
epilogueLdtmBits, epilogueTileM, epilogueTileN, gridTriggerSecondaryA,
108108
gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA,
109109
gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits,
110-
layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, numRegsCastAWarps,
111-
numRegsCopySfLdsSttm, numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp,
112-
numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma,
113-
numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId,
114-
outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC,
115-
sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler,
110+
layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n,
111+
numRegsCopySfLdsSttm, numSlicesForSplitK, numSlicesForSliceK, numStages,
112+
numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile,
113+
numStagesWorkId, outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB,
114+
sfLayoutC, sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler,
116115
transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8,
117116
useHoistTryWaitForCustomMmaSchedule, usePerTokenSfA, usePerTokenSfB,
118117
useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps,
@@ -125,6 +124,9 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
125124
mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting),
126125
mIsStaticBatch(isStaticBatch),
127126
mNumBatches(numBatches),
127+
mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp),
128+
mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp),
129+
mNumRegsCastAWarps(numRegsCastAWarps),
128130
mNumTokens(numTokens),
129131
mRouteImpl(routeImpl),
130132
mRouteSfsImpl(routeSfsImpl),
@@ -145,6 +147,12 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
145147
bool mIsStaticBatch{true};
146148
// Number of Gemm batches.
147149
int mNumBatches;
150+
// Number of registers per thread for non-epilogue warps
151+
int mNumRegsPerThreadNonEpilogueWarp{0};
152+
// Number of registers per thread for epilogue warps
153+
int mNumRegsPerThreadEpilogueWarp{0};
154+
// Number of registers for the cast A warps.
155+
int mNumRegsCastAWarps{0};
148156
// Total number of tokens.
149157
int mNumTokens{32};
150158
// Whether load the input tokens and do routing.
@@ -340,7 +348,6 @@ struct BatchedGemmConfig {
340348
char const* mHash{nullptr};
341349
#else
342350
trtllm::gen::CudaRunner* mCudaRunner{nullptr};
343-
int32_t mInstanceIdx{0};
344351
#endif
345352

346353
BatchedGemmOptions mOptions;
@@ -365,6 +372,11 @@ inline std::string dumpOptions(BatchedGemmOptions const& options) {
365372
<< static_cast<int32_t>(options.mRouteSfsImpl.value()) << ")}," << std::endl;
366373
ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl;
367374
ss << "mFusedAct=" << options.mFusedAct << "," << std::endl;
375+
ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << ","
376+
<< std::endl;
377+
ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << ","
378+
<< std::endl;
379+
ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl;
368380
ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl;
369381
return ss.str();
370382
}

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,8 @@ struct GemmOptions {
102102
bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit,
103103
bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, MatrixLayout layoutA,
104104
MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN,
105-
bool mockAllReduce, int n, int numRegsCastAWarps, int numRegsCopySfLdsSttm,
106-
int numRegsPerThreadEpilogueWarp, int numRegsPerThreadNonEpilogueWarp,
107-
int numSlicesForSplitK, int numSlicesForSliceK, int numStages, int numStagesMma,
105+
bool mockAllReduce, int n, int numRegsCopySfLdsSttm, int numSlicesForSplitK,
106+
int numSlicesForSliceK, int numStages, int numStagesMma,
108107
int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId,
109108
bool outputDebugTensors, bool patchF2fp, std::optional<int32_t> sfBlockSizeA,
110109
tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC,
@@ -152,10 +151,7 @@ struct GemmOptions {
152151
mMmaN{mmaN},
153152
mMockAllReduce{mockAllReduce},
154153
mN{n},
155-
mNumRegsCastAWarps(numRegsCastAWarps),
156154
mNumRegsCopySfLdsSttm(numRegsCopySfLdsSttm),
157-
mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp),
158-
mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp),
159155
mNumSlicesForSplitK{numSlicesForSplitK},
160156
mNumSlicesForSliceK{numSlicesForSliceK},
161157
mNumStages{numStages},
@@ -273,14 +269,8 @@ struct GemmOptions {
273269
bool mMockAllReduce{false};
274270
// The N dimension of GEMM.
275271
int mN{64 * 4};
276-
// Number of registers for the cast A warps.
277-
int mNumRegsCastAWarps{0};
278272
// Number of registers for the LDS+STTM warps.
279273
int mNumRegsCopySfLdsSttm{0};
280-
// Number of registers per thread for epilogue warps
281-
int mNumRegsPerThreadEpilogueWarp{0};
282-
// Number of registers per thread for non-epilogue warps
283-
int mNumRegsPerThreadNonEpilogueWarp{0};
284274
// Number of partitions along the K dimension. When mNumSlicesForSplitK > 1,
285275
// the problem is distributed across several SMs, where each CTA works on its local K slice.
286276
// Partial results are accumulated afterwards using either GMEM or DSMEM (in CGA)
@@ -387,7 +377,6 @@ struct GemmConfig {
387377
char const* mHash{nullptr};
388378
#else
389379
trtllm::gen::CudaRunner* mCudaRunner{nullptr};
390-
int32_t mInstanceIdx{0};
391380
#endif
392381

393382
GemmOptions mOptions{};
@@ -481,12 +470,7 @@ inline std::string dumpOptions(GemmOptions const& options) {
481470
ss << "mMmaN=" << options.mMmaN << "," << std::endl;
482471
ss << "mMockAllReduce=" << options.mMockAllReduce << "," << std::endl;
483472
ss << "mN=" << options.mN << "," << std::endl;
484-
ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl;
485473
ss << "mNumRegsCopySfLdsSttm=" << options.mNumRegsCopySfLdsSttm << "," << std::endl;
486-
ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << ","
487-
<< std::endl;
488-
ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << ","
489-
<< std::endl;
490474
ss << "mNumSlicesForSplitK=" << options.mNumSlicesForSplitK << "," << std::endl;
491475
ss << "mNumSlicesForSliceK=" << options.mNumSlicesForSliceK << "," << std::endl;
492476
ss << "mNumStages=" << options.mNumStages << "," << std::endl;

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind,
156156
char const* errorString;
157157
cuGetErrorString(result, &errorString);
158158
std::stringstream ss;
159-
ss << "Error: Failed to initialize the TMA descriptor. " << errorString << std::endl;
159+
ss << "Error: Failed to initialize the TMA descriptor " << result << std::endl;
160160

161161
ss << "tmaFormat: " << static_cast<int>(tmaDataFormat) << " dim: " << dim
162162
<< " gmem: " << gmemAddr << std::endl;
@@ -195,7 +195,7 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind,
195195
// TODO: make it work with the above descriptor?
196196
inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector<uint64_t> const& shapes,
197197
std::vector<uint64_t> const& strides,
198-
std::vector<uint32_t> const& tileShapes, void* gmemAddr) {
198+
const std::vector<uint32_t>& tileShapes, void* gmemAddr) {
199199
CUtensorMap desc{};
200200
CUtensorMapDataType tmaDataFormat;
201201
if (dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::UE8m0) {
@@ -251,7 +251,7 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector<uint64_t> c
251251
char const* errorString;
252252
cuGetErrorString(result, &errorString);
253253
std::stringstream ss;
254-
ss << "Error: Failed to initialize the TMA descriptor for SF. " << errorString << std::endl;
254+
ss << "Error: Failed to initialize the TMA descriptor for SF " << errorString << std::endl;
255255

256256
ss << "tmaFormat: " << static_cast<int>(tmaDataFormat) << " dim: " << dim
257257
<< " gmem: " << gmemAddr << std::endl;

include/flashinfer/trtllm/fused_moe/runner.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ enum class RoutingMethodType : int64_t {
5252
Unspecified = 6,
5353
};
5454

55+
inline int32_t maybeGetMinTokenCount(int32_t numPaddedTokens, int32_t hiddenSize,
56+
int32_t dtypeSizeBits) {
57+
// Pad so total size exceeds 128KiB for performance reasons
58+
int32_t minNumTokensRequired = common::divUp(128 * 1024 * 8, hiddenSize * dtypeSizeBits);
59+
return std::max(numPaddedTokens, minNumTokensRequired);
60+
}
61+
5562
inline std::string serializeMoeRoutingMethodType(RoutingMethodType routingMethodType) {
5663
switch (routingMethodType) {
5764
case RoutingMethodType::Default:

0 commit comments

Comments
 (0)