@@ -30,7 +30,7 @@ std::vector<torch::Tensor> moe_topk_sort_impl(torch::optional<torch::Tensor> con
3030 torch::optional<torch::Tensor> const & token_final_scales, int64_t const num_experts, int64_t const top_k,
3131 std::optional<int64_t > const n_group, std::optional<int64_t > const topk_group, int64_t const local_expert_offset,
3232 int64_t const local_num_experts, std::optional<double > const routed_scaling_factor, int64_t const tile_tokens_dim,
33- int64_t const routing_method_type)
33+ RoutingMethodType const routing_method_type)
3434{
3535 int64_t const num_tokens
3636 = token_selected_experts.has_value () ? token_selected_experts->size (0 ) : routing_logits->size (0 );
@@ -80,7 +80,7 @@ std::vector<torch::Tensor> moe_topk_sort_impl(torch::optional<torch::Tensor> con
8080 num_tokens_per_expert.data_ptr <int >(), tile_idx_to_expert_idx.data_ptr <int >(),
8181 tile_idx_to_mn_limit.data_ptr <int >(), num_non_exiting_tiles.data_ptr <int >(),
8282 batchedGemm::trtllm::gen::Dtype::Void /* dtypeElt */ , false /* use_routing_scales_on_input */ ,
83- false /* use_deep_seek_fp8 */ , static_cast <RoutingMethodType>( routing_method_type) , stream);
83+ false /* use_deep_seek_fp8 */ , routing_method_type, stream);
8484
8585 std::vector<torch::Tensor> results{tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx,
8686 permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles};
@@ -106,14 +106,12 @@ std::vector<torch::Tensor> moe_topk_sort(torch::Tensor const& routing_logits,
106106 }
107107 return moe_topk_sort_impl (routing_logits, routing_bias, std::nullopt , std::nullopt , num_experts, top_k, n_group,
108108 topk_group, local_expert_offset, local_num_experts, routed_scaling_factor, tile_tokens_dim,
109- routing_method_type);
109+ static_cast <RoutingMethodType>( routing_method_type) );
110110}
111111
112112std::vector<torch::Tensor> moe_sort (torch::Tensor const & token_selected_experts,
113113 torch::Tensor const & token_final_scales, int64_t const num_experts, int64_t const top_k,
114- std::optional<int64_t > const n_group, std::optional<int64_t > const topk_group, int64_t const local_expert_offset,
115- int64_t const local_num_experts, std::optional<double > const routed_scaling_factor, int64_t const tile_tokens_dim,
116- int64_t const routing_method_type)
114+ int64_t const local_expert_offset, int64_t const local_num_experts, int64_t const tile_tokens_dim)
117115{
118116 TORCH_CHECK (token_selected_experts.dim () == 2 , " token_selected_experts must be 2D." );
119117 int64_t const num_tokens = token_selected_experts.size (0 );
@@ -122,8 +120,8 @@ std::vector<torch::Tensor> moe_sort(torch::Tensor const& token_selected_experts,
122120 TORCH_CHECK (token_final_scales.size (0 ) == num_tokens, " token_final_scales.size(0) must be num_tokens." );
123121 TORCH_CHECK (token_final_scales.size (1 ) == top_k, " token_final_scales.size(1) must be top_k." );
124122 return moe_topk_sort_impl (std::nullopt , std::nullopt , token_selected_experts, token_final_scales, num_experts,
125- top_k, n_group, topk_group , local_expert_offset, local_num_experts, routed_scaling_factor , tile_tokens_dim,
126- routing_method_type );
123+ top_k, std:: nullopt , std:: nullopt , local_expert_offset, local_num_experts, std:: nullopt , tile_tokens_dim,
124+ RoutingMethodType::Renormalize );
127125}
128126
129127// Permute
@@ -259,9 +257,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
259257 " int? topk_group, int local_expert_offset, int local_num_experts, float? routed_scaling_factor, int "
260258 " tile_tokens_dim, int routing_method_type) -> Tensor[]" );
261259 m.def (
262- " moe_sort(Tensor token_selected_experts, Tensor token_final_scales, int num_experts, int top_k, int? n_group, "
263- " int? topk_group, int local_expert_offset, int local_num_experts, float? routed_scaling_factor, int "
264- " tile_tokens_dim, int routing_method_type) -> Tensor[]" );
260+ " moe_sort(Tensor token_selected_experts, Tensor token_final_scales, int num_experts, int top_k, "
261+ " int local_expert_offset, int local_num_experts, int tile_tokens_dim) -> Tensor[]" );
265262 m.def (
266263 " moe_permute(Tensor input, Tensor? input_sf, Tensor tile_idx_to_mn_limit, Tensor permuted_idx_to_expanded_idx, "
267264 " Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> (Tensor, Tensor?)" );
0 commit comments