Skip to content

Commit 4b6b292

Browse files
committed
MaxActiveBlocks
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
1 parent 1c92af1 commit 4b6b292

File tree

4 files changed

+49
-24
lines changed

4 files changed

+49
-24
lines changed

cpp/include/tensorrt_llm/common/cudaUtils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <optional>
3939
#include <sstream>
4040
#include <string>
41+
#include <unordered_map>
4142
#ifndef _WIN32 // Linux
4243
#include <sys/sysinfo.h>
4344
#endif // not WIN32
@@ -432,6 +433,24 @@ inline int getMaxSharedMemoryPerBlockOptin()
432433
return nByteMaxSharedMemoryPerBlockOptin;
433434
}
434435

436+
template <typename T>
437+
inline int getMaxActiveBlocksPerSM(T kernel, int blockSize, size_t dynamicSMemSize)
438+
{
439+
static std::unordered_map<T, int> cache;
440+
auto it = cache.find(kernel);
441+
if (it == cache.end())
442+
{
443+
int numBlocks;
444+
check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, blockSize, dynamicSMemSize));
445+
cache[kernel] = numBlocks;
446+
return numBlocks;
447+
}
448+
else
449+
{
450+
return it->second;
451+
}
452+
}
453+
435454
template <typename T1, typename T2>
436455
inline size_t divUp(T1 const& a, T2 const& b)
437456
{

cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,12 @@ void moePermute(InputType const* input, InputType* permuted_output, SFType const
141141
}
142142
#endif
143143

144+
auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;
144145
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
145-
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
146+
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
147+
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
146148
int32_t const threads = kThreadsPerBlock;
147149

148-
auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;
149-
150150
cudaLaunchConfig_t config;
151151
config.gridDim = blocks;
152152
config.blockDim = threads;
@@ -382,10 +382,6 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
382382
}
383383
#endif
384384

385-
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
386-
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
387-
int32_t const threads = kThreadsPerBlock;
388-
389385
auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,
390386
float const* global_sf, SFType* output_sf,
391387
int32_t const* tile_idx_to_mn_limit,
@@ -424,6 +420,11 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
424420
};
425421
auto kernel = get_act_kernel(activation_params.activation_type);
426422

423+
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
424+
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
425+
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
426+
int32_t const threads = kThreadsPerBlock;
427+
427428
cudaLaunchConfig_t config;
428429
config.gridDim = blocks;
429430
config.blockDim = threads;

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,11 +1587,6 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
15871587
int64_t num_padding_tokens = 0;
15881588
#endif
15891589

1590-
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
1591-
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1592-
int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens));
1593-
int64_t const threads = EXPAND_THREADS_PER_BLOCK;
1594-
15951590
auto func = [&]()
15961591
{
15971592
#ifdef ENABLE_FP8
@@ -1637,6 +1632,12 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
16371632
}
16381633
}();
16391634

1635+
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
1636+
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(func, EXPAND_THREADS_PER_BLOCK, 0);
1637+
int32_t const blocks
1638+
= std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(std::max(num_rows * k, num_padding_tokens)));
1639+
int32_t const threads = EXPAND_THREADS_PER_BLOCK;
1640+
16401641
cudaLaunchConfig_t config;
16411642
config.gridDim = blocks;
16421643
config.blockDim = threads;
@@ -1891,15 +1892,18 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro
18911892
if (parallelism_config.ep_size > 1 && enable_alltoall)
18921893
{
18931894
// If all-to-all comm is enabled, finalizeMoeRouting doesn't need to fill the invalid output tokens with zeros.
1894-
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
1895-
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1896-
int64_t const blocks = smCount * 8;
1897-
int64_t const threads = FINALIZE_THREADS_PER_BLOCK;
1898-
config.gridDim = blocks;
1899-
config.blockDim = threads;
19001895
auto func = final_scales
19011896
? &finalizeMoeRoutingNoFillingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::DEFAULT>
19021897
: &finalizeMoeRoutingNoFillingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::NO_SCALE>;
1898+
1899+
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
1900+
int32_t const maxBlocksPerSM
1901+
= tensorrt_llm::common::getMaxActiveBlocksPerSM(func, FINALIZE_THREADS_PER_BLOCK, 0);
1902+
int32_t const blocks = std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(num_rows * experts_per_token));
1903+
int32_t const threads = FINALIZE_THREADS_PER_BLOCK;
1904+
1905+
config.gridDim = blocks;
1906+
config.blockDim = threads;
19031907
cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, final_scales,
19041908
unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts,
19051909
expert_first_token_offset, num_rows, padded_cols, unpadded_cols, experts_per_token, num_experts_per_node,
@@ -2235,11 +2239,6 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
22352239
int64_t num_padding_tokens = 0;
22362240
#endif
22372241

2238-
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
2239-
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
2240-
int64_t const blocks = std::min(smCount * 8, std::max(expanded_num_tokens, num_padding_tokens));
2241-
int64_t const threads = ACTIVATION_THREADS_PER_BLOCK;
2242-
22432242
auto fn = [&]()
22442243
{
22452244
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
@@ -2302,6 +2301,12 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
23022301
}
23032302
}();
23042303

2304+
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
2305+
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(fn, ACTIVATION_THREADS_PER_BLOCK, 0);
2306+
int32_t const blocks
2307+
= std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(std::max(expanded_num_tokens, num_padding_tokens)));
2308+
int32_t const threads = ACTIVATION_THREADS_PER_BLOCK;
2309+
23052310
cudaLaunchConfig_t config;
23062311
config.gridDim = blocks;
23072312
config.blockDim = threads;

tensorrt_llm/_torch/autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ def choose_one(
727727
# Log the cache miss. Expect no cache miss in inference.
728728
if not is_cache_hit:
729729
logger.warning_once(
730-
f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}",
730+
f"[AutoTunner] {custom_op} using the fallback tactic, due to cache miss on input shapes={input_shapes}",
731731
key=(custom_op, "warning_autotuning_cache_miss_fallback"))
732732

733733
return (best_runner, best_tactic)

0 commit comments

Comments
 (0)