@@ -1587,11 +1587,6 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
15871587 int64_t num_padding_tokens = 0 ;
15881588#endif
15891589
1590- static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
1591- // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1592- int64_t const blocks = std::min (smCount * 8 , std::max (num_rows * k, num_padding_tokens));
1593- int64_t const threads = EXPAND_THREADS_PER_BLOCK;
1594-
15951590 auto func = [&]()
15961591 {
15971592#ifdef ENABLE_FP8
@@ -1637,6 +1632,12 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
16371632 }
16381633 }();
16391634
1635+ static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
1636+ int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM (func, EXPAND_THREADS_PER_BLOCK, 0 );
1637+ int32_t const blocks
1638+ = std::min (smCount * maxBlocksPerSM, static_cast <int32_t >(std::max (num_rows * k, num_padding_tokens)));
1639+ int32_t const threads = EXPAND_THREADS_PER_BLOCK;
1640+
16401641 cudaLaunchConfig_t config;
16411642 config.gridDim = blocks;
16421643 config.blockDim = threads;
@@ -1891,15 +1892,18 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro
18911892 if (parallelism_config.ep_size > 1 && enable_alltoall)
18921893 {
18931894 // If all-to-all comm is enabled, finalizeMoeRouting doesn't need to fill the invalid output tokens with zeros.
1894- static int const smCount = tensorrt_llm::common::getMultiProcessorCount ();
1895- // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1896- int64_t const blocks = smCount * 8 ;
1897- int64_t const threads = FINALIZE_THREADS_PER_BLOCK;
1898- config.gridDim = blocks;
1899- config.blockDim = threads;
19001895 auto func = final_scales
19011896 ? &finalizeMoeRoutingNoFillingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::DEFAULT>
19021897 : &finalizeMoeRoutingNoFillingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::NO_SCALE>;
1898+
1899+ static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
1900+ int32_t const maxBlocksPerSM
1901+ = tensorrt_llm::common::getMaxActiveBlocksPerSM (func, FINALIZE_THREADS_PER_BLOCK, 0 );
1902+ int32_t const blocks = std::min (smCount * maxBlocksPerSM, static_cast <int32_t >(num_rows * experts_per_token));
1903+ int32_t const threads = FINALIZE_THREADS_PER_BLOCK;
1904+
1905+ config.gridDim = blocks;
1906+ config.blockDim = threads;
19031907 cudaLaunchKernelEx (&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, final_scales,
19041908 unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts,
19051909 expert_first_token_offset, num_rows, padded_cols, unpadded_cols, experts_per_token, num_experts_per_node,
@@ -2235,11 +2239,6 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
22352239 int64_t num_padding_tokens = 0 ;
22362240#endif
22372241
2238- static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
2239- // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
2240- int64_t const blocks = std::min (smCount * 8 , std::max (expanded_num_tokens, num_padding_tokens));
2241- int64_t const threads = ACTIVATION_THREADS_PER_BLOCK;
2242-
22432242 auto fn = [&]()
22442243 {
22452244 // IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
@@ -2302,6 +2301,12 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
23022301 }
23032302 }();
23042303
2304+ static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
2305+ int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM (fn, ACTIVATION_THREADS_PER_BLOCK, 0 );
2306+ int32_t const blocks
2307+ = std::min (smCount * maxBlocksPerSM, static_cast <int32_t >(std::max (expanded_num_tokens, num_padding_tokens)));
2308+ int32_t const threads = ACTIVATION_THREADS_PER_BLOCK;
2309+
23052310 cudaLaunchConfig_t config;
23062311 config.gridDim = blocks;
23072312 config.blockDim = threads;
0 commit comments