Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
args.routing_logits = routing_logits->data;
auto const routing_bias_dtype =
routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would routing_bias not have value? Can we combine the statements and ensure that we always follow the dtype specified as the routing_bias tensor input's type?

auto bias_dtype = btg::Dtype::Fp32;
if (routing_bias_dtype == dl_bfloat16) {
bias_dtype = btg::Dtype::Bfloat16;
}
args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr;
args.hidden_states = hidden_states->data;
args.gemm1_weights = gemm1_weights->data;
Expand Down Expand Up @@ -172,20 +176,21 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(

tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
cudaStream_t stream = get_stream(routing_logits->device);
routing_runner.run(
routing_logits->data, args.routing_bias, args.num_tokens, args.num_experts, args.top_k,
args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts,
args.routed_scaling_factor, static_cast<int*>(expert_indexes->data),
static_cast<int*>(expert_count_histogram->data),
static_cast<int*>(total_num_padded_tokens->data),
static_cast<int*>(expanded_idx_to_permuted_idx->data),
nullptr /*static_cast<int*>(permuted_idx_to_expanded_idx->data)*/,
static_cast<int*>(permuted_idx_to_token_idx->data), expert_weights->data,
static_cast<int*>(num_tokens_per_expert->data),
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, use_routing_scales_on_input,
false /* use_deep_seek_fp8 */, static_cast<RoutingMethodType>(routing_method_type), stream);
routing_runner.run(routing_logits->data, args.routing_bias, args.num_tokens, args.num_experts,
args.top_k, args.n_group, args.topk_group, args.local_expert_offset,
args.local_num_experts, args.routed_scaling_factor,
static_cast<int*>(expert_indexes->data),
static_cast<int*>(expert_count_histogram->data),
static_cast<int*>(total_num_padded_tokens->data),
static_cast<int*>(expanded_idx_to_permuted_idx->data),
nullptr /*static_cast<int*>(permuted_idx_to_expanded_idx->data)*/,
static_cast<int*>(permuted_idx_to_token_idx->data), expert_weights->data,
static_cast<int*>(num_tokens_per_expert->data),
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, bias_dtype,
use_routing_scales_on_input, false /* use_deep_seek_fp8 */,
static_cast<RoutingMethodType>(routing_method_type), stream);

// MoE kernel except routing
TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8.";
Expand Down Expand Up @@ -366,7 +371,8 @@ void trtllm_fp8_block_scale_moe_launcher(

auto const routing_bias_dtype =
routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16;
args.mDtypeExpW = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
args.mDtypeExpW = btg::Dtype::Bfloat16 ;
//routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
args.routing_logits = static_cast<float*>(routing_logits->data);
args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr;
args.hidden_states = hidden_states->data;
Expand Down Expand Up @@ -398,6 +404,7 @@ void trtllm_fp8_block_scale_moe_launcher(
alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits->device);
Tensor permuted_idx_to_token_idx =
alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits->device);
// routing_bias_dtype,
Tensor expert_weights =
alloc_tensor({args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits->device);
Tensor expert_indexes =
Expand Down Expand Up @@ -444,8 +451,8 @@ void trtllm_fp8_block_scale_moe_launcher(
static_cast<int*>(num_tokens_per_expert->data),
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, false, true,
static_cast<RoutingMethodType>(routing_method_type), stream);
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, args.mDtypeExpW,
false, true, static_cast<RoutingMethodType>(routing_method_type), stream);

// MoE kernel except routing
TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8.";
Expand Down Expand Up @@ -827,7 +834,7 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
static_cast<int*>(num_tokens_per_expert->data),
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt,
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, args.mDtypeExpW,
false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */,
static_cast<RoutingMethodType>(routing_method_type), stream);

Expand Down
7 changes: 4 additions & 3 deletions csrc/trtllm_fused_moe_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert,
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit,
int32_t* numNonExitingCtas, btg::Dtype dtypeElt, bool useRoutingScalesOnInput,
bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) {
int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias,
bool useRoutingScalesOnInput, bool useDeepSeekFp8,
RoutingMethodType routingMethodType, cudaStream_t stream) {
if (routingMethodType == RoutingMethodType::DeepSeekV3) {
FLASHINFER_CHECK(topK <= 8, "For DeepSeek routing method, must have topK <= 8");
FLASHINFER_CHECK(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4");
moe::dev::routing::routingDeepSeek::Data routingData;
routingData.mDtypeExpW = btg::Dtype::Bfloat16;
routingData.mDtypeExpW = btg::Dtype::Bfloat16;
routingData.mUsePdl = true;

// output:
Expand Down
5 changes: 3 additions & 2 deletions include/flashinfer/trtllm/fused_moe/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ class Runner {
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert,
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas,
batchedGemm::trtllm::gen::Dtype dtypeElt, bool useRoutingScalesOnInput,
bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream);
batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype dtypeBias,
bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType,
cudaStream_t stream);

private:
int32_t mTileTokensDim{8};
Expand Down
Loading