diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 5741611644..e875019174 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -109,6 +109,10 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher( args.routing_logits = routing_logits->data; auto const routing_bias_dtype = routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16; + auto bias_dtype = btg::Dtype::Fp32; + if (routing_bias_dtype == dl_bfloat16) { + bias_dtype = btg::Dtype::Bfloat16; + } args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr; args.hidden_states = hidden_states->data; args.gemm1_weights = gemm1_weights->data; @@ -172,20 +176,21 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher( tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); cudaStream_t stream = get_stream(routing_logits->device); - routing_runner.run( - routing_logits->data, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, - args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, - args.routed_scaling_factor, static_cast(expert_indexes->data), - static_cast(expert_count_histogram->data), - static_cast(total_num_padded_tokens->data), - static_cast(expanded_idx_to_permuted_idx->data), - nullptr /*static_cast(permuted_idx_to_expanded_idx->data)*/, - static_cast(permuted_idx_to_token_idx->data), expert_weights->data, - static_cast(num_tokens_per_expert->data), - static_cast(cta_idx_xy_to_batch_idx->data), - static_cast(cta_idx_xy_to_mn_limit->data), - static_cast(num_non_exiting_ctas->data), args.mDtypeElt, use_routing_scales_on_input, - false /* use_deep_seek_fp8 */, static_cast(routing_method_type), stream); + routing_runner.run(routing_logits->data, args.routing_bias, args.num_tokens, args.num_experts, + args.top_k, args.n_group, args.topk_group, args.local_expert_offset, + args.local_num_experts, args.routed_scaling_factor, + static_cast(expert_indexes->data), + static_cast(expert_count_histogram->data), + static_cast(total_num_padded_tokens->data), + static_cast(expanded_idx_to_permuted_idx->data), + nullptr /*static_cast(permuted_idx_to_expanded_idx->data)*/, + static_cast(permuted_idx_to_token_idx->data), expert_weights->data, + static_cast(num_tokens_per_expert->data), + static_cast(cta_idx_xy_to_batch_idx->data), + static_cast(cta_idx_xy_to_mn_limit->data), + static_cast(num_non_exiting_ctas->data), args.mDtypeElt, bias_dtype, + use_routing_scales_on_input, false /* use_deep_seek_fp8 */, + static_cast(routing_method_type), stream); // MoE kernel except routing TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8."; @@ -366,7 +371,8 @@ void trtllm_fp8_block_scale_moe_launcher( auto const routing_bias_dtype = routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16; - args.mDtypeExpW = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + args.mDtypeExpW = btg::Dtype::Bfloat16 ; + //routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; args.routing_logits = static_cast(routing_logits->data); args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr; args.hidden_states = hidden_states->data; @@ -398,6 +404,7 @@ void trtllm_fp8_block_scale_moe_launcher( alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits->device); Tensor permuted_idx_to_token_idx = alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits->device); + // routing_bias_dtype, Tensor expert_weights = alloc_tensor({args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits->device); Tensor expert_indexes = @@ -444,8 +451,8 @@ void trtllm_fp8_block_scale_moe_launcher( static_cast(num_tokens_per_expert->data), static_cast(cta_idx_xy_to_batch_idx->data), static_cast(cta_idx_xy_to_mn_limit->data), - static_cast(num_non_exiting_ctas->data), args.mDtypeElt, false, true, - static_cast(routing_method_type), stream); + static_cast(num_non_exiting_ctas->data), args.mDtypeElt, args.mDtypeExpW, + false, true, static_cast(routing_method_type), stream); // MoE kernel except routing TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8."; @@ -827,7 +834,7 @@ Array trtllm_fp4_block_scale_moe_launcher( static_cast(num_tokens_per_expert->data), static_cast(cta_idx_xy_to_batch_idx->data), static_cast(cta_idx_xy_to_mn_limit->data), - static_cast(num_non_exiting_ctas->data), args.mDtypeElt, + static_cast(num_non_exiting_ctas->data), args.mDtypeElt, args.mDtypeExpW, false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */, static_cast(routing_method_type), stream); diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 931431fa2f..2558706d2c 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -55,13 +55,14 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, - int32_t* numNonExitingCtas, btg::Dtype dtypeElt, bool useRoutingScalesOnInput, - bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) { + int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias, + bool useRoutingScalesOnInput, bool useDeepSeekFp8, + RoutingMethodType routingMethodType, cudaStream_t stream) { if (routingMethodType == RoutingMethodType::DeepSeekV3) { FLASHINFER_CHECK(topK <= 8, "For DeepSeek routing method, must have topK <= 8"); FLASHINFER_CHECK(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4"); moe::dev::routing::routingDeepSeek::Data routingData; - routingData.mDtypeExpW = btg::Dtype::Bfloat16; + routingData.mDtypeExpW = btg::Dtype::Bfloat16; routingData.mUsePdl = true; // output: diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 5f066468e6..bff48fdb3c 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -113,8 +113,9 @@ class Runner { int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, - batchedGemm::trtllm::gen::Dtype dtypeElt, bool useRoutingScalesOnInput, - bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream); + batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype dtypeBias, + bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, + cudaStream_t stream); private: int32_t mTileTokensDim{8};