diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 64fece5021..fefa64452c 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -119,6 +119,7 @@ class FusedMoeLauncher { int64_t tile_tokens_dim{}; int64_t routing_method_type{}; + bool norm_topk_prob{true}; bool use_shuffled_weight{}; batchedGemm::gemm::MatrixLayout weight_layout{batchedGemm::gemm::MatrixLayout::MajorK}; @@ -130,6 +131,8 @@ class FusedMoeLauncher { btg::Dtype mDtypeWeights{btg::Dtype::Bfloat16}; btg::Dtype mRoutingBiasDtype{ btg::Dtype::Bfloat16}; // Dtype for expert weights in routing, based on routing bias + btg::Dtype mRoutingLogitsDtype{ + btg::Dtype::Bfloat16}; // Dtype for routing logits (Bfloat16 or Fp32) ActivationType activation_type{ActivationType::Swiglu}; int64_t intermediate_size_factor{2}; @@ -165,7 +168,8 @@ class FusedMoeLauncher { // May throw exception from TVM_FFI_ICHECK. void init_common(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, ActivationType activation_type); + int64_t weight_layout, ActivationType activation_type, + bool norm_topk_prob = true); // Routing logits [num_tokens, num_experts] void check_routing_logits_shape() const { @@ -389,7 +393,8 @@ class FusedMoeLauncher { static_cast(cta_idx_xy_to_mn_limit.data_ptr()), static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, - static_cast(routing_method_type), routing_stream); + static_cast(routing_method_type), routing_stream, + mRoutingLogitsDtype, norm_topk_prob); check_moe(); prepare_moe(moe_tactic); @@ -408,7 +413,7 @@ class FusedMoeLauncher { void FusedMoeLauncher::init_common( std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, ActivationType activation_type) { + int64_t weight_layout, ActivationType activation_type, bool norm_topk_prob) { // Check devicearchitecture: Blackwell (SM 10.x) required auto device = hidden_states.device().device_id; int major = 0, minor = 0; @@ -427,6 +432,7 @@ void FusedMoeLauncher::init_common( this->args = std::move(args); this->tile_tokens_dim = tile_tokens_dim; this->routing_method_type = routing_method_type; + this->norm_topk_prob = norm_topk_prob; this->use_shuffled_weight = use_shuffled_weight; TVM_FFI_ICHECK(0 <= weight_layout && weight_layout <= 2) << "the value of weight_layout is not recognized"; @@ -451,13 +457,14 @@ class Bf16MoeLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout) { + int64_t weight_layout, bool norm_topk_prob = true) { constexpr ActivationType activation_type = ActivationType::Swiglu; // not exposed in api for now // Do base class init and perform common checks FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, activation_type); + use_shuffled_weight, weight_layout, activation_type, + norm_topk_prob); } void check_routing() const override { @@ -486,6 +493,10 @@ class Bf16MoeLauncher : public FusedMoeLauncher { auto const routing_bias_dtype = routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + // Set routing logits dtype + auto const routing_logits_dtype = + routing_logits.has_value() ? routing_logits.value().dtype() : dl_bfloat16; + mRoutingLogitsDtype = routing_logits_dtype == dl_float32 ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16; // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr bool has_precomputed_indices = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; @@ -596,7 +607,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool use_routing_scales_on_input_param, - ActivationType activation_type) { + ActivationType activation_type, bool norm_topk_prob = true) { this->use_routing_scales_on_input = use_routing_scales_on_input_param; auto dtype = hidden_states.dtype(); @@ -612,7 +623,8 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { mDtypeWeights = btg::Dtype::E4m3; FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, activation_type); + use_shuffled_weight, weight_layout, activation_type, + norm_topk_prob); } void check_routing() const override { FusedMoeLauncher::check_routing_common(); } @@ -637,6 +649,10 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { auto const routing_bias_dtype = routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + { + auto const rl_dtype = routing_logits.has_value() ? routing_logits.value().dtype() : dl_bfloat16; + mRoutingLogitsDtype = rl_dtype == dl_float32 ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16; + } expert_weights = alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); @@ -799,7 +815,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout) { + int64_t weight_layout, bool norm_topk_prob = true) { constexpr ActivationType activation_type = ActivationType::Swiglu; if (quantization_type == Fp8QuantizationType::MxFp8) { @@ -825,7 +841,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { args->mDtypeOut = btg::Dtype::Bfloat16; FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, activation_type); + use_shuffled_weight, weight_layout, activation_type, + norm_topk_prob); } void check_routing() const override { @@ -864,11 +881,15 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { TVM_FFI_ICHECK_LT(args->top_k, (args->topk_group * args->num_experts / args->n_group)) << "top_k must be less than total number of experts in selected groups"; } else if (static_cast(routing_method_type) == + RoutingMethodType::Default || + static_cast(routing_method_type) == RoutingMethodType::Renormalize || static_cast(routing_method_type) == - RoutingMethodType::RenormalizeNaive) { + RoutingMethodType::RenormalizeNaive || + static_cast(routing_method_type) == + RoutingMethodType::SigmoidRenorm) { TVM_FFI_ICHECK(args->top_k <= 10 && args->top_k > 0) - << "Current routing kernel (no groups, renormalize) only supports top_k<=10 && top_k>0."; + << "Current routing kernel (no groups) only supports top_k<=10 && top_k>0."; } else if (static_cast(routing_method_type) == RoutingMethodType::Llama4) { TVM_FFI_ICHECK_EQ(args->top_k, 1) << "Current routing kernel (no groups, Llama4) only supports top_k=1."; @@ -904,12 +925,15 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { static_cast(const_cast(expert_indices.data_ptr())); } else { // Use routing_logits directly - args->routing_logits = static_cast(routing_logits.value().data_ptr()); + args->routing_logits = routing_logits.value().data_ptr(); } // Set expert weights dtype based on routing bias auto const routing_bias_dtype = routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + mRoutingLogitsDtype = routing_logits.has_value() + ? (routing_logits.value().dtype() == dl_float32 ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16) + : btg::Dtype::Bfloat16; // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr bool has_precomputed_weights = expert_weights.ndim() == 2 && expert_weights.size(0) > 0; if (!has_precomputed_weights) { @@ -1094,7 +1118,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { static_cast(cta_idx_xy_to_mn_limit.data_ptr()), static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, - static_cast(routing_method_type), routing_stream); + static_cast(routing_method_type), routing_stream, + mRoutingLogitsDtype, norm_topk_prob); check_moe(); prepare_moe(moe_tactic); @@ -1157,7 +1182,7 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { gemm2_weights_scale(gemm2_weights_scale) {} void init(std::unique_ptr&& args, - int64_t tile_tokens_dim, int64_t routing_method_type) { + int64_t tile_tokens_dim, int64_t routing_method_type, bool norm_topk_prob = true) { // currently only support mxint4 x bf16 auto dtype = hidden_states.dtype(); if (dtype == dl_bfloat16) { @@ -1173,7 +1198,8 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { FusedMoeLauncher::init_common( std::move(args), tile_tokens_dim, routing_method_type, /*use_shuffled_weight=*/true, - static_cast(batchedGemm::gemm::MatrixLayout::BlockMajorK), ActivationType::Swiglu); + static_cast(batchedGemm::gemm::MatrixLayout::BlockMajorK), ActivationType::Swiglu, + norm_topk_prob); } void check_routing() const override { FusedMoeLauncher::check_routing_common(); } @@ -1187,6 +1213,9 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { auto const routing_bias_dtype = routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + mRoutingLogitsDtype = routing_logits.has_value() + ? (routing_logits.value().dtype() == dl_float32 ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16) + : btg::Dtype::Bfloat16; expert_weights = alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); @@ -1332,7 +1361,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, ActivationType activation_type, btg::Dtype dtype_act, - btg::Dtype dtype_weights) { + btg::Dtype dtype_weights, bool norm_topk_prob = true) { static const std::tuple device_props = [this] { int major, minor; cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, @@ -1355,7 +1384,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { mDtypeWeights = dtype_weights; FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, activation_type); + use_shuffled_weight, weight_layout, activation_type, + norm_topk_prob); } void check_routing() const override { @@ -1402,6 +1432,9 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { args->mDtypeElt = mDtypeAct; auto routing_bias_dtype = routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + mRoutingLogitsDtype = routing_logits.has_value() + ? (routing_logits.value().dtype() == dl_float32 ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16) + : btg::Dtype::Bfloat16; } void check_moe() const override { @@ -1557,7 +1590,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { static_cast(cta_idx_xy_to_mn_limit.data_ptr()), static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, - static_cast(routing_method_type), routing_stream); + static_cast(routing_method_type), routing_stream, + mRoutingLogitsDtype, norm_topk_prob); check_moe(); prepare_moe(moe_tactic); @@ -1612,7 +1646,8 @@ Array trtllm_bf16_moe(Optional const& routing_logits, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, - bool enable_pdl, Array moe_tactic) { + bool enable_pdl, Array moe_tactic, + bool norm_topk_prob) { // Just some basic type validation first and leave more checks to the launcher if (routing_logits.has_value()) { TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || @@ -1661,7 +1696,7 @@ Array trtllm_bf16_moe(Optional const& routing_logits, expert_weights, hidden_states, gemm1_weights, gemm2_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, - weight_layout); + weight_layout, norm_topk_prob); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1690,17 +1725,13 @@ Array trtllm_fp8_per_tensor_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, bool use_routing_scales_on_input, int64_t routing_method_type, bool do_finalize, - bool enable_pdl, Array config_index, int64_t activation_type) { + bool enable_pdl, Array config_index, int64_t activation_type, + bool norm_topk_prob) { // Basic type validation auto dtype = hidden_states.dtype(); auto activation = static_cast(activation_type); - if (use_routing_scales_on_input) { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; - } else if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float."; - } else { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; - } + TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16) + << "FP8 per-tensor MoE: routing_logits must be float or bfloat16."; TVM_FFI_ICHECK(dtype == dl_float8_e4m3fn || dtype == dl_float16 || dtype == dl_bfloat16) << "FP8 MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16."; TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) @@ -1753,7 +1784,7 @@ Array trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, - weight_layout, use_routing_scales_on_input, activation); + weight_layout, use_routing_scales_on_input, activation, norm_topk_prob); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1782,7 +1813,8 @@ Array trtllm_fp8_block_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, - bool enable_pdl, Array config_index, Fp8QuantizationType quantization_type) { + bool enable_pdl, Array config_index, Fp8QuantizationType quantization_type, + bool norm_topk_prob) { // Basic type validation auto dtype = hidden_states.dtype(); @@ -1796,13 +1828,20 @@ Array trtllm_fp8_block_scale_moe( << "Either routing_logits or expert_indices must be provided."; if (use_routing_logits) { - if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { - TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) - << "routing_logits must be float."; - } else { - TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_bfloat16) - << "routing_logits must be bfloat16."; - } + TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || + routing_logits.value().dtype() == dl_bfloat16) + << "FP8 block scale MoE: routing_logits must be float or bfloat16."; + TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D."; + TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts) + << "routing_logits has incorrect shape."; + } + if (routing_bias.has_value()) { + TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 || + routing_bias.value().dtype() == dl_float32) + << "FP8 block scale MoE: routing_bias must be bfloat16 or float."; + TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D."; + TVM_FFI_ICHECK_EQ(routing_bias.value().size(0), num_experts) + << "routing_bias has incorrect shape."; } TVM_FFI_ICHECK(dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) << "FP8 block scale MoE: hidden_states must be fp16, bf16, or fp8."; @@ -1874,7 +1913,7 @@ Array trtllm_fp8_block_scale_moe( gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, expert_indices, expert_weights, quantization_type); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, - weight_layout); + weight_layout, norm_topk_prob); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1910,7 +1949,7 @@ Array trtllm_fp4_block_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t act_type, - TensorView output, Array config_index) { + TensorView output, Array config_index, bool norm_topk_prob) { // Determine data types based on input format int const num_tokens = hidden_states.size(0); int hidden_size = hidden_states.size(1); @@ -1937,10 +1976,6 @@ Array trtllm_fp4_block_scale_moe( TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D."; TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts) << "routing_logits has incorrect shape."; - if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { - TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) - << "routing_logits must be float."; - } } if (routing_bias.has_value()) { TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 || @@ -2018,7 +2053,7 @@ Array trtllm_fp4_block_scale_moe( output2_scales_scalar, expert_indices, expert_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true, /*weight_layout=*/0, static_cast(act_type), mDtypeAct, - mDtypeWeights); + mDtypeWeights, norm_topk_prob); launchers_map[curr_tile_N] = std::move(launcher); } @@ -2048,7 +2083,7 @@ Array trtllm_mxint4_block_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool do_finalize, bool enable_pdl, TensorView output, - Array config_index) { + Array config_index, bool norm_topk_prob) { // Determine data types based on input format int const num_tokens = hidden_states.size(0); int hidden_size = hidden_states.size(1); @@ -2064,7 +2099,9 @@ Array trtllm_mxint4_block_scale_moe( TVM_FFI_ICHECK_EQ(routing_logits.ndim(), 2) << "routing_logits must be 2D."; TVM_FFI_ICHECK_EQ(routing_logits.size(1), num_experts) << "routing_logits has incorrect shape."; if (routing_bias.has_value()) { - TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16) << "routing_bias must be bfloat16."; + TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 || + routing_bias.value().dtype() == dl_float32) + << "MxInt4 block scale MoE: routing_bias must be bfloat16 or float."; TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D."; TVM_FFI_ICHECK_EQ(routing_bias.value().size(0), num_experts) << "routing_bias has incorrect shape."; @@ -2107,7 +2144,7 @@ Array trtllm_mxint4_block_scale_moe( auto launcher = std::make_unique( routing_logits, routing_bias, hidden_states, gemm1_weights, gemm1_weights_scale, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale); - launcher->init(std::move(args), curr_tile_N, routing_method_type); + launcher->init(std::move(args), curr_tile_N, routing_method_type, norm_topk_prob); launchers_map[curr_tile_N] = std::move(launcher); } diff --git a/csrc/trtllm_fused_moe_routing_common.cu b/csrc/trtllm_fused_moe_routing_common.cu new file mode 100644 index 0000000000..14baa556c7 --- /dev/null +++ b/csrc/trtllm_fused_moe_routing_common.cu @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh" +#include "flashinfer/trtllm/fused_moe/RoutingKernel.h" + +namespace moe::dev::routing { +namespace routingCustom { +// Forward declarations of launch functions +void launchBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream); +void launchClusterKernel(Data const& data, void* stream); +void launchCoopKernel(Data const& data, int numBlocksCoop, uint32_t numThreadsHist, void* stream); +void launchInitExpertCounts(Data const& data, uint32_t numThreadsHist, void* stream); +void launchHistogramKernel(Data const& data, int numBlocksHistogram, uint32_t numThreadsHist, + void* stream); +void launchOffsetsKernel(Data const& data, int numBlocksOffsets, uint32_t numThreadsHist, + void* stream); +} // namespace routingCustom + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Implementation of shared post-topK pipeline for all routing methods. +// When topK is already computed (mPtrTopKIds or mPtrTopKPacked), we don't need +// routing-method-specific logic, so all methods can use the same workflow. +// This function handles all path selection: single-block, single-cluster, coop, multi-kernel. +template +void runPostTopKPipeline(DataType const& data, uint32_t /*numThreadsHist*/, void* stream) { + // Convert to routingCustom::Data for launching (kernels are shared) + routingCustom::Data customData; + // Copy base fields + static_cast(customData) = static_cast(data); + // Set routingCustom-specific defaults (not needed for utility kernels) + customData.mDtypeOutput = data.mDtypeOutput; + // The post-TopK kernels don't read routing logits (mPtrInput), only mPtrTopKPacked. + // Set mDtypeInput = mDtypeOutput so the dispatched template is , + // avoiding an unnecessary mixed-type instantiation. + customData.mDtypeInput = data.mDtypeOutput; + customData.mPreprocessType = RoutingPreprocessType::None; + customData.mPostprocessType = RoutingPostprocessType::Softmax; + + // Recompute numThreadsHist using routingCustom's expert tiers, since we launch custom kernels. + // Different routing methods (DeepSeek, Llama4) may have different expert tier thresholds + // that don't match routingCustom's tiers (128, 512, 2048). + uint32_t const numThreadsHist = + std::min(1024u, static_cast(routingCustom::getMaxNumExperts(data.mNumExperts))); + + // Determine which path to use based on token count + bool const useSingleBlock = data.mNumTokens <= routingCustom::BlockKernelMaxNumTokens; + bool const useSingleCluster = data.mNumTokens <= routingCustom::MaxNumTokensSingleClusterScores; + + // PDL overlap control: the LAST routing kernel must disable overlap so the consumer + // GEMM (which may lack cudaGridDependencySynchronize) can't start early. + // Use a separate copy for the last kernel to avoid mutating customData. + routingCustom::Data lastKernelData = customData; + lastKernelData.mPdlOverlapWithNext = false; + + if (useSingleBlock) { + // Single-block path: fuses all steps (histogram, offsets, permutation) + routingCustom::launchBlockKernel(lastKernelData, numThreadsHist, stream); + } else if (useSingleCluster) { + // Single-cluster path: uses distributed shared memory + routingCustom::launchClusterKernel(lastKernelData, stream); + } else { + // Check if we can use the coop path (more efficient for medium token counts) + // Coop kernel requires SM90+ (grid-sync) and MaxNumExperts <= 1024. + static int const smMajor = tensorrt_llm::common::getSMVersion() / 10; + bool const canUseCoop = + (smMajor >= 9) && (data.mNumExperts <= 1024) && (data.mPtrPermutedIdxSize != nullptr); + bool useCoop = false; + int numBlocksCoop = 0; + + if (canUseCoop) { + // Number of blocks we can use in the cooperative kernel + static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + // WAR: Reserve 8 SMs for overlapping kernels. + numBlocksCoop = smCount - 8; + // Maximum number of tokens supported by the kernel using a cooperative launch. + // The number of blocks must be: + // >= ⌈(numTokens * topK) / (MaxExpandedIdxPerThread * NumThreads)⌉ + // MaxExpandedIdxPerThread = 64 (from coop kernel) + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + useCoop = (data.mNumTokens <= maxTokensCoop); + } + + if (useCoop) { + // Coop path: cooperative launch fuses histogram + offsets (more efficient). + // The coop kernel atomicAdds to mPtrExpertCounts, so we must zero it first. + routingCustom::launchInitExpertCounts(customData, numThreadsHist, stream); + routingCustom::launchCoopKernel(lastKernelData, numBlocksCoop, numThreadsHist, stream); + } else { + // Large-token path: multi-kernel pipeline + FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr, + "When #tokens is large, `mPtrExpertCounts` is a required input."); + + // Step 1: Reset expert counts + routingCustom::launchInitExpertCounts(customData, numThreadsHist, stream); + + // Step 2-3: Histogram + Offsets + int32_t const expandedIdxSize = data.mNumTokens * data.mTopK; + int32_t const histogramEltsPerBlock = 8 * numThreadsHist; + int32_t const offsetEltsPerBlock = + routing::NumEltsPerOffsetTilePerThread * numThreadsHist; + int32_t const maxNumBlocks = 1024; + + int const numBlocksHistogram = std::min( + (expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets = std::min( + (expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + routingCustom::launchHistogramKernel(customData, numBlocksHistogram, numThreadsHist, stream); + routingCustom::launchOffsetsKernel(lastKernelData, numBlocksOffsets, numThreadsHist, stream); + } + } +} + +// Explicit instantiations for the three routing method Data types +template void runPostTopKPipeline(routingCustom::Data const&, uint32_t, + void*); +template void runPostTopKPipeline(routingDeepSeek::Data const&, uint32_t, + void*); +template void runPostTopKPipeline(routingLlama4::Data const&, uint32_t, + void*); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace moe::dev::routing diff --git a/csrc/trtllm_fused_moe_routing_custom.cu b/csrc/trtllm_fused_moe_routing_custom.cu new file mode 100644 index 0000000000..42020b7d12 --- /dev/null +++ b/csrc/trtllm_fused_moe_routing_custom.cu @@ -0,0 +1,647 @@ +/* + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Custom routing: entry point, kernel definitions, and launch wrappers. +// +// Kernel inventory: +// 1. routingIndicesBlockKernel — single-block fused kernel (≤4 tokens) +// 2. routingIndicesClusterKernel — single-cluster fused kernel (≤256 tokens, SM90+) +// 3. routingIndicesHistogramScoresKernel — TopK + histogram from raw scores +// 4. routingIndicesCoopKernel — cooperative histogram + offsets (defined in RoutingKernel.cuh) +// 5. routingInitExpertCounts — zero expert counts (defined in RoutingKernel.cuh) +// 6. routingIndicesHistogramKernel — histogram from packed TopK (defined in RoutingKernel.cuh) +// 7. routingIndicesOffsetsKernel — prefix-scan + permutation (defined in RoutingKernel.cuh) + +#include "flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh" +#include "tvm_ffi_utils.h" + +namespace moe::dev::routing { +namespace routingCustom { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 1. Block kernel — single-block fused kernel for ≤4 tokens. +// Fuses TopK, histogram, prefix-scan, and permutation in one block. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void + __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024) + routingIndicesBlockKernel(KernelParams params) { + // types used in this kernel + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using BaseType = typename KernelParams::ExpertSelectPolicy::template BaseType; + using TypePacked = PackedScoreIdx; + static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; + // When MaxNumExperts > 1024, cap actual thread count at 1024 and let each thread handle + // multiple experts. This is needed because CUDA blocks support at most 1024 threads. + static constexpr int NumThreadsBlock = MaxNumExperts <= 1024 ? MaxNumExperts : 1024; + static constexpr int ExpertsPerThread = MaxNumExperts / NumThreadsBlock; + static_assert(MaxNumExperts % NumThreadsBlock == 0, + "MaxNumExperts must be a multiple of NumThreadsBlock"); + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const laneIdx = cutlass::arch::LaneId(); + auto scoreOffset = warpIdx * params.mNumExperts; + bool validToken = warpIdx < params.mNumTokens; + + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + static constexpr int totalExpertCounts = BlockKernelMaxNumTokens * MaxNumExperts; + __shared__ int8_t __attribute((aligned(128))) smemOffset[totalExpertCounts]; + __shared__ int8_t __attribute((aligned(128))) smemKIdx[totalExpertCounts]; + + using Scan = cub::BlockScan; + __shared__ typename Scan::TempStorage tempStorage; + + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + for (int i = threadIdx.x; i < totalExpertCounts; i += blockDim.x) { + smemOffset[i] = int8_t{-1}; + smemKIdx[i] = int8_t{-1}; + } + __syncthreads(); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // then wait on primary grid + if (params.mUsePdl) { + cudaGridDependencySynchronize(); + } +#endif + + if (params.mPtrTopKIds != nullptr) { + if (validToken) { + if (laneIdx < params.mTopK) { + auto expertIdx = params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx]; + if (expertIdx > -1 && expertIdx < params.mNumExperts) { + int offset = warpIdx * MaxNumExperts + expertIdx; + smemKIdx[offset] = static_cast(laneIdx); + } else { + params.mPtrExpandedIdxToPermutedIdx[warpIdx * params.mTopK + laneIdx] = int32_t{-1}; + } + } + } + } else if (params.mPtrScores != nullptr) { + // in this case, each warp represents a token + BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; + int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; + + if (validToken) { + KernelParams::ExpertSelectPolicy::template apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK, + params.mPtrScores + scoreOffset, params); + + if (laneIdx < params.mTopK) { + int offset = warpIdx * MaxNumExperts + warpTopKExpertIdx[laneIdx]; + smemKIdx[offset] = static_cast(laneIdx); + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = + OutputT{warpTopKScore[laneIdx]}; + } + } + } // end if (validToken) + } else if (params.mPtrTopKPacked != nullptr) { + if (validToken) { + if (laneIdx < params.mTopK) { + auto const expandedIdx = warpIdx * params.mTopK + laneIdx; + auto const scoreIdx = params.mPtrTopKPacked[expandedIdx]; + int offset = warpIdx * MaxNumExperts + static_cast(scoreIdx.idx); + smemKIdx[offset] = static_cast(laneIdx); + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[expandedIdx] = static_cast(scoreIdx.score); + } + } + } + } + __syncthreads(); + + // Each thread handles ExpertsPerThread contiguous experts. + // Thread i handles experts [i * ExpertsPerThread, (i+1) * ExpertsPerThread). + // Contiguous assignment ensures prefix sum ordering is correct. + int accExpertCount[ExpertsPerThread]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + auto localExpIdx = expert - params.mLocalExpertsStartIdx; + auto isLocal = localExpIdx >= 0 && localExpIdx < params.mNumLocalExperts && + (localExpIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + + // Get the count of each expert and the offset for each token + accExpertCount[e] = 0; + if (isLocal) { + int offset = expert; + for (int j = 0; j < BlockKernelMaxNumTokens; j++) { + if (smemKIdx[offset] >= 0) { + smemOffset[offset] = static_cast(accExpertCount[e]); + accExpertCount[e]++; + } + offset += MaxNumExperts; + } + } + } + __syncthreads(); + + // Get the number of CTAs and the offset for each CTA. + // Use cub::BlockScan's array overload: each thread holds ExpertsPerThread items, + // and ExclusiveSum computes the prefix sum across all NumThreadsBlock * ExpertsPerThread + // items in thread order — exactly matching our contiguous expert assignment. + int32_t numCtaPerExpert[ExpertsPerThread]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + if (params.mIsPow2) { + numCtaPerExpert[e] = divUpLog2(accExpertCount[e], params.mPaddingLog2); + } else { + numCtaPerExpert[e] = divUpTileN(accExpertCount[e], params.mTileTokensDim); + } + // Expand from CGA count to CTA count to keep the semantic stable with downstream kernels + numCtaPerExpert[e] *= params.mClusterSizeInBatchDim; + } + int32_t ctaOffsetPerExpert[ExpertsPerThread]; + int32_t numNonExitingCtas; + Scan(tempStorage).ExclusiveSum(numCtaPerExpert, ctaOffsetPerExpert, numNonExitingCtas); + __syncthreads(); // Required barrier before reusing TempStorage for the next BlockScan + + // Compute padded expert scan counts (same array-overload pattern) + int32_t tmpCountPerExpert[ExpertsPerThread]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + if (params.mIsPow2) { + tmpCountPerExpert[e] = divUpMulLog2(accExpertCount[e], params.mPaddingLog2); + } else { + tmpCountPerExpert[e] = divUpMulTileN(accExpertCount[e], params.mTileTokensDim); + } + } + int32_t expertScanCountsPerExpert[ExpertsPerThread]; + Scan(tempStorage).ExclusiveSum(tmpCountPerExpert, expertScanCountsPerExpert); + __syncthreads(); + + // Write CTA configs for each expert this thread handles +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + auto localExpIdx = expert - params.mLocalExpertsStartIdx; + auto isLocal = localExpIdx >= 0 && localExpIdx < params.mNumLocalExperts && + (localExpIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + + if (isLocal) { + for (int cta = 0; cta < numCtaPerExpert[e]; ++cta) { + int32_t const mappedLocalIdx = + (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffsetPerExpert[e] + cta] = mappedLocalIdx; + // Write CTA-level MnLimits using ctaTile = cgaTile / clusterSize + int32_t mnLimit1; + int32_t mnLimit2; + if (params.mIsPow2) { + int32_t ctaPaddingLog2 = params.mPaddingLog2 - params.mClusterSizeLog2; + mnLimit1 = + mulLog2(ctaOffsetPerExpert[e] + cta + 1, ctaPaddingLog2); + mnLimit2 = mulLog2(ctaOffsetPerExpert[e], ctaPaddingLog2) + + accExpertCount[e]; + } else { + int32_t ctaTile = params.mTileTokensDim / params.mClusterSizeInBatchDim; + mnLimit1 = (ctaOffsetPerExpert[e] + cta + 1) * ctaTile; + mnLimit2 = ctaOffsetPerExpert[e] * ctaTile + accExpertCount[e]; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffsetPerExpert[e] + cta] = min(mnLimit1, mnLimit2); + } + } + } + + // at this point, we can write out padded count + if (threadIdx.x == 0) { + int32_t permutedIdxSize; + if (params.mIsPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas >> params.mClusterSizeLog2, params.mPaddingLog2); + } else { + permutedIdxSize = (numNonExitingCtas / params.mClusterSizeInBatchDim) * params.mTileTokensDim; + } + params.mPtrPermutedIdxSize[0] = permutedIdxSize; + params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; + } + + for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++) { +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + int offset = tokenIdx * MaxNumExperts + expert; + if (smemKIdx[offset] >= 0) { + auto localExpIdx = expert - params.mLocalExpertsStartIdx; + auto isLocal = localExpIdx >= 0 && localExpIdx < params.mNumLocalExperts && + (localExpIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + + int const expandedIdx = tokenIdx * params.mTopK + smemKIdx[offset]; + int const offsetWithinExpert = static_cast(smemOffset[offset]); + int const offsetForExpert = expertScanCountsPerExpert[e]; + int const permutedIdx = + isLocal ? offsetForExpert + offsetWithinExpert : int32_t{-1}; + + if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; + } + if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocal) { + params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; + } + if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocal) { + params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; + } + } + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Trigger the secondary kernel AFTER all global memory writes (including permutation indices). + // The downstream kernels depend on all routing outputs being visible. + if (params.mUsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif +} + +void launchBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream) { + LAUNCH_ROUTING_CUSTOM(data, false, routingIndicesBlockKernel, 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 2. Cluster kernel — single-cluster fused kernel for ≤256 tokens (SM90+). +// Uses distributed shared memory across 8 blocks in a cluster. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) + routingIndicesClusterKernel(KernelParams params) { + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using BaseType = typename KernelParams::ExpertSelectPolicy::template BaseType; + using TypePacked = PackedScoreIdx; + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + + __shared__ TypePacked __attribute((aligned(128))) + smemPackedScoreIdx[NumWarps * KernelParams::MaxNumTopExperts]; + + uint32_t const clusterBlockRank = blockIdx.x; + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const laneIdx = cutlass::arch::LaneId(); + auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx; + auto scoreOffset = warpTokenIdx * params.mNumExperts; + bool validToken = warpTokenIdx < params.mNumTokens; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + if (params.mUsePdl) { + cudaGridDependencySynchronize(); + } + + if (params.mPtrScores != nullptr) { + BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; + int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; + if (validToken) { + KernelParams::ExpertSelectPolicy::template apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK, + params.mPtrScores + scoreOffset, params); + if (laneIdx < params.mTopK) { + smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx] = + TypePacked{warpTopKScore[laneIdx], + static_cast(warpTopKExpertIdx[laneIdx])}; + } + } + } + + __cluster_barrier_arrive(); + __cluster_barrier_wait(); + + if (params.mPtrScores != nullptr) { + routingPermutation(params, smemPackedScoreIdx, warpIdx, + clusterBlockRank); + } else { + routingPermutation(params, smemPackedScoreIdx, warpIdx, + clusterBlockRank); + } +} +#else +__global__ void __launch_bounds__(NumThreads) + routingIndicesClusterKernel(KernelParams /* params */) { + assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); +} +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + +void launchClusterKernel(Data const& data, void* stream) { + LAUNCH_ROUTING_CUSTOM(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, + NumThreads, + /*smemSize=*/0, // No dynamic smem + stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 3. HistogramScores kernel — computes TopK from raw scores and initializes expert counts. +// Used as step 1 of the multi-kernel pipeline when input is raw logits. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void + __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024) + routingIndicesHistogramScoresKernel(KernelParams params) { + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using BaseType = typename KernelParams::ExpertSelectPolicy::template BaseType; + // Cap actual thread count at 1024 when MaxNumExperts > 1024. + static constexpr int NumThreadsBlock = + KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024; + + // VecSize stays based on MaxNumExperts — each warp still processes all experts for one token. + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + + int32_t const laneIdx = cutlass::arch::LaneId(); + int32_t const warpIdx = threadIdx.x / WarpSize; + // Use NumThreadsBlock (actual thread count) for grid-stride warp/thread addressing + int32_t const globalWarpIdx = blockIdx.x * NumThreadsBlock / WarpSize + warpIdx; + int32_t const globalWarpStride = gridDim.x * NumThreadsBlock / WarpSize; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Wait on primary grid. + if (params.mUsePdl) { + cudaGridDependencySynchronize(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + + // initialize the mPtrExpertCounts — use NumThreadsBlock for grid-stride + int32_t expertCountsNum = 2 * params.mNumExperts; + int32_t globalThreadIdx = blockIdx.x * NumThreadsBlock + threadIdx.x; + int32_t globalThreadStride = gridDim.x * NumThreadsBlock; + initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); + + // in this case, each warp represents a token, and we use a grid-stride loop + // over all warps/tokens + BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; + int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; + for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; + tokenIdx += globalWarpStride) { + auto scoreOffset = tokenIdx * params.mNumExperts; + + KernelParams::ExpertSelectPolicy::template apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK, + params.mPtrScores + scoreOffset, params); + + if (laneIdx < params.mTopK) { + PackedScoreIdx packedScore{static_cast(warpTopKScore[laneIdx]), + static_cast(warpTopKExpertIdx[laneIdx])}; + params.mPtrTopKPacked[tokenIdx * params.mTopK + laneIdx] = packedScore; + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Trigger secondary kernel AFTER writing all packed scores, so the next kernel + // (routingIndicesHistogramKernel) sees the completed mPtrTopKPacked writes. + if (params.mUsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +} + +void launchHistogramScoresKernel(Data const& data, uint32_t maxNumBlocks, + uint32_t numThreadsHist, void* stream) { + LAUNCH_ROUTING_CUSTOM(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 4. Coop kernel — cooperative histogram + offsets via grid-sync. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchCoopKernel(Data const& data, int numBlocksCoop, uint32_t numThreadsHist, void* stream) { + if (data.mNumExperts <= NumExperts128Experts) { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, + numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream, + NoOpPreprocess, NoOpPostprocess, NumExperts128Experts, + NumTop8Experts); + } else if (data.mNumExperts <= NumExperts160Experts) { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, + numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream, + NoOpPreprocess, NoOpPostprocess, NumExperts160Experts, + NumTop8Experts); + } else if (data.mNumExperts <= NumExperts256Experts) { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, + numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream, + NoOpPreprocess, NoOpPostprocess, NumExperts256Experts, + NumTop8Experts); + } else if (data.mNumExperts <= NumExperts384Experts) { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, + numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream, + NoOpPreprocess, NoOpPostprocess, NumExperts384Experts, + NumTop8Experts); + } else if (data.mNumExperts <= NumExperts512Experts) { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, + numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream, + NoOpPreprocess, NoOpPostprocess, NumExperts512Experts, + NumTop8Experts); + } else if (data.mNumExperts <= NumExperts576Experts) { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, + numBlocksCoop, numThreadsHist, /*smemSize=*/0, stream, + NoOpPreprocess, NoOpPostprocess, NumExperts576Experts, + NumTop8Experts); + } else { + FLASHINFER_WARN("Coop kernel does not support numExperts > %d", NumExperts576Experts); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 5-7. Launch wrappers for shared kernels (defined in RoutingKernel.cuh): +// - InitExpertCounts (zero expert counts) +// - Histogram kernel (histogram from packed TopK) +// - Offsets kernel (prefix-scan + permutation) +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchInitExpertCounts(Data const& data, uint32_t numThreadsHist, void* stream) { + LAUNCH_ROUTING_CUSTOM_NO_POLICY(data, false, routingInitExpertCounts, + (2 * data.mNumExperts - 1) / numThreadsHist + 1, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} + +void launchHistogramKernel(Data const& data, int numBlocksHistogram, uint32_t numThreadsHist, + void* stream) { + LAUNCH_ROUTING_CUSTOM_NO_POLICY(data, false, routingIndicesHistogramKernel, + numBlocksHistogram, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} + +void launchOffsetsKernel(Data const& data, int numBlocksOffsets, uint32_t numThreadsHist, + void* stream) { + LAUNCH_ROUTING_CUSTOM_NO_POLICY(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Entry point +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run(Data const& data, void* stream) { + TVM_FFI_ICHECK(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || + data.mPtrTopKIds != nullptr) + << "Routing kernel requires at least one input parameter"; + + // When topK is already computed (mPtrTopKIds or mPtrTopKPacked without scores), + // delegate to the shared post-topK pipeline which handles all path selection + // (single-block, single-cluster, coop, multi-kernel) automatically. + // No routing-method-specific logic needed. + if (data.mPtrTopKIds != nullptr || + (data.mPtrTopKPacked != nullptr && data.mPtrScores == nullptr)) { + if (data.mPtrTopKIds != nullptr) { + TVM_FFI_ICHECK(data.mPtrTopKWeights != nullptr) + << "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " + "custom routing."; + } + uint32_t const numThreadsHist = + std::min(1024u, static_cast(getMaxNumExperts(data.mNumExperts))); + runPostTopKPipeline(data, numThreadsHist, stream); + return; + } + + // After this point, input is mPtrScores (raw logits that need topK computation). + TVM_FFI_ICHECK(data.mPtrScores != nullptr) << "Expected mPtrScores to be non-null at this " + "point."; + TVM_FFI_ICHECK(data.mPtrPermutedIdxSize != nullptr && + data.mPtrCtaIdxXyToBatchIdx != nullptr && + data.mPtrCtaIdxXyToMnLimit != nullptr && + data.mPtrNumNonExitingCtas != nullptr) + << "Custom routing kernel expects permuted idx and grouped Gemm launch config buffers"; + TVM_FFI_ICHECK_LE(data.mTopK, static_cast(MaxSupportedTopExperts)) + << "Routing kernel expects topK experts <= " << MaxSupportedTopExperts << ", got " + << data.mTopK; + TVM_FFI_ICHECK_LE(data.mNumExperts, static_cast(MaxSupportedExperts)) + << "Routing kernel expects #experts " << data.mNumExperts << " to be no more than " + << MaxSupportedExperts << "."; + TVM_FFI_ICHECK_EQ(data.mNumExperts % 4, 0) + << "Routing kernel expects #experts " << data.mNumExperts + << " to be a multiple of 4."; + + bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens; + bool const useSingleCluster = data.mNumTokens <= MaxNumTokensSingleClusterScores; + + if (!useSingleCluster && !useSingleBlock) { + TVM_FFI_ICHECK(data.mPtrTopKPacked != nullptr) + << "When #tokens is large, `mPtrTopKPacked` is a required input."; + TVM_FFI_ICHECK(data.mPtrExpertCounts != nullptr) + << "When #tokens is large, `mPtrExpertCounts` is a required input."; + } + + uint32_t const numThreadsHist = + std::min(1024u, static_cast(getMaxNumExperts(data.mNumExperts))); + + // PDL overlap control: intermediate routing kernels allow the next routing kernel to overlap + // (mPdlOverlapWithNext = mUsePdl). The LAST routing kernel disables overlap so the consumer + // GEMM (which may not have cudaGridDependencySynchronize for routing data) can't start early. + // We need a mutable copy since `data` is const. + Data mutableData = data; + bool const pdl = data.mUsePdl; + + if (useSingleBlock) { + //@TODO: For now we use the single block kernel for cases with token number no larger than 4. + // We will future tune this threshold based on the performance. + mutableData.mPdlOverlapWithNext = false; // Last kernel — don't let consumer overlap + launchBlockKernel(mutableData, numThreadsHist, stream); + } else if (useSingleCluster) { + mutableData.mPdlOverlapWithNext = false; // Last kernel — don't let consumer overlap + launchClusterKernel(mutableData, stream); + } else { + // mPtrScores path: compute topK first via fused scores+histogram kernel, + // then use coop or multi-kernel pipeline for histogram + offsets. + uint32_t const maxNumBlocks = 1024; + + // Step 1: Compute topK from raw scores and write packed results to mPtrTopKPacked. + mutableData.mPdlOverlapWithNext = pdl; // Intermediate — allow next routing kernel to + // overlap + launchHistogramScoresKernel(mutableData, maxNumBlocks, numThreadsHist, stream); + + // Step 2+3: Histogram + Offsets — try coop path first, fall back to multi-kernel. + // Coop kernel fuses histogram + offsets into a single cooperative launch. + // Requires SM90+ (kernel uses grid-sync), numExperts <= 1024, and enough SM capacity. + static int const smMajor = tensorrt_llm::common::getSMVersion() / 10; + bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) && + (data.mPtrPermutedIdxSize != nullptr); + bool useCoop = false; + int numBlocksCoop = 0; + + if (canUseCoop) { + static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + numBlocksCoop = smCount - 8; // Reserve 8 SMs for overlapping kernels + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + useCoop = (data.mNumTokens <= maxTokensCoop); + } + + if (useCoop) { + // Coop path: 2 kernels (scores+topK → coop histogram+offsets) instead of 3. + mutableData.mPdlOverlapWithNext = + pdl; // Intermediate — allow next routing kernel to overlap + launchInitExpertCounts(mutableData, numThreadsHist, stream); + mutableData.mPdlOverlapWithNext = false; // Last kernel — don't let consumer overlap + launchCoopKernel(mutableData, numBlocksCoop, numThreadsHist, stream); + } else { + // Multi-kernel path: 3 kernels (scores+topK → histogram → offsets). + // Note: histogramScoresKernel already zeroes expert counts, so no initExpertCounts + // needed. + uint32_t const expandedIdxSize = data.mNumTokens * data.mTopK; + uint32_t const histogramEltsPerBlock = 8 * numThreadsHist; + uint32_t const offsetEltsPerBlock = + NumEltsPerOffsetTilePerThread * numThreadsHist; + + int const numBlocksHistogram = std::min( + (expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets = std::min( + (expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + mutableData.mPdlOverlapWithNext = + pdl; // Intermediate — allow next routing kernel to overlap + launchHistogramKernel(mutableData, numBlocksHistogram, numThreadsHist, stream); + mutableData.mPdlOverlapWithNext = false; // Last kernel — don't let consumer overlap + launchOffsetsKernel(mutableData, numBlocksOffsets, numThreadsHist, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingCustom +} // namespace moe::dev::routing diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index 5408d2d059..816f28199c 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,16 +14,37 @@ * limitations under the License. */ +// Merged from individual launch*.cu files for readability. +// Contains all kernel definitions and launch wrappers for routingDeepSeek. +// +// Kernel inventory: +// 1. routingMainKernel — DeepSeek-specific main kernel (sigmoid + bias + group TopK) +// 2. routingIndicesClusterKernel — single-cluster fused kernel (SM90+) +// 3. launchCoopKernel — delegates to routingCustom's coop implementation +// 4. launchInitExpertCounts — zero expert counts +// 5. launchHistogramKernel — histogram from packed TopK +// 6. launchOffsetsKernel — prefix-scan + permutation + #include #include #include "flashinfer/exception.h" #include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" +#include "flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh" namespace moe::dev::routing { +// Forward declaration of routingCustom's coop kernel (used by DeepSeek's coop path) +namespace routingCustom { +void launchCoopKernel(Data const& data, int numBlocksCoop, uint32_t numThreadsHist, void* stream); +} // namespace routingCustom + namespace routingDeepSeek { +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Constants and dispatch macros +// //////////////////////////////////////////////////////////////////////////////////////////////////// static constexpr int NumNemotronExperts = 512; @@ -32,11 +53,73 @@ static constexpr int NumDeepseekExperts = 256; static constexpr int MaxSupportedExpertCount = std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); static constexpr int NumTopGroupScores = 2; -static constexpr int DefaultMaxNumTopExperts = 8; -static constexpr int MaxSupportedTopExperts = 22; static constexpr int MaxNumTopGroups = 4; static constexpr int MaxNumGroups = 8; +static constexpr int NumTop8Experts = 8; +static constexpr int NumTop22Experts = 22; +static constexpr int MaxSupportedTopExperts = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int32_t getMaxNumExperts(int32_t numExperts) { + if (numExperts <= topk::MaxNumExpertsUnit) { + return topk::MaxNumExpertsUnit; + } else if (numExperts <= NumDeepseekExperts) { + return NumDeepseekExperts; + } else if (numExperts <= NumKimiK2Experts) { + return NumKimiK2Experts; + } else if (numExperts <= NumNemotronExperts) { + return NumNemotronExperts; + } else { + FLASHINFER_WARN("Unsupported numExperts"); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Helper macro: dispatch on topK tier for a given numExperts tier. +#define LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, forceFloatInput, numExperts) \ + if (data.mTopK <= NumTop8Experts) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, \ + numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, numExperts, NumTop8Experts); \ + } else if (data.mTopK <= NumTop22Experts) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, numExperts, NumTop22Experts); \ + } else { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, numExperts, MaxSupportedTopExperts); \ + } + +#define LAUNCH_ROUTING_DEEPSEEK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, forceFloatInput) \ + if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, forceFloatInput, topk::MaxNumExpertsUnit); \ + } else if (data.mNumExperts <= NumDeepseekExperts) { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, forceFloatInput, NumDeepseekExperts); \ + } else if (data.mNumExperts <= NumKimiK2Experts) { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, forceFloatInput, NumKimiK2Experts); \ + } else if (data.mNumExperts <= NumNemotronExperts) { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, forceFloatInput, NumNemotronExperts); \ + } else { \ + FLASHINFER_WARN("Unsupported numExperts"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 1. Main kernel — DeepSeek-specific routing with sigmoid activation, bias, and group TopK. +// Handles both grouped and non-grouped expert selection. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + template __global__ void routingMainKernel(KernelParams params) { // declare types @@ -65,7 +148,7 @@ __global__ void routingMainKernel(KernelParams params) { // note that for invalid scores, we use negative infinity, // needed for GLM-style routing where bias can be negative - static constexpr float invalidScoreFloat = -float(INFINITY); + static constexpr float invalidScoreFloat = float{-INFINITY}; const OutputT invalidScore = OutputT{invalidScoreFloat}; // load bias already; each warp represents one expert group @@ -76,8 +159,11 @@ __global__ void routingMainKernel(KernelParams params) { expertSelected = laneIdx < params.mNumExpertsPerGroup; } auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert; - auto biasVal = - expertSelected ? static_cast(params.mPtrRoutingBias[threadExpert]) : invalidScoreFloat; + auto biasVal = expertSelected + ? static_cast( + loadScalar(params.mPtrRoutingBias, threadExpert, params.mDtypeBias)) + : invalidScore; + // initialize the mPtrExpertCounts if (params.mPtrExpertCounts) { int32_t globalThreadIdx = blockIdx.x * blockDim.x + threadIdx.x; @@ -88,7 +174,7 @@ __global__ void routingMainKernel(KernelParams params) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) // trigger the secondary kernel when using PDL, then wait on primary - if constexpr (KernelParams::UsePdl) { + if (params.mUsePdl) { cudaTriggerProgrammaticLaunchCompletion(); cudaGridDependencySynchronize(); } @@ -100,15 +186,15 @@ __global__ void routingMainKernel(KernelParams params) { expertSelected ? static_cast(params.mPtrScores[scoreIdx]) : invalidScoreFloat; // get the sigmoid score // note that for invalid values, we simply use a negative value: - // sigmoig scores are always strictly positive + // sigmoid scores are always strictly positive auto scoreSigmoid = sigmoid_accurate(score); // write the sigmoid score to shared for later use if (expertSelected) { smemScoreSigmoid[threadExpert] = scoreSigmoid; } // get the score with bias - // note: with invalid values, invalidScoreFloat ensures values are always smaller than valid - // ones + // note: with invalid values, because sigmoid is < 1 and bias is -1, + // we must get a negative value, which is smaller than any valid value auto scoreBias = float{scoreSigmoid + float{biasVal}}; if (expertSelected) { @@ -127,7 +213,7 @@ __global__ void routingMainKernel(KernelParams params) { if constexpr (KernelParams::UseGroups) { topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, - /* minValue */ invalidScoreFloat); + /* minValue */ invalidScoreFloat); // get the final group score and write it to shared if (cute::elect_one_sync()) { auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; @@ -139,8 +225,9 @@ __global__ void routingMainKernel(KernelParams params) { __syncthreads(); auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - if constexpr (KernelParams::UseGroups) { // a single warp performs the selection of top groups, - // and goes on to select the final experts + if constexpr (KernelParams::UseGroups) { + // a single warp performs the selection of top groups, + // and goes on to select the final experts if (warpIdx == 0) { float groupScore = laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat; @@ -203,19 +290,19 @@ __global__ void routingMainKernel(KernelParams params) { __syncthreads(); if (warpIdx == 0) { int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1; - float intermidiateScore[NumInterTopKPerThread]; - int32_t intermidiateExpert[NumInterTopKPerThread]; + float intermediateScore[NumInterTopKPerThread]; + int32_t intermediateExpert[NumInterTopKPerThread]; for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) { int ii = i / WarpSize; if (i < NumInterTopK) { - intermidiateScore[ii] = smemInterTopScores[i]; - intermidiateExpert[ii] = smemInterTopExperts[i]; + intermediateScore[ii] = smemInterTopScores[i]; + intermediateExpert[ii] = smemInterTopExperts[i]; } else { - intermidiateScore[ii] = invalidScoreFloat; - intermidiateExpert[ii] = KernelParams::MaxNumExperts - 1; + intermediateScore[ii] = invalidScoreFloat; + intermediateExpert[ii] = KernelParams::MaxNumExperts - 1; } } - topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert, + topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert, /* minValue */ invalidScoreFloat, params.mTopK); } } else { @@ -243,7 +330,7 @@ __global__ void routingMainKernel(KernelParams params) { // determine whether our expert is local to this GPU auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F; auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); @@ -265,6 +352,25 @@ __global__ void routingMainKernel(KernelParams params) { } } +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Launch wrappers +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static void launchMainKernel(Data& data, int numBlocks, int numThreadsMain, void* stream) { + bool const forceFloatInput = (data.mDtypeInput == tg::Dtype::Fp32); + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, forceFloatInput); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 2. Cluster kernel — in a separate section because __cluster_dims__ can affect code generation. +// Ideally this would be in a separate .cu file; kept here for simplicity. +// //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -277,9 +383,8 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); int32_t const clusterBlockRank = blockIdx.x; - //@todo: try to move it into routingPermutation // then wait on primary grid - if constexpr (KernelParams::UsePdl) { + if (params.mUsePdl) { cudaGridDependencySynchronize(); } routingPermutation -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) - routingIndicesCoopKernel(KernelParams params) { - // number of experts is bounded by number of threads - int constexpr NumThreads = KernelParams::MaxNumExperts; - __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads]; - __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads]; - // needed for the exclusive sum of token offsets - using Scan = cub::BlockScan; - __shared__ typename Scan::TempStorage tempStorage; - // 64 elements -> 128+ registers. Above that we may start to see spilling to local memory. - static constexpr int MaxExpandedIdxPerThread = 64; - - // Initialize grid. - cg::grid_group grid = cg::this_grid(); - // Note: the following is more efficient than grid.block_index() because we don't use y and z. - int32_t const gridBlockIdx = blockIdx.x; - int32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x; - int32_t const numBlocks = gridDim.x; - int32_t const numThreadsPerGrid = numBlocks * NumThreads; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - - auto expandedIdxSize = params.mNumTokens * params.mTopK; - - // pre-fill the counts with 0 - smemExpertCount[threadIdx.x] = 0; - __syncthreads(); - - // then wait on primary grid - if constexpr (KernelParams::UsePdl) { - cudaGridDependencySynchronize(); - } - - // each thread keeps has some number of "expanded indexes" assigned to it - // for each of these, we keep the associated expert and offset within expert in registers - int32_t expertIndexes[MaxExpandedIdxPerThread]; - int32_t expertOffsets[MaxExpandedIdxPerThread]; - auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a - // time, and branch between a fast path without bound checks and a slow path with bound checks. - int constexpr IterStride = 4; - static_assert(MaxExpandedIdxPerThread % IterStride == 0); - - // Define a lambda to avoid code duplication in both branches. - auto loopBody = [&](int ii, int expandedIdx) { - int32_t expertIdx = params.mPtrTopKIds != nullptr ? params.mPtrTopKIds[expandedIdx] - : params.mPtrTopKPacked[expandedIdx].idx; - expertIndexes[ii] = expertIdx; - // check whether this expert is local to our GPU at all and ignore if not - auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; - expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0; - }; - -#pragma unroll - for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) { - // Whether it's safe to do multiple iterations without bound checks. - bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize; - if (takeFastPath) { -#pragma unroll - for (int32_t jj = 0; jj < IterStride; jj++) { - int const ii = ii0 + jj; - auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; - loopBody(ii, expandedIdx); - } - } else { - bool doBreak = false; -#pragma unroll - for (int32_t jj = 0; jj < IterStride; jj++) { - int const ii = ii0 + jj; - auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; - if (expandedIdx >= expandedIdxSize) { - doBreak = true; - break; - } - loopBody(ii, expandedIdx); - } - if (doBreak) { - break; - } - } - } - - // Make histogram (token counts per expert) available to all threads in the block. - __syncthreads(); - - // - // Each thread now represents one expert - // - - // Add the local bin count to the common bin count and get a per-CTA offset. - int32_t const localExpertCount = smemExpertCount[threadIdx.x]; - - int32_t blockExpertOffset = 0; - if (threadIdx.x < params.mNumExperts) { - blockExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); - } - - // Sync to wait for completion of the histogram reduction. - grid.sync(); - - // Get total count for this expert. - int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; - - // Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency. - - // Compute the runtime config for projections - // Whether or not an expert is local is taken into account when smemExpertCount is computed - // so we do not need to take it into account here. - - int32_t numCta; - if constexpr (KernelParams::isPow2) { - numCta = divUpLog2(count, params.mPaddingLog2); - } else { - numCta = divUpTileN(count, params.mTileTokensDim); - } - - int32_t ctaOffset; - int32_t numNonExitingCtas; - Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); - - for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks) { - const int32_t localExpertIdx = - (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; - params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - int32_t mnLimit1; - int32_t mnLimit2; - if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; - } else { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; - } - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); - } - - // get the padded offset associated with this expert - int32_t offset; - if constexpr (KernelParams::isPow2) { - offset = mulLog2(ctaOffset, params.mPaddingLog2); - } else { - offset = mulTileN(ctaOffset, params.mTileTokensDim); - } - int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); - } else { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); - } - - // write out padded count - if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) { - params.mPtrPermutedIdxSize[0] = permutedIdxSize; - params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; - } - - // write expert offsets to shared - smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; +static void launchClusterKernel(Data& data, int numThreadsHist, void* stream) { + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} - // make expert offsets available to all threads - __syncthreads(); +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 3-6. Launch wrappers for shared kernels. +// Coop delegates to routingCustom; others use LAUNCH_ROUTING_DEEPSEEK macro. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// - // trigger the secondary kernel when using PDL - // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, - // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens - // TODO: this is not sufficient to ensure visibility in the next kernel! - if constexpr (KernelParams::UsePdl) { - cudaTriggerProgrammaticLaunchCompletion(); - } +static void launchCoopKernel(Data& data, int numBlocksCoop, int /*numThreadsHist*/, void* stream) { + // Use routingCustom's coop kernel implementation (they are identical). + // Convert DeepSeek Data to Custom Data for launching. + routingCustom::Data customData; + // Copy base fields + static_cast(customData) = static_cast(data); + // Set routingCustom-specific defaults (not needed for coop kernel) + customData.mDtypeOutput = data.mDtypeOutput; + // The coop kernel doesn't read routing logits (mPtrInput), only mPtrTopKPacked. + // Set mDtypeInput = mDtypeOutput so the dispatched template is , + // avoiding an unnecessary mixed-type instantiation. + customData.mDtypeInput = data.mDtypeOutput; + customData.mPreprocessType = RoutingPreprocessType::None; + customData.mPostprocessType = RoutingPostprocessType::Softmax; + + // Recompute numThreadsHist using routingCustom's expert tiers (128, 512, 2048), + // since the custom coop kernel dispatch selects template parameters based on these tiers. + // DeepSeek's getMaxNumExperts uses different tiers (256, 384, 512) which would mismatch. + uint32_t const customNumThreadsHist = + std::min(1024u, static_cast(routingCustom::getMaxNumExperts(data.mNumExperts))); + routingCustom::launchCoopKernel(customData, numBlocksCoop, customNumThreadsHist, stream); +} -// each thread has the same "expanded indexes" assigned to it as above -// at this point, we know the final offsets of experts and the offsets within -// experts, which allows writing the final index values -#pragma unroll - for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) { - auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; - if (expandedIdx >= expandedIdxSize) { - break; - } - auto expertIdx = expertIndexes[ii]; - // check whether this expert is local to our GPU at all - auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; - auto tokenIdx = expandedIdx / params.mTopK; - auto permutedIdx = - isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; - if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { - params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; - } - if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert) { - params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; - } - if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) { - params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; - } - } +static void launchInitExpertCounts(Data& data, int numThreadsHist, void* stream) { + LAUNCH_ROUTING_DEEPSEEK(data, false, routingInitExpertCounts, + (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/false); } -#else -__global__ void routingIndicesCoopKernel(KernelParams params) { - assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures"); + +static void launchHistogramKernel(Data& data, int numBlocksHistogram, int numThreadsHist, + void* stream) { + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); } -#endif -int constexpr getMaxNumExperts(int32_t numExperts) { - if (numExperts <= topk::MaxNumExpertsUnit) { - return topk::MaxNumExpertsUnit; - } else if (numExperts <= NumDeepseekExperts) { - return NumDeepseekExperts; - } else if (numExperts <= NumKimiK2Experts) { - return NumKimiK2Experts; - } else if (numExperts <= NumNemotronExperts) { - return NumNemotronExperts; - } else { - TLLM_LOG_ERROR("Unsupported numExperts"); - return 0; - } +static void launchOffsetsKernel(Data& data, int numBlocksOffsets, int numThreadsHist, + void* stream) { + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); } //////////////////////////////////////////////////////////////////////////////////////////////////// -#define LAUNCH_ROUTING_DEEPSEEK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ - extraFlag) \ - if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ - LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, topk::MaxNumExpertsUnit, \ - DefaultMaxNumTopExperts); \ - } else if (data.mNumExperts <= NumDeepseekExperts) { \ - LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumDeepseekExperts, DefaultMaxNumTopExperts); \ - } else if (data.mNumExperts <= NumKimiK2Experts) { \ - LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumKimiK2Experts, DefaultMaxNumTopExperts); \ - } else if (data.mNumExperts <= NumNemotronExperts) { \ - if (data.mTopK <= DefaultMaxNumTopExperts) { \ - LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumNemotronExperts, \ - DefaultMaxNumTopExperts); \ - } else { \ - LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumNemotronExperts, MaxSupportedTopExperts); \ - } \ - } else { \ - TLLM_LOG_ERROR("Unsupported numExperts"); \ - } -void runImpl(Data& data, void* stream) { +void run(Data& data, void* stream) { FLASHINFER_CHECK( data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, "Routing kernel requires at least one input parameter"); - if (data.mPtrTopKIds != nullptr) { - FLASHINFER_CHECK(data.mPtrTopKWeights != nullptr, - "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " - "DeepSeek routing."); + + // When topK is already computed (mPtrTopKIds or mPtrTopKPacked without scores), + // delegate to the shared post-topK pipeline which handles all path selection + // (single-block, single-cluster, coop, multi-kernel) automatically. + // No routing-method-specific logic needed. + if (data.mPtrTopKIds != nullptr || + (data.mPtrTopKPacked != nullptr && data.mPtrScores == nullptr)) { + if (data.mPtrTopKIds != nullptr) { + FLASHINFER_CHECK( + data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " + "DeepSeek routing."); + } + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); + runPostTopKPipeline(data, numThreadsHist, stream); + return; } - if (data.mPtrExpandedIdxToPermutedIdx != nullptr || - data.mPtrPermutedIdxToExpandedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr) - FLASHINFER_CHECK( - (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) && data.mPtrPermutedIdxSize, - "If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required"); + + // After this point, input is mPtrScores (raw logits that need DeepSeek-specific routing). FLASHINFER_CHECK(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet"); - FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups, - "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, - data.mNumLimitedGroups); - // Test limits according to values passed in launch, see definition of LAUNCH_ROUTING_DEEPSEEK - if (data.mNumExperts <= NumKimiK2Experts) { - FLASHINFER_CHECK( - data.mTopK <= DefaultMaxNumTopExperts, - "When NumExperts <= NumKimiK2Experts, routing kernel expects topK experts <= %d, got %d", - DefaultMaxNumTopExperts, data.mTopK); - } else { - FLASHINFER_CHECK( - data.mTopK <= MaxSupportedTopExperts, - "When NumExperts > NumKimiK2Experts, routing kernel expects topK experts <= %d, got %d", - MaxSupportedTopExperts, data.mTopK); - } - FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", - data.mTopK); - FLASHINFER_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize, - "Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", - data.mTopK, data.mNumLimitedGroups); - FLASHINFER_CHECK(data.mTopK <= data.mNumExperts, - "Routing kernel expects topK %d to be at most #experts %d", data.mTopK, + FLASHINFER_CHECK(data.mNumExperts >= data.mTopK, + "Routing kernel expects topK (%d) to be <= numExperts (%d)", data.mTopK, data.mNumExperts); FLASHINFER_CHECK(data.mNumExperts <= MaxSupportedExpertCount, "Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, MaxSupportedExpertCount); - FLASHINFER_CHECK(data.mNumExpertGroups >= data.mNumLimitedGroups, - "Routing kernel expects top groups %d to be limited by #expert groups %d", - data.mNumLimitedGroups, data.mNumExpertGroups); + FLASHINFER_CHECK(data.mTopK <= MaxSupportedTopExperts, + "Routing kernel expects topK experts <= %d, got %d", MaxSupportedTopExperts, + data.mTopK); + + if (data.mPtrExpandedIdxToPermutedIdx != nullptr || + data.mPtrPermutedIdxToExpandedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr) + FLASHINFER_CHECK( + data.mPtrTopKPacked != nullptr && data.mPtrPermutedIdxSize, + "If permuted index is required, `mPtrTopKPacked` is also required"); + + // Routing needs to be executed - validate routing kernel constraints if (data.mNumExpertGroups > 1) { FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups, - "Routing kernel expects #experts groups %d to be <= #warps %d", + "Routing kernel expects #expert groups %d to be <= max groups %d", data.mNumExpertGroups, MaxNumGroups); FLASHINFER_CHECK(data.mNumExperts % data.mNumExpertGroups == 0, "Routing kernel expects #experts %d to be a multiple of #expert groups %d", data.mNumExperts, data.mNumExpertGroups); FLASHINFER_CHECK( data.mNumExperts / data.mNumExpertGroups <= WarpSize, - "Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d", - data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups); + "Routing kernel expects #experts per group <= warp size (%d), got %d experts / %d groups " + "= %d experts per group", + WarpSize, data.mNumExperts, data.mNumExpertGroups, + data.mNumExperts / data.mNumExpertGroups); + FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups, + "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, + data.mNumLimitedGroups); + FLASHINFER_CHECK(data.mNumExpertGroups >= data.mNumLimitedGroups, + "Routing kernel expects top groups %d to be limited by #expert groups %d", + data.mNumLimitedGroups, data.mNumExpertGroups); + FLASHINFER_CHECK(data.mNumExperts % 4 == 0, + "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); } - FLASHINFER_CHECK(data.mNumExperts % 4 == 0, - "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); int const numBlocks = data.mNumTokens; int const numThreadsHist = getMaxNumExperts(data.mNumExperts); + bool const pdl = data.mUsePdl; - bool const useSingleCluster = data.mNumTokens <= 1024; - if (!useSingleCluster) { - // Reset the global histograms (not used in single-cluster code path). - // Cover both for the cooperative and two-kernel code paths. - FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr, - "When #tokens is large, `mPtrExpertCounts` is a required input."); - } else { - data.mPtrExpertCounts = - nullptr; // Set it to nullptr for single-cluster code path, as it won't be used - } - - // Number of blocks we can use in the cooperative kernel - // The number of blocks must be: - // >= ⌈(numTokens * topK) / (MaxExpandedIdxPerThread * NumThreads)⌉ - // <= numSms, assuming an occupancy of 1 block/SM - // - // If too small for the given numTokens, fall back to the less performant two-step method. - // - // The upper bound is a strict requirement. The number of blocks should be determined by querying - // the device properties, or conservatively low. - // /!\ The following number is not portable!! (but works on H100 and B200) - int const numBlocksCoop = 128; - - // Maximum number of tokens supported by the kernel using a cooperative launch. - int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; - if (data.mPtrTopKIds == nullptr) { - int const numThreadsMain = - max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1); - } else { - // Reset the global histograms. - LAUNCH_ROUTING_DEEPSEEK(data, false, routingInitExpertCounts, - (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1); - } + // Step 1: Run DeepSeek-specific topK computation (writes to mPtrTopKPacked) + int const numThreadsMain = + max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); + data.mPdlOverlapWithNext = pdl; // Intermediate — allow permutation pipeline to overlap + launchMainKernel(data, numBlocks, numThreadsMain, stream); + // Step 2: Permutation pipeline (reads from mPtrTopKPacked written by step 1) if (data.mPtrPermutedIdxSize != nullptr) { + bool const useSingleCluster = data.mNumTokens <= 1024; + if (!useSingleCluster) { + FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr, + "When #tokens is large, `mPtrExpertCounts` is a required input."); + } else { + data.mPtrExpertCounts = + nullptr; // Set it to nullptr for single-cluster code path, as it won't be used + } + + // Number of blocks we can use in the cooperative kernel + static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + // WAR: Reserve 8 SMs for overlapping kernels. + int const numBlocksCoop = smCount - 8; + // Maximum number of tokens supported by the kernel using a cooperative launch. + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + if (useSingleCluster) { - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingIndicesClusterKernel, - NumBlocksPerCluster, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1); + data.mPdlOverlapWithNext = false; // Last kernel + launchClusterKernel(data, numThreadsHist, stream); } else if (data.mNumTokens <= maxTokensCoop) { - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1); + data.mPdlOverlapWithNext = false; // Last kernel + launchCoopKernel(data, numBlocksCoop, numThreadsHist, stream); } else { const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; const int32_t histogramEltsPerBlock = 8 * numThreadsHist; @@ -671,25 +573,18 @@ void runImpl(Data& data, void* stream) { int const numBlocksOffsets = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingIndicesHistogramKernel, - numBlocksHistogram, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1); - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1); + data.mPdlOverlapWithNext = pdl; // Intermediate + launchHistogramKernel(data, numBlocksHistogram, numThreadsHist, stream); + data.mPdlOverlapWithNext = false; // Last kernel + launchOffsetsKernel(data, numBlocksOffsets, numThreadsHist, stream); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// -void run(Data& data, void* stream) { runImpl(data, stream); } - -//////////////////////////////////////////////////////////////////////////////////////////////////// +#undef LAUNCH_DEEPSEEK_WITH_TOPK +#undef LAUNCH_ROUTING_DEEPSEEK } // namespace routingDeepSeek } // namespace moe::dev::routing diff --git a/csrc/trtllm_fused_moe_routing_llama4.cu b/csrc/trtllm_fused_moe_routing_llama4.cu index 31674e0a8e..38eff3036e 100644 --- a/csrc/trtllm_fused_moe_routing_llama4.cu +++ b/csrc/trtllm_fused_moe_routing_llama4.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,8 @@ namespace routingLlama4 { static constexpr int NumThreads = 1024; static constexpr int NumWarps = NumThreads / WarpSize; static constexpr int MaxNumTopExperts = 1; -static constexpr int NumExpertsLimit = 128; +// static constexpr int MaxNumExperts = 128; +static constexpr int MaxSupportedExperts = 128; static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; static constexpr int WarpKernelSmemStride = 33; @@ -101,7 +102,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // then wait on primary grid - if constexpr (KernelParams::UsePdl) { + if (params.mUsePdl) { cudaGridDependencySynchronize(); } #endif @@ -150,9 +151,9 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam scoreIdx = TypePacked{static_cast(params.mPtrTopKPacked[threadIdx.x].score), static_cast(params.mPtrTopKPacked[threadIdx.x].idx)}; if (params.mPtrTopKWeights != nullptr) { - // we also compute the final score here and write it out if required - auto finalScore = OutputT{sigmoid_accurate(float{scoreIdx.score})}; - params.mPtrTopKWeights[threadIdx.x] = finalScore; + // mPtrTopKPacked already contains sigmoid scores (produced by the scores-path + // kernels), so we just pass them through — no need to apply sigmoid again. + params.mPtrTopKWeights[threadIdx.x] = scoreIdx.score; } } } @@ -190,13 +191,15 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam for (int ii = 0; ii < ExpertsPerThread; ++ii) { auto count = getBits(expertCount, ii); int32_t num; - if constexpr (KernelParams::isPow2) { + if (params.mIsPow2) { num = divUpLog2(count, params.mPaddingLog2); } else { num = divUpTileN(count, params.mTileTokensDim); } numCta += num; } + // Expand from CGA count to CTA count to keep the semantic stable with downstream kernels + numCta *= params.mClusterSizeInBatchDim; // second, we perform the exclusive sum across the warp int32_t ctaOffset; int32_t numNonExitingCtas; @@ -209,24 +212,28 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam for (int ii = 0; ii < ExpertsPerThread; ++ii) { auto count = getBits(expertCount, ii); int32_t finalNumCta; - if constexpr (KernelParams::isPow2) { + if (params.mIsPow2) { finalNumCta = divUpLog2(count, params.mPaddingLog2); } else { finalNumCta = divUpTileN(count, params.mTileTokensDim); } + finalNumCta *= params.mClusterSizeInBatchDim; auto expertIdx = threadIdx.x * ExpertsPerThread + ii; // during the scan for expert offsets, we can already write out // both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit` for (int cta = 0; cta < finalNumCta; ++cta) { params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx; + // Write CTA-level MnLimits using ctaTile = cgaTile / clusterSize int32_t mnLimit1; int32_t mnLimit2; - if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffsetExp + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffsetExp, params.mPaddingLog2) + count; + if (params.mIsPow2) { + int32_t ctaPaddingLog2 = params.mPaddingLog2 - params.mClusterSizeLog2; + mnLimit1 = mulLog2(ctaOffsetExp + cta + 1, ctaPaddingLog2); + mnLimit2 = mulLog2(ctaOffsetExp, ctaPaddingLog2) + count; } else { - mnLimit1 = mulTileN(ctaOffsetExp + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffsetExp, params.mTileTokensDim) + count; + int32_t ctaTile = params.mTileTokensDim / params.mClusterSizeInBatchDim; + mnLimit1 = (ctaOffsetExp + cta + 1) * ctaTile; + mnLimit2 = ctaOffsetExp * ctaTile + count; } params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta] = min(mnLimit1, mnLimit2); } @@ -236,18 +243,20 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam // at this point, we can write out padded count from the warp-aggregate if (cute::elect_one_sync()) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + if (params.mIsPow2) { + permutedIdxSize = + mulLog2(numNonExitingCtas >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + permutedIdxSize = (numNonExitingCtas / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } + params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // we can trigger the next kernel at this point - if constexpr (KernelParams::UsePdl) { + if (params.mUsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif @@ -258,16 +267,18 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam // here, we keep the local offset for each of the thread's experts in a field // of registers auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; + // Convert CTA-level ctaOffset back to token-space (CGA granularity) int32_t finalExpertOffset[ExpertsPerThread]; - if constexpr (KernelParams::isPow2) { - finalExpertOffset[0] = mulLog2(ctaOffset, params.mPaddingLog2); + if (params.mIsPow2) { + finalExpertOffset[0] = + mulLog2(ctaOffset >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - finalExpertOffset[0] = mulTileN(ctaOffset, params.mTileTokensDim); + finalExpertOffset[0] = (ctaOffset / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } #pragma unroll for (int ii = 1; ii < ExpertsPerThread; ++ii) { int32_t tmp; - if constexpr (KernelParams::isPow2) { + if (params.mIsPow2) { tmp = divUpMulLog2(getBits(expertCount, ii - 1), params.mPaddingLog2); } else { tmp = divUpMulTileN(getBits(expertCount, ii - 1), params.mTileTokensDim); @@ -294,7 +305,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam auto expertIdx = threadIdx.x * ExpertsPerThread + ii; auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; // the permuted index: we add the local offset relative to this expert and token // to the global offset from the scan for this expert auto permutedIdx = isLocalExpert ? finalExpertOffset[ii] + localOffsetToken : int32_t{-1}; @@ -338,7 +349,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu auto warp = cg::tiled_partition(block); // then wait on primary grid - if constexpr (KernelParams::UsePdl) { + if (params.mUsePdl) { cudaGridDependencySynchronize(); } @@ -420,7 +431,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid and trigger secondary kernel. - if constexpr (KernelParams::UsePdl) { + if (params.mUsePdl) { cudaGridDependencySynchronize(); cudaTriggerProgrammaticLaunchCompletion(); } @@ -458,25 +469,38 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) } //////////////////////////////////////////////////////////////////////////////////////////////////// -int constexpr getMaxNumExperts(int32_t numExperts) { +int getMaxNumExperts(int32_t numExperts) { if (numExperts <= topk::MaxNumExpertsUnit) { return topk::MaxNumExpertsUnit; } else { - TLLM_LOG_ERROR("Unsupported numExperts"); + FLASHINFER_WARN("Unsupported numExperts"); return 0; } } //////////////////////////////////////////////////////////////////////////////////////////////////// -void runImpl(Data const& data, void* stream) { +void run(Data const& data, void* stream) { FLASHINFER_CHECK( data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, "Routing kernel requires at least one input parameter"); - if (data.mPtrTopKIds != nullptr) { - FLASHINFER_CHECK( - data.mPtrTopKWeights != nullptr, - "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for Llama4 routing."); + // When topK is already computed (mPtrTopKIds or mPtrTopKPacked without scores), + // delegate to the shared post-topK pipeline. This avoids Llama4-specific issues: + // - The Llama4 cluster kernel loads one token per warp but useSingleCluster uses + // the thread-based capacity, causing unprocessed tokens for medium token counts. + // - The Llama4 device kernel applies sigmoid to packed scores that may already + // contain sigmoid values (produced by the scores-path kernels). + if (data.mPtrTopKIds != nullptr || + (data.mPtrTopKPacked != nullptr && data.mPtrScores == nullptr)) { + if (data.mPtrTopKIds != nullptr) { + FLASHINFER_CHECK( + data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for Llama4 " + "routing."); + } + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); + runPostTopKPipeline(data, numThreadsHist, stream); + return; } FLASHINFER_CHECK( data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && @@ -485,23 +509,26 @@ void runImpl(Data const& data, void* stream) { FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d", MaxNumTopExperts, data.mTopK); - FLASHINFER_CHECK(data.mNumExperts <= NumExpertsLimit, + FLASHINFER_CHECK(data.mNumExperts <= MaxSupportedExperts, "Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, - NumExpertsLimit); + MaxSupportedExperts); + // static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads"); + // static_assert(MaxNumExperts <= numThreadsHist, "#experts must be bounded by #threads"); FLASHINFER_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + // After this point, mPtrTopKIds is guaranteed to be nullptr. + // Input is either mPtrScores (raw logits) or mPtrTopKPacked (topK already computed, needs + // sigmoid). bool const useSingleWarp = (data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens) || data.mNumTokens < WarpKernelMaxNumTokens; - bool const useSingleCluster = - data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) - ? MaxNumTokensSingleClusterScores - : MaxNumTokensSingleCluster); + bool const useSingleCluster = data.mNumTokens <= ((data.mPtrScores != nullptr) + ? MaxNumTokensSingleClusterScores + : MaxNumTokensSingleCluster); if (!useSingleCluster) { - FLASHINFER_CHECK( - (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr), - "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."); + FLASHINFER_CHECK(data.mPtrTopKPacked != nullptr, + "When #tokens is large, `mPtrTopKPacked` is a required input."); FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); } @@ -532,7 +559,7 @@ void runImpl(Data const& data, void* stream) { int const numBlocksOffsets = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - if (data.mPtrScores != nullptr && data.mPtrTopKIds == nullptr) { + if (data.mPtrScores != nullptr) { LAUNCH_ROUTING_LLAMA4(data, /*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks, numThreadsHist, @@ -558,23 +585,6 @@ void runImpl(Data const& data, void* stream) { } } -void run(Data const& data, void* stream) { - FLASHINFER_CHECK( - data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, - "Routing kernel requires at least one input parameter"); - FLASHINFER_CHECK( - data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && - data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, - "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"); - FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, - "Routing kernel expects topK experts <= ", MaxNumTopExperts, ", got ", - data.mTopK); - FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got ", - data.mPaddingLog2); - - runImpl(data, stream); -} - //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace routingLlama4 diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu deleted file mode 100644 index 364c267c00..0000000000 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ /dev/null @@ -1,506 +0,0 @@ -/* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" -#include "tvm_ffi_utils.h" - -namespace moe::dev::routing { -namespace routingRenormalize { -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static constexpr int NumThreads = 1024; -static constexpr int NumWarps = NumThreads / WarpSize; -static constexpr int MaxNumTopExperts = 10; -static constexpr int NumExpertsLimit = 512; -static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; -static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; -static constexpr int BlockKernelMaxNumTokens = 4; - -template -__forceinline__ __device__ void routingTopKExperts( - cg::thread_block_tile const& warp, DataType (&score)[VecSize], - int32_t (&idx)[VecSize], DataType (&warpTopKScore)[MaxNumTopExperts], - int32_t (&warpTopKExpertIdx)[MaxNumTopExperts], int32_t const laneIdx, int32_t const numExperts, - int32_t topK, InputType const* ptrScores, bool const normTopkProb, - bool const applySoftmaxAfterTopK) { - DataType minScore = DataType{-INFINITY}; - - for (int i = 0; i < VecSize; i++) { - auto expertIdx = i * WarpSize + laneIdx; - auto newScore = expertIdx < numExperts ? static_cast(ptrScores[expertIdx]) : minScore; - score[i] = newScore; - idx[i] = expertIdx; - } - if constexpr (DoSoftmaxBeforeTopK) { - calcSoftmax(warp, score); - } - - // Get the top-k scores and their corresponding expert indices - topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore, topK); - - // Normalize the scores - if constexpr (DoSoftmaxBeforeTopK) { - float sum = float{1.f}; - if (normTopkProb) { - sum = static_cast(laneIdx < topK ? warpTopKScore[laneIdx] : 0); - sum = cg::reduce(warp, sum, cg::plus()); - } - if (laneIdx < topK) { - warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum; - } - } else { - if (applySoftmaxAfterTopK) { - auto softmaxScore = - calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK); - if (laneIdx < topK) { - warpTopKScore[laneIdx] = softmaxScore; - } - } - // If applySoftmaxAfterTopK is false, we keep the raw TopK values without softmax - } -} - -template -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) - routingIndicesBlockKernel(KernelParams params) { - // types used in this kernel - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - using BaseType = std::conditional_t; - using TypePacked = PackedScoreIdx; - int constexpr MaxNumExperts = KernelParams::MaxNumExperts; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - int32_t const laneIdx = cutlass::arch::LaneId(); - int32_t const expert = threadIdx.x; - auto scoreOffset = warpIdx * params.mNumExperts; - bool validToken = warpIdx < params.mNumTokens; - - static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; - static constexpr int totalExpertCounts = BlockKernelMaxNumTokens * MaxNumExperts; - __shared__ int8_t __attribute((aligned(128))) smemOffset[totalExpertCounts]; - __shared__ int8_t __attribute((aligned(128))) smemKIdx[totalExpertCounts]; - - using Scan = cub::BlockScan; - __shared__ typename Scan::TempStorage tempStorage; - - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - for (int i = threadIdx.x; i < totalExpertCounts; i += blockDim.x) { - smemOffset[i] = int8_t{-1}; - smemKIdx[i] = int8_t{-1}; - } - __syncthreads(); - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // then wait on primary grid - if constexpr (KernelParams::UsePdl) { - cudaGridDependencySynchronize(); - } -#endif - - if (params.mPtrTopKIds != nullptr) { - if (validToken) { - if (laneIdx < params.mTopK) { - int offset = warpIdx * MaxNumExperts + params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx]; - smemKIdx[offset] = static_cast(laneIdx); - } - } - } else if (params.mPtrScores != nullptr) { - // in this case, each warp represents a token - BaseType score[VecSize]; - int32_t idx[VecSize]; - - BaseType warpTopKScore[MaxNumTopExperts]; - int32_t warpTopKExpertIdx[MaxNumTopExperts]; - - BaseType minScore = BaseType{-INFINITY}; - if (validToken) { - routingTopKExperts( - warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, - params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, - params.mApplySoftmaxAfterTopK); - - if (laneIdx < params.mTopK) { - int offset = warpIdx * MaxNumExperts + warpTopKExpertIdx[laneIdx]; - smemKIdx[offset] = static_cast(laneIdx); - if (params.mPtrTopKWeights != nullptr) { - params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = - OutputT{warpTopKScore[laneIdx]}; - } - } - } // end if (validToken) - } else if (params.mPtrTopKPacked != nullptr) { - if (validToken) { - if (laneIdx < params.mTopK) { - int offset = warpIdx * MaxNumExperts + - static_cast(params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].idx); - smemKIdx[offset] = static_cast(laneIdx); - if (params.mPtrTopKWeights != nullptr) { - params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = - static_cast(params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].score); - } - } - } - } - __syncthreads(); - - // set local experts - auto localExpertIdx = expert - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < params.mNumLocalExperts && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; - // Get the count of each expert and the offset for each token - int accExpertCount = 0; - - if (isLocalExpert) { - int offset = expert; - for (int j = 0; j < BlockKernelMaxNumTokens; j++) { - if (smemKIdx[offset] >= 0) { - smemOffset[offset] = static_cast(accExpertCount); - accExpertCount++; - } - offset += MaxNumExperts; - } - } - __syncthreads(); - // Get the number of CTAs and the offset for each CTA - int32_t numCta; - if constexpr (KernelParams::isPow2) { - numCta = divUpLog2(accExpertCount, params.mPaddingLog2); - } else { - numCta = divUpTileN(accExpertCount, params.mTileTokensDim); - } - int32_t ctaOffset = 0; - int32_t numNonExitingCtas; - Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); - - int32_t expertScanCounts = 0; - int32_t tmpCount; - if constexpr (KernelParams::isPow2) { - tmpCount = divUpMulLog2(accExpertCount, params.mPaddingLog2); - } else { - tmpCount = divUpMulTileN(accExpertCount, params.mTileTokensDim); - } - Scan(tempStorage).ExclusiveSum(tmpCount, expertScanCounts); - __syncthreads(); - - if (isLocalExpert) { - for (int cta = 0; cta < numCta; ++cta) { - const int32_t localExpertIdx = - (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; - params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - int32_t mnLimit1; - int32_t mnLimit2; - if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + accExpertCount; - } else { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + accExpertCount; - } - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); - } - } - - // at this point, we can write out padded count - if (threadIdx.x == 0) { - int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); - } else { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); - } - params.mPtrPermutedIdxSize[0] = permutedIdxSize; - params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // we can trigger the next kernel at this point - if constexpr (KernelParams::UsePdl) { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif - - for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++) { - int offset = tokenIdx * MaxNumExperts + threadIdx.x; - if (smemKIdx[offset] >= 0) { - int const expandedIdx = tokenIdx * params.mTopK + smemKIdx[offset]; - int const offsetWithinExpert = static_cast(smemOffset[offset]); - int const offsetForExpert = expertScanCounts; - int const permutedIdx = isLocalExpert ? offsetForExpert + offsetWithinExpert : int32_t{-1}; - - params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; - if (isLocalExpert) { - if (params.mPtrPermutedIdxToExpandedIdx != nullptr) { - params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; - } - params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; - } - } - } -} - -template -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) - routingIndicesClusterKernel(KernelParams params) { - // number of tokens/expanded idx is bounded by total number of warps - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - - using BaseType = std::conditional_t; - using TypePacked = PackedScoreIdx; - - static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; - - __shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * MaxNumTopExperts]; - - uint32_t const clusterBlockRank = blockIdx.x; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - int32_t const laneIdx = cutlass::arch::LaneId(); - - auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx; - auto scoreOffset = warpTokenIdx * params.mNumExperts; - bool validToken = warpTokenIdx < params.mNumTokens; - - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - // then wait on primary grid - if constexpr (KernelParams::UsePdl) { - cudaGridDependencySynchronize(); - } - - if (params.mPtrScores != nullptr) { - // in this case, each warp represents a token - BaseType score[VecSize]; - int32_t idx[VecSize]; - - BaseType warpTopKScore[MaxNumTopExperts]; - int32_t warpTopKExpertIdx[MaxNumTopExperts]; - - BaseType minScore = BaseType{-INFINITY}; - if (validToken) { - routingTopKExperts( - warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, - params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, - params.mApplySoftmaxAfterTopK); - - if (laneIdx < params.mTopK) { - smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx] = - TypePacked{warpTopKScore[laneIdx], static_cast(warpTopKExpertIdx[laneIdx])}; - } - } // end if (validToken) - } - - // make packed scores available to all threads in cluster - __cluster_barrier_arrive(); - __cluster_barrier_wait(); - - if (params.mPtrScores != nullptr) { - routingPermutation(params, smemPackedScoreIdx, warpIdx, - clusterBlockRank); - } else { - routingPermutation(params, smemPackedScoreIdx, warpIdx, - clusterBlockRank); - } -} -#else -__global__ void __launch_bounds__(NumThreads) - routingIndicesClusterKernel(KernelParams /* params */) { - assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); -} -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// this kernel is needed in case we have scores as input for the histogram kernel -template -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) - routingIndicesHistogramScoresKernel(KernelParams params) { - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - using BaseType = std::conditional_t; - - static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; - - int32_t const laneIdx = cutlass::arch::LaneId(); - int32_t const warpIdx = threadIdx.x / WarpSize; - int32_t const globalWarpIdx = blockIdx.x * KernelParams::MaxNumExperts / WarpSize + warpIdx; - int32_t const globalWarpStride = gridDim.x * KernelParams::MaxNumExperts / WarpSize; - BaseType minScore = BaseType{-INFINITY}; - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Wait on primary grid. - if constexpr (KernelParams::UsePdl) { - cudaGridDependencySynchronize(); - } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - - // initialize the mPtrExpertCounts - int32_t expertCountsNum = 2 * params.mNumExperts; - int32_t globalThreadIdx = blockIdx.x * KernelParams::MaxNumExperts + threadIdx.x; - int32_t globalThreadStride = gridDim.x * KernelParams::MaxNumExperts; - initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Trigger secondary kernel. - if constexpr (KernelParams::UsePdl) { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - - // in this case, each warp represents a token, and we use a grid-stride loop - // over all warps/tokens - BaseType allScores[VecSize]; - int32_t allExpertIdx[VecSize]; - BaseType warpTopKScore[MaxNumTopExperts]; - int32_t warpTopKExpertIdx[MaxNumTopExperts]; - for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) { - auto scoreOffset = tokenIdx * params.mNumExperts; - - routingTopKExperts( - warp, allScores, allExpertIdx, warpTopKScore, warpTopKExpertIdx, laneIdx, - params.mNumExperts, params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, - params.mApplySoftmaxAfterTopK); - - if (laneIdx < params.mTopK) { - PackedScoreIdx packedScore{static_cast(warpTopKScore[laneIdx]), - static_cast(warpTopKExpertIdx[laneIdx])}; - params.mPtrTopKPacked[tokenIdx * params.mTopK + laneIdx] = packedScore; - } - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -int32_t constexpr getMaxNumExperts(int32_t numExperts) { - if (numExperts <= topk::MaxNumExpertsUnit) { - return topk::MaxNumExpertsUnit; - } else if (numExperts <= NumExpertsLimit) { - return NumExpertsLimit; - } else { - TLLM_LOG_ERROR("Unsupported numExperts"); - return 0; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define LAUNCH_ROUTING_RENORNALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1) \ - if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, topk::MaxNumExpertsUnit); \ - } else if (data.mNumExperts <= NumExpertsLimit) { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, NumExpertsLimit); \ - } else { \ - TLLM_LOG_ERROR("Unsupported numExperts"); \ - } - -//////////////////////////////////////////////////////////////////////////////////////////////////// -void run(Data const& data, void* stream) { - TVM_FFI_ICHECK(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || - data.mPtrTopKIds != nullptr) - << "Routing kernel requires at least one input parameter"; - if (data.mPtrTopKIds != nullptr) { - TVM_FFI_ICHECK(data.mPtrTopKWeights != nullptr) - << "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " - "Renormalize routing."; - } - TVM_FFI_ICHECK(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && - data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr) - << "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"; - TVM_FFI_ICHECK_LE(data.mTopK, MaxNumTopExperts) - << "Routing kernel expects topK experts <= " << MaxNumTopExperts << ", got " << data.mTopK; - TVM_FFI_ICHECK_LE(data.mNumExperts, NumExpertsLimit) - << "Routing kernel expects #experts " << data.mNumExperts << " to be no more than " - << NumExpertsLimit << "."; - TVM_FFI_ICHECK_EQ(data.mNumExperts % 4, 0) - << "Routing kernel expects #experts " << data.mNumExperts << " to be a multiple of 4."; - - // FIXME: routingIndicesBlockKernel breaks the vllm + gpt-oss DeepEP - bool const useSingleBlock = - data.mNumTokens <= BlockKernelMaxNumTokens && data.mPtrTopKPacked == nullptr; - - bool const useSingleCluster = - data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) - ? MaxNumTokensSingleClusterScores - : MaxNumTokensSingleCluster); - - if (!useSingleCluster && !useSingleBlock) { - TVM_FFI_ICHECK(data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) - << "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."; - TVM_FFI_ICHECK(data.mPtrExpertCounts != nullptr) - << "When #tokens is large, `mPtrExpertCounts` is a required input."; - } - uint32_t const numThreadsHist = getMaxNumExperts(data.mNumExperts); - if (useSingleBlock) { - //@TODO: For now we use the single block kernel for cases with token number no larger than 4. - // We will future tune this threshold based on the performance. - LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesBlockKernel, 1, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); - } else if (useSingleCluster) { - LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, - NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); - } else { - uint32_t const expandedIdxSize = data.mNumTokens * data.mTopK; - uint32_t const histogramEltsPerBlock = 8 * numThreadsHist; - uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; - - // Limit grid size (all kernels use a grid-stride loop). - uint32_t const maxNumBlocks = 1024; - - int const numBlocksHistogram = std::min( - (expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); - int const numBlocksOffsets = - std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - - if (data.mPtrScores != nullptr && data.mPtrTopKIds == nullptr) { - LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); - } else { - // Reset the global histograms. - LAUNCH_ROUTING_RENORNALIZE(data, false, routingInitExpertCounts, - (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); - } - LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesHistogramKernel, numBlocksHistogram, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); - LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index af48040d0a..d28bf681bc 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -58,16 +58,16 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8, - RoutingMethodType routingMethodType, cudaStream_t stream) { + RoutingMethodType routingMethodType, cudaStream_t stream, + btg::Dtype dtypeLogits, bool normTopkProb) { if (routingMethodType == RoutingMethodType::DeepSeekV3) { FLASHINFER_CHECK(topK <= 22, "For DeepSeek routing method, must have topK <= 22"); FLASHINFER_CHECK(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4"); moe::dev::routing::routingDeepSeek::Data routingData; - routingData.mDtypeExpW = - btg::Dtype::Bfloat16; // for DeepSeek, the expW is currently always bfloat16 - routingData.mDtypeBias = dtypeBias; // for DeepSeek, the bias can be bfloat16 or fp32 - - routingData.mDtypeScore = btg::Dtype::Fp32; // for DeepSeek, the score is currently always fp32 + routingData.mDtypeOutput = + btg::Dtype::Bfloat16; // for DeepSeek, the expW is currently always bfloat16 + routingData.mDtypeInput = dtypeLogits; // routing logits can be bfloat16 or fp32 + routingData.mDtypeBias = dtypeBias; // for DeepSeek, the bias can be bfloat16 or fp32 routingData.mUsePdl = true; // output: @@ -85,7 +85,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 // input: routingData.mPtrRoutingBias = routingBias; - routingData.mPtrScores = reinterpret_cast(routingLogits); + routingData.mPtrScores = routingLogits; // type-erased; InputT selected by forceFloatInput routingData.mNumTokens = numTokens; routingData.mNumExperts = numExperts; routingData.mNumExpertGroups = nGroup; @@ -106,7 +106,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 topkGroup); } moe::dev::routing::routingLlama4::Data routingData; - routingData.mDtypeExpW = btg::Dtype::Bfloat16; + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mDtypeInput = dtypeLogits; // routing logits can be bfloat16 or fp32 routingData.mUsePdl = true; // output: @@ -133,21 +134,46 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; moe::dev::routing::routingLlama4::run(routingData, stream); - } else if (routingMethodType == RoutingMethodType::Renormalize /* default */ - || routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK */ - || routingMethodType == RoutingMethodType::TopK /* TopK only (no softmax) */) { - moe::dev::routing::routingRenormalize::Data routingData; + } else if (routingMethodType == RoutingMethodType::Default /* Softmax -> TopK */ + || routingMethodType == RoutingMethodType::Renormalize /* TopK -> Softmax */ + || routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK -> Renormalize */ + || routingMethodType == RoutingMethodType::TopK /* TopK only (no softmax) */ + || routingMethodType == RoutingMethodType::SigmoidRenorm /* Sigmoid -> TopK -> Renormalize */) { + using namespace moe::dev::routing; + routingCustom::Data routingData; // // Config // - routingData.mDtypeExpW = btg::Dtype::Bfloat16; - // routingData.mDtypeElt = dtypeElt; // no-op for now as hidden_state is not input + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mDtypeInput = dtypeLogits; // routing logits can be bfloat16 or fp32 routingData.mUsePdl = true; - routingData.mDoSoftmaxBeforeTopK = routingMethodType == RoutingMethodType::RenormalizeNaive; - routingData.mNormTopkProb = routingMethodType == RoutingMethodType::RenormalizeNaive; - routingData.mApplySoftmaxAfterTopK = routingMethodType == RoutingMethodType::Renormalize; + + // Map routing method types to policy-based routing: + // Note: RenormalizeNaive (Softmax → TopK → SumNormalize) is mathematically equivalent + // to Renormalize (TopK → Softmax), because taking softmax over all experts, selecting + // top-K, and dividing by their sum produces the same result as applying softmax only + // over the top-K values. We therefore use the same Renormalize implementation for both. + if (routingMethodType == RoutingMethodType::Default) { + // Softmax -> TopK (softmax on all scores, then select top-K) + routingData.mPreprocessType = RoutingPreprocessType::Softmax; + routingData.mPostprocessType = RoutingPostprocessType::None; + } else if (routingMethodType == RoutingMethodType::SigmoidRenorm) { + // Sigmoid -> TopK -> SumNormalize (renormalize) + routingData.mPreprocessType = RoutingPreprocessType::Sigmoid; + routingData.mPostprocessType = RoutingPostprocessType::SumNormalize; + routingData.mNormTopkProb = normTopkProb; + } else if (routingMethodType == RoutingMethodType::Renormalize || + routingMethodType == RoutingMethodType::RenormalizeNaive) { + // TopK -> Softmax (also used for RenormalizeNaive, see comment above) + routingData.mPreprocessType = RoutingPreprocessType::None; + routingData.mPostprocessType = RoutingPostprocessType::Softmax; + } else { + // TopK only (no softmax or renormalize) + routingData.mPreprocessType = RoutingPreprocessType::None; + routingData.mPostprocessType = RoutingPostprocessType::None; + } routingData.mPtrScores = routingLogits; @@ -181,7 +207,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; - moe::dev::routing::routingRenormalize::run(routingData, stream); + routingCustom::run(routingData, stream); } else { FLASHINFER_CHECK(false, "Unimplemented routing method ", serializeMoeRoutingMethodType(routingMethodType), " of enum ", diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 7e0760e7b2..5964310c4e 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -71,8 +71,10 @@ class RoutingMethodType(IntEnum): RenormalizeNaive = (4,) # TopK only (no softmax) TopK = (5,) + # SigmoidRenorm: Sigmoid -> TopK -> Renormalize (divide by sum of top-K weights) + SigmoidRenorm = (6,) # Unspecified - Unspecified = 6 + Unspecified = (7,) # Copied from csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h @@ -1132,6 +1134,7 @@ def forward( kwargs["do_finalize"], kwargs["enable_pdl"], [-1, -1] if tactic == -1 else tactic, + kwargs.get("norm_topk_prob", True), ) elif ( self.dtype_act == DtypeTrtllmGen.E4m3 @@ -1190,6 +1193,7 @@ def forward( kwargs["enable_pdl"], [-1, -1] if tactic == -1 else tactic, self.fp8_quantization_type, + kwargs.get("norm_topk_prob", True), ) else: # FP8 per tensor scale @@ -1217,6 +1221,7 @@ def forward( kwargs["enable_pdl"], [-1, -1] if tactic == -1 else tactic, self.activation_type, + kwargs.get("norm_topk_prob", True), ) elif ( self.dtype_act == DtypeTrtllmGen.Bfloat16 @@ -1246,6 +1251,7 @@ def forward( kwargs["enable_pdl"], output, [-1, -1] if tactic == -1 else tactic, + kwargs.get("norm_topk_prob", True), ) else: moe_op.trtllm_fp4_block_scale_moe( @@ -1281,6 +1287,7 @@ def forward( self.activation_type, output, [-1, -1] if tactic == -1 else tactic, + kwargs.get("norm_topk_prob", True), ) @classmethod @@ -1337,6 +1344,7 @@ def trtllm_bf16_moe_op( do_finalize: bool = True, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, + norm_topk_prob: bool = True, ) -> List[torch.Tensor]: assert routing_logits is not None or topk_ids is not None, ( "either routing_logits or topk_ids must be provided" @@ -1435,6 +1443,7 @@ def trtllm_bf16_moe_op( do_finalize, enable_pdl, [-1, -1] if tactic == -1 else tactic, + norm_topk_prob, ) if do_finalize: return [output] @@ -1500,6 +1509,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, + norm_topk_prob: bool = True, ) -> List[torch.Tensor]: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) @@ -1587,6 +1597,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( enable_pdl, [-1, -1] if tactic == -1 else tactic, activation_type, + norm_topk_prob, ) if do_finalize: return [output] @@ -1657,6 +1668,7 @@ def trtllm_fp8_block_scale_moe_op( enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, + norm_topk_prob: bool = True, ) -> List[torch.Tensor]: # Determine routing mode: compute from logits or use pre-computed if routing_logits is None: @@ -1781,6 +1793,7 @@ def trtllm_fp8_block_scale_moe_op( enable_pdl, [-1, -1] if tactic == -1 else tactic, fp8_quantization_type, + norm_topk_prob, ) if do_finalize: @@ -1867,6 +1880,7 @@ def trtllm_fp4_block_scale_moe_op( activation_type: int = ActivationType.Swiglu.value, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, + norm_topk_prob: bool = True, ) -> List[torch.Tensor]: if routing_logits is None: assert topk_ids is not None, ( @@ -2010,6 +2024,7 @@ def trtllm_fp4_block_scale_moe_op( activation_type, output, [-1, -1] if tactic == -1 else tactic, + norm_topk_prob, ) if do_finalize: return [output] @@ -2088,6 +2103,7 @@ def trtllm_mxint4_block_scale_moe_op( enable_pdl: Optional[bool] = None, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, + norm_topk_prob: bool = True, ) -> List[torch.Tensor]: routing_dtype = routing_logits.dtype hidden_size = hidden_states.shape[-1] @@ -2185,6 +2201,7 @@ def trtllm_mxint4_block_scale_moe_op( enable_pdl, output, [-1, -1] if tactic == -1 else tactic, + norm_topk_prob, ) if do_finalize: return [output] @@ -2255,6 +2272,7 @@ def trtllm_bf16_moe( do_finalize: bool = True, enable_pdl: bool = True, tune_max_num_tokens: int = 8192, + norm_topk_prob: bool = True, ) -> Union[List[torch.Tensor], torch.Tensor]: """BF16 MoE operation with autotuning support. @@ -2319,6 +2337,7 @@ def trtllm_bf16_moe( do_finalize, enable_pdl, tune_max_num_tokens, + norm_topk_prob, ) if do_finalize: @@ -2413,6 +2432,7 @@ def trtllm_bf16_routed_moe( do_finalize, enable_pdl, tune_max_num_tokens, + True, # norm_topk_prob: not used for pre-computed routing ) if do_finalize: @@ -2448,6 +2468,7 @@ def trtllm_fp8_per_tensor_scale_moe( enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, + norm_topk_prob: bool = True, ) -> Union[List[torch.Tensor], torch.Tensor]: """FP8 per tensor scale MoE operation. @@ -2507,6 +2528,7 @@ def trtllm_fp8_per_tensor_scale_moe( enable_pdl, tune_max_num_tokens, activation_type, + norm_topk_prob, ) if do_finalize: @@ -2543,6 +2565,7 @@ def trtllm_fp8_block_scale_moe( enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, + norm_topk_prob: bool = True, ) -> Union[List[torch.Tensor], torch.Tensor]: """FP8 block scale MoE operation. @@ -2609,6 +2632,7 @@ def trtllm_fp8_block_scale_moe( enable_pdl, tune_max_num_tokens, fp8_quantization_type, + norm_topk_prob, ) if do_finalize: @@ -2713,6 +2737,7 @@ def trtllm_fp8_block_scale_routed_moe( enable_pdl, tune_max_num_tokens, fp8_quantization_type, + True, # norm_topk_prob: not used for pre-computed routing ) if do_finalize: @@ -2756,6 +2781,7 @@ def trtllm_fp4_block_scale_moe( activation_type: int = ActivationType.Swiglu.value, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, + norm_topk_prob: bool = True, ) -> List[torch.Tensor]: """FP4 block scale MoE operation. @@ -2853,6 +2879,7 @@ def trtllm_fp4_block_scale_moe( activation_type, output, tune_max_num_tokens, + norm_topk_prob, ) @@ -2987,6 +3014,7 @@ def trtllm_fp4_block_scale_routed_moe( activation_type, output, tune_max_num_tokens, + True, # norm_topk_prob: not used for pre-computed routing ) @@ -3015,6 +3043,7 @@ def trtllm_mxint4_block_scale_moe( enable_pdl: Optional[bool] = None, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, + norm_topk_prob: bool = True, ) -> List[torch.Tensor]: """MxInt4 block scale MoE operation. @@ -3086,4 +3115,5 @@ def trtllm_mxint4_block_scale_moe( enable_pdl, output, tune_max_num_tokens, + norm_topk_prob, ) diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index a3240f7353..75bdcd783c 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -263,7 +263,8 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec: jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_runner.cu", jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_deepseek.cu", jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_llama4.cu", - jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_renormalize.cu", + jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_custom.cu", + jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_common.cu", jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_dev_kernel.cu", jit_env.FLASHINFER_CSRC_DIR / "trtllm_batched_gemm_runner.cu", ], diff --git a/include/flashinfer/trtllm/common/cudaUtils.h b/include/flashinfer/trtllm/common/cudaUtils.h index d10c40550a..c343a94986 100644 --- a/include/flashinfer/trtllm/common/cudaUtils.h +++ b/include/flashinfer/trtllm/common/cudaUtils.h @@ -269,4 +269,26 @@ inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); } inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); } +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Device query helpers — thin wrappers around CUDA runtime queries. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int getSMVersion() { + int device{-1}; + cudaGetDevice(&device); + int sm_major = 0; + int sm_minor = 0; + cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device); + cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device); + return sm_major * 10 + sm_minor; +} + +inline int getMultiProcessorCount() { + int device{-1}; + cudaGetDevice(&device); + int count = 0; + cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device); + return count; +} + } // namespace tensorrt_llm::common diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 560063c023..b2bf238162 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -148,108 +148,10 @@ namespace moe::dev { LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream); \ } -#define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ - if (data.mPaddingLog2 > 0) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, true), kernel, numBlocks, numThreads, smemSize, \ - stream); \ - } else { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, false), kernel, numBlocks, numThreads, \ - smemSize, stream); \ - } - -#define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ - if (data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ - } - -#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, extraFlag, numExperts, \ - numTopExperts) \ - if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, float, float, numExperts, numTopExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN( \ - data, coopLaunch, \ - LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN( \ - data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN( \ - data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, \ - numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ - } - -#define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, numExperts, numTopExperts) \ - if (extraFlag) { \ - LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, true, numExperts, numTopExperts); \ - } else { \ - LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, false, numExperts, numTopExperts); \ - } - //////////////////////////////////////////////////////////////////////////////////////////////////// - -#define LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, numExperts) \ - if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ - } - +// NOTE: Old routing-specific macros (LAUNCH_TILEN, LAUNCH_ROUTING_LLAMA4, +// LAUNCH_ROUTING_DEEPSEEK_*, LAUNCH_ROUTING_WITH_NUM_EXPERTS) have been moved to +// RoutingDevKernel.h which uses the new template signature with runtime isPow2/UsePdl. //////////////////////////////////////////////////////////////////////////////////////////////////// namespace activation { diff --git a/include/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh b/include/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh new file mode 100644 index 0000000000..4cb82e7172 --- /dev/null +++ b/include/flashinfer/trtllm/fused_moe/RoutingCustomPolicy.cuh @@ -0,0 +1,609 @@ +/* + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "RoutingKernel.cuh" + +namespace moe::dev::routing { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Preprocess policies: applied to all expert scores BEFORE topK selection. +// +// Each policy must provide: +// - template using BaseType +// The data type used for intermediate score computation. +// - template struct Params { void set(Data const&); } +// Policy-specific runtime data, populated from the host-side Data struct. +// Empty for policies that don't need extra data (zero register cost). +// - template +// static void apply(warp, score[VecSize], idx[VecSize], numExperts, params) +// Transforms scores in-place before topK selection. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// No-op: scores are passed through unchanged. +struct NoOpPreprocess { + /// BaseType: when no preprocess is applied, use the input type directly. + template + using BaseType = InputT; + + template + struct Params { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply( + cg::thread_block_tile const& /*warp*/, DataType (&/*score*/)[VecSize], + int32_t const (&/*idx*/)[VecSize], int32_t /*numExperts*/, ParamsT const& /*params*/) {} +}; + +/// Softmax: applies softmax over all expert scores before topK selection. +struct SoftmaxPreprocess { + /// BaseType: softmax is always computed in float for numerical stability. + template + using BaseType = float; + + template + struct Params { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&score)[VecSize], + int32_t const (&/*idx*/)[VecSize], + int32_t /*numExperts*/, + ParamsT const& /*params*/) { + calcSoftmax(warp, score); + } +}; + +/// Sigmoid: applies sigmoid(score) for topK selection (no bias). +struct SigmoidPreprocess { + /// BaseType: sigmoid is computed in float for numerical stability. + template + using BaseType = float; + + template + struct Params { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply( + cg::thread_block_tile const& /*warp*/, DataType (&score)[VecSize], + int32_t const (&idx)[VecSize], int32_t numExperts, ParamsT const& /*params*/) { +#pragma unroll + for (int i = 0; i < VecSize; i++) { + float s = sigmoid_accurate(static_cast(score[i])); + score[i] = idx[i] < numExperts ? static_cast(s) : DataType{-INFINITY}; + } + } +}; + +/// SigmoidBias: applies sigmoid(score) + bias[expertIdx] for topK selection. +/// Used by DeepSeek-style routing where expert selection is based on biased sigmoid scores. +struct SigmoidBiasPreprocess { + /// BaseType: sigmoid is computed in float for numerical stability. + template + using BaseType = float; + + template + struct Params { + // Store as void const* to support any bias dtype (float, bfloat16, etc.) without conversion. + void const* ptrRoutingBias = nullptr; + batchedGemm::trtllm::gen::Dtype dtypeBias = batchedGemm::trtllm::gen::Dtype::Bfloat16; + + void set(routingCustom::Data const& data) { + ptrRoutingBias = data.mPtrRoutingBias; + dtypeBias = data.mDtypeBias; + } + }; + + template + __forceinline__ __device__ static void apply( + cg::thread_block_tile const& /*warp*/, DataType (&score)[VecSize], + int32_t const (&idx)[VecSize], int32_t numExperts, ParamsT const& params) { +#pragma unroll + for (int i = 0; i < VecSize; i++) { + float s = sigmoid_accurate(static_cast(score[i])); + float bias = idx[i] < numExperts + ? loadScalar(params.ptrRoutingBias, idx[i], params.dtypeBias) + : float{-INFINITY}; + score[i] = static_cast(s + bias); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Postprocess policies: applied to the top-K scores AFTER topK selection. +// +// Each policy must provide: +// - template struct Params { void set(Data const&); } +// Policy-specific runtime data. Empty when not needed. +// - template +// static void apply(warp, warpTopKScore[K], warpTopKExpertIdx[K], laneIdx, topK, params) +// Transforms top-K scores in-place after topK selection. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// No-op: top-K scores are left unchanged. +struct NoOpPostprocess { + template + struct Params { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply( + cg::thread_block_tile const& /*warp*/, DataType (&/*warpTopKScore*/)[K], + int32_t const (&/*warpTopKExpertIdx*/)[K], int32_t /*laneIdx*/, int32_t /*topK*/, + ParamsT const& /*params*/) {} +}; + +/// Softmax: applies softmax over the top-K scores. +struct SoftmaxPostprocess { + template + struct Params { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&warpTopKScore)[K], + int32_t const (&/*warpTopKExpertIdx*/)[K], + int32_t laneIdx, int32_t topK, + ParamsT const& /*params*/) { + DataType minScore = DataType{-INFINITY}; + auto softmaxScore = + calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK); + if (laneIdx < topK) { + warpTopKScore[laneIdx] = softmaxScore; + } + } +}; + +/// SumNormalize: divides each top-K score by the sum of all top-K scores. +/// Used when softmax has already been applied before topK selection. +struct SumNormalizePostprocess { + template + struct Params { + bool normTopkProb = true; + + void set(routingCustom::Data const& data) { normTopkProb = data.mNormTopkProb; } + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&warpTopKScore)[K], + int32_t const (&/*warpTopKExpertIdx*/)[K], + int32_t laneIdx, int32_t topK, + ParamsT const& params) { + float sum = float{1.f}; + if (params.normTopkProb) { + sum = static_cast(laneIdx < topK ? warpTopKScore[laneIdx] : 0); + sum = cg::reduce(warp, sum, cg::plus()); + } + if (laneIdx < topK) { + warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum; + } + } +}; + +/// ScaledSumNormalize: recovers un-biased sigmoid scores by subtracting per-expert bias from the +/// selection scores (sigmoid + bias), then normalizes by sum and applies routeScale. +/// Used by DeepSeek-style routing: final_weight = sigmoid(raw) * routeScale / (sum + epsilon). +/// DeepSeek uses epsilon=0 (no guard); MiniMax2 uses epsilon=1e-20 to prevent division by zero. +struct ScaledSumNormalizePostprocess { + template + struct Params { + // Store as void const* to support any bias dtype (float, bfloat16, etc.) without conversion. + void const* ptrRoutingBias = nullptr; + batchedGemm::trtllm::gen::Dtype dtypeBias = batchedGemm::trtllm::gen::Dtype::Bfloat16; + float routeScale = 1.0f; + float sumEpsilon = 0.0f; + + void set(routingCustom::Data const& data) { + ptrRoutingBias = data.mPtrRoutingBias; + dtypeBias = data.mDtypeBias; + routeScale = data.mRouteScale; + sumEpsilon = data.mSumEpsilon; + } + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&warpTopKScore)[K], + int32_t const (&warpTopKExpertIdx)[K], + int32_t laneIdx, int32_t topK, + ParamsT const& params) { + // Recover sigmoid score: selection_score = sigmoid(raw) + bias, so sigmoid = score - bias + float biasVal = + laneIdx < topK + ? loadScalar(params.ptrRoutingBias, warpTopKExpertIdx[laneIdx], params.dtypeBias) + : 0.f; + float sigmoidScore = + laneIdx < topK ? (static_cast(warpTopKScore[laneIdx]) - biasVal) : 0.f; + float sum = cg::reduce(warp, sigmoidScore, cg::plus()); + if (laneIdx < topK) { + warpTopKScore[laneIdx] = + static_cast(sigmoidScore * params.routeScale / (sum + params.sumEpsilon)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// ExpertSelectPolicy: encapsulates the entire expert selection logic. +// +// Each policy must provide: +// - template using BaseType +// The data type used for intermediate score computation. +// - template struct Params { void set(Data const&); } +// Policy-specific runtime data, populated from the host-side Data struct. +// Empty for policies that don't need extra data (zero register cost). +// - template +// static void apply(warp, warpTopKScore[K], warpTopKExpertIdx[K], laneIdx, numExperts, topK, +// ptrScores, params) +// Selects the top-K experts and computes their weights. +// +// The default TopKExpertSelect wraps existing PreprocessPolicy + PostprocessPolicy, +// but users can write completely custom policies that bypass the preprocess+topK+postprocess +// pattern (e.g., lookup-table-based expert selection). +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default ExpertSelectPolicy: preprocess + topK reduction + postprocess. +/// Wraps existing PreprocessPolicy and PostprocessPolicy as internal composition. +template +struct TopKExpertSelect { + /// BaseType: delegated to the preprocess policy. + template + using BaseType = typename PreprocessPolicy_::template BaseType; + + /// Params: combines preprocess and postprocess runtime parameters. + template + struct Params { + typename PreprocessPolicy_::template Params mPreprocessParams; + typename PostprocessPolicy_::template Params mPostprocessParams; + + void set(routingCustom::Data const& data) { + mPreprocessParams.set(data); + mPostprocessParams.set(data); + } + }; + + /// Selects top-K experts using preprocess → topK reduction → postprocess. + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&warpTopKScore)[K], + int32_t (&warpTopKExpertIdx)[K], + int32_t const laneIdx, int32_t const numExperts, + int32_t topK, InputType const* ptrScores, + KP const& params) { + DataType minScore = DataType{-INFINITY}; + DataType score[VecSize]; + int32_t idx[VecSize]; + + for (int i = 0; i < VecSize; i++) { + auto expertIdx = i * WarpSize + laneIdx; + auto newScore = + expertIdx < numExperts ? static_cast(ptrScores[expertIdx]) : minScore; + score[i] = newScore; + idx[i] = expertIdx; + } + + // Apply preprocess (e.g. softmax over all scores, sigmoid + bias, ...) + PreprocessPolicy_::apply(warp, score, idx, numExperts, + params.mExpertSelectParams.mPreprocessParams); + + // Get the top-k scores and their corresponding expert indices + topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore, topK); + + // Apply postprocess (e.g. renormalize, softmax over top-K, scaled renormalize, ...) + PostprocessPolicy_::apply(warp, warpTopKScore, warpTopKExpertIdx, laneIdx, topK, + params.mExpertSelectParams.mPostprocessParams); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace routingCustom { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Expert-count tiers (must be multiples of WarpSize=32 and of 4). +// Each tier covers all values ≤ the tier constant. +static constexpr int NumExperts128Experts = 128; +static constexpr int NumExperts160Experts = 160; +static constexpr int NumExperts256Experts = 256; +static constexpr int NumExperts384Experts = 384; +static constexpr int NumExperts512Experts = 512; +static constexpr int NumExperts576Experts = 576; +static constexpr int MaxSupportedExperts = 2048; + +// TopK tiers (must be ≤ WarpSize=32). +static constexpr int NumTop4Experts = 4; +static constexpr int NumTop8Experts = 8; +static constexpr int NumTop16Experts = 16; +static constexpr int NumTop22Experts = 22; +static constexpr int MaxSupportedTopExperts = 32; + +static constexpr int NumThreads = 1024; +static constexpr int NumWarps = NumThreads / WarpSize; + +static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; +static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; + +static constexpr int BlockKernelMaxNumTokens = 4; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int32_t getMaxNumExperts(int32_t numExperts) { + if (numExperts <= NumExperts128Experts) { + return NumExperts128Experts; + } else if (numExperts <= NumExperts160Experts) { + return NumExperts160Experts; + } else if (numExperts <= NumExperts256Experts) { + return NumExperts256Experts; + } else if (numExperts <= NumExperts384Experts) { + return NumExperts384Experts; + } else if (numExperts <= NumExperts512Experts) { + return NumExperts512Experts; + } else if (numExperts <= NumExperts576Experts) { + return NumExperts576Experts; + } else if (numExperts <= MaxSupportedExperts) { + return MaxSupportedExperts; + } else { + FLASHINFER_WARN("Unsupported numExperts"); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// TIER PAIR TYPES — compile-time (MaxNumExperts, MaxNumTopExperts) configuration. +// +// Each Tier declares a supported kernel instantiation. +// TierList, ...> is an ordered list tried from first to last. +// The dispatch picks the FIRST pair where numExperts ≤ E AND topK ≤ K. +// +// Pairs must be sorted so that tighter tiers come first: +// - Sort by E ascending, then by K ascending within equal E. +// - A config (numExperts, topK) always matches the tightest available pair. +// - If the tightest expert tier doesn't have a topK that covers the runtime topK, +// the dispatch falls through to the next larger expert tier that does. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Tier { + static constexpr int kExperts = E_; + static constexpr int kTopK = K_; +}; + +template +struct TierList {}; + +// Recursive dispatch: try each tier in order, call `fn` with the first match. +// fn receives (integral_constant, integral_constant) as compile-time args. +// Base case: empty list — no match. +template +inline bool dispatchTierPairs(TierList<>*, Data const& /*data*/, Fn&& /*fn*/) { + return false; +} + +// Recursive case: check First, then recurse on Rest... +template +inline bool dispatchTierPairs(TierList*, Data const& data, Fn&& fn) { + if (data.mNumExperts <= First::kExperts && data.mTopK <= First::kTopK) { + fn(std::integral_constant{}, + std::integral_constant{}); + return true; + } + return dispatchTierPairs(static_cast*>(nullptr), data, + std::forward(fn)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// POLICY TIER CONFIGURATION +// +// PolicyTraits::Pairs declares the supported (expert, topK) pairs. +// Only these pairs are compiled as kernel instantiations. +// To add support for a new model config, add a Tier to the appropriate TierList. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default: fallback for new/unknown policies. +template +struct PolicyTraits { + using Pairs = TierList, Tier<128, 32>, Tier<256, 8>, Tier<256, 32>, Tier<512, 8>, + Tier<512, 32>, Tier<2048, 8>, Tier<2048, 32>>; +}; + +/// Softmax + None (Default: Softmax -> TopK). +template <> +struct PolicyTraits { + using Pairs = TierList, // Small expert counts (≤128 experts) + Tier<256, 8> // Medium expert counts (≤256 experts) + >; +}; + +/// None + Softmax (Renormalize): many model configs. +template <> +struct PolicyTraits { + using Pairs = TierList< + Tier<128, 4>, // Mixtral 8x7B (topK=2), Qwen2-MoE (topK=4), Arctic (topK=2), DBRX + // (topK=4), GPT-OSS + Tier<128, 8>, // DeepSeek-V2-Lite (topK=6), Mixtral 8x22B (topK=2) + Tier<160, 8>, // Qwen3-Coder-480B + Tier<256, 8>, // Mistral Large 3 (topK=8) + Tier<256, 16>, // Models with 256 experts and topK 9..16 + Tier<512, 8>, // Various 512-expert models + Tier<512, 16>, // Various 512-expert models with high topK + Tier<512, 22>, // Nemotron Super V3 (512 experts, topK=22) + Tier<576, 8>, // Customized model with 576 experts + Tier<2048, 32> // Large-expert fallback + >; +}; + +/// Sigmoid + SumNormalize (SigmoidRenorm: Sigmoid -> TopK -> Renormalize). +template <> +struct PolicyTraits { + using Pairs = TierList, // Small expert counts (≤128 experts) + Tier<256, 8> // Medium expert counts (≤256 experts) + >; +}; + +/// SigmoidBias + ScaledSumNormalize (DeepSeek nGroup≤1 / MiniMax2 / Kimi-K2 / Nemotron SuperV3). +template <> +struct PolicyTraits { + using Pairs = TierList, // Small expert counts (≤128 experts, e.g. DeepSeek-V2-Lite) + Tier<256, 8>, // MiniMax M2 (256 experts, topK=6) + Tier<384, 8>, // Kimi K2 (384 experts) + Tier<512, 8>, // DeepSeek nGroup≤1 (256 experts → E512 fallback) + Tier<512, 22> // Nemotron Super V3 (512 experts, topK=22, nGroup≤1) + >; +}; + +/// None + None (TopK only: no softmax or renormalize). +template <> +struct PolicyTraits { + using Pairs = TierList, // Small expert counts (≤128 experts) + Tier<256, 8> // Medium expert counts (≤256 experts) + >; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GENERIC DISPATCH MACROS +// +// These macros are fixed infrastructure — they never need editing when adding new +// policies or changing tier support. All configuration lives in PolicyTraits above. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Generic per-policy dispatch. Iterates PolicyTraits::Pairs, +// picking the first (expert, topK) pair that covers the runtime values. +// +// IMPORTANT: numThreads is clamped to at least min(MaxNumExperts, 1024) from the dispatched tier. +#define LAUNCH_ROUTING_FOR_POLICY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, PreProc, PostProc) \ + [&](auto pt_tag_) { \ + using Pairs_ = typename decltype(pt_tag_)::Pairs; \ + bool dispatched_ = dispatchTierPairs( \ + static_cast(nullptr), data, [&](auto eTag_, auto kTag_) { \ + constexpr int tierMaxExp_ = decltype(eTag_)::value; \ + constexpr int tierThreads_ = tierMaxExp_ <= 1024 ? tierMaxExp_ : 1024; \ + int const effectiveThreads_ = \ + std::max(static_cast(numThreads), tierThreads_); \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, effectiveThreads_, \ + smemSize, stream, PreProc, PostProc, \ + decltype(eTag_)::value, decltype(kTag_)::value); \ + }); \ + if (!dispatched_) { \ + FLASHINFER_WARN("No tier covers numExperts=%d topK=%d", data.mNumExperts, data.mTopK); \ + } \ + }(PolicyTraits{}) + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// CUSTOM EXPERT SELECT DISPATCH +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Generic dispatch for custom ExpertSelectPolicy. PolicyTraits key is . +// Same numThreads clamping as LAUNCH_ROUTING_FOR_POLICY — see comment above. +#define LAUNCH_ROUTING_FOR_EXPERT_SELECT(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, ExpertSelect) \ + [&](auto pt_tag_) { \ + using Pairs_ = typename decltype(pt_tag_)::Pairs; \ + bool dispatched_ = dispatchTierPairs( \ + static_cast(nullptr), data, [&](auto eTag_, auto kTag_) { \ + constexpr int tierMaxExp_ = decltype(eTag_)::value; \ + constexpr int tierThreads_ = tierMaxExp_ <= 1024 ? tierMaxExp_ : 1024; \ + int const effectiveThreads_ = \ + std::max(static_cast(numThreads), tierThreads_); \ + LAUNCH_ROUTING_WITH_EXPERT_SELECT(data, coopLaunch, kernel, numBlocks, \ + effectiveThreads_, smemSize, stream, ExpertSelect, \ + decltype(eTag_)::value, decltype(kTag_)::value); \ + }); \ + if (!dispatched_) { \ + FLASHINFER_WARN("No tier covers numExperts=%d topK=%d", data.mNumExperts, data.mTopK); \ + } \ + }(PolicyTraits{}) + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// PUBLIC DISPATCH MACROS +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Lightweight dispatch for utility kernels (histogram, init-counts, offsets) that do NOT use +// expert select policies, InputT, or MaxNumTopExperts. +// - Always uses NoOp expert select (no policy dispatch). +// - Always uses a fixed NumTop8Experts (no topK-tier dispatch). +// - Dispatches only on expert tiers. +#define LAUNCH_ROUTING_CUSTOM_NO_POLICY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream) \ + if (data.mNumExperts <= NumExperts128Experts) { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, NoOpPreprocess, NoOpPostprocess, NumExperts128Experts, \ + NumTop8Experts); \ + } else if (data.mNumExperts <= NumExperts160Experts) { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, NoOpPreprocess, NoOpPostprocess, NumExperts160Experts, \ + NumTop8Experts); \ + } else if (data.mNumExperts <= NumExperts256Experts) { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, NoOpPreprocess, NoOpPostprocess, NumExperts256Experts, \ + NumTop8Experts); \ + } else if (data.mNumExperts <= NumExperts384Experts) { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, NoOpPreprocess, NoOpPostprocess, NumExperts384Experts, \ + NumTop8Experts); \ + } else if (data.mNumExperts <= NumExperts512Experts) { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, NoOpPreprocess, NoOpPostprocess, NumExperts512Experts, \ + NumTop8Experts); \ + } else if (data.mNumExperts <= NumExperts576Experts) { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, NoOpPreprocess, NoOpPostprocess, NumExperts576Experts, \ + NumTop8Experts); \ + } else if (data.mNumExperts <= MaxSupportedExperts) { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, NoOpPreprocess, NoOpPostprocess, MaxSupportedExperts, \ + NumTop8Experts); \ + } else { \ + FLASHINFER_WARN("Unsupported numExperts"); \ + } + +// Top-level dispatch: maps runtime preprocess/postprocess enums to compile-time policy types, +// then delegates to LAUNCH_ROUTING_FOR_POLICY which reads PolicyTraits for tier support. +// Use this ONLY for kernels that call ExpertSelectPolicy::apply (block, cluster, histogramScores). +#define LAUNCH_ROUTING_CUSTOM(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.mPreprocessType == RoutingPreprocessType::SigmoidBias) { \ + LAUNCH_ROUTING_FOR_POLICY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + SigmoidBiasPreprocess, ScaledSumNormalizePostprocess); \ + } else if (data.mPreprocessType == RoutingPreprocessType::Sigmoid) { \ + LAUNCH_ROUTING_FOR_POLICY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + SigmoidPreprocess, SumNormalizePostprocess); \ + } else if (data.mPreprocessType == RoutingPreprocessType::Softmax && \ + data.mPostprocessType == RoutingPostprocessType::None) { \ + LAUNCH_ROUTING_FOR_POLICY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + SoftmaxPreprocess, NoOpPostprocess); \ + } else if (data.mPreprocessType == RoutingPreprocessType::Softmax) { \ + LAUNCH_ROUTING_FOR_POLICY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + SoftmaxPreprocess, SumNormalizePostprocess); \ + } else if (data.mPostprocessType == RoutingPostprocessType::Softmax) { \ + LAUNCH_ROUTING_FOR_POLICY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, SoftmaxPostprocess); \ + } else { \ + LAUNCH_ROUTING_FOR_POLICY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingCustom +} // namespace moe::dev::routing diff --git a/include/flashinfer/trtllm/fused_moe/RoutingDevKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingDevKernel.h new file mode 100644 index 0000000000..3c92dd7502 --- /dev/null +++ b/include/flashinfer/trtllm/fused_moe/RoutingDevKernel.h @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "DevKernel.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Routing-specific launch macros. +// These macros build on top of LAUNCH_ESC from DevKernel.h. +// +// Unlike the generic LAUNCH_PDL (which instantiates 2 kernels for UsePdl=true/false), +// LAUNCH_PDL_ROUTING instantiates only 1 kernel and passes UsePdl as a runtime field +// in KernelParams. This halves routing kernel instantiations. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define LAUNCH_PDL_ROUTING(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, \ + stream) \ + do { \ + cudaLaunchConfig_t config{}; \ + config.gridDim = numBlocks; \ + config.blockDim = numThreads; \ + config.dynamicSmemBytes = smemSize; \ + config.stream = (cudaStream_t)stream; \ + \ + cudaLaunchAttribute attributes[2] = {}; \ + attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \ + /* mUsePdl controls in-kernel sync/trigger; mPdlOverlapWithNext controls whether */ \ + /* the NEXT kernel in the stream is allowed to start before this one finishes. */ \ + /* Both must be true for overlap: mUsePdl ensures PDL is globally enabled, */ \ + /* mPdlOverlapWithNext is false for the last routing kernel so the consumer GEMM */ \ + /* (which may lack cudaGridDependencySynchronize) can't read stale routing data. */ \ + attributes[0].val.programmaticStreamSerializationAllowed = \ + int(data.mUsePdl && data.mPdlOverlapWithNext); \ + attributes[1].id = cudaLaunchAttributeCooperative; \ + attributes[1].val.cooperative = int(coopLaunch); \ + config.attrs = attributes; \ + config.numAttrs = 2; \ + auto params = KernelParams::setKernelParams(data); \ + auto kernelTyped = kernel>; \ + if (smemSize > 48 * 1024) \ + CHECK_CUDA_ERROR( \ + cudaFuncSetAttribute(kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize));\ + CHECK_CUDA_ERROR(cudaLaunchKernelEx(&config, kernelTyped, params)); \ + } while (0) + +// Llama4 dispatch: uses data.mDtypeOutput and data.mDtypeInput. +#define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.mDtypeOutput == tg::Dtype::Fp32) { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/, \ + 1 /* Always 1 for llama4*/), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && \ + data.mDtypeInput == tg::Dtype::Fp32) { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, \ + 128 /* Always 128 for llama4*/, 1 /* Always 1 for llama4*/), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeOutput == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, \ + 128 /* Always 128 for llama4*/, 1 /* Always 1 for llama4*/), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported dtypeOutput"); \ + } + +// DeepSeek dispatch: uses data.mDtypeOutput. +#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag, forceFloatInput, \ + numExperts, numTopExperts) \ + if (data.mDtypeOutput == tg::Dtype::Fp32 && extraFlag) { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, float, numExperts, numTopExperts, true), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeOutput == tg::Dtype::Fp32) { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, float, numExperts, numTopExperts, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, true), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && extraFlag) { \ + LAUNCH_PDL_ROUTING( \ + data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && forceFloatInput) { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeOutput == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL_ROUTING( \ + data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported dtypeOutput"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// routingCustom dispatch: uses data.mDtypeOutput (OutputT) and data.mDtypeInput (InputT). +// These are routingCustom::Data fields, NOT used by DeepSeek/Llama4 macros. +// Wraps (PreProc, PostProc) into TopKExpertSelect for the standard +// preprocess→topK→postprocess flow. +#define LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, PreProc, PostProc, numExperts, numTopExperts) \ + if (data.mDtypeOutput == tg::Dtype::Fp32) { \ + LAUNCH_PDL_ROUTING( \ + data, coopLaunch, \ + LAUNCH_ESC(float, float, numExperts, numTopExperts, TopKExpertSelect), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && \ + data.mDtypeInput == tg::Dtype::Fp32) { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, \ + TopKExpertSelect), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeOutput == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, \ + TopKExpertSelect), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported dtypeOutput"); \ + } + +// routingCustom dispatch for custom ExpertSelectPolicy types that don't use PreProc/PostProc. +// Use this when the policy does NOT follow the standard preprocess→topK→postprocess pattern. +// ExpertSelect must satisfy the ExpertSelectPolicy concept (see RoutingCustomPolicy.cuh). +#define LAUNCH_ROUTING_WITH_EXPERT_SELECT(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, ExpertSelect, numExperts, \ + numTopExperts) \ + if (data.mDtypeOutput == tg::Dtype::Fp32) { \ + LAUNCH_PDL_ROUTING( \ + data, coopLaunch, \ + LAUNCH_ESC(float, float, numExperts, numTopExperts, ExpertSelect), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && \ + data.mDtypeInput == tg::Dtype::Fp32) { \ + LAUNCH_PDL_ROUTING( \ + data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, ExpertSelect), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeOutput == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL_ROUTING( \ + data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, ExpertSelect), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported dtypeOutput"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh b/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh index 17143ab8a4..7511c27549 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,19 +15,21 @@ */ #pragma once +#include "RoutingDevKernel.h" +#include "RoutingKernel.h" +#include "RoutingKernelTopK.cuh" + #include #include -#include - #include + #include -#include +#include -#include "DevKernel.h" -#include "RoutingKernel.h" -#include "RoutingKernelTopK.cuh" +#include //////////////////////////////////////////////////////////////////////////////////////////////////// + namespace moe::dev { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -44,6 +46,23 @@ static constexpr int NumEltsPerOffsetTilePerThread = 8; //////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dereference a type-erased pointer at the given index, reading the value in its native dtype. +/// Returns float since routing computations are done in float for numerical stability. +__forceinline__ __device__ float loadScalar(void const* ptr, int idx, + batchedGemm::trtllm::gen::Dtype dtype) { + namespace tg = batchedGemm::trtllm::gen; + switch (dtype) { + case tg::Dtype::Fp32: + return static_cast(ptr)[idx]; + case tg::Dtype::Bfloat16: + return static_cast(static_cast<__nv_bfloat16 const*>(ptr)[idx]); + default: + return 0.f; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + static __device__ inline float sigmoid_accurate(float x) { return 0.5f * tanhf(0.5f * x) + 0.5f; } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -97,7 +116,8 @@ __host__ __device__ constexpr int32_t getBits(int32_t value, int idx) { template __host__ __device__ constexpr void setBits(int32_t& value, int32_t newBits, int idx) { if constexpr (!IsZero) { - int mask = idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF; + int mask = + idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF; value &= mask; } value |= (newBits << (idx * 8)); @@ -151,18 +171,22 @@ __device__ void calcSoftmax(cg::thread_block_tile const& warp, template __device__ DataType calcSoftmax(cg::thread_block_tile const& warp, DataType score, int32_t laneIdx, int32_t NumTopExperts) { - DataType maxScore = DataType{-INFINITY}; + // Compute in float to support half/bfloat16 inputs safely. + // cg::reduce with cg::greater only supports float/double and integer types; + // using __nv_bfloat16 or __half directly can generate unsupported redux.sync.max instructions. + float maxScore = -INFINITY; if (laneIdx < NumTopExperts) { - maxScore = score >= maxScore ? score : maxScore; + float si = static_cast(score); + maxScore = si >= maxScore ? si : maxScore; } - maxScore = cg::reduce(warp, maxScore, cg::greater()); + maxScore = cg::reduce(warp, maxScore, cg::greater()); - float sumScore = float{0.f}; - float newScore; + float sumScore = 0.f; + float newScore = 0.f; // Get the summation of scores for each token if (laneIdx < NumTopExperts) { - newScore = static_cast(score) - static_cast(maxScore); - newScore = static_cast(exp(newScore)); + newScore = static_cast(score) - maxScore; + newScore = expf(newScore); sumScore += newScore; } sumScore = cg::reduce(warp, sumScore, cg::plus()); @@ -183,6 +207,13 @@ __device__ void routingPermutation(KernelParams params, using OutputT = typename KernelParams::OutputT; using TypePacked = PackedScoreIdx; + // When MaxNumExperts > NumThreads, each thread handles multiple experts. + static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; + static constexpr int ExpertsPerThread = + MaxNumExperts <= NumThreads ? 1 : MaxNumExperts / NumThreads; + static_assert(MaxNumExperts <= NumThreads || MaxNumExperts % NumThreads == 0, + "MaxNumExperts must be <= NumThreads or a multiple of NumThreads"); + static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; // Number of threads in the cluster. static constexpr int NumThreadsPerCluster = NumThreads * NumBlocksPerCluster; @@ -199,13 +230,17 @@ __device__ void routingPermutation(KernelParams params, uint32_t const clusterThreadIdx = NumThreads * clusterBlockRank + threadIdx.x; auto expandedIdxSize = params.mNumTokens * params.mTopK; - // number of experts is bounded by number of threads - __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads]; - __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads]; + // number of experts may exceed number of threads — size by MaxNumExperts + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[MaxNumExperts]; + __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[MaxNumExperts]; - // pre-fill the counts with 0 - if (threadIdx.x < params.mNumExperts) { - smemExpertCount[threadIdx.x] = 0; + // pre-fill the counts with 0 — each thread handles ExpertsPerThread experts +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + smemExpertCount[expert] = 0; + } } __syncthreads(); @@ -216,12 +251,6 @@ __device__ void routingPermutation(KernelParams params, int32_t expertOffsets[MaxExpandedIdxPerThread]; auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a - // time, and branch between a fast path without bound checks and a slow path with bound checks. - // TODO(mjoux): potentially add this back for perf tuning - // int constexpr IterStride = 4; - // static_assert(MaxExpandedIdxPerThread % IterStride == 0); - // Define a lambda to avoid code duplication in both branches. auto loopBody = [&](int ii, int expandedIdx) { TypePacked scoreIdx; @@ -243,7 +272,7 @@ __device__ void routingPermutation(KernelParams params, // check whether this expert is local to our GPU at all and ignore if not auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0; if (params.mPtrTopKWeights != nullptr && params.mPtrTopKIds == nullptr) { params.mPtrTopKWeights[expandedIdx] = OutputT{scoreIdx.score}; @@ -284,30 +313,37 @@ __device__ void routingPermutation(KernelParams params, __cluster_barrier_wait(); // - // Each thread now represents one expert + // Each thread now represents ExpertsPerThread experts // - // Total number of tokens for this expert. - int32_t count = 0; + // Total number of tokens for each expert this thread handles. + int32_t count[ExpertsPerThread]; // Per-expert offset for this block. - int32_t blockExpertOffset = 0; + int32_t blockExpertOffset[ExpertsPerThread]; - if (threadIdx.x < params.mNumExperts) { - // Get the histogram bin from each rank for this expert. - int32_t expertCounts[NumBlocksPerCluster]; #pragma unroll - for (int rank = 0; rank < NumBlocksPerCluster; rank++) { - int32_t const* remoteSmem = cg::cluster_group::map_shared_rank(smemExpertCount, rank); - expertCounts[rank] = rank * NumWarps < params.mNumTokens ? remoteSmem[threadIdx.x] : 0; - } + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + count[e] = 0; + blockExpertOffset[e] = 0; + + if (expert < params.mNumExperts) { + // Get the histogram bin from each rank for this expert. + int32_t expertCounts[NumBlocksPerCluster]; +#pragma unroll + for (int rank = 0; rank < NumBlocksPerCluster; rank++) { + int32_t const* remoteSmem = cg::cluster_group::map_shared_rank(smemExpertCount, rank); + expertCounts[rank] = rank * NumWarps < params.mNumTokens ? remoteSmem[expert] : 0; + } - // Compute an exclusive prefix sum of the block-local count. + // Compute an exclusive prefix sum of the block-local count. #pragma unroll - for (int rank = 0; rank < NumBlocksPerCluster; rank++) { - if (rank == clusterBlockRank) { - blockExpertOffset = count; + for (int rank = 0; rank < NumBlocksPerCluster; rank++) { + if (rank == clusterBlockRank) { + blockExpertOffset[e] = count[e]; + } + count[e] += expertCounts[rank]; } - count += expertCounts[rank]; } } @@ -317,54 +353,66 @@ __device__ void routingPermutation(KernelParams params, // Compute the runtime config for projections // Whether or not an expert is local is taken into account when smemExpertCount is computed // so we do not need to take it into account here. - - int32_t numCta; - if constexpr (KernelParams::isPow2) { - numCta = divUpLog2(count, params.mPaddingLog2); - } else { - numCta = divUpTileN(count, params.mTileTokensDim); + int32_t numCta[ExpertsPerThread]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + if (params.mIsPow2) { + numCta[e] = divUpLog2(count[e], params.mPaddingLog2); + } else { + numCta[e] = divUpTileN(count[e], params.mTileTokensDim); + } + // Expand from CGA count to CTA count to keep the semantic stable with downstream kernels + numCta[e] *= params.mClusterSizeInBatchDim; } - int32_t ctaOffset; + int32_t ctaOffset[ExpertsPerThread]; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); - if (threadIdx.x < params.mNumExperts) { - // Strided loop to share this work between blocks. - for (int32_t cta = clusterBlockRank; cta < numCta; cta += NumBlocksPerCluster) { - const int32_t localExpertIdx = - (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; - params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - int32_t mnLimit1; - int32_t mnLimit2; - if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + // Strided loop to share this work between blocks. + for (int32_t cta = clusterBlockRank; cta < numCta[e]; cta += NumBlocksPerCluster) { + const int32_t localExpertIdx = + (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffset[e] + cta] = localExpertIdx; + // Write CTA-level MnLimits using ctaTile = cgaTile / clusterSize + int32_t mnLimit1; + int32_t mnLimit2; + if (params.mIsPow2) { + int32_t ctaPaddingLog2 = params.mPaddingLog2 - params.mClusterSizeLog2; + mnLimit1 = mulLog2(ctaOffset[e] + cta + 1, ctaPaddingLog2); + mnLimit2 = mulLog2(ctaOffset[e], ctaPaddingLog2) + count[e]; + } else { + int32_t ctaTile = params.mTileTokensDim / params.mClusterSizeInBatchDim; + mnLimit1 = (ctaOffset[e] + cta + 1) * ctaTile; + mnLimit2 = ctaOffset[e] * ctaTile + count[e]; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset[e] + cta] = min(mnLimit1, mnLimit2); + } + + // get the padded offset associated with this expert (token-space, CGA granularity) + int32_t offset; + if (params.mIsPow2) { + offset = mulLog2(ctaOffset[e] >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + offset = (ctaOffset[e] / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); - } - // get the padded offset associated with this expert - int32_t offset; - if constexpr (KernelParams::isPow2) { - offset = mulLog2(ctaOffset, params.mPaddingLog2); - } else { - offset = mulTileN(ctaOffset, params.mTileTokensDim); + // write expert offsets to shared + smemExpertOffset[expert] = offset + blockExpertOffset[e]; } - // write expert offsets to shared - smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; } // write out padded count if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync()) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + if (params.mIsPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + permutedIdxSize = (numNonExitingCtas / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; @@ -378,16 +426,6 @@ __device__ void routingPermutation(KernelParams params, // implement break with EXIT. __cluster_barrier_wait(); - // trigger the secondary kernel when using PDL - // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, - // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens - // TODO: this is not sufficient to ensure visibility in the next kernel! -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - if constexpr (KernelParams::UsePdl) { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif - // each thread has the same "expanded indexes" assigned to it as above // at this point, we know the final offsets of experts and the offsets within // experts, which allows writing the final index values @@ -402,7 +440,7 @@ __device__ void routingPermutation(KernelParams params, // check whether this expert is local to our GPU at all auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; auto tokenIdx = expandedIdx / params.mTopK; auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; @@ -416,6 +454,17 @@ __device__ void routingPermutation(KernelParams params, params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; } } + + // Trigger the secondary kernel AFTER all global memory writes are complete. + // The downstream kernels (permute, FC1 GEMM) depend on mPtrCtaIdxXyToBatchIdx, + // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas, mPtrPermutedIdxSize, AND + // mPtrExpandedIdxToPermutedIdx / mPtrPermutedIdxToTokenIdx. + // Triggering before the permutation writes causes the consumer to read stale data → NaN. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if (params.mUsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -428,25 +477,36 @@ __device__ void routingPermutation(KernelParams params, // Note: the histogram calculation could also be fused with routingMainKernel, but this might be // inefficient if we have one CTA per token doing a single global atomic. template -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) - routingIndicesHistogramKernel(KernelParams params) { +__global__ void + __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024) + routingIndicesHistogramKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; + static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; + // Cap actual thread count at 1024 when MaxNumExperts > 1024. + static constexpr int NumThreadsBlock = MaxNumExperts <= 1024 ? MaxNumExperts : 1024; + static constexpr int ExpertsPerThread = MaxNumExperts / NumThreadsBlock; + static_assert(MaxNumExperts % NumThreadsBlock == 0, + "MaxNumExperts must be a multiple of NumThreadsBlock"); - // number of experts is bounded by number of threads - __shared__ int32_t __attribute((aligned(128))) smemExpertCount[KernelParams::MaxNumExperts]; + // number of experts is bounded by MaxNumExperts (may exceed thread count) + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[MaxNumExperts]; // For unrolling. uint32_t constexpr NumEltsPerThread = 8; - // Pre-fill the counts with 0 - if (threadIdx.x < params.mNumExperts) { - smemExpertCount[threadIdx.x] = 0; + // Pre-fill the counts with 0 — each thread handles ExpertsPerThread experts +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + smemExpertCount[expert] = 0; + } } __syncthreads(); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid and trigger secondary kernel. - if constexpr (KernelParams::UsePdl) { + if (params.mUsePdl) { cudaGridDependencySynchronize(); cudaTriggerProgrammaticLaunchCompletion(); } @@ -455,8 +515,9 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) uint32_t const expandedIdxSize = params.mNumTokens * params.mTopK; uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - uint32_t const gridBlockOffset = blockIdx.x * KernelParams::MaxNumExperts; - uint32_t const gridStride = gridDim.x * KernelParams::MaxNumExperts; + // Use NumThreadsBlock (actual thread count) for grid-stride addressing + uint32_t const gridBlockOffset = blockIdx.x * NumThreadsBlock; + uint32_t const gridStride = gridDim.x * NumThreadsBlock; // Define a lambda to avoid code duplication in branches. auto loopBody = [&](int expandedIdx) { @@ -475,25 +536,25 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) // check whether this expert is local to our GPU at all and ignore if not auto localExpertIdx = idx - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; if (isLocalExpert) { atomicAdd(&smemExpertCount[idx], 1); } }; - // Grid-stride loop. + // Grid-stride loop using NumThreadsBlock as block width. for (uint32_t expandedIdx0 = gridBlockOffset * NumEltsPerThread; expandedIdx0 < expandedIdxSize; expandedIdx0 += gridStride * NumEltsPerThread) { // Fast path if bound checks aren't necessary - if (expandedIdx0 + NumEltsPerThread * KernelParams::MaxNumExperts <= expandedIdxSize) { + if (expandedIdx0 + NumEltsPerThread * NumThreadsBlock <= expandedIdxSize) { #pragma unroll for (uint32_t ii = 0; ii < NumEltsPerThread; ii++) { - uint32_t expandedIdx = expandedIdx0 + ii * KernelParams::MaxNumExperts + threadIdx.x; + uint32_t expandedIdx = expandedIdx0 + ii * NumThreadsBlock + threadIdx.x; loopBody(expandedIdx); } } else { for (uint32_t expandedIdx = expandedIdx0 + threadIdx.x; expandedIdx < expandedIdxSize; - expandedIdx += KernelParams::MaxNumExperts) { + expandedIdx += NumThreadsBlock) { loopBody(expandedIdx); } } @@ -501,33 +562,44 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) __syncthreads(); // - // Each thread now represents one expert + // Each thread now represents ExpertsPerThread experts // // Reduce histograms with atomics. - if (threadIdx.x < params.mNumExperts) { - int32_t const localExpertCount = smemExpertCount[threadIdx.x]; - atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + int32_t const localExpertCount = smemExpertCount[expert]; + atomicAdd(¶ms.mPtrExpertCounts[expert], localExpertCount); + } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) - routingIndicesOffsetsKernel(KernelParams params) { +__global__ void + __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024) + routingIndicesOffsetsKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; - - // number of experts is bounded by number of threads - __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[KernelParams::MaxNumExperts]; - __shared__ int32_t __attribute((aligned(128))) smemExpertCount[KernelParams::MaxNumExperts]; - __shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[KernelParams::MaxNumExperts]; - // needed for the exclusive sum of token offsets - using Scan = cub::BlockScan; + static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; + // Cap actual thread count at 1024 when MaxNumExperts > 1024. + static constexpr int NumThreadsBlock = MaxNumExperts <= 1024 ? MaxNumExperts : 1024; + static constexpr int ExpertsPerThread = MaxNumExperts / NumThreadsBlock; + static_assert(MaxNumExperts % NumThreadsBlock == 0, + "MaxNumExperts must be a multiple of NumThreadsBlock"); + + // number of experts — shared memory sized by MaxNumExperts (may exceed thread count) + __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[MaxNumExperts]; + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[MaxNumExperts]; + __shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[MaxNumExperts]; + // BlockScan uses actual thread count; array overload handles ExpertsPerThread items per thread + using Scan = cub::BlockScan; __shared__ typename Scan::TempStorage tempStorage; static constexpr int MaxExpandedIdxPerThread = NumEltsPerOffsetTilePerThread; - static constexpr int MaxExpandedIdxPerBlock = - KernelParams::MaxNumExperts * MaxExpandedIdxPerThread; + // Tile size uses actual thread count + static constexpr int MaxExpandedIdxPerBlock = NumThreadsBlock * MaxExpandedIdxPerThread; int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); @@ -537,82 +609,99 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid. - if constexpr (KernelParams::UsePdl) { + if (params.mUsePdl) { cudaGridDependencySynchronize(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // The expert offsets are common to all tiles of all blocks. // Load the histogram, scan it and write offsets to shared memory. - // Note: the scan is redundant in all CTAs. Would it make sense to use an intermediate kernel for - // the scan, with PDL? // - // Each thread represents one expert. + // Each thread represents ExpertsPerThread experts. // - // Get total count for this expert. - int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; + // Get total count for each expert this thread handles. + int32_t count[ExpertsPerThread]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + count[e] = (expert < params.mNumExperts) ? params.mPtrExpertCounts[expert] : 0; + } // Compute the runtime config for projections // Whether or not an expert is local is taken into account when the histogram is computed // so we do not need to take it into account here. - // const int32_t numCta = divUpLog2(count, params.mPaddingLog2); - int32_t numCta; - if constexpr (KernelParams::isPow2) { - numCta = divUpLog2(count, params.mPaddingLog2); - } else { - numCta = divUpTileN(count, params.mTileTokensDim); + int32_t numCta[ExpertsPerThread]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + if (params.mIsPow2) { + numCta[e] = divUpLog2(count[e], params.mPaddingLog2); + } else { + numCta[e] = divUpTileN(count[e], params.mTileTokensDim); + } + // Expand from CGA count to CTA count to keep the semantic stable with downstream kernels + numCta[e] *= params.mClusterSizeInBatchDim; } - int32_t ctaOffset; + int32_t ctaOffset[ExpertsPerThread]; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); - if (threadIdx.x < params.mNumExperts) { - // Get the padded offset associated with this expert - int32_t offset; - if constexpr (KernelParams::isPow2) { - offset = mulLog2(ctaOffset, params.mPaddingLog2); - } else { - offset = mulTileN(ctaOffset, params.mTileTokensDim); - } +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + // Get the padded offset associated with this expert (token-space, CGA granularity) + int32_t offset; + if (params.mIsPow2) { + offset = mulLog2(ctaOffset[e] >> params.mClusterSizeLog2, params.mPaddingLog2); + } else { + offset = (ctaOffset[e] / params.mClusterSizeInBatchDim) * params.mTileTokensDim; + } - // Write expert offsets to shared - smemExpertOffset[threadIdx.x] = offset; + // Write expert offsets to shared + smemExpertOffset[expert] = offset; + } } // Sync to make expert offsets available to all threads. __syncthreads(); - // The first block writes out padded count - if (blockIdx.x == 0 && warpIdx == KernelParams::MaxNumExperts / WarpSize - 1 && - cute::elect_one_sync()) { + // The first block writes out padded count (use last warp of actual thread count) + if (blockIdx.x == 0 && warpIdx == NumThreadsBlock / WarpSize - 1 && cute::elect_one_sync()) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + if (params.mIsPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + permutedIdxSize = (numNonExitingCtas / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } - if (threadIdx.x < params.mNumExperts) { - // Strided loop to share this work between blocks. - for (int32_t cta = blockIdx.x; cta < numCta; cta += gridDim.x) { - const int32_t localExpertIdx = - (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; - params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - int32_t mnLimit1; - int32_t mnLimit2; - if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; - } else { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + // Strided loop to share this work between blocks. + for (int32_t cta = blockIdx.x; cta < numCta[e]; cta += gridDim.x) { + const int32_t localExpertIdx = + (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffset[e] + cta] = localExpertIdx; + // Write CTA-level MnLimits using ctaTile = cgaTile / clusterSize + int32_t mnLimit1; + int32_t mnLimit2; + if (params.mIsPow2) { + int32_t ctaPaddingLog2 = params.mPaddingLog2 - params.mClusterSizeLog2; + mnLimit1 = mulLog2(ctaOffset[e] + cta + 1, ctaPaddingLog2); + mnLimit2 = mulLog2(ctaOffset[e], ctaPaddingLog2) + count[e]; + } else { + int32_t ctaTile = params.mTileTokensDim / params.mClusterSizeInBatchDim; + mnLimit1 = (ctaOffset[e] + cta + 1) * ctaTile; + mnLimit2 = ctaOffset[e] * ctaTile + count[e]; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset[e] + cta] = min(mnLimit1, mnLimit2); } - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } } @@ -627,9 +716,13 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) __syncthreads(); } - // Pre-fill the counts with 0 - if (threadIdx.x < params.mNumExperts) { - smemExpertCount[threadIdx.x] = 0; + // Pre-fill the counts with 0 — each thread handles ExpertsPerThread experts +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + smemExpertCount[expert] = 0; + } } __syncthreads(); @@ -646,7 +739,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) // check whether this expert is local to our GPU at all and ignore if not auto localExpertIdx = expertIndexes[ii] - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIndexes[ii], 1) : 0; }; @@ -655,28 +748,25 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) { auto expandedIdx = - tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; + tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsBlock + threadIdx.x; loopBody(ii, expandedIdx); } } else { // For the last tile, we need to exit the loop when out of bounds. - // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a - // time, and branch between a fast path without bound checks and a slow path with bound checks int constexpr IterStride = 4; static_assert(MaxExpandedIdxPerThread % IterStride == 0); #pragma unroll for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) { - // Whether it's safe to do multiple iterations without bound checks. bool const takeFastPath = - tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * KernelParams::MaxNumExperts <= + tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * NumThreadsBlock <= expandedIdxSize; if (takeFastPath) { #pragma unroll for (int32_t jj = 0; jj < IterStride; jj++) { int const ii = ii0 + jj; auto expandedIdx = - tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; + tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsBlock + threadIdx.x; loopBody(ii, expandedIdx); } } else { @@ -685,7 +775,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) for (int32_t jj = 0; jj < IterStride; jj++) { int const ii = ii0 + jj; auto expandedIdx = - tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; + tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsBlock + threadIdx.x; if (expandedIdx >= expandedIdxSize) { doBreak = true; break; @@ -703,19 +793,21 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) __syncthreads(); // - // Each thread now represents one expert + // Each thread now represents ExpertsPerThread experts // - if (threadIdx.x < params.mNumExperts) { - // Add the local bin count to the common bin count and get a per-CTA offset. We use the second - // half of the histogram buffer for this histogram, because the first half already holds the - // reduced histogram from the previous kernel. - int32_t const localExpertCount = smemExpertCount[threadIdx.x]; - int32_t const tileExpertOffset = - atomicAdd(¶ms.mPtrExpertCounts[params.mNumExperts + threadIdx.x], localExpertCount); - - // Make per-expert tile offsets available to all threads in the block. - smemExpertTileOffset[threadIdx.x] = tileExpertOffset + smemExpertOffset[threadIdx.x]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + // Add the local bin count to the common bin count and get a per-CTA offset. + int32_t const localExpertCount = smemExpertCount[expert]; + int32_t const tileExpertOffset = + atomicAdd(¶ms.mPtrExpertCounts[params.mNumExperts + expert], localExpertCount); + + // Make per-expert tile offsets available to all threads in the block. + smemExpertTileOffset[expert] = tileExpertOffset + smemExpertOffset[expert]; + } } __syncthreads(); @@ -725,7 +817,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) // check whether this expert is local to our GPU at all auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; auto tokenIdx = expandedIdx / params.mTopK; auto permutedIdx = isLocalExpert ? (expertOffsets[ii] + smemExpertTileOffset[expertIdx]) : int32_t{-1}; @@ -744,14 +836,14 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) { auto expandedIdx = - tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; + tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsBlock + threadIdx.x; storeLoopBody(ii, expandedIdx); } } else { #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) { auto expandedIdx = - tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; + tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsBlock + threadIdx.x; if (expandedIdx >= expandedIdxSize) { break; } @@ -762,9 +854,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Trigger secondary kernel. - // Note: this does not guarantee the visibility of prior writes unless the consumer executes a - // dependency sync. - if constexpr (KernelParams::UsePdl) { + if (params.mUsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -773,16 +863,21 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) //////////////////////////////////////////////////////////////////////////////////////////////////// template -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) - routingInitExpertCounts(KernelParams params) { +__global__ void + __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024) + routingInitExpertCounts(KernelParams params) { + // Cap actual thread count at 1024 when MaxNumExperts > 1024. + static constexpr int NumThreadsBlock = + KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024; + // initialize the mPtrExpertCounts int32_t expertCountsNum = 2 * params.mNumExperts; - int32_t globalThreadIdx = blockIdx.x * KernelParams::MaxNumExperts + threadIdx.x; - int32_t globalThreadStride = gridDim.x * KernelParams::MaxNumExperts; + int32_t globalThreadIdx = blockIdx.x * NumThreadsBlock + threadIdx.x; + int32_t globalThreadStride = gridDim.x * NumThreadsBlock; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid. - if constexpr (KernelParams::UsePdl) { + if (params.mUsePdl) { cudaGridDependencySynchronize(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -790,11 +885,261 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Wait on primary grid. - if constexpr (KernelParams::UsePdl) { + // Trigger secondary kernel. + if (params.mUsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Cooperative launch kernel: fuses histogram + offsets computation for medium token counts. +// This kernel is shared by routingCustom, routingDeepSeek, and can be used by other routing +// methods. It uses cooperative groups to synchronize across multiple CTAs and compute expert counts, +// offsets, and permutation indices in a single kernel launch. +// +// Requirements: +// - MaxNumExperts <= 1024 (enforced by static_assert) +// - SM90+ architecture (cooperative groups) +// - mPtrPermutedIdxSize must be non-null (needed for permutation) +// +// The kernel handles both mPtrTopKIds and mPtrTopKPacked input formats. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +template +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesCoopKernel(KernelParams params) { + // number of experts is bounded by number of threads (coop kernel requires MaxNumExperts <= 1024) + using OutputT = typename KernelParams::OutputT; + static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; + static constexpr int NumThreads = MaxNumExperts; + static_assert(MaxNumExperts <= 1024, "Coop kernel requires MaxNumExperts <= 1024"); + + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[MaxNumExperts]; + __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[MaxNumExperts]; + // needed for the exclusive sum of token offsets + using Scan = cub::BlockScan; + __shared__ typename Scan::TempStorage tempStorage; + // 64 elements -> 128+ registers. Above that we may start to see spilling to local memory. + static constexpr int MaxExpandedIdxPerThread = 64; + + // Initialize grid. + cg::grid_group grid = cg::this_grid(); + int32_t const gridBlockIdx = blockIdx.x; + int32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x; + int32_t const numBlocks = gridDim.x; + int32_t const numThreadsPerGrid = numBlocks * NumThreads; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + + auto expandedIdxSize = params.mNumTokens * params.mTopK; + + // pre-fill the counts with 0 — each thread represents one expert + smemExpertCount[threadIdx.x] = 0; + __syncthreads(); + + // then wait on primary grid + if (params.mUsePdl) { + cudaGridDependencySynchronize(); + } + + // each thread keeps has some number of "expanded indexes" assigned to it + // for each of these, we keep the associated expert and offset within expert in registers + int32_t expertIndexes[MaxExpandedIdxPerThread]; + int32_t expertOffsets[MaxExpandedIdxPerThread]; + auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; + int constexpr IterStride = 4; + static_assert(MaxExpandedIdxPerThread % IterStride == 0); + + // Define a lambda to avoid code duplication in both branches. + // Use shared device function for expert index extraction. + auto loopBody = [&](int ii, int expandedIdx) { + int32_t expertIdx = + getExpertIdxFromInputWithWeights(params, expandedIdx, params.mPtrTopKWeights); + expertIndexes[ii] = expertIdx; + // check whether this expert is local to our GPU at all and ignore if not + auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0; + }; + +#pragma unroll + for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) { + bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize; + if (takeFastPath) { +#pragma unroll + for (int32_t jj = 0; jj < IterStride; jj++) { + int const ii = ii0 + jj; + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + loopBody(ii, expandedIdx); + } + } else { + bool doBreak = false; +#pragma unroll + for (int32_t jj = 0; jj < IterStride; jj++) { + int const ii = ii0 + jj; + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + if (expandedIdx >= expandedIdxSize) { + doBreak = true; + break; + } + loopBody(ii, expandedIdx); + } + if (doBreak) { + break; + } + } + } + + // Make histogram (token counts per expert) available to all threads in the block. + __syncthreads(); + + // + // Each thread now represents one expert + // + + // Add the local bin count to the common bin count and get a per-CTA offset. + int32_t const localExpertCount = smemExpertCount[threadIdx.x]; + + int32_t blockExpertOffset = 0; + if (threadIdx.x < params.mNumExperts) { + blockExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); + } + + // Sync to wait for completion of the histogram reduction. + grid.sync(); + + // Get total count for this expert. + int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; + + int32_t numCta; + if (params.mIsPow2) { + numCta = divUpLog2(count, params.mPaddingLog2); + } else { + numCta = divUpTileN(count, params.mTileTokensDim); + } + // Expand from CGA count to CTA count to keep the semantic stable with downstream kernels + numCta *= params.mClusterSizeInBatchDim; + + int32_t ctaOffset; + int32_t numNonExitingCtas; + Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); + + for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks) { + const int32_t localExpertIdx = + (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; + // Write CTA-level MnLimits using ctaTile = cgaTile / clusterSize + int32_t mnLimit1; + int32_t mnLimit2; + if (params.mIsPow2) { + int32_t ctaPaddingLog2 = params.mPaddingLog2 - params.mClusterSizeLog2; + mnLimit1 = mulLog2(ctaOffset + cta + 1, ctaPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, ctaPaddingLog2) + count; + } else { + int32_t ctaTile = params.mTileTokensDim / params.mClusterSizeInBatchDim; + mnLimit1 = (ctaOffset + cta + 1) * ctaTile; + mnLimit2 = ctaOffset * ctaTile + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); + } + + // get the padded offset associated with this expert (token-space, CGA granularity) + int32_t offset; + if (params.mIsPow2) { + offset = mulLog2(ctaOffset >> params.mClusterSizeLog2, params.mPaddingLog2); + } else { + offset = (ctaOffset / params.mClusterSizeInBatchDim) * params.mTileTokensDim; + } + int32_t permutedIdxSize; + if (params.mIsPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas >> params.mClusterSizeLog2, params.mPaddingLog2); + } else { + permutedIdxSize = (numNonExitingCtas / params.mClusterSizeInBatchDim) * params.mTileTokensDim; + } + + // write out padded count + if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) { + params.mPtrPermutedIdxSize[0] = permutedIdxSize; + params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; + } + + // write expert offsets to shared + smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; + + // make expert offsets available to all threads + __syncthreads(); + + // each thread has the same "expanded indexes" assigned to it as above +#pragma unroll + for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) { + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + if (expandedIdx >= expandedIdxSize) { + break; + } + auto expertIdx = expertIndexes[ii]; + // check whether this expert is local to our GPU at all + auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + auto tokenIdx = expandedIdx / params.mTopK; + auto permutedIdx = + isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; + if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; + } + if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert) { + params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; + } + if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) { + params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; + } + } + + // Trigger the secondary kernel AFTER all global memory writes (including permutation indices). + if (params.mUsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } +} +#else +template +__global__ void routingIndicesCoopKernel(KernelParams params) { + assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures"); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Shared device functions for coop kernel (used by both routingCustom and routingDeepSeek) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Device function to extract expert index from either mPtrTopKIds or mPtrTopKPacked. +template +__forceinline__ __device__ int32_t getExpertIdxFromInput(KernelParams const& params, + int32_t expandedIdx) { + if (params.mPtrTopKIds != nullptr) { + return params.mPtrTopKIds[expandedIdx]; + } else { + return params.mPtrTopKPacked[expandedIdx].idx; + } +} + +// Overload for routingCustom that also writes topK weights if needed. +template +__forceinline__ __device__ int32_t getExpertIdxFromInputWithWeights( + KernelParams const& params, int32_t expandedIdx, + typename KernelParams::OutputT* topKWeights) { + if (params.mPtrTopKIds != nullptr) { + return params.mPtrTopKIds[expandedIdx]; + } else { + PackedScoreIdx scoreIdx = params.mPtrTopKPacked[expandedIdx]; + if (topKWeights != nullptr) { + topKWeights[expandedIdx] = static_cast(scoreIdx.score); + } + return scoreIdx.idx; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace routing } // namespace moe::dev diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index ba90742ce0..fab8737017 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,21 @@ struct PackedScoreIdx { struct DataBase { bool mUsePdl{false}; + // Controls the cudaLaunchAttributeProgrammaticStreamSerialization launch attribute. + // When true, the NEXT kernel in the stream is allowed to start before this kernel completes. + // When false (default), the next kernel waits for this kernel to finish (normal serialization). + // + // This is separate from mUsePdl because: + // - mUsePdl controls IN-KERNEL behavior: cudaGridDependencySynchronize (wait for predecessor) + // and cudaTriggerProgrammaticLaunchCompletion (signal successor). + // - mPdlOverlapWithNext controls the LAUNCH ATTRIBUTE: whether the runtime is allowed to + // dispatch the next kernel before this one finishes. + // + // The LAST routing kernel in a multi-kernel chain should set mPdlOverlapWithNext = false + // to prevent the consumer GEMM (which may not have cudaGridDependencySynchronize for routing + // data) from starting early and reading stale permutation indices. + bool mPdlOverlapWithNext{false}; + // optional: only used as an intermediate buffer when the number of tokens is large. // dim: max([2*NumThreads] = [512], mNumExperts*2) int32_t* mPtrExpertCounts{nullptr}; @@ -57,6 +72,7 @@ struct DataBase { // Note: this array (mPtrPermutedIdxToTokenIdx) is uninitialized // Any out-of-bounds values are undefined. int32_t* mPtrPermutedIdxToTokenIdx{nullptr}; + // optional: if `nullptr`, it is not filled // dim: [mNumTokens, mTopK] // When mPtrTopKIds is provided, mPtrTopKWeights must be also provided as inputs. @@ -95,8 +111,14 @@ struct DataBase { int32_t mNumTokens; int32_t mNumExperts; int32_t mTopK; - int32_t mPaddingLog2; + // Cluster-wide tile size in token dimension. int32_t mTileTokensDim; + // log2() of the padding size in cluster-wide tile. + int32_t mPaddingLog2; + // Cluster size (e.g., 1x2, 2x1, etc.) in batch dimension. + int32_t mClusterSizeInBatchDim{1}; + // log2() of the cluster size in batch dimension. + int32_t mClusterSizeLog2{0}; /// For expert parallelization int32_t mLocalExpertsStartIdx; @@ -104,13 +126,15 @@ struct DataBase { int32_t mNumLocalExperts; }; -template +template struct KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; static constexpr int MaxNumExperts = MaxNumExperts_; - static constexpr bool isPow2 = isPow2_; - static constexpr bool UsePdl = UsePdl_; + static constexpr int MaxNumTopExperts = MaxNumTopExperts_; + + bool mUsePdl = false; + bool mIsPow2 = false; // Public pointer members int32_t* mPtrExpertCounts = nullptr; @@ -131,6 +155,8 @@ struct KernelParamsBase { int32_t mPaddingLog2 = -1; int32_t mTileTokensDim = 0; + int32_t mClusterSizeInBatchDim = 1; + int32_t mClusterSizeLog2 = 0; int32_t mLocalExpertsStartIdx = 0; int32_t mLocalExpertsStrideLog2 = 0; int32_t mNumLocalExperts = 0; @@ -138,6 +164,8 @@ struct KernelParamsBase { // Public initialization function - make it a template to accept different Data types template void setBaseParams(DataType const& data) { + mUsePdl = data.mUsePdl; + mIsPow2 = data.mPaddingLog2 > 0; mPtrExpertCounts = data.mPtrExpertCounts; mPtrPermutedIdxSize = data.mPtrPermutedIdxSize; mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx; @@ -155,6 +183,8 @@ struct KernelParamsBase { mPaddingLog2 = data.mPaddingLog2; mTileTokensDim = data.mTileTokensDim; + mClusterSizeInBatchDim = data.mClusterSizeInBatchDim; + mClusterSizeLog2 = data.mClusterSizeLog2; mLocalExpertsStartIdx = data.mLocalExpertsStartIdx; mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2; mNumLocalExperts = data.mNumLocalExperts; @@ -165,13 +195,15 @@ namespace routingDeepSeek { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Data : public DataBase { - tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; - tg::Dtype mDtypeBias{tg::Dtype::Bfloat16}; - tg::Dtype mDtypeScore{tg::Dtype::Fp32}; + tg::Dtype mDtypeOutput{tg::Dtype::Bfloat16}; + tg::Dtype mDtypeInput{tg::Dtype::Fp32}; // InputT: routing logits dtype (Bfloat16 or Fp32) + // // Grouped Gemm Launch Config Buffers // void const* mPtrRoutingBias; + // Dtype of the routing bias buffer (Bfloat16 or Fp32). + tg::Dtype mDtypeBias{tg::Dtype::Bfloat16}; int32_t mHiddenDim; // not used int32_t mNumExpertGroups; @@ -181,15 +213,14 @@ struct Data : public DataBase { bool mUseRoutingSoftmax; }; -template -struct KernelParams : public KernelParamsBase { +template +struct KernelParams + : public KernelParamsBase { using InputT = InputT_; - using BiasT = BiasT_; using OutputT = OutputT_; static constexpr bool UseGroups = UseGroups_; - static constexpr int MaxNumTopExperts = MaxNumTopExperts_; PackedScoreIdx* mPtrTopKPacked = nullptr; @@ -197,7 +228,9 @@ struct KernelParams : public KernelParamsBase*)data.mPtrTopKPacked; - // params.mPtrTopKWeightsFull = static_cast(data.mPtrTopKWeightsFull); - params.mPtrRoutingBias = static_cast(data.mPtrRoutingBias); + params.mPtrRoutingBias = data.mPtrRoutingBias; + params.mDtypeBias = data.mDtypeBias; params.mNumExpertGroups = data.mNumExpertGroups; params.mNumExpertsPerGroup = data.mNumExperts / data.mNumExpertGroups; @@ -236,11 +269,13 @@ namespace routingLlama4 { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Data : public DataBase { - tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; + tg::Dtype mDtypeOutput{tg::Dtype::Bfloat16}; + tg::Dtype mDtypeInput{tg::Dtype::Bfloat16}; // InputT: routing logits dtype (Bfloat16 or Fp32) }; -template -struct KernelParams : public KernelParamsBase { +template +struct KernelParams + : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; @@ -264,49 +299,99 @@ void run(Data const& data, void* stream); //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace routingRenormalize { +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Routing preprocess/postprocess policy type enums. +// These are used to select the compile-time policy at dispatch time. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class RoutingPreprocessType { + None, // No preprocessing before topK + Softmax, // Apply softmax on all expert scores before topK + Sigmoid, // Apply sigmoid(score) for topK selection (no bias) + SigmoidBias, // Apply sigmoid(score) + bias for topK selection (DeepSeek-style) +}; + +enum class RoutingPostprocessType { + None, // No postprocessing after topK + Softmax, // Apply softmax on top-K scores + SumNormalize, // Normalize top-K scores by their sum + ScaledSumNormalize, // Recover sigmoid scores, normalize by sum and scale (DeepSeek-style) +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace routingCustom { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Data : public DataBase { - tg::Dtype mDtypeExpW{tg::Dtype::Fp32}; - tg::Dtype mDtypeElt{tg::Dtype::Bfloat16}; + tg::Dtype mDtypeOutput{tg::Dtype::Fp32}; // OutputT: expert weights dtype (typically Bfloat16) + tg::Dtype mDtypeInput{tg::Dtype::Bfloat16}; // InputT: routing logits dtype (Bfloat16 or Fp32) - bool mDoSoftmaxBeforeTopK{false}; + RoutingPreprocessType mPreprocessType{RoutingPreprocessType::None}; + RoutingPostprocessType mPostprocessType{RoutingPostprocessType::Softmax}; bool mNormTopkProb{true}; // Default value is true for Qwen3 model - bool mApplySoftmaxAfterTopK{false}; + + // Optional: per-expert routing bias (used by SigmoidBias preprocess). + void const* mPtrRoutingBias{nullptr}; + // Dtype of the routing bias buffer (Bfloat16 or Fp32). Used to read mPtrRoutingBias correctly. + tg::Dtype mDtypeBias{tg::Dtype::Bfloat16}; + // Optional: scaling factor applied to final scores (used by ScaledSumNormalize postprocess). + float mRouteScale{1.0f}; + // Optional: epsilon added to the sum before division to prevent division by zero. + // MiniMax2 uses 1e-20f; DeepSeek uses 0.0f (no epsilon). + float mSumEpsilon{0.0f}; }; -template -struct KernelParams : public KernelParamsBase { +template +struct KernelParams + : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; + using ExpertSelectPolicy = ExpertSelectPolicy_; - static constexpr bool DoSoftmaxBeforeTopK = DoSoftmaxBeforeTopK_; + // Expert select policy params — empty structs have zero register cost. + using ExpertSelectParams = typename ExpertSelectPolicy::template Params; PackedScoreIdx* mPtrTopKPacked = nullptr; int32_t mTopK = 0; - bool mNormTopkProb = true; - bool mApplySoftmaxAfterTopK = false; + ExpertSelectParams mExpertSelectParams; static KernelParams setKernelParams(Data const& data) { KernelParams params; params.setBaseParams(data); params.mPtrTopKPacked = (PackedScoreIdx*)data.mPtrTopKPacked; - params.mNormTopkProb = data.mNormTopkProb; - params.mApplySoftmaxAfterTopK = data.mApplySoftmaxAfterTopK; params.mTopK = data.mTopK; + + // Policy populates only the fields it needs from Data. + params.mExpertSelectParams.set(data); return params; } }; void run(Data const& data, void* stream); -} // namespace routingRenormalize +} // namespace routingCustom + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Shared utility for post-topK pipeline when mPtrTopKIds != nullptr. +// All routing methods (Custom, DeepSeek, Llama4) use the same workflow in this case: +// 1. Reset expert counts +// 2. Run histogram kernel +// 3. Run offsets kernel +// Since the kernels are shared and we don't need routing-method-specific logic, +// we can use routingCustom's launch mechanism. +// +// This function works with any Data type that inherits from DataBase. +// Implementation is in trtllm_fused_moe_routing_common.cu +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void runPostTopKPipeline(DataType const& data, uint32_t numThreadsHist, void* stream); //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace routing diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh b/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh index 42aa877d26..9a8d376e2d 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh @@ -18,7 +18,6 @@ #include #include - #include namespace moe::dev::routing { @@ -32,7 +31,7 @@ namespace cg = cooperative_groups; static constexpr int WarpSize = 32; static constexpr int MaxNumExpertsUnit = 128; -static constexpr int MaxNumTopK = 10; +static constexpr int MaxSupportedTopExperts = 32; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -53,8 +52,7 @@ struct TopKRedType { static __host__ __device__ inline TypeCmp makeCmpVal(TypeExpW val, int32_t idx = 0) { auto valueBits = cub::Traits::TwiddleIn( reinterpret_cast::UnsignedBits&>(val)); - TypeCmp compactTmp; - memcpy(&compactTmp, &valueBits, sizeof(valueBits)); + TypeCmp compactTmp = valueBits; compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx)); // Use 65535 minus idx to give higher priority to elements with smaller indices. return compactTmp; @@ -78,8 +76,10 @@ struct TopKRedType { __host__ __device__ operator TypeCmp() const noexcept { return compVal; } __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) { + // Use fast redux.sync.max.u32 on SM100+ for 32-bit packed types only. + // For 64-bit packed types (float scores), fall back to cg::reduce. #ifdef __CUDA_ARCH__ - static constexpr bool hasFastRedux = __CUDA_ARCH__ >= 1000; + static constexpr bool hasFastRedux = (__CUDA_ARCH__ / 100) >= 10; #else static constexpr bool hasFastRedux = false; #endif @@ -95,18 +95,77 @@ struct TopKRedType { //////////////////////////////////////////////////////////////////////////////////////////////////// -#define TOPK_SWAP(I, J) \ - { \ - auto pairMin = min(topK[I].compVal, topK[J].compVal); \ - auto pairMax = max(topK[I].compVal, topK[J].compVal); \ - topK[I].compVal = pairMax; \ - topK[J].compVal = pairMin; \ +#define TOPK_SWAP(I, J) \ + { \ + auto pairMin = min(topK[I].compVal, topK[J].compVal); \ + auto pairMax = max(topK[I].compVal, topK[J].compVal); \ + topK[I].compVal = pairMax; \ + topK[J].compVal = pairMin; \ } -//////////////////////////////////////////////////////////////////////////////////////////////////// +// Helper to check if N is a power of 2 +template +struct IsPowerOf2 { + static constexpr bool value = (N > 0) && ((N & (N - 1)) == 0); +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// template -struct Sort; +struct Sort { + static_assert(N > 0 && N <= 64, "Sort only supports N in range [1, 64]"); + + static __device__ void run(RedType* topK) { + if constexpr (IsPowerOf2::value) { +// Bitonic sort for power-of-2 sizes - more efficient +#pragma unroll + for (int k = 2; k <= N; k *= 2) { +#pragma unroll + for (int j = k / 2; j > 0; j /= 2) { +#pragma unroll + for (int i = 0; i < N; ++i) { + int ixj = i ^ j; + if (ixj > i) { + if ((i & k) == 0) { + if (topK[i].compVal < topK[ixj].compVal) { + auto tmp = topK[i].compVal; + topK[i].compVal = topK[ixj].compVal; + topK[ixj].compVal = tmp; + } + } else { + if (topK[i].compVal > topK[ixj].compVal) { + auto tmp = topK[i].compVal; + topK[i].compVal = topK[ixj].compVal; + topK[ixj].compVal = tmp; + } + } + } + } + } + } + } else { +// Odd-even transposition sort for non-power-of-2 sizes +#pragma unroll + for (int pass = 0; pass < N; ++pass) { +#pragma unroll + for (int i = 0; i < N - 1; i += 2) { + if (topK[i].compVal < topK[i + 1].compVal) { + auto tmp = topK[i].compVal; + topK[i].compVal = topK[i + 1].compVal; + topK[i + 1].compVal = tmp; + } + } +#pragma unroll + for (int i = 1; i < N - 1; i += 2) { + if (topK[i].compVal < topK[i + 1].compVal) { + auto tmp = topK[i].compVal; + topK[i].compVal = topK[i + 1].compVal; + topK[i + 1].compVal = tmp; + } + } + } + } + } +}; template struct Sort<1, RedType> { @@ -160,14 +219,14 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile const }; template -__forceinline__ __device__ void reduceTopKFunc(cg::thread_block_tile const& warp, - Type (&out)[K], int32_t (&outIdx)[K], - Type (&value)[N], int32_t (&idx)[N], - Type const minValue, int actualK = K) { +__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, + Type (&out)[K], int32_t (&outIdx)[K], Type (&value)[N], + int32_t (&idx)[N], Type const minValue, + int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < WarpSize, "Top K must have K < WarpSize"); + static_assert(K <= WarpSize, "Top K must have K <= WarpSize"); static_assert(N > 0, "Top K must have N > 0"); - static_assert(N < 5, "Only support candidates number less than or equal to 128"); + static_assert(N <= 64, "Only support candidates number less than or equal to 64*32=2048"); using RedType = TopKRedType; RedType topK[N]; #pragma unroll @@ -178,15 +237,13 @@ __forceinline__ __device__ void reduceTopKFunc(cg::thread_block_tile c Sort::run(topK); typename RedType::TypeCmp packedMax{}; -#pragma unroll - for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct - { + for (int kk = 0; kk < actualK; ++kk) { bool update = kk > 0 && packedMax == topK[0].compVal; #pragma unroll for (int nn = 0; nn < N; ++nn) { - topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} - : update ? topK[nn + 1] - : topK[nn]; + topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} + : update ? topK[nn + 1] + : topK[nn]; } // get the next largest value packedMax = topK[0].reduce(warp); @@ -194,58 +251,6 @@ __forceinline__ __device__ void reduceTopKFunc(cg::thread_block_tile c } }; -template -__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, - Type (&out)[K], int32_t (&outIdx)[K], Type (&value)[N], - int32_t (&idx)[N], Type const minValue, - int actualK = K) { - static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < WarpSize, "Top K must have K < WarpSize"); - static_assert(N > 0, "Top K must have N > 0"); - static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512"); - using RedType = TopKRedType; - - if constexpr (N <= 4) { - reduceTopKFunc(warp, out, outIdx, value, idx, minValue, actualK); - } else { - constexpr int numLoops = (N - 1) / 4 + 1; - constexpr int numResults = (numLoops * K - 1) / WarpSize + 1; - - Type topKBufferValue[numResults]; - int32_t topKBufferIdx[numResults]; - int32_t laneIdx = threadIdx.x % WarpSize; - - for (int ii = 0; ii < numResults; ++ii) { - topKBufferValue[ii] = minValue; - topKBufferIdx[ii] = ii * WarpSize - 1; - } - for (int loop = 0; loop < numLoops; ++loop) { - int start = loop * 4; - Type topKValue[K]; - int32_t topKIdx[K]; - Type inValue[4]; - int32_t inIdx[4]; - for (int i = 0; i < 4; ++i) { - inValue[i] = value[start + i]; - inIdx[i] = idx[start + i]; - } - reduceTopKFunc(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK); - int inOffset = laneIdx % K; - if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) { - topKBufferValue[0] = topKValue[inOffset]; - topKBufferIdx[0] = topKIdx[inOffset]; - } - if (loop == numLoops - 1 && (laneIdx < (numLoops * K - WarpSize))) { - topKBufferValue[1] = topKValue[inOffset]; - topKBufferIdx[1] = topKIdx[inOffset]; - } - } - - reduceTopKFunc(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, - actualK); - } -}; - #undef TOPK_SWAP } // namespace topk } // namespace moe::dev::routing diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 46617e5dbd..deb14fb2a6 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -48,8 +48,10 @@ enum class RoutingMethodType : int64_t { RenormalizeNaive = 4, // TopK only (no softmax) TopK = 5, + // SigmoidRenorm: Sigmoid -> TopK -> Renormalize (divide by sum of top-K weights) + SigmoidRenorm = 6, // Unspecified - Unspecified = 6, + Unspecified = 7, }; inline int32_t maybeGetMinTokenCount(int32_t numPaddedTokens, int32_t hiddenSize, @@ -73,6 +75,8 @@ inline std::string serializeMoeRoutingMethodType(RoutingMethodType routingMethod return "RenormalizeNaive"; case RoutingMethodType::TopK: return "TopK"; + case RoutingMethodType::SigmoidRenorm: + return "SigmoidRenorm"; default: return "InvalidRountingMethod"; // TODO throw error }; @@ -128,7 +132,8 @@ class Runner { int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, - cudaStream_t stream); + cudaStream_t stream, + batchedGemm::trtllm::gen::Dtype dtypeLogits, bool normTopkProb = true); private: int32_t mTileTokensDim{8}; diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 62f9860644..ba072543e5 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -214,6 +214,7 @@ def _run_moe_computation(self, runtime_args): activation_type=self.config["activation_type"], do_finalize=True, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, + norm_topk_prob=self.config.get("norm_topk_prob", True), ) return output # Extract tensor from tuple @@ -573,6 +574,8 @@ def call_moe( gemm1_bias = kwargs["gemm1_bias"] gemm2_bias = kwargs["gemm2_bias"] + norm_topk_prob = kwargs.get("norm_topk_prob", True) + # Create CUDA graph configuration config = { "hidden_states_scale_global": hidden_states_scale_global, @@ -587,6 +590,7 @@ def call_moe( "enable_autotune": enable_autotune, "gemm1_bias": gemm1_bias, "gemm2_bias": gemm2_bias, + "norm_topk_prob": norm_topk_prob, } runtime_args = { @@ -793,6 +797,7 @@ def call_moe( routing_method_type = kwargs["routing_method_type"] enable_autotune = kwargs.get("enable_autotune", True) routed_scaling = kwargs.get("routed_scaling", 1.0) + norm_topk_prob = kwargs.get("norm_topk_prob", True) # Use autotuner for optimal kernel selection with autotune(enable_autotune): @@ -817,6 +822,7 @@ def call_moe( routed_scaling, routing_method_type=routing_method_type, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, + norm_topk_prob=norm_topk_prob, ) return output[0].to(torch.float) @@ -1113,6 +1119,7 @@ def call_moe( enable_pdl = kwargs.get("enable_pdl") hidden_states_scale = kwargs["hidden_states_scale"] hidden_states_quant = kwargs["hidden_states_quant"] + norm_topk_prob = kwargs.get("norm_topk_prob", True) # Generate block scales and quantize hidden states at runtime hidden_states_fp8 = hidden_states_quant.to(torch.float8_e4m3fn) @@ -1154,6 +1161,7 @@ def call_moe( enable_pdl=enable_pdl, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, fp8_quantization_type=quantization_mode, + norm_topk_prob=norm_topk_prob, ) return output.to(torch.float) @@ -1320,6 +1328,7 @@ def call_moe( routing_method_type = kwargs["routing_method_type"] enable_autotune = kwargs.get("enable_autotune", True) activation_type = kwargs["activation_type"] + norm_topk_prob = kwargs.get("norm_topk_prob", True) # Quantize to FP8 per-tensor using pre-computed global scale factor hidden_states_fp8, _ = quant_fp8_per_tensor( @@ -1354,6 +1363,7 @@ def call_moe( routing_method_type, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, activation_type=activation_type, + norm_topk_prob=norm_topk_prob, ) return output.to(torch.float) @@ -1492,6 +1502,8 @@ def call_moe( routing_method_type = kwargs["routing_method_type"] enable_autotune = kwargs.get("enable_autotune", True) + norm_topk_prob = kwargs.get("norm_topk_prob", True) + # Use autotuner for optimal kernel selection with autotune(enable_autotune): output = trtllm_bf16_moe( @@ -1512,6 +1524,7 @@ def call_moe( weight_layout=static_data["weight_layout"], routing_method_type=routing_method_type, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, + norm_topk_prob=norm_topk_prob, ) return output.to(torch.float) @@ -1757,6 +1770,24 @@ def routing_reference_no_aux( return permute_info, scores +def routing_reference_default(expert_logits, top_k, num_experts, padding): + """Softmax -> TopK routing reference (Default method).""" + scores = torch.nn.functional.softmax(expert_logits.float(), dim=-1) + topk_values, topk_idx = torch.topk(scores, k=top_k, dim=-1) + + topk_values = topk_values.to(expert_logits.dtype) + + new_mask = torch.zeros_like(expert_logits) + new_mask.scatter_(-1, topk_idx, 1) + scores = expert_logits * new_mask + + for i in range(topk_idx.shape[0]): + for j in range(topk_idx.shape[1]): + scores[i, topk_idx[i, j]] = topk_values[i, j] + permute_info = routing_reference(scores, top_k, padding) + return permute_info, scores + + def routing_reference_renormalize(expert_logits, top_k, num_experts, padding): """TopK -> Softmax routing reference.""" topk_values, topk_idx = torch.topk(expert_logits, k=top_k, dim=-1) @@ -1810,6 +1841,27 @@ def routing_reference_topk(expert_logits, top_k, num_experts, padding): return permute_info, scores +def routing_reference_sigmoid_renorm(expert_logits, top_k, num_experts, padding, norm_topk_prob=True): + """Sigmoid -> TopK -> Renormalize routing reference.""" + sigmoid_scores = torch.sigmoid(expert_logits.float()) + topk_values, topk_idx = torch.topk(sigmoid_scores, k=top_k, dim=-1) + + # Renormalize: divide by sum of top-K weights + if norm_topk_prob: + topk_values = topk_values / (topk_values.sum(dim=-1, keepdim=True) + 1e-20) + topk_values = topk_values.to(expert_logits.dtype) + + new_mask = torch.zeros_like(expert_logits) + new_mask.scatter_(-1, topk_idx, 1) + scores = expert_logits * new_mask + + for i in range(topk_idx.shape[0]): + for j in range(topk_idx.shape[1]): + scores[i, topk_idx[i, j]] = topk_values[i, j] + permute_info = routing_reference(scores, top_k, padding) + return permute_info, scores + + def check_accuracy(a, b, atol, rtol, percent): """Unified accuracy checking function with detailed error reporting.""" if not torch.isfinite(a).all(): @@ -2551,6 +2603,7 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): "enable_autotune": kwargs.get("enable_autotune", True), "gemm1_bias": args.gemm1_bias, "gemm2_bias": args.gemm2_bias, + "norm_topk_prob": kwargs.get("norm_topk_prob", True), } return moe_impl.call_moe( @@ -2580,6 +2633,9 @@ def run_moe_test( zero_hidden_states=False, gemm1_bias=None, gemm2_bias=None, + routing_logits_dtype=None, + routing_bias_dtype=None, + norm_topk_prob=True, ): """Common test logic for all routing methods.""" skip_checks( @@ -2620,17 +2676,20 @@ def run_moe_test( assert top_k < (top_k_groups * num_experts / n_groups) # Create test data based on routing method - if routing_method_type == RoutingMethodType.DeepSeekV3: - expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( - torch.float - ) + # Use explicit dtype if provided; otherwise fall back to default per routing method. + if routing_logits_dtype is not None: + logits_dtype = routing_logits_dtype + elif routing_method_type == RoutingMethodType.DeepSeekV3: + logits_dtype = torch.float else: - expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( - torch.bfloat16 - ) + logits_dtype = torch.bfloat16 + expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( + logits_dtype + ) if routing_config["has_routing_bias"]: - routing_bias = torch.randn(num_experts, device="cuda", dtype=torch.bfloat16) + bias_dtype = routing_bias_dtype if routing_bias_dtype is not None else torch.bfloat16 + routing_bias = torch.randn(num_experts, device="cuda", dtype=bias_dtype) else: routing_bias = None @@ -2656,7 +2715,11 @@ def run_moe_test( # Generate routing info use_routing_scales_on_input = routing_method_type == RoutingMethodType.Llama4 - if routing_method_type == RoutingMethodType.DeepSeekV3: + if routing_method_type == RoutingMethodType.Default: + permute_info, scores = routing_reference_default( + expert_logits, top_k, num_experts, padding + ) + elif routing_method_type == RoutingMethodType.DeepSeekV3: permute_info, scores = routing_reference_no_aux( expert_logits, routing_bias, @@ -2672,13 +2735,19 @@ def run_moe_test( expert_logits, top_k, num_experts, padding ) elif routing_method_type == RoutingMethodType.RenormalizeNaive: - permute_info, scores = routing_reference_renormalize_naive( + # RenormalizeNaive (Softmax → TopK → SumNormalize) is mathematically equivalent + # to Renormalize (TopK → Softmax), so we use the same reference implementation. + permute_info, scores = routing_reference_renormalize( expert_logits, top_k, num_experts, padding ) elif routing_method_type == RoutingMethodType.TopK: permute_info, scores = routing_reference_topk( expert_logits, top_k, num_experts, padding ) + elif routing_method_type == RoutingMethodType.SigmoidRenorm: + permute_info, scores = routing_reference_sigmoid_renorm( + expert_logits, top_k, num_experts, padding, norm_topk_prob=norm_topk_prob + ) elif routing_method_type == RoutingMethodType.Llama4: permute_info, scores = routing_reference_no_aux( expert_logits, @@ -2758,6 +2827,7 @@ def run_moe_test( enable_pdl=True, hidden_states_quant=inputs_data["hidden_states"], enable_autotune=enable_autotune, + norm_topk_prob=norm_topk_prob, ) # Compare outputs @@ -2771,6 +2841,119 @@ def run_moe_test( ) +# Test: Default routing (Softmax -> TopK) +@pytest.mark.parametrize("num_tokens", [8, 768, 3072]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [1024, 768, 512, 384]) +@pytest.mark.parametrize( + "moe_impl", + [ + pytest.param(BF16Moe(), id="BF16xBF16"), + pytest.param( + FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_DEEPSEEK), + id="FP8_Block_DeepSeek", + ), + pytest.param( + FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_MXFP8), + id="FP8_Block_MxFp8", + ), + pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), + pytest.param(MxInt4BlockScaleMoe(), id="MxInt4xBf16"), + ], +) +@pytest.mark.parametrize( + "routing_config", + [ + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Default, + "compatible_moe_impls": [ + FP8PerTensorMoe, + FP8BlockScaleMoe, + FP4Moe, + BF16Moe, + MxInt4BlockScaleMoe, + ], + "compatible_intermediate_size": [384, 768, 1024], + "enable_autotune": True, + }, + id="Default_128e_top8", + ), + ], +) +@pytest.mark.parametrize( + "weight_processing", + [ + pytest.param( + { + "use_shuffled_weight": False, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP8BlockScaleMoe], + }, + id="NoShuffle_MajorK", + ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], + }, + id="Shuffled_MajorK", + ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.BlockMajorK, + "compatible_moe_impls": [ + FP8BlockScaleMoe, + BF16Moe, + MxInt4BlockScaleMoe, + ], + }, + id="Shuffled_BlockMajorK", + ), + ], +) +@pytest.mark.parametrize( + "activation_type", + [ + pytest.param(ActivationType.Swiglu.value, id="Swiglu"), + pytest.param(ActivationType.Geglu.value, id="Geglu"), + ], +) +def test_default_routing( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + activation_type, + cache_permute_indices, +): + """Test Default (Softmax -> TopK) routing configurations.""" + run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + activation_type, + cache_permute_indices, + ) + + # Test: Renormalize routing @pytest.mark.parametrize( "zero_hidden_states", @@ -2937,6 +3120,128 @@ def test_renormalize_routing( ) +# Test: SigmoidRenorm routing +@pytest.mark.parametrize("num_tokens", [8, 768, 3072]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [1024, 768, 512, 384]) +@pytest.mark.parametrize( + "moe_impl", + [ + pytest.param(BF16Moe(), id="BF16xBF16"), + pytest.param( + FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_DEEPSEEK), + id="FP8_Block_DeepSeek", + ), + pytest.param( + FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_MXFP8), + id="FP8_Block_MxFp8", + ), + pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), + pytest.param(MxInt4BlockScaleMoe(), id="MxInt4xBf16"), + ], +) +@pytest.mark.parametrize( + "routing_config", + [ + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.SigmoidRenorm, + "compatible_moe_impls": [ + FP8PerTensorMoe, + FP8BlockScaleMoe, + FP4Moe, + BF16Moe, + MxInt4BlockScaleMoe, + ], + "compatible_intermediate_size": [384, 768, 1024], + "enable_autotune": True, + }, + id="SigmoidRenorm_128e_top8", + ), + ], +) +@pytest.mark.parametrize( + "weight_processing", + [ + pytest.param( + { + "use_shuffled_weight": False, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP8BlockScaleMoe], + }, + id="NoShuffle_MajorK", + ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], + }, + id="Shuffled_MajorK", + ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.BlockMajorK, + "compatible_moe_impls": [ + FP8BlockScaleMoe, + BF16Moe, + MxInt4BlockScaleMoe, + ], + }, + id="Shuffled_BlockMajorK", + ), + ], +) +@pytest.mark.parametrize( + "activation_type", + [ + pytest.param(ActivationType.Swiglu.value, id="Swiglu"), + pytest.param(ActivationType.Geglu.value, id="Geglu"), + ], +) +@pytest.mark.parametrize( + "norm_topk_prob", + [ + pytest.param(True, id="NormTopkProb"), + pytest.param(False, id="NoNormTopkProb"), + ], +) +def test_sigmoid_renorm_routing( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + activation_type, + cache_permute_indices, + norm_topk_prob, +): + """Test SigmoidRenorm routing configurations.""" + run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + activation_type, + cache_permute_indices, + norm_topk_prob=norm_topk_prob, + ) + + # Test: DeepSeekV3 routing @pytest.mark.parametrize("num_tokens", [8, 768, 3072]) @pytest.mark.parametrize("hidden_size", [1024]) @@ -3330,3 +3635,128 @@ def test_nvfp4_moe_gemm_bias( gemm1_bias=gemm1_bias, gemm2_bias=gemm2_bias, ) + + +# Test: routing_logits and routing_bias dtype flexibility +# Verifies that all MoE backends accept both float32 and bfloat16 +# for routing_logits and routing_bias. +@pytest.mark.parametrize("num_tokens", [4, 32, 768]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [1024]) +@pytest.mark.parametrize( + "moe_impl", + [ + pytest.param(FP8PerTensorMoe(), id="FP8_PerTensor"), + pytest.param( + FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_DEEPSEEK), + id="FP8_Block_DeepSeek", + ), + pytest.param( + FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_MXFP8), + id="FP8_Block_MxFp8", + ), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), + pytest.param(MxInt4BlockScaleMoe(), id="MxInt4xBf16"), + pytest.param(BF16Moe(), id="Bf16xBf16"), + ], +) +@pytest.mark.parametrize( + "routing_config", + [ + # Renormalize routing (no bias) + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + "compatible_moe_impls": [ + FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe, BF16Moe, MxInt4BlockScaleMoe + ], + "compatible_intermediate_size": [512, 768, 1024, 2048, 2944], + "enable_autotune": True, + }, + id="Renorm_128e_top8", + ), + # DeepSeekV3 routing (with bias) — tests bias dtype flexibility too + pytest.param( + { + "num_experts": 256, + "top_k": 8, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [ + FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe, BF16Moe, MxInt4BlockScaleMoe + ], + "compatible_intermediate_size": [512, 768, 1024, 2048, 2944], + "enable_autotune": True, + }, + id="DeepSeekV3_256e_top8_bias", + ), + ], +) +@pytest.mark.parametrize( + "weight_processing", + [ + pytest.param( + { + "use_shuffled_weight": False, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [ + FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe, MxInt4BlockScaleMoe + ], + }, + id="MajorK", + ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.BlockMajorK, + "compatible_moe_impls": [ + FP8BlockScaleMoe, BF16Moe, MxInt4BlockScaleMoe + ], + }, + id="Shuffled_BlockMajorK", + ), + ], +) +@pytest.mark.parametrize("activation_type", [ActivationType.Swiglu]) +@pytest.mark.parametrize( + "routing_logits_dtype", [torch.bfloat16, torch.float32], ids=["logits_bf16", "logits_fp32"] +) +@pytest.mark.parametrize( + "routing_bias_dtype", [torch.bfloat16, torch.float32], ids=["bias_bf16", "bias_fp32"] +) +def test_routing_dtype_flexibility( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + activation_type, + cache_permute_indices, + routing_logits_dtype, + routing_bias_dtype, +): + """Test that routing_logits and routing_bias accept both float32 and bfloat16.""" + run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + activation_type, + cache_permute_indices, + routing_logits_dtype=routing_logits_dtype, + routing_bias_dtype=routing_bias_dtype, + )