Skip to content

Commit c721fb7

Browse files
authored
bugfix: partially fix tests/test_trtllm_gen_fused_moe.py unit test failure (#1724)
<!-- .github/pull_request_template.md --> ## 📌 Description To fix the `Routing kernel expects top K 6 to be <= #warps 2` error happens with `tests/test_trtllm_gen_fused_moe.py` on GB200 and GB300 Moreover, revert number of experts template in [PR 1696](#1696) ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 3475e4b commit c721fb7

File tree

5 files changed

+113
-199
lines changed

5 files changed

+113
-199
lines changed

csrc/trtllm_fused_moe_routing_deepseek.cu

Lines changed: 42 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ namespace routingDeepSeek {
2222

2323
////////////////////////////////////////////////////////////////////////////////////////////////////
2424

25+
static constexpr int NumThreads = 384;
26+
static constexpr int NumWarps = NumThreads / WarpSize;
2527
static constexpr int NumTopGroupScores = 2;
2628
static constexpr int MaxNumTopExperts = 8;
29+
static constexpr int MaxNumTopGroupsDefault = 16;
2730

28-
__host__ __device__ int constexpr getMaxNumTopGroups(const bool useGroups, const int numExperts) {
31+
__host__ __device__ int getMaxNumTopGroups(const bool useGroups, const int numExperts) {
2932
if (useGroups || numExperts <= 256) {
3033
return 4;
3134
} else {
@@ -38,10 +41,8 @@ __global__ void routingMainKernel(KernelParams params) {
3841
// declare types
3942
using OutputT = typename KernelParams::OutputT;
4043
using InputT = typename KernelParams::InputT;
41-
static constexpr int NumThreads = KernelParams::NumExperts; // DeepSeek uses 1 thread per expert
4244
static constexpr int NumWarps = NumThreads / WarpSize;
43-
constexpr int MaxNumTopGroups =
44-
getMaxNumTopGroups(KernelParams::UseGroups, KernelParams::NumExperts);
45+
int MaxNumTopGroups = getMaxNumTopGroups(KernelParams::UseGroups, params.mNumExperts);
4546

4647
// declare shared memory structure
4748
// number of experts is bounded by number of threads
@@ -71,19 +72,19 @@ __global__ void routingMainKernel(KernelParams params) {
7172

7273
// load bias already; each warp represents one expert group
7374
auto threadExpert = threadIdx.x;
74-
bool expertSelected = threadExpert < KernelParams::NumExperts;
75+
bool expertSelected = threadExpert < params.mNumExperts;
7576
if constexpr (KernelParams::UseGroups) {
7677
threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx;
7778
expertSelected = laneIdx < params.mNumExpertsPerGroup;
7879
}
79-
auto scoreIdx = int64_t{blockIdx.x} * int64_t{KernelParams::NumExperts} + threadExpert;
80+
auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert;
8081
auto biasVal = expertSelected ? params.mPtrRoutingBias[threadExpert] : invalidScore;
8182

8283
// initialize the mPtrExpertCounts
8384
if (params.mPtrExpertCounts) {
8485
int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x;
8586
int32_t globalThreadStride = gridDim.x * NumThreads;
86-
int32_t expertCountsNum = 2 * KernelParams::NumExperts;
87+
int32_t expertCountsNum = 2 * params.mNumExperts;
8788
initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0);
8889
}
8990

@@ -118,10 +119,10 @@ __global__ void routingMainKernel(KernelParams params) {
118119
// registers for top group score reduction
119120
float topExpGroupScores[NumTopGroupScores];
120121
[[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores];
121-
float topGroups[MaxNumTopGroups]; // bound of params.mNumLimitedGroups
122-
int32_t topGroupIdx[MaxNumTopGroups];
123-
float expertScoreGroup[MaxNumTopGroups];
124-
int32_t expertIdxGroup[MaxNumTopGroups];
122+
float topGroups[MaxNumTopGroupsDefault]; // bound of params.mNumLimitedGroups
123+
int32_t topGroupIdx[MaxNumTopGroupsDefault];
124+
float expertScoreGroup[MaxNumTopGroupsDefault];
125+
int32_t expertIdxGroup[MaxNumTopGroupsDefault];
125126
float topScores[MaxNumTopExperts]; // bound of params.mTopK
126127
int32_t topExperts[MaxNumTopExperts];
127128

@@ -168,15 +169,13 @@ __global__ void routingMainKernel(KernelParams params) {
168169
} else {
169170
// without groups, each thread just takes `MaxNumTopGroups` experts
170171

171-
#pragma unroll
172172
for (int ii = 0; ii < MaxNumTopGroups; ++ii) {
173173
auto expertIdx = ii * WarpSize + laneIdx;
174174
expertIdxGroup[ii] = expertIdx;
175175
expertScoreGroup[ii] =
176-
expertIdx < KernelParams::NumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat;
176+
expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat;
177177
}
178178
}
179-
180179
topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup,
181180
/* minValue */ invalidScoreFloat, params.mTopK);
182181

@@ -214,10 +213,8 @@ __global__ void routingMainKernel(KernelParams params) {
214213
template <typename KernelParams>
215214
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
216215
__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1)
217-
__launch_bounds__(KernelParams::NumExperts) routingIndicesClusterKernel(KernelParams params) {
216+
routingIndicesClusterKernel(KernelParams params) {
218217
using OutputT = typename KernelParams::OutputT;
219-
static constexpr int NumThreads = KernelParams::NumExperts; // DeepSeek uses 1 thread per expert
220-
static constexpr int NumWarps = NumThreads / WarpSize;
221218

222219
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
223220
int32_t const clusterBlockRank = blockIdx.x;
@@ -241,9 +238,7 @@ __global__ void routingIndicesClusterKernel(KernelParams params) {
241238

242239
template <typename KernelParams>
243240
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
244-
__global__ void __launch_bounds__(KernelParams::NumExperts)
245-
routingIndicesCoopKernel(KernelParams params) {
246-
static constexpr int NumThreads = KernelParams::NumExperts; // DeepSeek uses 1 thread per expert
241+
__global__ void routingIndicesCoopKernel(KernelParams params) {
247242
static constexpr int NumWarps = NumThreads / WarpSize;
248243
// number of experts is bounded by number of threads
249244
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads];
@@ -336,16 +331,15 @@ __global__ void __launch_bounds__(KernelParams::NumExperts)
336331
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
337332

338333
int32_t blockExpertOffset = 0;
339-
if (threadIdx.x < KernelParams::NumExperts) {
334+
if (threadIdx.x < params.mNumExperts) {
340335
blockExpertOffset = atomicAdd(&params.mPtrExpertCounts[threadIdx.x], localExpertCount);
341336
}
342337

343338
// Sync to wait for completion of the histogram reduction.
344339
grid.sync();
345340

346341
// Get total count for this expert.
347-
int32_t count =
348-
(threadIdx.x < KernelParams::NumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0;
342+
int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0;
349343

350344
// Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency.
351345

@@ -422,16 +416,13 @@ __global__ void routingIndicesCoopKernel(KernelParams params) {
422416
#endif
423417

424418
////////////////////////////////////////////////////////////////////////////////////////////////////
425-
426-
template <int NumExperts>
427419
void runImpl(Data& data, void* stream) {
428-
static constexpr int NumThreads = NumExperts; // DeepSeek: 1 thread per expert
429420
static constexpr int NumWarps = NumThreads / WarpSize;
430-
const int MaxNumTopGroups = getMaxNumTopGroups(data.mNumExpertGroups > 1, NumExperts);
421+
int MaxNumTopGroups = getMaxNumTopGroups(data.mNumExpertGroups > 1, data.mNumExperts);
431422

432423
// Validate that the template parameter matches the data
433-
TORCH_CHECK(data.mNumExperts == NumExperts, "DeepSeek routing kernel expects exactly ",
434-
NumExperts, " experts, got ", data.mNumExperts);
424+
// TORCH_CHECK(data.mNumExperts == NumExperts, "DeepSeek routing kernel expects exactly ",
425+
// NumExperts, " experts, got ", data.mNumExperts);
435426
TORCH_CHECK(data.mPtrExpertIdx != nullptr || data.mPtrPermutedIdxSize != nullptr ||
436427
data.mPtrExpertWeights != nullptr,
437428
"Routing kernel requires at least one output parameter");
@@ -503,22 +494,21 @@ void runImpl(Data& data, void* stream) {
503494
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
504495
/*coopLaunch=*/false, routingMainKernel, numBlocks, NumThreads,
505496
/*smemSize=*/0, // No dynamic smem
506-
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true,
507-
NumExperts);
497+
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
508498

509499
if (data.mPtrPermutedIdxSize != nullptr) {
510500
if (useSingleCluster) {
511-
LAUNCH_ROUTING_WITH_EXTRA_FLAG(
512-
data,
513-
/*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
514-
/*smemSize=*/0, // No dynamic smem
515-
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true, NumExperts);
501+
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
502+
/*coopLaunch=*/false, routingIndicesClusterKernel,
503+
NumBlocksPerCluster, NumThreads,
504+
/*smemSize=*/0, // No dynamic smem
505+
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
516506
} else if (data.mNumTokens <= maxTokensCoop) {
517-
LAUNCH_ROUTING_WITH_EXTRA_FLAG(
518-
data,
519-
/*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, NumThreads,
520-
/*smemSize=*/0, // No dynamic smem
521-
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true, NumExperts);
507+
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
508+
/*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop,
509+
NumThreads,
510+
/*smemSize=*/0, // No dynamic smem
511+
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
522512
} else {
523513
const int32_t expandedIdxSize = data.mNumTokens * data.mTopK;
524514

@@ -533,31 +523,21 @@ void runImpl(Data& data, void* stream) {
533523
int const numBlocksOffsets =
534524
std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);
535525

536-
LAUNCH_ROUTING_WITH_EXTRA_FLAG(
537-
data,
538-
/*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreads,
539-
/*smemSize=*/0, // No dynamic smem
540-
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true, NumExperts);
541-
LAUNCH_ROUTING_WITH_EXTRA_FLAG(
542-
data,
543-
/*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreads,
544-
/*smemSize=*/0, // No dynamic smem
545-
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true, NumExperts);
526+
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
527+
/*coopLaunch=*/false, routingIndicesHistogramKernel,
528+
numBlocksHistogram, NumThreads,
529+
/*smemSize=*/0, // No dynamic smem
530+
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
531+
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
532+
/*coopLaunch=*/false, routingIndicesOffsetsKernel,
533+
numBlocksOffsets, NumThreads,
534+
/*smemSize=*/0, // No dynamic smem
535+
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
546536
}
547537
}
548538
}
549539

550-
void run(Data& data, void* stream) {
551-
if (data.mNumExperts == 72) {
552-
runImpl<72>(data, stream);
553-
} else if (data.mNumExperts == 256) {
554-
runImpl<256>(data, stream);
555-
} else if (data.mNumExperts == 384) {
556-
runImpl<384>(data, stream);
557-
} else {
558-
TORCH_CHECK(false, "Unsupported number of experts: ", data.mNumExperts);
559-
}
560-
}
540+
void run(Data& data, void* stream) { runImpl(data, stream); }
561541

562542
////////////////////////////////////////////////////////////////////////////////////////////////////
563543

csrc/trtllm_fused_moe_routing_llama4.cu

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,13 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
110110
// for each token, we load the scores then use `reduceTopK` for this.
111111
// each thread works on 4 experts, so a local reduction is done before
112112
for (int tokenIdx = 0; tokenIdx < params.mNumTokens; ++tokenIdx) {
113-
auto scoreOffset = tokenIdx * KernelParams::NumExperts;
113+
auto scoreOffset = tokenIdx * params.mNumExperts;
114114
int32_t warpMaxExpertIdx[MaxNumTopExperts];
115115
InputT warpMaxScore[MaxNumTopExperts];
116116

117117
// Use routingTopKExperts function instead of inline logic
118118
routingTopKExperts<InputT, ExpertsPerThread>(warp, warpMaxScore, warpMaxExpertIdx,
119-
threadIdx.x, KernelParams::NumExperts,
119+
threadIdx.x, params.mNumExperts,
120120
params.mPtrScores + scoreOffset);
121121

122122
if (cute::elect_one_sync()) {
@@ -285,7 +285,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
285285

286286
// TODO(mjoux): expand to more tokens (possibly)
287287
auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx;
288-
auto scoreOffset = warpTokenIdx * KernelParams::NumExperts;
288+
auto scoreOffset = warpTokenIdx * params.mNumExperts;
289289
bool validToken = warpTokenIdx < params.mNumTokens;
290290
InputT minScore = InputT{-INFINITY};
291291

@@ -306,7 +306,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
306306

307307
if (validToken) {
308308
routingTopKExperts<InputT, MaxNumExperts / WarpSize>(warp, warpMaxScore, warpMaxExpertIdx,
309-
laneIdx, KernelParams::NumExperts,
309+
laneIdx, params.mNumExperts,
310310
params.mPtrScores + scoreOffset);
311311
if (cute::elect_one_sync()) {
312312
auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})};
@@ -351,7 +351,7 @@ __global__ void __launch_bounds__(NumThreadsHist)
351351
auto warp = cg::tiled_partition<WarpSize>(block);
352352

353353
// initialize the mPtrExpertCounts
354-
int32_t expertCountsNum = 2 * KernelParams::NumExperts;
354+
int32_t expertCountsNum = 2 * params.mNumExperts;
355355
int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x;
356356
int32_t globalThreadStride = gridDim.x * NumThreads;
357357
initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0);
@@ -367,12 +367,12 @@ __global__ void __launch_bounds__(NumThreadsHist)
367367
// in this case, each warp represents a token, and we use a grid-stride loop
368368
// over all warps/tokens
369369
for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) {
370-
auto scoreOffset = tokenIdx * KernelParams::NumExperts;
370+
auto scoreOffset = tokenIdx * params.mNumExperts;
371371
int32_t warpMaxExpertIdx[MaxNumTopExperts];
372372
InputT warpMaxScore[MaxNumTopExperts];
373373

374374
routingTopKExperts<InputT, MaxNumExperts / WarpSize>(warp, warpMaxScore, warpMaxExpertIdx,
375-
laneIdx, KernelParams::NumExperts,
375+
laneIdx, params.mNumExperts,
376376
params.mPtrScores + scoreOffset);
377377

378378
if (cute::elect_one_sync()) {
@@ -384,25 +384,19 @@ __global__ void __launch_bounds__(NumThreadsHist)
384384
}
385385

386386
////////////////////////////////////////////////////////////////////////////////////////////////////
387-
388-
template <int NumExperts>
389387
void runImpl(Data const& data, void* stream) {
390-
// Validate that the runtime value matches the template parameter
391-
TORCH_CHECK(data.mNumExperts == NumExperts, "Llama4 routing kernel expects #experts ",
392-
data.mNumExperts, " to match template parameter ", NumExperts);
393-
394388
TORCH_CHECK(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr,
395389
"Routing kernel requires at least one input parameter");
396390
TORCH_CHECK(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr &&
397391
data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr,
398392
"Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers");
399393
TORCH_CHECK(data.mTopK <= MaxNumTopExperts,
400394
"Routing kernel expects topK experts <= ", MaxNumTopExperts, ", got ", data.mTopK);
401-
TORCH_CHECK(NumExperts <= MaxNumExperts, "Routing kernel expects #experts ", NumExperts,
402-
" to be at most max #experts ", MaxNumExperts);
395+
TORCH_CHECK(data.mNumExperts <= MaxNumExperts, "Routing kernel expects #experts ",
396+
data.mNumExperts, " to be at most max #experts ", MaxNumExperts);
403397
static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads");
404398
static_assert(MaxNumExperts <= NumThreadsHist, "#experts must be bounded by #threads");
405-
TORCH_CHECK(NumExperts % 4 == 0, "Routing kernel expects #experts ", NumExperts,
399+
TORCH_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts ", data.mNumExperts,
406400
" to be a multiple of 4.");
407401
TORCH_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got ",
408402
data.mPaddingLog2);
@@ -424,13 +418,13 @@ void runImpl(Data const& data, void* stream) {
424418
LAUNCH_ROUTING(data,
425419
/*coopLaunch=*/false, routingIndicesWarpKernel, 1, WarpSize,
426420
/*smemSize=*/0, // No dynamic smem
427-
stream, NumExperts);
421+
stream);
428422
} else if (useSingleCluster) {
429423
LAUNCH_ROUTING(data,
430424
/*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster,
431425
NumThreads,
432426
/*smemSize=*/0, // No dynamic smem
433-
stream, NumExperts);
427+
stream);
434428
} else {
435429
const uint32_t expandedIdxSize = data.mNumTokens * data.mTopK;
436430

@@ -450,23 +444,23 @@ void runImpl(Data const& data, void* stream) {
450444
/*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks,
451445
NumThreadsHist,
452446
/*smemSize=*/0, // No dynamic smem
453-
stream, NumExperts);
447+
stream);
454448
} else {
455449
// Reset the global histograms.
456450
CHECK_CUDA_ERROR(cudaMemsetAsync(data.mPtrExpertCounts, 0,
457-
static_cast<size_t>(2 * NumExperts) * sizeof(int32_t),
451+
static_cast<size_t>(2 * data.mNumExperts) * sizeof(int32_t),
458452
(cudaStream_t)stream));
459453
}
460454
LAUNCH_ROUTING(data,
461455
/*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram,
462456
NumThreadsHist,
463457
/*smemSize=*/0, // No dynamic smem
464-
stream, NumExperts);
458+
stream);
465459
LAUNCH_ROUTING(data,
466460
/*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets,
467461
NumThreadsHist,
468462
/*smemSize=*/0, // No dynamic smem
469-
stream, NumExperts);
463+
stream);
470464
}
471465
}
472466

@@ -481,18 +475,7 @@ void run(Data const& data, void* stream) {
481475
TORCH_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got ",
482476
data.mPaddingLog2);
483477

484-
// Dispatch to the appropriate template instantiation based on the number of experts
485-
switch (data.mNumExperts) {
486-
case 16:
487-
runImpl<16>(data, stream);
488-
break;
489-
case 128:
490-
runImpl<128>(data, stream);
491-
break;
492-
default:
493-
TORCH_CHECK(false, "Unsupported number of experts: ", data.mNumExperts,
494-
". Supported values are: 16, 128");
495-
}
478+
runImpl(data, stream);
496479
}
497480

498481
////////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)