Skip to content

Commit a6716f0

Browse files
committed
clean moe_sort
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
1 parent 6ee2634 commit a6716f0

File tree

3 files changed

+11
-35
lines changed

3 files changed

+11
-35
lines changed

cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ std::vector<torch::Tensor> moe_topk_sort_impl(torch::optional<torch::Tensor> con
3030
torch::optional<torch::Tensor> const& token_final_scales, int64_t const num_experts, int64_t const top_k,
3131
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const local_expert_offset,
3232
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
33-
int64_t const routing_method_type)
33+
RoutingMethodType const routing_method_type)
3434
{
3535
int64_t const num_tokens
3636
= token_selected_experts.has_value() ? token_selected_experts->size(0) : routing_logits->size(0);
@@ -80,7 +80,7 @@ std::vector<torch::Tensor> moe_topk_sort_impl(torch::optional<torch::Tensor> con
8080
num_tokens_per_expert.data_ptr<int>(), tile_idx_to_expert_idx.data_ptr<int>(),
8181
tile_idx_to_mn_limit.data_ptr<int>(), num_non_exiting_tiles.data_ptr<int>(),
8282
batchedGemm::trtllm::gen::Dtype::Void /* dtypeElt */, false /* use_routing_scales_on_input */,
83-
false /* use_deep_seek_fp8 */, static_cast<RoutingMethodType>(routing_method_type), stream);
83+
false /* use_deep_seek_fp8 */, routing_method_type, stream);
8484

8585
std::vector<torch::Tensor> results{tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx,
8686
permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles};
@@ -106,14 +106,12 @@ std::vector<torch::Tensor> moe_topk_sort(torch::Tensor const& routing_logits,
106106
}
107107
return moe_topk_sort_impl(routing_logits, routing_bias, std::nullopt, std::nullopt, num_experts, top_k, n_group,
108108
topk_group, local_expert_offset, local_num_experts, routed_scaling_factor, tile_tokens_dim,
109-
routing_method_type);
109+
static_cast<RoutingMethodType>(routing_method_type));
110110
}
111111

112112
std::vector<torch::Tensor> moe_sort(torch::Tensor const& token_selected_experts,
113113
torch::Tensor const& token_final_scales, int64_t const num_experts, int64_t const top_k,
114-
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const local_expert_offset,
115-
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
116-
int64_t const routing_method_type)
114+
int64_t const local_expert_offset, int64_t const local_num_experts, int64_t const tile_tokens_dim)
117115
{
118116
TORCH_CHECK(token_selected_experts.dim() == 2, "token_selected_experts must be 2D.");
119117
int64_t const num_tokens = token_selected_experts.size(0);
@@ -122,8 +120,8 @@ std::vector<torch::Tensor> moe_sort(torch::Tensor const& token_selected_experts,
122120
TORCH_CHECK(token_final_scales.size(0) == num_tokens, "token_final_scales.size(0) must be num_tokens.");
123121
TORCH_CHECK(token_final_scales.size(1) == top_k, "token_final_scales.size(1) must be top_k.");
124122
return moe_topk_sort_impl(std::nullopt, std::nullopt, token_selected_experts, token_final_scales, num_experts,
125-
top_k, n_group, topk_group, local_expert_offset, local_num_experts, routed_scaling_factor, tile_tokens_dim,
126-
routing_method_type);
123+
top_k, std::nullopt, std::nullopt, local_expert_offset, local_num_experts, std::nullopt, tile_tokens_dim,
124+
RoutingMethodType::Renormalize);
127125
}
128126

