Skip to content

Commit 8a3b870

Browse files
authored
[None][feat] Update TRTLLM MoE MxFP4 cubins; autotune tileN (#8156)
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
1 parent 15de45d commit 8a3b870

File tree

1,240 files changed

+8406
-10379
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,240 files changed

+8406
-10379
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,6 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
144144
}
145145
}
146146

147-
// FIXME: Disable split-k for now.
148-
if (options.mClusterDimZ != 1)
149-
{
150-
continue;
151-
}
152-
153147
if (options.mFusedAct)
154148
{
155149
if (options.mActType != static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType))
@@ -158,14 +152,29 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
158152
}
159153
}
160154

155+
// FIXME: Disables a few static scheduler kernels (schedS) that appears to have issues;
156+
// found after commit e257cb3533; still under investigation. Offending kernels:
157+
// bmm_E2m1_E2m1E2m1_Fp32_t128x64x256_s6_et128x64_m128x64x64_cga1x1x1_16dp256b_TN_transOut_schedS_bN_ldgsts_tmaOpt_clmp_swiGlu_dynBatch_sm100a
158+
// bmm_MxE4m3_MxE2m1MxE4m3_Fp32_t128x64x256_s3_et128x64_m128x64x32_cga1x1x1_16dp256b_TN_transOut_schedS_biasM_bN_ldgsts_tmaOpt_clmp_swiGlu_dynBatch_sm100f
159+
if (options.mTileScheduler == TileScheduler::Static && options.mUseTmaOobOpt == true
160+
&& options.mTileN == 64)
161+
{
162+
continue;
163+
}
164+
161165
if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM)
162166
{
163167
mPassingConfigIndices.push_back(i);
164168
}
165169
}
166170
}
167171

168-
TLLM_CHECK_WITH_INFO(!mPassingConfigIndices.empty(), "No kernel found for the given options");
172+
TLLM_CHECK_WITH_INFO(!mPassingConfigIndices.empty(),
173+
"No kernel found for the given options: mDtypeA: %s, mDtypeB: %s, mDtypeC: %s, mUseDeepSeekFp8: %d, "
174+
"mTransposeMmaOutput: %d, mRouteAct: %d, mFusedAct: %d, mIsStaticBatch: %d, mTileSize: %d",
175+
tg::dtypeToString(mOptions.dtypeA).c_str(), tg::dtypeToString(mOptions.dtypeB).c_str(),
176+
tg::dtypeToString(mOptions.dtypeC).c_str(), mOptions.deepSeekFp8, mOptions.transposeMmaOutput,
177+
mOptions.routeAct, mOptions.fusedAct, mOptions.staticBatch, mOptions.tileSize);
169178
}
170179

171180
size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
@@ -277,7 +286,8 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
277286
auto envVarVal = std::getenv("TLLM_BATCHED_GEMM_PRINT_NAME");
278287
if (envVarVal && std::atoi(envVarVal) == 1)
279288
{
280-
TLLM_LOG_INFO("numBatches %d Gemm %d %d %d Kernel %s\n", numBatches, m, n, k, config.mFunctionName);
289+
TLLM_LOG_INFO("NumBatches %d, MaxNumCtasInBatchDim %d, ShapeMNK %d %d %d, Kernel %s", numBatches,
290+
maxNumCtasInBatchDim, m, n, k, config.mFunctionName);
281291
}
282292
// FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
283293
bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,12 @@ class TrtllmGenBatchedGemmRunner
7676
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
7777
void* workspace, CUstream stream, int device, int32_t configIndex);
7878

79-
// NVFP4 per-block scaling GEMM
79+
// Block-scaling GEMM
8080
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
8181
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
8282
int32_t configIndex);
8383

84+
// Block-scaling GEMM with SwiGLU activation
8485
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
8586
void const* b, void const* sfB, float const* bias, float const* swiGluAlpha, float const* swiGluBeta,
8687
float const* clampLimit, void* c, void* outSfC, void* workspace, CUstream stream, int device,

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,9 @@ class BatchedGemmInterface
530530
return std::make_tuple(numCtasBatch, numCtasTile, numCtasInner);
531531
}
532532

533+
// Creates GemmOptions from kernel and data.
534+
BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, BatchedGemmData const& data) const;
535+
533536
// Returns the number of CTAs of the current kernel.
534537
int32_t getNumCtas(
535538
BatchedGemmOptions const& options, std::optional<int32_t> maxNumCtasInBatchDim = std::nullopt) const
@@ -541,9 +544,6 @@ class BatchedGemmInterface
541544
// Returns true if the configuration of the cubin can be executed for the given params.
542545
bool isValidConfig(BatchedGemmConfig const& config, BatchedGemmData const& data) const;
543546

544-
// Creates GemmOptions from kernel and data.
545-
BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, BatchedGemmData const& data) const;
546-
547547
private:
548548
// Aligns the pointer to the alignment
549549
template <typename Dtype>

0 commit comments

Comments
 (0)