@@ -327,6 +327,7 @@ __global__ void routingMainKernel(KernelParams params)
327327 // note that with invalid values, because sigmoid is < 1 and bias is -1,
328328 // we must get a negative value, which is smaller than any valid value
329329 auto scoreBias = float {scoreSigmoid + float {biasVal}};
330+
330331 if (expertSelected)
331332 {
332333 smemScoreBias[threadExpert] = scoreBias;
@@ -859,7 +860,6 @@ __global__ void routingIndicesCoopKernel(KernelParams params)
859860// inefficient if we have one CTA per token doing a single global atomic.
860861
861862template <typename KernelParams>
862- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
863863__global__ void __launch_bounds__ (NumThreads) routingIndicesHistogramKernel(KernelParams params)
864864{
865865 // number of experts is bounded by number of threads
@@ -872,12 +872,14 @@ __global__ void __launch_bounds__(NumThreads) routingIndicesHistogramKernel(Kern
872872 smemExpertCount[threadIdx .x ] = 0 ;
873873 __syncthreads ();
874874
875+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
875876 // Wait on primary grid and trigger secondary kernel.
876877 if constexpr (KernelParams::UsePdl)
877878 {
878879 cudaGridDependencySynchronize ();
879880 cudaTriggerProgrammaticLaunchCompletion ();
880881 }
882+ #endif
881883
882884 int32_t const expandedIdxSize = params.mNumTokens * params.mTopK ;
883885 int32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2 ;
@@ -932,17 +934,10 @@ __global__ void __launch_bounds__(NumThreads) routingIndicesHistogramKernel(Kern
932934 int32_t const localExpertCount = smemExpertCount[threadIdx .x ];
933935 atomicAdd (¶ms.mPtrExpertCounts [threadIdx .x ], localExpertCount);
934936}
935- #else
936- __global__ void routingIndicesHistogramKernel (KernelParams params)
937- {
938- assert (false && " routingIndicesHistogramKernel is only supported on SM90+ architectures" );
939- }
940- #endif
941937
942938// //////////////////////////////////////////////////////////////////////////////////////////////////
943939
944940template <typename KernelParams>
945- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
946941__global__ void __launch_bounds__ (NumThreads) routingIndicesOffsetsKernel(KernelParams params)
947942{
948943 // number of experts is bounded by number of threads
@@ -960,11 +955,13 @@ __global__ void __launch_bounds__(NumThreads) routingIndicesOffsetsKernel(Kernel
960955 int32_t const expandedIdxSize = params.mNumTokens * params.mTopK ;
961956 int32_t const numTiles = (expandedIdxSize + MaxExpandedIdxPerBlock - 1 ) / (MaxExpandedIdxPerBlock);
962957
958+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
963959 // Wait on primary grid.
964960 if constexpr (KernelParams::UsePdl)
965961 {
966962 cudaGridDependencySynchronize ();
967963 }
964+ #endif
968965
969966 // The expert offsets are common to all tiles of all blocks.
970967 // Load the histogram, scan it and write offsets to shared memory.
@@ -1163,17 +1160,13 @@ __global__ void __launch_bounds__(NumThreads) routingIndicesOffsetsKernel(Kernel
11631160 // Trigger secondary kernel.
11641161 // Note: this does not guarantee the visibility of prior writes unless the consumer executes a
11651162 // dependency sync.
1163+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
11661164 if constexpr (KernelParams::UsePdl)
11671165 {
11681166 cudaTriggerProgrammaticLaunchCompletion ();
11691167 }
1170- }
1171- #else
1172- __global__ void routingIndicesOffsetsKernel (KernelParams params)
1173- {
1174- assert (false && " routingIndicesOffsetsKernel is only supported on SM90+ architectures" );
1175- }
11761168#endif
1169+ }
11771170
11781171// //////////////////////////////////////////////////////////////////////////////////////////////////
11791172
@@ -1577,7 +1570,6 @@ __host__ __device__ constexpr void setBits(int32_t& value, int32_t newBits, int
15771570// //////////////////////////////////////////////////////////////////////////////////////////////////
15781571
15791572template <typename KernelParams>
1580- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
15811573__global__ void __launch_bounds__ (WarpSize) routingIndicesWarpKernel(KernelParams params)
15821574{
15831575 // types used in this kernel
@@ -1614,11 +1606,13 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
16141606 }
16151607 __syncwarp ();
16161608
1609+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
16171610 // then wait on primary grid
16181611 if constexpr (KernelParams::UsePdl)
16191612 {
16201613 cudaGridDependencySynchronize ();
16211614 }
1615+ #endif
16221616
16231617 if (params.mPtrScores != nullptr )
16241618 {
@@ -1744,12 +1738,14 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
17441738 params.mPtrNumNonExitingCtas [0 ] = numNonExitingCtas;
17451739 }
17461740
1741+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
17471742#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
17481743 // we can trigger the next kernel at this point
17491744 if constexpr (KernelParams::UsePdl)
17501745 {
17511746 cudaTriggerProgrammaticLaunchCompletion ();
17521747 }
1748+ #endif
17531749#endif
17541750
17551751 // at this point, all values for offsets are ready, except the final offsets
@@ -1806,13 +1802,6 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
18061802 }
18071803 }
18081804}
1809- #else
1810- __global__ void routingIndicesWarpKernel (KernelParams params)
1811- {
1812- assert (false && " routingIndicesWarpKernel is only supported on SM90+ architectures" );
1813- }
1814- #endif
1815-
18161805// //////////////////////////////////////////////////////////////////////////////////////////////////
18171806
18181807template <typename KernelParams>
@@ -2076,7 +2065,6 @@ __global__ void routingIndicesClusterKernel(KernelParams params)
20762065
20772066// this kernel is needed in case we have scores as input for the histogram kernel
20782067template <typename KernelParams>
2079- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
20802068__global__ void __launch_bounds__ (NumThreadsHist) routingIndicesHistogramScoresKernel(KernelParams params)
20812069{
20822070 using TypeExpW = typename KernelParams::TypeExpW;
@@ -2094,12 +2082,14 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresK
20942082 auto block = cg::this_thread_block ();
20952083 auto warp = cg::tiled_partition<WarpSize>(block);
20962084
2085+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
20972086 // Wait on primary grid and trigger secondary kernel.
20982087 if constexpr (KernelParams::UsePdl)
20992088 {
21002089 cudaGridDependencySynchronize ();
21012090 cudaTriggerProgrammaticLaunchCompletion ();
21022091 }
2092+ #endif
21032093
21042094 // in this case, each warp represents a token, and we use a grid-stride loop
21052095 // over all warps/tokens
@@ -2132,12 +2122,6 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresK
21322122 }
21332123 }
21342124}
2135- #else
2136- __global__ void routingIndicesHistogramScoresKernel (KernelParams params)
2137- {
2138- assert (false && " routingIndicesHistogramScoresKernel is only supported on SM90+ architectures" );
2139- }
2140- #endif
21412125
21422126// Two-step approach (if number of tokens exceed limits of what cluster / cooperative launch
21432127// variants can handle): in order to minimize the amount of data to exchange through global memory,
@@ -2148,7 +2132,6 @@ __global__ void routingIndicesHistogramScoresKernel(KernelParams params)
21482132// Note: the histogram calculation could also be fused with routingMainKernel, but this might be
21492133// inefficient if we have one CTA per token doing a single global atomic.
21502134template <typename KernelParams>
2151- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
21522135__global__ void __launch_bounds__ (NumThreadsHist) routingIndicesHistogramKernel(KernelParams params)
21532136{
21542137 using TypeExpW = typename KernelParams::TypeExpW;
@@ -2166,12 +2149,14 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramKernel(
21662149 }
21672150 __syncthreads ();
21682151
2152+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
21692153 // Wait on primary grid and trigger secondary kernel.
21702154 if constexpr (KernelParams::UsePdl)
21712155 {
21722156 cudaGridDependencySynchronize ();
21732157 cudaTriggerProgrammaticLaunchCompletion ();
21742158 }
2159+ #endif
21752160
21762161 uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
21772162 uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2 ;
@@ -2234,17 +2219,10 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramKernel(
22342219 atomicAdd (¶ms.mPtrExpertCounts [threadIdx .x ], localExpertCount);
22352220 }
22362221}
2237- #else
2238- __global__ void routingIndicesHistogramKernel (KernelParams params)
2239- {
2240- assert (false && " routingIndicesHistogramKernel is only supported on SM90+ architectures" );
2241- }
2242- #endif
22432222
22442223// //////////////////////////////////////////////////////////////////////////////////////////////////
22452224
22462225template <typename KernelParams>
2247- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
22482226__global__ void __launch_bounds__ (NumThreadsHist) routingIndicesOffsetsKernel(KernelParams params)
22492227{
22502228 using TypeExpW = typename KernelParams::TypeExpW;
@@ -2264,11 +2242,13 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
22642242 uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
22652243 uint32_t const numTiles = (expandedIdxSize + MaxExpandedIdxPerBlock - 1 ) / (MaxExpandedIdxPerBlock);
22662244
2245+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
22672246 // Wait on primary grid.
22682247 if constexpr (KernelParams::UsePdl)
22692248 {
22702249 cudaGridDependencySynchronize ();
22712250 }
2251+ #endif
22722252
22732253 // The expert offsets are common to all tiles of all blocks.
22742254 // Load the histogram, scan it and write offsets to shared memory.
@@ -2484,6 +2464,7 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
24842464 }
24852465 }
24862466
2467+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
24872468// Trigger secondary kernel.
24882469// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
24892470// dependency sync.
@@ -2493,13 +2474,8 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
24932474 cudaTriggerProgrammaticLaunchCompletion ();
24942475 }
24952476#endif
2496- }
2497- #else
2498- __global__ void routingIndicesOffsetsKernel (KernelParams params)
2499- {
2500- assert (false && " routingIndicesOffsetsKernel is only supported on SM90+ architectures" );
2501- }
25022477#endif
2478+ }
25032479
25042480// //////////////////////////////////////////////////////////////////////////////////////////////////
25052481
@@ -2599,7 +2575,7 @@ void run(Data const& data, void* stream)
25992575
26002576// //////////////////////////////////////////////////////////////////////////////////////////////////
26012577
2602- namespace routingQwen3
2578+ namespace routingRenormalize
26032579{
26042580
26052581// //////////////////////////////////////////////////////////////////////////////////////////////////
@@ -3230,13 +3206,13 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresK
32303206 auto block = cg::this_thread_block ();
32313207 auto warp = cg::tiled_partition<WarpSize>(block);
32323208
3209+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32333210 // Wait on primary grid.
32343211 if constexpr (KernelParams::UsePdl)
32353212 {
3236- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32373213 cudaGridDependencySynchronize ();
3238- #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32393214 }
3215+ #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32403216
32413217 // initialize the mPtrPermutedIdxToTokenIdx
32423218 int32_t globalThreadIdx = globalWarpIdx * WarpSize + laneIdx;
@@ -3261,13 +3237,13 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresK
32613237 }
32623238 }
32633239
3240+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32643241 // Trigger secondary kernel.
32653242 if constexpr (KernelParams::UsePdl)
32663243 {
3267- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32683244 cudaTriggerProgrammaticLaunchCompletion ();
3269- #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32703245 }
3246+ #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
32713247
32723248 // in this case, each warp represents a token, and we use a grid-stride loop
32733249 // over all warps/tokens
@@ -3360,14 +3336,14 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramKernel(
33603336 }
33613337 __syncthreads ();
33623338
3339+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
33633340 // Wait on primary grid and trigger secondary kernel.
33643341 if constexpr (KernelParams::UsePdl)
33653342 {
3366- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
33673343 cudaGridDependencySynchronize ();
33683344 cudaTriggerProgrammaticLaunchCompletion ();
3369- #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
33703345 }
3346+ #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
33713347
33723348 uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
33733349 uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2 ;
@@ -3454,13 +3430,13 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
34543430 uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
34553431 uint32_t const numTiles = (expandedIdxSize + MaxExpandedIdxPerBlock - 1 ) / (MaxExpandedIdxPerBlock);
34563432
3433+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
34573434 // Wait on primary grid.
34583435 if constexpr (KernelParams::UsePdl)
34593436 {
3460- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
34613437 cudaGridDependencySynchronize ();
3462- #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
34633438 }
3439+ #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
34643440
34653441 // The expert offsets are common to all tiles of all blocks.
34663442 // Load the histogram, scan it and write offsets to shared memory.
@@ -3676,17 +3652,17 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
36763652 }
36773653 }
36783654
3655+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
36793656// Trigger secondary kernel.
36803657// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
36813658// dependency sync.
36823659#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
36833660 if constexpr (KernelParams::UsePdl)
36843661 {
3685- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
36863662 cudaTriggerProgrammaticLaunchCompletion ();
3687- #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
36883663 }
36893664#endif
3665+ #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
36903666}
36913667
36923668// //////////////////////////////////////////////////////////////////////////////////////////////////
@@ -3756,7 +3732,7 @@ void run(Data const& data, void* stream)
37563732
37573733// //////////////////////////////////////////////////////////////////////////////////////////////////
37583734
3759- } // namespace routingQwen3
3735+ } // namespace routingRenormalize
37603736
37613737// //////////////////////////////////////////////////////////////////////////////////////////////////
37623738
0 commit comments