129127
// Permute
@@ -259,9 +257,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
259257
"int? topk_group, int local_expert_offset, int local_num_experts, float? routed_scaling_factor, int "
260258
"tile_tokens_dim, int routing_method_type) -> Tensor[]");
261259
m.def(
262-
"moe_sort(Tensor token_selected_experts, Tensor token_final_scales, int num_experts, int top_k, int? n_group, "
263-
"int? topk_group, int local_expert_offset, int local_num_experts, float? routed_scaling_factor, int "
264-
"tile_tokens_dim, int routing_method_type) -> Tensor[]");
260+
"moe_sort(Tensor token_selected_experts, Tensor token_final_scales, int num_experts, int top_k, "
261+
"int local_expert_offset, int local_num_experts, int tile_tokens_dim) -> Tensor[]");
265262
m.def(
266263
"moe_permute(Tensor input, Tensor? input_sf, Tensor tile_idx_to_mn_limit, Tensor permuted_idx_to_expanded_idx, "
267264
"Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> (Tensor, Tensor?)");

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ...utils import AuxStreamType, Fp4QuantizedTensor, ceil_div
1212
from .fused_moe_cutlass import CutlassFusedMoE
1313
from .quantization import MoEWeightLoadingMode
14-
from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod
14+
from .routing import BaseMoeRoutingMethod
1515

1616

1717
@torch.compile(options={"max-autotune": True})
@@ -358,31 +358,15 @@ def forward_chunk_nvfp4(
358358
dim=0,
359359
sizes=None if use_dp_padding else all_rank_num_tokens)
360360

361-
# DeepSeekV3 style routing
362-
if isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod):
363-
top_k = self.routing_method.routing_impl.top_k
364-
n_group = self.routing_method.routing_impl.n_group
365-
topk_group = self.routing_method.routing_impl.topk_group
366-
routed_scaling_factor = self.routing_method.routing_impl.routed_scaling_factor
367-
else:
368-
top_k = self.routing_method.top_k
369-
n_group = None
370-
topk_group = None
371-
routed_scaling_factor = None
372-
373361
tile_size = 128
374362
tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles = torch.ops.trtllm.moe_sort(
375363
token_selected_experts=token_selected_experts,
376364
token_final_scales=token_final_scales,
377365
num_experts=self.num_slots,
378-
top_k=top_k,
379-
n_group=n_group,
380-
topk_group=topk_group,
366+
top_k=self.routing_method.experts_per_token,
381367
local_expert_offset=self.slot_start,
382368
local_num_experts=self.expert_size_per_partition,
383-
routed_scaling_factor=routed_scaling_factor,
384369
tile_tokens_dim=tile_size,
385-
routing_method_type=self.routing_method.routing_method_type,
386370
)
387371

388372
permuted_x, permuted_sf = torch.ops.trtllm.moe_permute(
@@ -392,7 +376,7 @@ def forward_chunk_nvfp4(
392376
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
393377
num_non_exiting_tiles=num_non_exiting_tiles,
394378
tile_tokens_dim=tile_size,
395-
top_k=top_k,
379+
top_k=self.routing_method.experts_per_token,
396380
)
397381
h1 = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
398382
input=permuted_x.view(torch.float4_e2m1fn_x2),

tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch
33

44
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import cute_dsl_nvfp4_grouped_gemm_ref
5-
from tensorrt_llm._torch.modules.fused_moe.routing import RoutingMethodType
65
from tensorrt_llm._torch.utils import unswizzle_sf
76
from tensorrt_llm._utils import get_sm_version
87

@@ -45,13 +44,9 @@ def test_moe_sort(num_tokens: int, top_k: int, ep_size: int, tile_size: int):
4544
token_final_scales=token_final_scales,
4645
num_experts=num_experts,
4746
top_k=top_k,
48-
n_group=1,
49-
topk_group=1,
5047
local_expert_offset=0,
5148
local_num_experts=num_local_experts,
52-
routed_scaling_factor=1.0,
5349
tile_tokens_dim=tile_size,
54-
routing_method_type=RoutingMethodType.DeepSeekV3,
5550
)
5651

5752
num_tokens_per_expert = torch.bincount(token_selected_experts.flatten(), minlength=num_experts)

0 commit comments

Comments
 (0)