@@ -87,32 +87,6 @@ struct LoraParams
8787
8888namespace cutlass_kernels
8989{
90- static inline size_t pad_to_multiple_of_16 (size_t const & input)
91- {
92- static constexpr int ALIGNMENT = 16 ;
93- return ALIGNMENT * ((input + ALIGNMENT - 1 ) / ALIGNMENT);
94- }
95-
96- class CubKeyValueSorter
97- {
98- public:
99- CubKeyValueSorter ();
100-
101- CubKeyValueSorter (int const num_experts_per_node);
102-
103- void updateNumExperts (int const num_experts_per_node);
104-
105- static size_t getWorkspaceSize (size_t const num_key_value_pairs, int const num_experts_per_node);
106-
107- void run (void * workspace, size_t const workspace_size, int const * keys_in, int * keys_out, int const * values_in,
108- int * values_out, size_t const num_key_value_pairs, cudaStream_t stream);
109-
110- private:
111- static int expertsToBits (int experts);
112- int num_experts_;
113- int num_bits_;
114- };
115-
11690/* *
11791 * \brief Describes what parallelism mode the MoE is using
11892 *
@@ -397,9 +371,9 @@ class CutlassMoeFCRunnerInterface
397371 ActivationType fc1_activation_type, void const * fc2_expert_weights, void const * fc2_expert_biases,
398372 QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
399373 int const num_experts, int const experts_per_token, char * workspace_ptr, void * final_output,
400- int * expanded_source_row_to_expanded_dest_row , MOEParallelismConfig parallelism_config,
401- bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
402- bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
374+ int * unpermuted_row_to_permuted_row , MOEParallelismConfig parallelism_config, bool const enable_alltoall ,
375+ bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode ,
376+ MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
403377 = 0;
404378
405379 // Aliases for profiling the gemms
@@ -413,22 +387,22 @@ class CutlassMoeFCRunnerInterface
413387 int const num_experts_per_node, ActivationType fc1_activation_type, float const ** alpha_scale_ptr_array,
414388 bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
415389 cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int * num_active_experts_per,
416- int * active_expert_global_ids, int start_expert )
390+ int * active_expert_global_ids)
417391 = 0;
418392
419393 virtual void gemm2 (void const * const input, void * const gemm_output, void * const final_output,
420394 int64_t const * const expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template,
421395 void const * const fc2_expert_weights, void const * const fc2_expert_biases, void const * const fc2_int_scales,
422396 float const * const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const * fc2_fp4_act_flat,
423397 QuantParams quant_params, float const * const token_topk_unpermuted_scales,
424- float const * const token_topk_permuted_scales, int const * const expanded_source_row_to_expanded_dest_row ,
425- int const * expanded_dest_row_to_expanded_source_row , int const * const expert_for_source_row ,
398+ float const * const token_topk_permuted_scales, int const * const unpermuted_row_to_permuted_row ,
399+ int const * permuted_row_to_unpermuted_row , int const * const token_selected_experts ,
426400 int64_t const * const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
427401 int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
428402 int64_t const experts_per_token, float const ** alpha_scale_ptr_array, bool use_lora, void * fc2_lora,
429403 bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
430404 bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
431- int * num_active_experts_per, int * active_expert_global_ids, int start_expert )
405+ int * num_active_experts_per, int * active_expert_global_ids)
432406 = 0;
433407
434408 virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@@ -544,9 +518,9 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
544518 ActivationType fc1_activation_type, void const * fc2_expert_weights, void const * fc2_expert_biases,
545519 QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
546520 int const num_experts, int const experts_per_token, char * workspace_ptr, void * final_output,
547- int * expanded_source_row_to_expanded_dest_row , MOEParallelismConfig parallelism_config,
548- bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
549- bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override ;
521+ int * unpermuted_row_to_permuted_row , MOEParallelismConfig parallelism_config, bool const enable_alltoall ,
522+ bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode ,
523+ MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override ;
550524
551525 // We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work
552526 static void gemm1 (MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
@@ -565,7 +539,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
565539 int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
566540 int const num_experts_per_node, ActivationType fc1_activation_type, float const ** alpha_scale_ptr_array,
567541 bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
568- bool min_latency_mode, int * num_active_experts_per, int * active_expert_global_ids, int start_expert );
542+ bool min_latency_mode, int * num_active_experts_per, int * active_expert_global_ids);
569543
570544 static void gemm2 (MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
571545 DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const * const input, void * const gemm_output,
@@ -574,14 +548,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
574548 ScaleBiasType const * const fc2_expert_biases, ScaleBiasType const * const fc2_int_scales,
575549 float const * const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const * fc2_fp4_act_flat,
576550 QuantParams quant_params, float const * const token_topk_unpermuted_scales,
577- float const * const token_topk_permuted_scales, int const * const expanded_source_row_to_expanded_dest_row ,
578- int const * expanded_dest_row_to_expanded_source_row , int const * const expert_for_source_row ,
551+ float const * const token_topk_permuted_scales, int const * const unpermuted_row_to_permuted_row ,
552+ int const * permuted_row_to_unpermuted_row , int const * const token_selected_experts ,
579553 int64_t const * const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
580554 int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
581555 int64_t const experts_per_token, float const ** alpha_scale_ptr_array, bool use_lora, void * fc2_lora,
582556 cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
583557 cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int * num_active_experts_per,
584- int * active_expert_global_ids, int start_expert );
558+ int * active_expert_global_ids);
585559
586560 // Overrides to allow us to forward on to the internal functions with the pointers using the correct type
587561 void gemm1 (void const * const input, void * const output, void * const intermediate_result,
@@ -594,7 +568,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
594568 int const num_experts_per_node, ActivationType fc1_activation_type, float const ** alpha_scale_ptr_array,
595569 bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
596570 cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int * num_active_experts_per,
597- int * active_expert_global_ids, int start_expert ) override
571+ int * active_expert_global_ids) override
598572 {
599573 auto * block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner () : nullptr ;
600574 return Self::gemm1 (moe_gemm_runner_, block_scale_gemm_runner, static_cast <T const *>(input),
@@ -603,33 +577,33 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
603577 num_valid_tokens_ptr, static_cast <ScaleBiasType const *>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant,
604578 fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size,
605579 num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, config,
606- min_latency_mode, num_active_experts_per, active_expert_global_ids, start_expert );
580+ min_latency_mode, num_active_experts_per, active_expert_global_ids);
607581 }
608582
609583 void gemm2 (void const * const input, void * const gemm_output, void * const final_output,
610584 int64_t const * const expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template,
611585 void const * const fc2_expert_weights, void const * const fc2_expert_biases, void const * const fc2_int_scales,
612586 float const * const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const * fc2_fp4_act_flat,
613587 QuantParams quant_params, float const * const token_topk_unpermuted_scales,
614- float const * const token_topk_permuted_scales, int const * const expanded_source_row_to_expanded_dest_row ,
615- int const * expanded_dest_row_to_expanded_source_row , int const * const expert_for_source_row ,
588+ float const * const token_topk_permuted_scales, int const * const unpermuted_row_to_permuted_row ,
589+ int const * permuted_row_to_unpermuted_row , int const * const token_selected_experts ,
616590 int64_t const * const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
617591 int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
618592 int64_t const experts_per_token, float const ** alpha_scale_ptr_array, bool use_lora, void * fc2_lora,
619593 bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
620594 bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
621- int * num_active_experts_per, int * active_expert_global_ids, int start_expert ) override
595+ int * num_active_experts_per, int * active_expert_global_ids) override
622596 {
623597 auto * block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner () : nullptr ;
624598 return Self::gemm2 (moe_gemm_runner_, block_scale_gemm_runner, static_cast <T const *>(input), gemm_output,
625599 static_cast <OutputType*>(final_output), expert_first_token_offset, tma_ws_input_template,
626600 static_cast <WeightType const *>(fc2_expert_weights), static_cast <ScaleBiasType const *>(fc2_expert_biases),
627601 static_cast <ScaleBiasType const *>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
628- token_topk_unpermuted_scales, token_topk_permuted_scales, expanded_source_row_to_expanded_dest_row ,
629- expanded_dest_row_to_expanded_source_row, expert_for_source_row , num_valid_tokens_ptr, num_rows,
630- expanded_num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array,
631- use_lora, fc2_lora, stream, parallelism_config, enable_alltoall, config, min_latency_mode,
632- num_active_experts_per, active_expert_global_ids, start_expert );
602+ token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row ,
603+ permuted_row_to_unpermuted_row, token_selected_experts , num_valid_tokens_ptr, num_rows, expanded_num_rows ,
604+ hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora ,
605+ stream, parallelism_config, enable_alltoall, config, min_latency_mode, num_active_experts_per ,
606+ active_expert_global_ids);
633607 }
634608
635609 virtual size_t getGemmWorkspaceSize (int num_experts_per_node) const override
@@ -763,30 +737,29 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
763737 static void BlockScaleFC2 (DeepSeekBlockScaleGemmRunner& gemm_runner, T const * const input, void * const gemm_output,
764738 OutputType* const final_output, int64_t const * const expert_first_token_offset,
765739 WeightType const * const fc2_expert_weights, ScaleBiasType const * const fc2_expert_biases,
766- float const * const token_topk_unpermuted_scales, int const * const expanded_source_row_to_expanded_dest_row ,
767- int const * const expanded_dest_row_to_expanded_source_row , int const * const expert_for_source_row ,
740+ float const * const token_topk_unpermuted_scales, int const * const unpermuted_row_to_permuted_row ,
741+ int const * const permuted_row_to_unpermuted_row , int const * const token_selected_experts ,
768742 int64_t const * const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
769- int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k,
743+ int64_t const hidden_size, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k,
770744 MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params,
771745 cudaStream_t stream);
772746
773747 T const * applyPrequantScale (void * smoothed_act, void const * permuted_data, void const * prequant_scales,
774748 int64_t const * num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
775749 cudaStream_t stream);
776750
777- CubKeyValueSorter sorter_;
778751 MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
779752 std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;
780753
781754 std::optional<cutlass_extensions::CutlassGemmConfig> gemm1_config_;
782755 std::optional<cutlass_extensions::CutlassGemmConfig> gemm2_config_;
783756
784757 // Pointers
785- int * unpermuted_token_selected_experts_{};
786- int * unpermuted_source_token_ids_{};
787- int * permuted_source_token_ids_{};
758+ int * permuted_row_to_unpermuted_row_{};
788759 int * permuted_token_selected_experts_{};
789- char * sorter_ws_{};
760+ int * blocked_expert_counts_{};
761+ int * blocked_expert_counts_cumsum_{};
762+ int * blocked_row_to_unpermuted_row_{};
790763 T* permuted_data_{};
791764 float * permuted_token_final_scales_{};
792765
@@ -859,7 +832,6 @@ struct GemmProfilerBackend
859832 mParallelismConfig = parallelism_config;
860833 mEnableAlltoall = enable_alltoall;
861834 mSM = common::getSMVersion ();
862- mSorter .updateNumExperts (mNumExpertsPerNode );
863835
864836 mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
865837 if (dtype == nvinfer1::DataType::kFP8
@@ -883,7 +855,6 @@ struct GemmProfilerBackend
883855 cudaStream_t const & stream);
884856
885857 CutlassMoeFCRunnerInterface* mInterface ;
886- CubKeyValueSorter mSorter ;
887858
888859 GemmToProfile mGemmToProfile = GemmToProfile::Undefined;
889860 std::vector<Config> mAllTacticsSaved ;
0 commit comments