Skip to content

Commit b4dab23

Browse files
authored
[TRTLLM-5965] perf: Optimize MoE sort kernels for large-scale EP (NVIDIA#5435)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
1 parent 94dc97a commit b4dab23

File tree

6 files changed

+629
-414
lines changed

6 files changed

+629
-414
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Lines changed: 31 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -87,32 +87,6 @@ struct LoraParams
8787

8888
namespace cutlass_kernels
8989
{
90-
static inline size_t pad_to_multiple_of_16(size_t const& input)
91-
{
92-
static constexpr int ALIGNMENT = 16;
93-
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
94-
}
95-
96-
class CubKeyValueSorter
97-
{
98-
public:
99-
CubKeyValueSorter();
100-
101-
CubKeyValueSorter(int const num_experts_per_node);
102-
103-
void updateNumExperts(int const num_experts_per_node);
104-
105-
static size_t getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts_per_node);
106-
107-
void run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out, int const* values_in,
108-
int* values_out, size_t const num_key_value_pairs, cudaStream_t stream);
109-
110-
private:
111-
static int expertsToBits(int experts);
112-
int num_experts_;
113-
int num_bits_;
114-
};
115-
11690
/**
11791
* \brief Describes what parallelism mode the MoE is using
11892
*
@@ -397,9 +371,9 @@ class CutlassMoeFCRunnerInterface
397371
ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
398372
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
399373
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
400-
int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config,
401-
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
402-
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
374+
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
375+
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
376+
MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
403377
= 0;
404378

405379
// Aliases for profiling the gemms
@@ -413,22 +387,22 @@ class CutlassMoeFCRunnerInterface
413387
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
414388
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
415389
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
416-
int* active_expert_global_ids, int start_expert)
390+
int* active_expert_global_ids)
417391
= 0;
418392

419393
virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output,
420394
int64_t const* const expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template,
421395
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
422396
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
423397
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
424-
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
425-
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
398+
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
399+
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
426400
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
427401
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
428402
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
429403
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
430404
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
431-
int* num_active_experts_per, int* active_expert_global_ids, int start_expert)
405+
int* num_active_experts_per, int* active_expert_global_ids)
432406
= 0;
433407

434408
virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@@ -544,9 +518,9 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
544518
ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
545519
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
546520
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
547-
int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config,
548-
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
549-
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
521+
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
522+
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
523+
MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
550524

551525
// We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work
552526
static void gemm1(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
@@ -565,7 +539,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
565539
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
566540
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
567541
bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
568-
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert);
542+
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids);
569543

570544
static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
571545
DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output,
@@ -574,14 +548,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
574548
ScaleBiasType const* const fc2_expert_biases, ScaleBiasType const* const fc2_int_scales,
575549
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
576550
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
577-
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
578-
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
551+
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
552+
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
579553
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
580554
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
581555
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
582556
cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
583557
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
584-
int* active_expert_global_ids, int start_expert);
558+
int* active_expert_global_ids);
585559

586560
// Overrides to allow us to forward on to the internal functions with the pointers using the correct type
587561
void gemm1(void const* const input, void* const output, void* const intermediate_result,
@@ -594,7 +568,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
594568
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
595569
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
596570
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
597-
int* active_expert_global_ids, int start_expert) override
571+
int* active_expert_global_ids) override
598572
{
599573
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
600574
return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input),
@@ -603,33 +577,33 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
603577
num_valid_tokens_ptr, static_cast<ScaleBiasType const*>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant,
604578
fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size,
605579
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, config,
606-
min_latency_mode, num_active_experts_per, active_expert_global_ids, start_expert);
580+
min_latency_mode, num_active_experts_per, active_expert_global_ids);
607581
}
608582

609583
void gemm2(void const* const input, void* const gemm_output, void* const final_output,
610584
int64_t const* const expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template,
611585
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
612586
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
613587
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
614-
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
615-
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
588+
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
589+
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
616590
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
617591
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
618592
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
619593
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
620594
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
621-
int* num_active_experts_per, int* active_expert_global_ids, int start_expert) override
595+
int* num_active_experts_per, int* active_expert_global_ids) override
622596
{
623597
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
624598
return Self::gemm2(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input), gemm_output,
625599
static_cast<OutputType*>(final_output), expert_first_token_offset, tma_ws_input_template,
626600
static_cast<WeightType const*>(fc2_expert_weights), static_cast<ScaleBiasType const*>(fc2_expert_biases),
627601
static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
628-
token_topk_unpermuted_scales, token_topk_permuted_scales, expanded_source_row_to_expanded_dest_row,
629-
expanded_dest_row_to_expanded_source_row, expert_for_source_row, num_valid_tokens_ptr, num_rows,
630-
expanded_num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array,
631-
use_lora, fc2_lora, stream, parallelism_config, enable_alltoall, config, min_latency_mode,
632-
num_active_experts_per, active_expert_global_ids, start_expert);
602+
token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row,
603+
permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows,
604+
hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora,
605+
stream, parallelism_config, enable_alltoall, config, min_latency_mode, num_active_experts_per,
606+
active_expert_global_ids);
633607
}
634608

635609
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
@@ -763,30 +737,29 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
763737
static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output,
764738
OutputType* const final_output, int64_t const* const expert_first_token_offset,
765739
WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases,
766-
float const* const token_topk_unpermuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
767-
int const* const expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
740+
float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row,
741+
int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts,
768742
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
769-
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k,
743+
int64_t const hidden_size, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k,
770744
MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params,
771745
cudaStream_t stream);
772746

773747
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
774748
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
775749
cudaStream_t stream);
776750

777-
CubKeyValueSorter sorter_;
778751
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
779752
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;
780753

781754
std::optional<cutlass_extensions::CutlassGemmConfig> gemm1_config_;
782755
std::optional<cutlass_extensions::CutlassGemmConfig> gemm2_config_;
783756

784757
// Pointers
785-
int* unpermuted_token_selected_experts_{};
786-
int* unpermuted_source_token_ids_{};
787-
int* permuted_source_token_ids_{};
758+
int* permuted_row_to_unpermuted_row_{};
788759
int* permuted_token_selected_experts_{};
789-
char* sorter_ws_{};
760+
int* blocked_expert_counts_{};
761+
int* blocked_expert_counts_cumsum_{};
762+
int* blocked_row_to_unpermuted_row_{};
790763
T* permuted_data_{};
791764
float* permuted_token_final_scales_{};
792765

@@ -859,7 +832,6 @@ struct GemmProfilerBackend
859832
mParallelismConfig = parallelism_config;
860833
mEnableAlltoall = enable_alltoall;
861834
mSM = common::getSMVersion();
862-
mSorter.updateNumExperts(mNumExpertsPerNode);
863835

864836
mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
865837
if (dtype == nvinfer1::DataType::kFP8
@@ -883,7 +855,6 @@ struct GemmProfilerBackend
883855
cudaStream_t const& stream);
884856

885857
CutlassMoeFCRunnerInterface* mInterface;
886-
CubKeyValueSorter mSorter;
887858

888859
GemmToProfile mGemmToProfile = GemmToProfile::Undefined;
889860
std::vector<Config> mAllTacticsSaved;

0 commit comments

Comments
 (0)