Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -333,18 +333,18 @@ void trtllm_fp8_block_scale_moe_launcher(
<< "routing_bias has incorrect shape.";
}

if (n_group <= 0 || topk_group <= 0) {
TVM_FFI_ICHECK_EQ(top_k, 1) << "Current routing kernel (no groups) only supports top_k=1.";
} else {
TVM_FFI_ICHECK_LE(top_k, 8) << "Current routing kernel (with groups) only supports top_k<=8.";
TVM_FFI_ICHECK_LE(topk_group, 4)
<< "Current routing kernel (with groups) only supports topk_group<=4.";
TVM_FFI_ICHECK_LE(topk_group, n_group) << "n_group must not be smaller than topk_group.";
TVM_FFI_ICHECK_EQ(num_experts % n_group, 0) << "num_experts must be divisible by n_group";
// This check ensures we have enough experts in the selected groups to handle the top_k routing
TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group))
<< "top_k must be less than total number of experts in selected groups";
}
// if (n_group <= 0 || topk_group <= 0) {
// TVM_FFI_ICHECK_EQ(top_k, 1) << "Current routing kernel (no groups) only supports top_k=1.";
// } else {
// TVM_FFI_ICHECK_LE(top_k, 8) << "Current routing kernel (with groups) only supports
// top_k<=8."; TVM_FFI_ICHECK_LE(topk_group, 4)
// << "Current routing kernel (with groups) only supports topk_group<=4.";
// TVM_FFI_ICHECK_LE(topk_group, n_group) << "n_group must not be smaller than topk_group.";
// TVM_FFI_ICHECK_EQ(num_experts % n_group, 0) << "num_experts must be divisible by n_group";
// // This check ensures we have enough experts in the selected groups to handle the top_k
// routing TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group))
// << "top_k must be less than total number of experts in selected groups";
// }
TVM_FFI_ICHECK_EQ(num_experts % 4, 0)
<< "Routing kernel expects that num_experts must be divisible by 4";
TVM_FFI_ICHECK_GT(num_experts, top_k) << "num_experts must be greater than top_k";
Expand Down Expand Up @@ -684,8 +684,8 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
TVM_FFI_ICHECK(topk_group.has_value()) << "if n_group is given, topk_group must be given";
TVM_FFI_ICHECK_EQ(num_experts % n_group.value(), 0)
<< "num_experts must be divisible by n_group";
TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
<< "Current routing kernel (with groups) only supports top_k<=8 && top_k>0.";
// TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
// << "Current routing kernel (with groups) only supports top_k<=8 && top_k>0.";
TVM_FFI_ICHECK(topk_group.value() <= 4 && topk_group.value() > 0)
<< "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0.";
TVM_FFI_ICHECK_LE(topk_group.value(), n_group.value())
Expand All @@ -698,9 +698,9 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
static_cast<RoutingMethodType>(routing_method_type) ==
RoutingMethodType::RenormalizeNaive ||
static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::TopK) {
TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
<< "Current routing kernel (no groups, renormalize/topk) only supports top_k<=8 && "
"top_k>0.";
// TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
// << "Current routing kernel (no groups, renormalize/topk) only supports top_k<=8 && "
// "top_k>0.";
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
TVM_FFI_ICHECK_EQ(top_k, 1)
<< "Current routing kernel (no groups, Llama4) only supports top_k=1.";
Expand Down
453 changes: 272 additions & 181 deletions csrc/trtllm_fused_moe_routing_deepseek.cu

Large diffs are not rendered by default.

227 changes: 143 additions & 84 deletions csrc/trtllm_fused_moe_routing_llama4.cu

Large diffs are not rendered by default.

336 changes: 255 additions & 81 deletions csrc/trtllm_fused_moe_routing_renormalize.cu

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions csrc/trtllm_fused_moe_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mUsePdl = true;

// output:
routingData.mPtrExpertIdx = routingExpertIndexes;
routingData.mPtrTopKPacked = routingExpertIndexes;
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permutedIdxSize;
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
routingData.mPtrExpertWeights = expertWeights;
routingData.mPtrTopKWeights = expertWeights;

routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx;
routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit;
Expand Down Expand Up @@ -102,12 +102,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mUsePdl = true;

// output:
routingData.mPtrExpertIdx = routingExpertIndexes;
routingData.mPtrTopKPacked = routingExpertIndexes;
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permutedIdxSize;
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
routingData.mPtrExpertWeights = expertWeights;
routingData.mPtrTopKWeights = expertWeights;

routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx;
routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit;
Expand Down Expand Up @@ -144,12 +144,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
//
// Outputs
//
routingData.mPtrExpertIdx = routingExpertIndexes;
routingData.mPtrTopKPacked = routingExpertIndexes;
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permutedIdxSize;
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
routingData.mPtrExpertWeights = expertWeights;
routingData.mPtrTopKWeights = expertWeights;

//
// Grouped Gemm Launch Config Buffers
Expand Down
9 changes: 5 additions & 4 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,7 +1340,7 @@ def _fake_trtllm_fp8_per_tensor_scale_moe(
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float,
routed_scaling_factor: Optional[float],
use_routing_scales_on_input: bool,
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
Expand Down Expand Up @@ -1372,7 +1372,7 @@ def trtllm_fp8_block_scale_moe_op(
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float,
routed_scaling_factor: Optional[float],
tile_tokens_dim: int,
routing_method_type: int,
use_shuffled_weight: bool = False,
Expand All @@ -1381,6 +1381,7 @@ def trtllm_fp8_block_scale_moe_op(
) -> torch.Tensor:
if enable_pdl is None:
enable_pdl = device_support_pdl(hidden_states.device)

# Call the C++ function for block scale MoE
moe_op.trtllm_fp8_block_scale_moe(
routing_logits,
Expand Down Expand Up @@ -1427,7 +1428,7 @@ def _fake_trtllm_fp8_block_scale_moe(
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float,
routed_scaling_factor: Optional[float],
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
use_shuffled_weight: bool = False,
Expand Down Expand Up @@ -1755,7 +1756,7 @@ def trtllm_fp8_block_scale_moe(
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float,
routed_scaling_factor: Optional[float],
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
use_shuffled_weight: bool = False,
Expand Down
85 changes: 54 additions & 31 deletions include/flashinfer/trtllm/fused_moe/DevKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "../../exception.h"
// #include <tensorrt_llm/common/assert.h>
#include "flashinfer/trtllm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"

namespace moe::dev {

Expand Down Expand Up @@ -112,39 +113,61 @@ namespace moe::dev {
FLASHINFER_WARN("Unsupported pair"); \
}

#define LAUNCH_ROUTING(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float), kernel, numBlocks, numThreads, \
smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else { \
FLASHINFER_WARN("Unsupported dtypeExpW"); \
#define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, coopLaunch, \
LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else { \
TLLM_LOG_ERROR("Unsupported dtypeExpW"); \
}

#define LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
stream, extraFlag, forceFloatInput) \
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, true), kernel, numBlocks, numThreads, \
smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, false), kernel, numBlocks, numThreads, \
smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, true), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, true), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, false), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, false), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else { \
FLASHINFER_WARN("Unsupported dtypeExpW"); \
#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, \
numThreads, smemSize, stream, extraFlag, \
forceFloatInput, numExperts) \
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, true), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \
kernel, numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, false), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \
kernel, numBlocks, numThreads, smemSize, stream); \
} else { \
TLLM_LOG_ERROR("Unsupported dtypeExpW"); \
}

////////////////////////////////////////////////////////////////////////////////////////////////////

#define LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
stream, extraFlag1, numExperts) \
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \
kernel, numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \
kernel, numBlocks, numThreads, smemSize, stream); \
} else { \
TLLM_LOG_ERROR("Unsupported dtypeExpW"); \
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading