Skip to content

Commit 6021a43

Browse files
authored
Make moe permute and final as custom op (NVIDIA#5412)
Signed-off-by: Mindy Li <[email protected]>
1 parent 5773cfd commit 6021a43

File tree

11 files changed

+2100
-2
lines changed

11 files changed

+2100
-2
lines changed

cpp/tensorrt_llm/kernels/moeUtilOp.cu

Lines changed: 893 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include "cutlass_kernels/include/moe_kernels.h"
20+
#include "tensorrt_llm/common/cudaUtils.h"
21+
#include <cuda_bf16.h>
22+
#include <cuda_fp16.h>
23+
24+
namespace tensorrt_llm::kernels
25+
{
26+
bool fusedBuildExpertMapsSortFirstToken(int const* token_selected_experts, int* unpermuted_token_selected_experts,
27+
int* permuted_source_token_ids, int64_t* expert_first_token_offset, int64_t const num_tokens,
28+
int const num_experts_per_node, int const experts_per_token, int const start_expert, int const end_expert,
29+
cudaStream_t stream);
30+
31+
void buildExpertMaps(int const* token_selected_experts, int* unpermuted_token_selected_experts,
32+
int* unpermuted_source_token_ids, int64_t const num_tokens, int const num_experts_per_node,
33+
int const experts_per_token, int const start_expert, int const end_expert, cudaStream_t stream);
34+
35+
void generateTokenPermutation(int const* unpermuted_token_selected_experts, int const* unpermuted_source_token_ids,
36+
int* permuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset,
37+
int64_t num_rows, int64_t num_experts_per_node, int64_t k, cutlass_kernels::CubKeyValueSorter& sorter,
38+
void* sorter_ws, cudaStream_t stream);
39+
40+
template <class InputActivationsType, class ExpandedActivationsType>
41+
void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
42+
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
43+
int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row,
44+
int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
45+
int const num_experts_per_node, float const* fc1_act_global_scale, int64_t* expert_first_token_offset,
46+
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
47+
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream);
48+
49+
template <class OutputType, class GemmOutputType, class ScaleBiasType>
50+
void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows,
51+
OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* final_scales,
52+
int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int64_t const num_rows,
53+
int64_t const cols, int64_t const experts_per_token, int64_t const* num_valid_ptr,
54+
cutlass_kernels::MOEParallelismConfig parallelism_config, cudaStream_t stream);
55+
56+
} // namespace tensorrt_llm::kernels

cpp/tensorrt_llm/kernels/quantization.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, int64_t const nu
275275
// FP4 Quantization
276276

