Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 40 additions & 28 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,11 @@ class FusedMoeLauncher {
check_routing();
prepare_routing();

// Execute routing
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
cudaStream_t routing_stream = get_stream(hidden_states.device());

// Execute routing (handles both pre-computed and from-logits paths)
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);

routing_runner.run(
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
Expand All @@ -353,6 +354,7 @@ class FusedMoeLauncher {
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/,
static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(),
nullptr /*expertIds - not used when computing from logits*/,
static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
Expand Down Expand Up @@ -1114,8 +1116,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale,
Optional<TensorView> const& gemm2_bias, Optional<TensorView> const& output1_scales_scalar,
Optional<TensorView> const& output1_scales_gate_scalar,
Optional<TensorView> const& output2_scales_scalar, TensorView const& expert_indices,
TensorView const& expert_weights)
Optional<TensorView> const& output2_scales_scalar, TensorView const& topk_ids,
Optional<TensorView> const& topk_weights)
: FusedMoeLauncher(routing_logits, routing_bias, hidden_states, gemm1_weights,
output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights,
output2_scales_scalar),
Expand All @@ -1127,8 +1129,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
gemm1_clamp_limit(gemm1_clamp_limit),
gemm2_weights_scale(gemm2_weights_scale),
gemm2_bias(gemm2_bias),
expert_indices(expert_indices),
expert_weights(expert_weights) {}
topk_ids(topk_ids),
topk_weights(topk_weights) {}

void init(std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>&& args,
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight,
Expand Down Expand Up @@ -1189,9 +1191,9 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
workspace.total_num_padded_tokens = static_cast<int*>(total_num_padded_tokens.data_ptr());
workspace.total_max_padded_tokens = max_num_padded_tokens;
workspace.ProjUpTileN = tile_tokens_dim;
workspace.routing_expert_indexes =
static_cast<int*>(const_cast<void*>(expert_indices.data_ptr()));
workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr());
workspace.routing_expert_indexes = static_cast<int*>(const_cast<void*>(topk_ids.data_ptr()));
workspace.expert_weights =
topk_weights.has_value() ? const_cast<void*>(topk_weights.value().data_ptr()) : nullptr;
workspace.permuted_idx_size = static_cast<int*>(total_num_padded_tokens.data_ptr());
workspace.expanded_idx_to_permuted_idx =
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr());
Expand Down Expand Up @@ -1330,8 +1332,9 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
int32_t max_num_padded_tokens_gemm1{};
int32_t max_num_padded_tokens_gemm2{};
Optional<Tensor> gemm1_output_scale;
TensorView expert_indices;
TensorView expert_weights;
TensorView topk_ids; // [num_tokens, top_k] - pre-computed or output top-k expert indices
Optional<TensorView> topk_weights; // [num_tokens, top_k] - pre-computed or output top-k routing
// weights (optional for packed format)

public:
Array<Tensor> run(int64_t moe_tactic, bool enable_pdl = true,
Expand All @@ -1344,21 +1347,30 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
cudaStream_t routing_stream = get_stream(hidden_states.device());

routing_runner.run(
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
args->routed_scaling_factor, static_cast<int*>(expert_indices.data_ptr()),
static_cast<int*>(expert_count_histogram.data_ptr()),
static_cast<int*>(total_num_padded_tokens.data_ptr()),
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/,
static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(),
static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype,
use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);
// Determine routing mode based on inputs:
// - If routing_logits != nullptr: compute routing from logits
// - If routing_logits == nullptr && topk_weights has value: unpacked pre-computed (mPtrTopKIds
// + mPtrTopKWeights)
// - If routing_logits == nullptr && topk_weights is None: packed pre-computed (mPtrTopKPacked)
bool use_unpacked_prerouted = args->routing_logits == nullptr && topk_weights.has_value();
int32_t* precomputed_topk_ids =
use_unpacked_prerouted ? static_cast<int32_t*>(topk_ids.data_ptr()) : nullptr;
void* topk_weights_ptr = topk_weights.has_value() ? topk_weights.value().data_ptr() : nullptr;
routing_runner.run(args->routing_logits, args->routing_bias, args->num_tokens,
args->num_experts, args->top_k, args->n_group, args->topk_group,
args->local_expert_offset, args->local_num_experts,
args->routed_scaling_factor, static_cast<int*>(topk_ids.data_ptr()),
static_cast<int*>(expert_count_histogram.data_ptr()),
static_cast<int*>(total_num_padded_tokens.data_ptr()),
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
nullptr /*permuted_idx_to_expanded_idx*/,
static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), topk_weights_ptr,
precomputed_topk_ids, static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt,
mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);

check_moe();
prepare_moe(moe_tactic);
Expand Down Expand Up @@ -1648,7 +1660,7 @@ Tensor trtllm_fp8_block_scale_moe(
}

