@@ -230,59 +230,62 @@ inline __device__ __host__ T divUp(T m, T n)
230230// Return (block_size, cluster_size, loads_per_thread)
231231std::tuple<int , int , int > adjustGridConfig (int numTokens, int dim, int eltsPerThread)
232232{
233- // Start with preferred block_size and cluster_size
234- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
235- int clusterSize = 8 ;
236- #else
237- int clusterSize = 1 ;
238- #endif
233+ static int SM = tensorrt_llm::common::getSMVersion ();
234+
235+ int clusterSize = SM >= 90 ? 8 : 1 ;
239236 int blockSize = 128 ;
240237 // ========================== Adjust the grid configuration ==========================
241238 int threadsNeeded = divUp (dim, eltsPerThread);
242239 int loadsPerThread = 1 ;
243240
244241 blockSize = divUp (threadsNeeded, clusterSize);
245- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
246- while (threadsNeeded % clusterSize != 0 && clusterSize > 1 )
247- {
248- clusterSize /= 2 ;
249- }
250- blockSize = divUp (threadsNeeded, clusterSize);
251- while (blockSize < 128 && clusterSize >= 2 )
252- {
253- blockSize *= 2 ;
254- clusterSize /= 2 ;
255- }
256- int smCount = getMultiProcessorCount ();
257- while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512 )
242+ if (clusterSize > 1 )
258243 {
259- blockSize *= 2 ;
260- clusterSize /= 2 ;
244+ while (threadsNeeded % clusterSize != 0 && clusterSize > 1 )
245+ {
246+ clusterSize /= 2 ;
247+ }
248+ blockSize = divUp (threadsNeeded, clusterSize);
249+ while (blockSize < 128 && clusterSize >= 2 )
250+ {
251+ blockSize *= 2 ;
252+ clusterSize /= 2 ;
253+ }
254+ int smCount = getMultiProcessorCount ();
255+ while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512 )
256+ {
257+ blockSize *= 2 ;
258+ clusterSize /= 2 ;
259+ }
261260 }
262- #endif
263261
264262 // Trying to scale up use multiple loads or CGA
265263 while (blockSize > 1024 )
266264 {
267- # if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
268- if (clusterSize < 8 )
265+ // Scale up with CGA if supported
266+ if (SM >= 90 )
269267 {
270- clusterSize = clusterSize << 1 ;
271- }
272- else
273- {
274- break ;
275- }
276- #else
277- if (loadsPerThread < 8 )
278- {
279- loadsPerThread += 1 ;
268+ if (clusterSize < 8 )
269+ {
270+ clusterSize = clusterSize << 1 ;
271+ }
272+ else
273+ {
274+ break ;
275+ }
280276 }
281277 else
282278 {
283- break ;
279+
280+ if (loadsPerThread < 8 )
281+ {
282+ loadsPerThread += 1 ;
283+ }
284+ else
285+ {
286+ break ;
287+ }
284288 }
285- #endif
286289 blockSize = divUp (threadsNeeded, clusterSize * loadsPerThread);
287290 }
288291 return {blockSize, clusterSize, loadsPerThread};
@@ -420,9 +423,9 @@ __global__ void __launch_bounds__(1024) oneshotAllreduceFusionKernel(T* outputPt
420423 }
421424 float blockSum = blockReduceSum<float , true >(threadSum);
422425
423- __shared__ float sharedVal[8 ]; // Temporary variable to share the sum within block
424426 float fullSum = blockSum;
425427#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
428+ __shared__ float sharedVal[8 ]; // Temporary variable to share the sum within block
426429 namespace cg = cooperative_groups;
427430 cg::cluster_group cluster = cg::this_cluster ();
428431 int const numBlocks = cluster.num_blocks ();
@@ -459,45 +462,40 @@ using detail::adjustGridConfig;
459462
460463void oneshotAllreduceFusionOp (AllReduceFusionParams const & params)
461464{
465+
466+ static int const kSMVersion = tensorrt_llm::common::getSMVersion ();
462467 int const numTokens = params.numTokens ;
463468 int const tokenDim = params.tokenDim ;
464469 int const eltsPerThread = sizeof (float4 ) / getDTypeSize (params.dType );
465470
466471 auto [blockSize, clusterSize, loadsPerThread] = adjustGridConfig (numTokens, tokenDim, eltsPerThread);
467472 dim3 grid (numTokens, clusterSize, 1 );
468473
469- TLLM_CHECK_WITH_INFO (blockSize <= 1024 && loadsPerThread == 1 ,
470- " Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)" , tokenDim,
471- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
472- 1024 * 8 * eltsPerThread);
473- #else
474- 1024 * eltsPerThread);
475- #endif
476-
477474 TLLM_LOG_DEBUG (
478475 " [MNNVL AllReduceOneShot] Dispatch: grid size: (%d, %d, 1), block_size: %d, cluster_size: %d, "
479476 " loads_per_thread: %d, "
480477 " threads_needed: %d" ,
481478 numTokens, clusterSize, blockSize, clusterSize, loadsPerThread, divUp (tokenDim, eltsPerThread));
482479
480+ TLLM_CHECK_WITH_INFO (blockSize <= 1024 && loadsPerThread == 1 ,
481+ " Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)" , tokenDim,
482+ 1024 * (kSMVersion >= 90 ? 8 : 1 ) * eltsPerThread);
483+
483484 cudaLaunchAttribute attrs[2 ];
484485 attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
485486 attrs[0 ].val .programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL () ? 1 : 0 ;
486- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
487487 attrs[1 ].id = cudaLaunchAttributeClusterDimension;
488488 attrs[1 ].val .clusterDim .x = 1 ;
489489 attrs[1 ].val .clusterDim .y = clusterSize;
490490 attrs[1 ].val .clusterDim .z = 1 ;
491- #endif
492491
493- cudaLaunchConfig_t config
494- {
495- .gridDim = grid, .blockDim = blockSize, .dynamicSmemBytes = 0 , .stream = params.stream , .attrs = attrs,
496- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
497- .numAttrs = 2 ,
498- #else
499- .numAttrs = 1 ,
500- #endif
492+ cudaLaunchConfig_t config{
493+ .gridDim = grid,
494+ .blockDim = blockSize,
495+ .dynamicSmemBytes = 0 ,
496+ .stream = params.stream ,
497+ .attrs = attrs,
498+ .numAttrs = kSMVersion >= 90 ? 2U : 1U ,
501499 };
502500
503501#define LAUNCH_ALLREDUCE_KERNEL (WORLD_SIZE, T, RMSNORM ) \
@@ -831,9 +829,9 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
831829 float blockSum = blockReduceSum<float , true >(threadSum);
832830
833831 float fullSum = blockSum;
834- __shared__ float sharedVal[8 ];
835832 // Use CGA Reduction if supported
836833#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
834+ __shared__ float sharedVal[8 ];
837835 int const numBlocks = cluster.num_blocks ();
838836 if (numBlocks > 1 )
839837 {
@@ -876,13 +874,19 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
876874 }
877875 constexpr int kELTS_SIZE = sizeof (T_IN);
878876
877+ // Issue ACQBLK at the end. Assuming preceding kernel will not modify the buffer_flags.
878+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
879+ cudaGridDependencySynchronize ();
880+ #endif
881+
879882 // Update the buffer pointers
880883 flag.waitAndUpdate ({static_cast <uint32_t >(divUp<uint32_t >(numTokens, worldSize) * worldSize * dim * kELTS_SIZE ),
881884 static_cast <uint32_t >(numTokens * dim * kELTS_SIZE ), 0 , 0 });
882885}
883886
884887void twoshotAllreduceFusionOp (AllReduceFusionParams const & params)
885888{
889+ static int const kSMVersion = tensorrt_llm::common::getSMVersion ();
886890 int const numTokens = params.numTokens ;
887891 int const tokenDim = params.tokenDim ;
888892 int const numEltsPerThread = sizeof (float4 ) / getDTypeSize (params.dType );
@@ -959,17 +963,13 @@ void twoshotAllreduceFusionOp(AllReduceFusionParams const& params)
959963 rnConfig.attrs = rnAttrs;
960964 rnAttrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
961965 rnAttrs[0 ].val .programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL () ? 1 : 0 ;
962- #ifndef DISABLE_CGA
963966 rnAttrs[1 ].id = cudaLaunchAttributeClusterDimension;
964967 rnAttrs[1 ].val .clusterDim .x = 1 ;
965968 rnAttrs[1 ].val .clusterDim .y = rnClusterSize;
966969 rnAttrs[1 ].val .clusterDim .z = 1 ;
967- rnConfig.numAttrs = 2 ;
968- #else
969- rnConfig.numAttrs = 1 ;
970- #endif
970+ rnConfig.numAttrs = (kSMVersion >= 90 ) ? 2U : 1U ;
971971
972- bool const rnUseCGA = rnClusterSize > 1 ;
972+ bool const rnUseCGA = kSMVersion >= 90 && rnClusterSize > 1 ;
973973 int const dimPadded = divUp (tokenDim, numEltsPerThread * rnNumThreads) * numEltsPerThread * rnNumThreads;
974974 int const iters = dimPadded / rnNumThreads;
975975
0 commit comments