277277
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
278-
// constexpr int CVT_FP4_SF_VEC_SIZE = 16;
278+
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
279279
constexpr int CVT_FP4_THREADS_PER_WARP = 32;
280280
constexpr int CVT_FP8_TO_FP4_ELTS_PER_THREAD = 16;
281281

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ add_library(
6565
logitsBitmaskOp.cpp
6666
mambaConv1dOp.cpp
6767
moeOp.cpp
68+
moeUtilOp.cpp
6869
moeCommOp.cpp
6970
moeLoadBalanceOp.cpp
7071
fp8BlockScaleMoe.cpp

cpp/tensorrt_llm/thop/moeUtilOp.cpp

Lines changed: 449 additions & 0 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,79 @@ def _(
397397
pad_slot_id: int,
398398
) -> None:
399399
pass
400+
401+
@torch.library.register_fake("trtllm::moe_permute_op")
402+
def _(
403+
input: torch.Tensor,
404+
token_selected_experts: torch.Tensor,
405+
token_final_scales: torch.Tensor,
406+
fc1_expert_weights: torch.Tensor,
407+
fc2_expert_weights: torch.Tensor,
408+
quant_scales: List[torch.Tensor],
409+
input_sf: Optional[torch.Tensor],
410+
num_experts_per_node: int,
411+
tp_size: int,
412+
tp_rank: int,
413+
ep_size: int,
414+
ep_rank: int,
415+
cluster_size: int,
416+
cluster_rank: int,
417+
min_latency_mode: bool,
418+
use_fp8_block_scaling: bool,
419+
):
420+
421+
experts_per_token = token_selected_experts.shape[1]
422+
num_rows = input.shape[0]
423+
hidden_size = input.shape[1]
424+
425+
num_moe_inputs = experts_per_token * num_rows
426+
427+
unpermuted_token_selected_experts_tensor = token_selected_experts.new_empty(
428+
(num_moe_inputs, ), dtype=torch.int32)
429+
unpermuted_source_token_ids_tensor = token_selected_experts.new_empty(
430+
(num_moe_inputs, ), dtype=torch.int32)
431+
permuted_source_token_ids_tensor = token_selected_experts.new_empty(
432+
(num_moe_inputs, ), dtype=torch.int32)
433+
permuted_token_selected_experts_tensor = token_selected_experts.new_empty(
434+
(num_moe_inputs, ), dtype=torch.int32)
435+
permuted_data_tensor = input.new_empty((num_moe_inputs, hidden_size),
436+
dtype=torch.float32)
437+
expert_first_token_offset_tensor = token_selected_experts.new_empty(
438+
(num_experts_per_node + 1, ), dtype=torch.int64)
439+
permuted_token_final_scales_tensor = token_selected_experts.new_empty(
440+
(num_moe_inputs, ), dtype=torch.float32)
441+
src_to_dest_map_tensor = token_selected_experts.new_empty(
442+
(num_moe_inputs, ), dtype=torch.int32)
443+
444+
return (
445+
unpermuted_token_selected_experts_tensor,
446+
unpermuted_source_token_ids_tensor,
447+
permuted_source_token_ids_tensor,
448+
permuted_token_selected_experts_tensor,
449+
permuted_data_tensor,
450+
expert_first_token_offset_tensor,
451+
permuted_token_final_scales_tensor,
452+
src_to_dest_map_tensor,
453+
)
454+
455+
@torch.library.register_fake("trtllm::moe_finalize_scale_op")
456+
def _(
457+
gemm2_output: torch.Tensor,
458+
fc2_expert_biases: torch.Tensor,
459+
unpermuted_final_scales: torch.Tensor,
460+
expanded_source_row_to_expanded_dest_row: torch.Tensor,
461+
expert_for_source_row: torch.Tensor,
462+
expert_first_token_offset_tensor: torch.Tensor,
463+
num_rows: torch.SymInt,
464+
hidden_size: torch.SymInt,
465+
experts_per_token: int,
466+
num_experts_per_node: int,
467+
tp_size: int,
468+
tp_rank: int,
469+
ep_size: int,
470+
ep_rank: int,
471+
):
472+
num_rows_val = int(num_rows)
473+
hidden_size_val = int(hidden_size)
474+
return gemm2_output.new_empty((num_rows_val, hidden_size_val),
475+
dtype=gemm2_output.dtype)

tensorrt_llm/_torch/modules/fused_moe/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .create_moe import create_moe, get_moe_cls
2+
from .fused_moe_cute_dsl import CuteDslFusedMoE
23
from .fused_moe_cutlass import CutlassFusedMoE
34
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
45
from .fused_moe_vanilla import VanillaMoE
@@ -17,6 +18,7 @@
1718
__all__ = [
1819
"BaseMoeRoutingMethod",
1920
"create_moe",
21+
"CuteDslFusedMoE",
2022
"CutlassFusedMoE",
2123
"DeepSeekV3MoeRoutingMethod",
2224
"DefaultMoeRoutingMethod",

tensorrt_llm/_torch/modules/fused_moe/create_moe.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tensorrt_llm.models.modeling_utils import QuantConfig
77

88
from ...model_config import ModelConfig
9+
from .fused_moe_cute_dsl import CuteDslFusedMoE
910
from .fused_moe_cutlass import CutlassFusedMoE
1011
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
1112
from .fused_moe_vanilla import VanillaMoE
@@ -28,6 +29,8 @@ def get_moe_cls(
2829
return CutlassFusedMoE
2930
elif moe_backend.upper() == "VANILLA":
3031
return VanillaMoE
32+
elif moe_backend.upper() == "CUTEDSL":
33+
return CuteDslFusedMoE
3134
elif moe_backend.upper() == "TRTLLM":
3235
if quant_config is not None and (
3336
quant_config.quant_mode.has_fp8_block_scales()
@@ -122,5 +125,19 @@ def create_moe(
122125
weight_loading_mode=weight_loading_mode,
123126
apply_router_weight_on_input=apply_router_weight_on_input,
124127
)
128+
elif moe_cls == CuteDslFusedMoE:
129+
return moe_cls(
130+
routing_method=routing_method,
131+
num_experts=num_experts,
132+
hidden_size=hidden_size,
133+
intermediate_size=intermediate_size,
134+
dtype=dtype,
135+
reduce_results=reduce_results,
136+
model_config=model_config,
137+
aux_stream=aux_stream,
138+
weight_loading_mode=weight_loading_mode,
139+
apply_router_weight_on_input=apply_router_weight_on_input,
140+
layer_idx=layer_idx,
141+
)
125142
else:
126143
raise ValueError(f"Unsupported moe backend: {moe_cls}")

0 commit comments

Comments
 (0)