Skip to content

Commit d135f59

Browse files
authored
Add unit test for routing kernels (NVIDIA#5405)
Signed-off-by: Christina Zhang <[email protected]>
1 parent 578dbc8 commit d135f59

File tree

9 files changed

+1522
-59
lines changed

9 files changed

+1522
-59
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cu

Lines changed: 31 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ __global__ void routingMainKernel(KernelParams params)
327327
// note that with invalid values, because sigmoid is < 1 and bias is -1,
328328
// we must get a negative value, which is smaller than any valid value
329329
auto scoreBias = float{scoreSigmoid + float{biasVal}};
330+
330331
if (expertSelected)
331332
{
332333
smemScoreBias[threadExpert] = scoreBias;
@@ -859,7 +860,6 @@ __global__ void routingIndicesCoopKernel(KernelParams params)
859860
// inefficient if we have one CTA per token doing a single global atomic.
860861

861862
template <typename KernelParams>
862-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
863863
__global__ void __launch_bounds__(NumThreads) routingIndicesHistogramKernel(KernelParams params)
864864
{
865865
// number of experts is bounded by number of threads
@@ -872,12 +872,14 @@ __global__ void __launch_bounds__(NumThreads) routingIndicesHistogramKernel(Kern
872872
smemExpertCount[threadIdx.x] = 0;
873873
__syncthreads();
874874

875+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
875876
// Wait on primary grid and trigger secondary kernel.
876877
if constexpr (KernelParams::UsePdl)
877878
{
878879
cudaGridDependencySynchronize();
879880
cudaTriggerProgrammaticLaunchCompletion();
880881
}
882+
#endif
881883

882884
int32_t const expandedIdxSize = params.mNumTokens * params.mTopK;
883885
int32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
@@ -932,17 +934,10 @@ __global__ void __launch_bounds__(NumThreads) routingIndicesHistogramKernel(Kern
932934
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
933935
atomicAdd(&params.mPtrExpertCounts[threadIdx.x], localExpertCount);
934936
}
935-
#else
936-
__global__ void routingIndicesHistogramKernel(KernelParams params)
937-
{
938-
assert(false && "routingIndicesHistogramKernel is only supported on SM90+ architectures");
939-
}
940-
#endif
941937

942938
////////////////////////////////////////////////////////////////////////////////////////////////////
943939

944940
template <typename KernelParams>
945-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
946941
__global__ void __launch_bounds__(NumThreads) routingIndicesOffsetsKernel(KernelParams params)
947942
{
948943
// number of experts is bounded by number of threads
@@ -960,11 +955,13 @@ __global__ void __launch_bounds__(NumThreads) routingIndicesOffsetsKernel(Kernel
960955
int32_t const expandedIdxSize = params.mNumTokens * params.mTopK;
961956
int32_t const numTiles = (expandedIdxSize + MaxExpandedIdxPerBlock - 1) / (MaxExpandedIdxPerBlock);
962957

958+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
963959
// Wait on primary grid.
964960
if constexpr (KernelParams::UsePdl)
965961
{
966962
cudaGridDependencySynchronize();
967963
}
964+
#endif
968965

969966
// The expert offsets are common to all tiles of all blocks.
970967
// Load the histogram, scan it and write offsets to shared memory.
@@ -1163,17 +1160,13 @@ __global__ void __launch_bounds__(NumThreads) routingIndicesOffsetsKernel(Kernel
11631160
// Trigger secondary kernel.
11641161
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
11651162
// dependency sync.
1163+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
11661164
if constexpr (KernelParams::UsePdl)
11671165
{
11681166
cudaTriggerProgrammaticLaunchCompletion();
11691167
}
1170-
}
1171-
#else
1172-
__global__ void routingIndicesOffsetsKernel(KernelParams params)
1173-
{
1174-
assert(false && "routingIndicesOffsetsKernel is only supported on SM90+ architectures");
1175-
}
11761168
#endif
1169+
}
11771170

11781171
////////////////////////////////////////////////////////////////////////////////////////////////////
11791172

@@ -1577,7 +1570,6 @@ __host__ __device__ constexpr void setBits(int32_t& value, int32_t newBits, int
15771570
////////////////////////////////////////////////////////////////////////////////////////////////////
15781571

15791572
template <typename KernelParams>
1580-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
15811573
__global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParams params)
15821574
{
15831575
// types used in this kernel
@@ -1614,11 +1606,13 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
16141606
}
16151607
__syncwarp();
16161608

1609+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
16171610
// then wait on primary grid
16181611
if constexpr (KernelParams::UsePdl)
16191612
{
16201613
cudaGridDependencySynchronize();
16211614
}
1615+
#endif
16221616

16231617
if (params.mPtrScores != nullptr)
16241618
{
@@ -1744,12 +1738,14 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
17441738
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
17451739
}
17461740

1741+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
17471742
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
17481743
// we can trigger the next kernel at this point
17491744
if constexpr (KernelParams::UsePdl)
17501745
{
17511746
cudaTriggerProgrammaticLaunchCompletion();
17521747
}
1748+
#endif
17531749
#endif
17541750

17551751
// at this point, all values for offsets are ready, except the final offsets
@@ -1806,13 +1802,6 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
18061802
}
18071803
}
18081804
}
1809-
#else
1810-
__global__ void routingIndicesWarpKernel(KernelParams params)
1811-
{
1812-
assert(false && "routingIndicesWarpKernel is only supported on SM90+ architectures");
1813-
}
1814-
#endif
1815-
18161805
////////////////////////////////////////////////////////////////////////////////////////////////////
18171806

18181807
template <typename KernelParams>
@@ -2076,7 +2065,6 @@ __global__ void routingIndicesClusterKernel(KernelParams params)
20762065

20772066
// this kernel is needed in case we have scores as input for the histogram kernel
20782067
template <typename KernelParams>
2079-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
20802068
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresKernel(KernelParams params)
20812069
{
20822070
using TypeExpW = typename KernelParams::TypeExpW;
@@ -2094,12 +2082,14 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresK
20942082
auto block = cg::this_thread_block();
20952083
auto warp = cg::tiled_partition<WarpSize>(block);
20962084

2085+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
20972086
// Wait on primary grid and trigger secondary kernel.
20982087
if constexpr (KernelParams::UsePdl)
20992088
{
21002089
cudaGridDependencySynchronize();
21012090
cudaTriggerProgrammaticLaunchCompletion();
21022091
}
2092+
#endif
21032093

21042094
// in this case, each warp represents a token, and we use a grid-stride loop
21052095
// over all warps/tokens
@@ -2132,12 +2122,6 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresK
21322122
}
21332123
}
21342124
}
2135-
#else
2136-
__global__ void routingIndicesHistogramScoresKernel(KernelParams params)
2137-
{
2138-
assert(false && "routingIndicesHistogramScoresKernel is only supported on SM90+ architectures");
2139-
}
2140-
#endif
21412125