Array<Tensor> trtllm_fp4_block_scale_moe(
Optional<TensorView> routing_logits, TensorView topk_ids, TensorView expert_weights,
Optional<TensorView> routing_logits, TensorView topk_ids, Optional<TensorView> topk_weights,
Optional<TensorView> routing_bias, TensorView hidden_states,
Optional<TensorView> hidden_states_scale, TensorView gemm1_weights,
TensorView gemm1_weights_scale, Optional<TensorView> gemm1_bias,
Expand Down Expand Up @@ -1762,7 +1774,7 @@ Array<Tensor> trtllm_fp4_block_scale_moe(
routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights,
gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights,
gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar,
output2_scales_scalar, topk_ids, expert_weights);
output2_scales_scalar, topk_ids, topk_weights);
launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true,
/*weight_layout=*/0, gated_act_type, mDtypeAct, mDtypeWeights);

Expand Down
21 changes: 14 additions & 7 deletions csrc/trtllm_fused_moe_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
int32_t localNumExperts, float routedScalingFactor, int32_t* routingExpertIndexes,
int32_t* expertCountHistogram, int32_t* permutedIdxSize,
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert,
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit,
int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias,
bool useRoutingScalesOnInput, bool useDeepSeekFp8,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* expertIds,
int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx,
int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, btg::Dtype dtypeElt,
btg::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8,
RoutingMethodType routingMethodType, cudaStream_t stream) {
if (routingMethodType == RoutingMethodType::DeepSeekV3) {
FLASHINFER_CHECK(topK <= 8, "For DeepSeek routing method, must have topK <= 8");
Expand All @@ -84,7 +84,10 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3

// input:
routingData.mPtrRoutingBias = routingBias;
routingData.mPtrScores = reinterpret_cast<float*>(routingLogits);
// Pre-computed routing support: when expertIds is provided, use it directly
routingData.mPtrScores =
expertIds == nullptr ? reinterpret_cast<float*>(routingLogits) : nullptr;
routingData.mPtrTopKIds = expertIds;
routingData.mNumTokens = numTokens;
routingData.mNumExperts = numExperts;
routingData.mNumExpertGroups = nGroup;
Expand Down Expand Up @@ -121,7 +124,9 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mPtrNumNonExitingCtas = numNonExitingCtas;

// input:
routingData.mPtrScores = routingLogits;
// Pre-computed routing support: when expertIds is provided, use it directly
routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr;
routingData.mPtrTopKIds = expertIds;
routingData.mNumTokens = numTokens;
routingData.mNumExperts = numExperts;
routingData.mTopK = topK;
Expand All @@ -147,7 +152,9 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mNormTopkProb = routingMethodType == RoutingMethodType::RenormalizeNaive;
routingData.mApplySoftmaxAfterTopK = routingMethodType == RoutingMethodType::Renormalize;

routingData.mPtrScores = routingLogits;
// Pre-computed routing support: when expertIds is provided, use it directly
routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr;
routingData.mPtrTopKIds = expertIds;

//
// Outputs
Expand Down
1 change: 1 addition & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
cutlass_fused_moe,
reorder_rows_for_gated_act_gemm,
trtllm_fp4_block_scale_moe,
trtllm_fp4_block_scale_prerouted_moe,
trtllm_fp4_block_scale_routed_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
gen_trtllm_gen_fused_moe_sm100_module,
reorder_rows_for_gated_act_gemm,
trtllm_fp4_block_scale_moe,
trtllm_fp4_block_scale_prerouted_moe,
trtllm_fp4_block_scale_routed_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
Expand All @@ -52,6 +53,7 @@
"reorder_rows_for_gated_act_gemm",
"trtllm_bf16_moe",
"trtllm_fp4_block_scale_moe",
"trtllm_fp4_block_scale_prerouted_moe",
"trtllm_fp4_block_scale_routed_moe",
"trtllm_fp8_block_scale_moe",
"trtllm_fp8_per_tensor_scale_moe",
Expand Down
126 changes: 117 additions & 9 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2448,8 +2448,9 @@ def trtllm_fp4_block_scale_moe(


@flashinfer_api
def trtllm_fp4_block_scale_routed_moe(
def trtllm_fp4_block_scale_prerouted_moe(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
hidden_states_scale: Optional[torch.Tensor],
Expand Down Expand Up @@ -2480,15 +2481,22 @@ def trtllm_fp4_block_scale_routed_moe(
output: Optional[torch.Tensor] = None,
tune_max_num_tokens: int = 8192,
) -> List[torch.Tensor]:
"""FP4 block scale MoE operation.
"""FP4 block scale MoE operation with pre-computed routing.

This function accepts pre-computed top-k expert assignments (topk_ids and topk_weights)
instead of computing routing from logits. Use this when routing decisions are made
externally (e.g., for load balancing across GPUs or custom routing logic).

Args:
topk_ids (torch.Tensor): shape [seq_len, top_k]
Tensor of top-k indices and expert weights. Dtype must be int32.
It must represent a packed value. The most significant 16/32 bits represent the score and
the least significant 16 bits represent the index of the chosen expert (unsigned).
Tensor of top-k expert indices. Dtype must be int32.
Each element contains the index of the selected expert.
topk_weights (torch.Tensor): shape [seq_len, top_k]
Tensor of top-k routing weights. Dtype must be bfloat16.
Each element contains the routing weight for the corresponding expert.
These weights are used in the final weighted sum of expert outputs.
routing_bias (Optional[torch.Tensor]): shape [num_experts]
Tensor of routing bias. Can be None for some routing methods. Must be the same type as routing logits.
Tensor of routing bias. Can be None for pre-routed mode.
hidden_states (torch.Tensor): shape [seq_len, hidden_size // 2 if nvfp4 else hidden_size]
Tensor of input hidden states. Supports bfloat16, mxfp8, and nvfp4 (packed into uint8)
hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // (32 if mxfp8, 16 if mxfp4)]
Expand Down Expand Up @@ -2541,12 +2549,112 @@ def trtllm_fp4_block_scale_routed_moe(

Returns:
List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output.
Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing.
Otherwise, returns intermediate results that need further processing.
"""
# routing_logits=None triggers pre-computed routing path in the kernel
return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe(
None,
None, # routing_logits
topk_ids,
None,
topk_weights,
routing_bias,
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm1_bias,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
gemm2_weights,
gemm2_weights_scale,
gemm2_bias,
output1_scale_scalar,
output1_scale_gate_scalar,
output2_scale_scalar,
num_experts,
top_k,
n_group,
topk_group,
intermediate_size,
local_expert_offset,
local_num_experts,
routed_scaling_factor,
routing_method_type,
do_finalize,
enable_pdl,
gated_act_type,
output,
tune_max_num_tokens,
)


def trtllm_fp4_block_scale_routed_moe(
topk_ids: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
hidden_states_scale: Optional[torch.Tensor],
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm1_bias: Optional[torch.Tensor],
gemm1_alpha: Optional[torch.Tensor],
gemm1_beta: Optional[torch.Tensor],
gemm1_clamp_limit: Optional[torch.Tensor],
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
gemm2_bias: Optional[torch.Tensor],
output1_scale_scalar: Optional[torch.Tensor],
output1_scale_gate_scalar: Optional[torch.Tensor],
output2_scale_scalar: Optional[torch.Tensor],
num_experts: int,
top_k: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: Optional[float],
routing_method_type: int = 0,
do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
gated_act_type: int = 0,
output: Optional[torch.Tensor] = None,
tune_max_num_tokens: int = 8192,
) -> List[torch.Tensor]:
"""FP4 block scale MoE operation with packed routing input.

.. deprecated::
This function is deprecated and will be removed in a future version.
Use :func:`trtllm_fp4_block_scale_prerouted_moe` instead, which accepts
separate ``topk_ids`` and ``topk_weights`` tensors.

Args:
topk_ids (torch.Tensor): shape [seq_len, top_k]
Tensor of packed top-k indices and weights. Dtype must be int32.
The most significant 16/32 bits represent the score and
the least significant 16 bits represent the expert index.
routing_bias (Optional[torch.Tensor]): shape [num_experts]
Tensor of routing bias. Can be None.
hidden_states (torch.Tensor): Input hidden states.
... (other args same as trtllm_fp4_block_scale_prerouted_moe)

Returns:
List[torch.Tensor]: Output tensors.
"""
import warnings

warnings.warn(
"trtllm_fp4_block_scale_routed_moe is deprecated and will be removed in version 0.8. "
"Use trtllm_fp4_block_scale_prerouted_moe instead, which accepts separate "
"topk_ids (int32) and topk_weights (bfloat16) tensors.",
DeprecationWarning,
stacklevel=2,
)
# Old API uses packed format - pass to kernel with topk_weights=None
# The kernel will use mPtrTopKPacked path instead of mPtrTopKIds + mPtrTopKWeights
return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe(
None, # routing_logits
topk_ids, # packed format
None, # topk_weights (None for packed format)
routing_bias,
hidden_states,
hidden_states_scale,
Expand Down
10 changes: 5 additions & 5 deletions include/flashinfer/trtllm/fused_moe/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ class Runner {
int32_t localNumExperts, float routedScalingFactor, int32_t* routingExpertIndexes,
int32_t* expertCountHistogram, int32_t* permutedIdxSize,
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert,
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas,
batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype dtypeBias,
bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType,
cudaStream_t stream);
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* expertIds,
int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit,
int32_t* numNonExitingCtas, batchedGemm::trtllm::gen::Dtype dtypeElt,
batchedGemm::trtllm::gen::Dtype dtypeBias, bool useRoutingScalesOnInput,
bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream);

private:
int32_t mTileTokensDim{8};
Expand Down
Loading
Loading