@@ -22,10 +22,13 @@ namespace routingDeepSeek {
22
22
23
23
// //////////////////////////////////////////////////////////////////////////////////////////////////
24
24
25
+ static constexpr int NumThreads = 384 ;
26
+ static constexpr int NumWarps = NumThreads / WarpSize;
25
27
static constexpr int NumTopGroupScores = 2 ;
26
28
static constexpr int MaxNumTopExperts = 8 ;
29
+ static constexpr int MaxNumTopGroupsDefault = 16 ;
27
30
28
- __host__ __device__ int constexpr getMaxNumTopGroups (const bool useGroups, const int numExperts) {
31
+ __host__ __device__ int getMaxNumTopGroups (const bool useGroups, const int numExperts) {
29
32
if (useGroups || numExperts <= 256 ) {
30
33
return 4 ;
31
34
} else {
@@ -38,10 +41,8 @@ __global__ void routingMainKernel(KernelParams params) {
38
41
// declare types
39
42
using OutputT = typename KernelParams::OutputT;
40
43
using InputT = typename KernelParams::InputT;
41
- static constexpr int NumThreads = KernelParams::NumExperts; // DeepSeek uses 1 thread per expert
42
44
static constexpr int NumWarps = NumThreads / WarpSize;
43
- constexpr int MaxNumTopGroups =
44
- getMaxNumTopGroups (KernelParams::UseGroups, KernelParams::NumExperts);
45
+ int MaxNumTopGroups = getMaxNumTopGroups (KernelParams::UseGroups, params.mNumExperts );
45
46
46
47
// declare shared memory structure
47
48
// number of experts is bounded by number of threads
@@ -71,19 +72,19 @@ __global__ void routingMainKernel(KernelParams params) {
71
72
72
73
// load bias already; each warp represents one expert group
73
74
auto threadExpert = threadIdx .x ;
74
- bool expertSelected = threadExpert < KernelParams::NumExperts ;
75
+ bool expertSelected = threadExpert < params. mNumExperts ;
75
76
if constexpr (KernelParams::UseGroups) {
76
77
threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx;
77
78
expertSelected = laneIdx < params.mNumExpertsPerGroup ;
78
79
}
79
- auto scoreIdx = int64_t {blockIdx .x } * int64_t {KernelParams::NumExperts } + threadExpert;
80
+ auto scoreIdx = int64_t {blockIdx .x } * int64_t {params. mNumExperts } + threadExpert;
80
81
auto biasVal = expertSelected ? params.mPtrRoutingBias [threadExpert] : invalidScore;
81
82
82
83
// initialize the mPtrExpertCounts
83
84
if (params.mPtrExpertCounts ) {
84
85
int32_t globalThreadIdx = blockIdx .x * NumThreads + threadIdx .x ;
85
86
int32_t globalThreadStride = gridDim .x * NumThreads;
86
- int32_t expertCountsNum = 2 * KernelParams::NumExperts ;
87
+ int32_t expertCountsNum = 2 * params. mNumExperts ;
87
88
initArr (globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts , 0 );
88
89
}
89
90
@@ -118,10 +119,10 @@ __global__ void routingMainKernel(KernelParams params) {
118
119
// registers for top group score reduction
119
120
float topExpGroupScores[NumTopGroupScores];
120
121
[[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores];
121
- float topGroups[MaxNumTopGroups ]; // bound of params.mNumLimitedGroups
122
- int32_t topGroupIdx[MaxNumTopGroups ];
123
- float expertScoreGroup[MaxNumTopGroups ];
124
- int32_t expertIdxGroup[MaxNumTopGroups ];
122
+ float topGroups[MaxNumTopGroupsDefault ]; // bound of params.mNumLimitedGroups
123
+ int32_t topGroupIdx[MaxNumTopGroupsDefault ];
124
+ float expertScoreGroup[MaxNumTopGroupsDefault ];
125
+ int32_t expertIdxGroup[MaxNumTopGroupsDefault ];
125
126
float topScores[MaxNumTopExperts]; // bound of params.mTopK
126
127
int32_t topExperts[MaxNumTopExperts];
127
128
@@ -168,15 +169,13 @@ __global__ void routingMainKernel(KernelParams params) {
168
169
} else {
169
170
// without groups, each thread just takes `MaxNumTopGroups` experts
170
171
171
- #pragma unroll
172
172
for (int ii = 0 ; ii < MaxNumTopGroups; ++ii) {
173
173
auto expertIdx = ii * WarpSize + laneIdx;
174
174
expertIdxGroup[ii] = expertIdx;
175
175
expertScoreGroup[ii] =
176
- expertIdx < KernelParams::NumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat;
176
+ expertIdx < params. mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat;
177
177
}
178
178
}
179
-
180
179
topk::reduceTopK (warp, topScores, topExperts, expertScoreGroup, expertIdxGroup,
181
180
/* minValue */ invalidScoreFloat, params.mTopK );
182
181
@@ -214,10 +213,8 @@ __global__ void routingMainKernel(KernelParams params) {
214
213
template <typename KernelParams>
215
214
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
216
215
__global__ void __cluster_dims__ (NumBlocksPerCluster, 1 , 1 )
217
- __launch_bounds__(KernelParams::NumExperts) routingIndicesClusterKernel(KernelParams params) {
216
+ routingIndicesClusterKernel(KernelParams params) {
218
217
using OutputT = typename KernelParams::OutputT;
219
- static constexpr int NumThreads = KernelParams::NumExperts; // DeepSeek uses 1 thread per expert
220
- static constexpr int NumWarps = NumThreads / WarpSize;
221
218
222
219
int32_t const warpIdx = __shfl_sync (0xffffffff , threadIdx .x / WarpSize, 0 );
223
220
int32_t const clusterBlockRank = blockIdx .x ;
@@ -241,9 +238,7 @@ __global__ void routingIndicesClusterKernel(KernelParams params) {
241
238
242
239
template <typename KernelParams>
243
240
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
244
- __global__ void __launch_bounds__ (KernelParams::NumExperts)
245
- routingIndicesCoopKernel(KernelParams params) {
246
- static constexpr int NumThreads = KernelParams::NumExperts; // DeepSeek uses 1 thread per expert
241
+ __global__ void routingIndicesCoopKernel (KernelParams params) {
247
242
static constexpr int NumWarps = NumThreads / WarpSize;
248
243
// number of experts is bounded by number of threads
249
244
__shared__ int32_t __attribute ((aligned (128 ))) smemExpertCount[NumThreads];
@@ -336,16 +331,15 @@ __global__ void __launch_bounds__(KernelParams::NumExperts)
336
331
int32_t const localExpertCount = smemExpertCount[threadIdx .x ];
337
332
338
333
int32_t blockExpertOffset = 0 ;
339
- if (threadIdx .x < KernelParams::NumExperts ) {
334
+ if (threadIdx .x < params. mNumExperts ) {
340
335
blockExpertOffset = atomicAdd (¶ms.mPtrExpertCounts [threadIdx .x ], localExpertCount);
341
336
}
342
337
343
338
// Sync to wait for completion of the histogram reduction.
344
339
grid.sync ();
345
340
346
341
// Get total count for this expert.
347
- int32_t count =
348
- (threadIdx .x < KernelParams::NumExperts) ? params.mPtrExpertCounts [threadIdx .x ] : 0 ;
342
+ int32_t count = (threadIdx .x < params.mNumExperts ) ? params.mPtrExpertCounts [threadIdx .x ] : 0 ;
349
343
350
344
// Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency.
351
345
@@ -422,16 +416,13 @@ __global__ void routingIndicesCoopKernel(KernelParams params) {
422
416
#endif
423
417
424
418
// //////////////////////////////////////////////////////////////////////////////////////////////////
425
-
426
- template <int NumExperts>
427
419
void runImpl (Data& data, void * stream) {
428
- static constexpr int NumThreads = NumExperts; // DeepSeek: 1 thread per expert
429
420
static constexpr int NumWarps = NumThreads / WarpSize;
430
- const int MaxNumTopGroups = getMaxNumTopGroups (data.mNumExpertGroups > 1 , NumExperts );
421
+ int MaxNumTopGroups = getMaxNumTopGroups (data.mNumExpertGroups > 1 , data. mNumExperts );
431
422
432
423
// Validate that the template parameter matches the data
433
- TORCH_CHECK (data.mNumExperts == NumExperts, " DeepSeek routing kernel expects exactly " ,
434
- NumExperts, " experts, got " , data.mNumExperts );
424
+ // TORCH_CHECK(data.mNumExperts == NumExperts, "DeepSeek routing kernel expects exactly ",
425
+ // NumExperts, " experts, got ", data.mNumExperts);
435
426
TORCH_CHECK (data.mPtrExpertIdx != nullptr || data.mPtrPermutedIdxSize != nullptr ||
436
427
data.mPtrExpertWeights != nullptr ,
437
428
" Routing kernel requires at least one output parameter" );
@@ -503,22 +494,21 @@ void runImpl(Data& data, void* stream) {
503
494
LAUNCH_ROUTING_WITH_EXTRA_FLAG (data,
504
495
/* coopLaunch=*/ false , routingMainKernel, numBlocks, NumThreads,
505
496
/* smemSize=*/ 0 , // No dynamic smem
506
- stream, data.mNumExpertGroups > 1 , /* forceFloatInput=*/ true ,
507
- NumExperts);
497
+ stream, data.mNumExpertGroups > 1 , /* forceFloatInput=*/ true );
508
498
509
499
if (data.mPtrPermutedIdxSize != nullptr ) {
510
500
if (useSingleCluster) {
511
- LAUNCH_ROUTING_WITH_EXTRA_FLAG (
512
- data ,
513
- /* coopLaunch= */ false , routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
514
- /* smemSize=*/ 0 , // No dynamic smem
515
- stream, data.mNumExpertGroups > 1 , /* forceFloatInput=*/ true , NumExperts );
501
+ LAUNCH_ROUTING_WITH_EXTRA_FLAG (data,
502
+ /* coopLaunch= */ false , routingIndicesClusterKernel ,
503
+ NumBlocksPerCluster, NumThreads,
504
+ /* smemSize=*/ 0 , // No dynamic smem
505
+ stream, data.mNumExpertGroups > 1 , /* forceFloatInput=*/ true );
516
506
} else if (data.mNumTokens <= maxTokensCoop) {
517
- LAUNCH_ROUTING_WITH_EXTRA_FLAG (
518
- data ,
519
- /* coopLaunch= */ true , routingIndicesCoopKernel, numBlocksCoop, NumThreads,
520
- /* smemSize=*/ 0 , // No dynamic smem
521
- stream, data.mNumExpertGroups > 1 , /* forceFloatInput=*/ true , NumExperts );
507
+ LAUNCH_ROUTING_WITH_EXTRA_FLAG (data,
508
+ /* coopLaunch= */ true , routingIndicesCoopKernel, numBlocksCoop ,
509
+ NumThreads,
510
+ /* smemSize=*/ 0 , // No dynamic smem
511
+ stream, data.mNumExpertGroups > 1 , /* forceFloatInput=*/ true );
522
512
} else {
523
513
const int32_t expandedIdxSize = data.mNumTokens * data.mTopK ;
524
514
@@ -533,31 +523,21 @@ void runImpl(Data& data, void* stream) {
533
523
int const numBlocksOffsets =
534
524
std::min ((expandedIdxSize + offsetEltsPerBlock - 1 ) / offsetEltsPerBlock, maxNumBlocks);
535
525
536
- LAUNCH_ROUTING_WITH_EXTRA_FLAG (
537
- data ,
538
- /* coopLaunch= */ false , routingIndicesHistogramKernel, numBlocksHistogram, NumThreads,
539
- /* smemSize=*/ 0 , // No dynamic smem
540
- stream, data.mNumExpertGroups > 1 , /* forceFloatInput=*/ true , NumExperts );
541
- LAUNCH_ROUTING_WITH_EXTRA_FLAG (
542
- data ,
543
- /* coopLaunch= */ false , routingIndicesOffsetsKernel, numBlocksOffsets, NumThreads,
544
- /* smemSize=*/ 0 , // No dynamic smem
545
- stream, data.mNumExpertGroups > 1 , /* forceFloatInput=*/ true , NumExperts );
526
+ LAUNCH_ROUTING_WITH_EXTRA_FLAG (data,
527
+ /* coopLaunch= */ false , routingIndicesHistogramKernel ,
528
+ numBlocksHistogram, NumThreads,
529
+ /* smemSize=*/ 0 , // No dynamic smem
530
+ stream, data.mNumExpertGroups > 1 , /* forceFloatInput=*/ true );
531
+ LAUNCH_ROUTING_WITH_EXTRA_FLAG (data,
532
+ /* coopLaunch= */ false , routingIndicesOffsetsKernel ,
533
+ numBlocksOffsets, NumThreads,
534
+ /* smemSize=*/ 0 , // No dynamic smem
535
+ stream, data.mNumExpertGroups > 1 , /* forceFloatInput=*/ true );
546
536
}
547
537
}
548
538
}
549
539
550
- void run (Data& data, void * stream) {
551
- if (data.mNumExperts == 72 ) {
552
- runImpl<72 >(data, stream);
553
- } else if (data.mNumExperts == 256 ) {
554
- runImpl<256 >(data, stream);
555
- } else if (data.mNumExperts == 384 ) {
556
- runImpl<384 >(data, stream);
557
- } else {
558
- TORCH_CHECK (false , " Unsupported number of experts: " , data.mNumExperts );
559
- }
560
- }
540
+ void run (Data& data, void * stream) { runImpl (data, stream); }
561
541
562
542
// //////////////////////////////////////////////////////////////////////////////////////////////////
563
543
0 commit comments