21422126
// Two-step approach (if number of tokens exceed limits of what cluster / cooperative launch
21432127
// variants can handle): in order to minimize the amount of data to exchange through global memory,
@@ -2148,7 +2132,6 @@ __global__ void routingIndicesHistogramScoresKernel(KernelParams params)
21482132
// Note: the histogram calculation could also be fused with routingMainKernel, but this might be
21492133
// inefficient if we have one CTA per token doing a single global atomic.
21502134
template <typename KernelParams>
2151-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
21522135
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramKernel(KernelParams params)
21532136
{
21542137
using TypeExpW = typename KernelParams::TypeExpW;
@@ -2166,12 +2149,14 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramKernel(
21662149
}
21672150
__syncthreads();
21682151

2152+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
21692153
// Wait on primary grid and trigger secondary kernel.
21702154
if constexpr (KernelParams::UsePdl)
21712155
{
21722156
cudaGridDependencySynchronize();
21732157
cudaTriggerProgrammaticLaunchCompletion();
21742158
}
2159+
#endif
21752160

21762161
uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
21772162
uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
@@ -2234,17 +2219,10 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramKernel(
22342219
atomicAdd(&params.mPtrExpertCounts[threadIdx.x], localExpertCount);
22352220
}
22362221
}
2237-
#else
2238-
__global__ void routingIndicesHistogramKernel(KernelParams params)
2239-
{
2240-
assert(false && "routingIndicesHistogramKernel is only supported on SM90+ architectures");
2241-
}
2242-
#endif
22432222

22442223
////////////////////////////////////////////////////////////////////////////////////////////////////
22452224

22462225
template <typename KernelParams>
2247-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
22482226
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(KernelParams params)
22492227
{
22502228
using TypeExpW = typename KernelParams::TypeExpW;
@@ -2264,11 +2242,13 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
22642242
uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
22652243
uint32_t const numTiles = (expandedIdxSize + MaxExpandedIdxPerBlock - 1) / (MaxExpandedIdxPerBlock);
22662244

2245+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
22672246
// Wait on primary grid.
22682247
if constexpr (KernelParams::UsePdl)
22692248
{
22702249
cudaGridDependencySynchronize();
22712250
}
2251+
#endif
22722252

22732253
// The expert offsets are common to all tiles of all blocks.
22742254
// Load the histogram, scan it and write offsets to shared memory.
@@ -2484,6 +2464,7 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
24842464
}
24852465
}
24862466

2467+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
24872468
// Trigger secondary kernel.
24882469
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
24892470
// dependency sync.
@@ -2493,13 +2474,8 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
24932474
cudaTriggerProgrammaticLaunchCompletion();
24942475
}
24952476
#endif
2496-
}
2497-
#else
2498-
__global__ void routingIndicesOffsetsKernel(KernelParams params)
2499-
{
2500-
assert(false && "routingIndicesOffsetsKernel is only supported on SM90+ architectures");
2501-
}
25022477
#endif
2478+
}
25032479

25042480
////////////////////////////////////////////////////////////////////////////////////////////////////
25052481

@@ -2599,7 +2575,7 @@ void run(Data const& data, void* stream)
25992575

26002576
////////////////////////////////////////////////////////////////////////////////////////////////////
26012577

2602-
namespace routingQwen3
2578+
namespace routingRenormalize
26032579
{
26042580

26052581
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -3230,13 +3206,13 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresK
32303206
auto block = cg::this_thread_block();
32313207
auto warp = cg::tiled_partition<WarpSize>(block);
32323208

3209+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32333210
// Wait on primary grid.
32343211
if constexpr (KernelParams::UsePdl)
32353212
{
3236-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32373213
cudaGridDependencySynchronize();
3238-
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32393214
}
3215+
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32403216

32413217
// initialize the mPtrPermutedIdxToTokenIdx
32423218
int32_t globalThreadIdx = globalWarpIdx * WarpSize + laneIdx;
@@ -3261,13 +3237,13 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresK
32613237
}
32623238
}
32633239

3240+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32643241
// Trigger secondary kernel.
32653242
if constexpr (KernelParams::UsePdl)
32663243
{
3267-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32683244
cudaTriggerProgrammaticLaunchCompletion();
3269-
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32703245
}
3246+
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32713247

32723248
// in this case, each warp represents a token, and we use a grid-stride loop
32733249
// over all warps/tokens
@@ -3360,14 +3336,14 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramKernel(
33603336
}
33613337
__syncthreads();
33623338

3339+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
33633340
// Wait on primary grid and trigger secondary kernel.
33643341
if constexpr (KernelParams::UsePdl)
33653342
{
3366-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
33673343
cudaGridDependencySynchronize();
33683344
cudaTriggerProgrammaticLaunchCompletion();
3369-
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
33703345
}
3346+
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
33713347

33723348
uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
33733349
uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
@@ -3454,13 +3430,13 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
34543430
uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
34553431
uint32_t const numTiles = (expandedIdxSize + MaxExpandedIdxPerBlock - 1) / (MaxExpandedIdxPerBlock);
34563432

3433+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
34573434
// Wait on primary grid.
34583435
if constexpr (KernelParams::UsePdl)
34593436
{
3460-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
34613437
cudaGridDependencySynchronize();
3462-
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
34633438
}
3439+
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
34643440

34653441
// The expert offsets are common to all tiles of all blocks.
34663442
// Load the histogram, scan it and write offsets to shared memory.
@@ -3676,17 +3652,17 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
36763652
}
36773653
}
36783654

3655+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
36793656
// Trigger secondary kernel.
36803657
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
36813658
// dependency sync.
36823659
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
36833660
if constexpr (KernelParams::UsePdl)
36843661
{
3685-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
36863662
cudaTriggerProgrammaticLaunchCompletion();
3687-
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
36883663
}
36893664
#endif
3665+
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
36903666
}
36913667

36923668
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -3756,7 +3732,7 @@ void run(Data const& data, void* stream)
37563732

37573733
////////////////////////////////////////////////////////////////////////////////////////////////////
37583734

3759-
} // namespace routingQwen3
3735+
} // namespace routingRenormalize
37603736

37613737
////////////////////////////////////////////////////////////////////////////////////////////////////
37623738

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ void run(Data const& data, void* stream);
307307

308308
////////////////////////////////////////////////////////////////////////////////////////////////////
309309

310-
namespace routingQwen3
310+
namespace routingRenormalize
311311
{
312312

313313
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -430,7 +430,7 @@ struct KernelParams
430430

431431
void run(Data const& data, void* stream);
432432

433-
} // namespace routingQwen3
433+
} // namespace routingRenormalize
434434

435435
////////////////////////////////////////////////////////////////////////////////////////////////////
436436

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
154154
else if (routingMethodType == RoutingMethodType::Renormalize /* default */
155155
|| routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK */)
156156
{
157-
moe::dev::routingQwen3::Data routingData;
157+
moe::dev::routingRenormalize::Data routingData;
158158

159159
//
160160
// Config
@@ -196,7 +196,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
196196
routingData.mLocalExpertsStrideLog2 = 0;
197197
routingData.mNumLocalExperts = localNumExperts;
198198

199-
moe::dev::routingQwen3::run(routingData, stream);
199+
moe::dev::routingRenormalize::run(routingData, stream);
200200
}
201201
else
202202
{

cpp/tests/unit_tests/kernels/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,9 @@ set(SAMPLING_KERNEL_TEST_SRC
7676
sampling/samplingPenaltyTest.cpp sampling/samplingUtilsTest.cu)
7777

7878
add_gtest(samplingKernelsTest "${SAMPLING_KERNEL_TEST_SRC}")
79+
80+
set(ROUTING_KERNEL_TEST_SRC
81+
routing/routingTest.cpp routing/routingLlama4Test.cpp
82+
routing/routingRenormalizeTest.cpp routing/routingDeepSeekTest.cpp)
83+
84+
add_gtest(routingKernelsTest "${ROUTING_KERNEL_TEST_SRC}")

0 commit comments

Comments
 (0)