Skip to content

Commit a9f456e

Browse files
committed
feat - support normal no quant multi-dp moe strategy
1 parent fc58ee8 commit a9f456e

File tree

7 files changed

+192
-25
lines changed

7 files changed

+192
-25
lines changed

rtp_llm/models_py/modules/factory/attention/cuda_impl/flash_infer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
class FlashInferPrefillImpl(FMHAPrefillImplBase):
1919

2020
def __init__(
21-
self,
22-
attn_configs: AttentionConfigs,
23-
attn_inputs: PyAttentionInputs
21+
self, attn_configs: AttentionConfigs, attn_inputs: PyAttentionInputs
2422
) -> None:
2523
super().__init__(
2624
FlashInferPrefillOp(attn_configs),
@@ -40,16 +38,14 @@ def support_cuda_graph(self) -> bool:
4038
class FlashInferDecodeImpl(FMHADecodeImplBase):
4139

4240
def __init__(
43-
self,
44-
attn_configs: AttentionConfigs,
45-
attn_inputs: PyAttentionInputs
41+
self, attn_configs: AttentionConfigs, attn_inputs: PyAttentionInputs
4642
) -> None:
4743
super().__init__(
4844
FlashInferDecodeOp(attn_configs),
4945
FusedRopeKVCacheDecodeOp(attn_configs),
5046
attn_inputs,
5147
)
52-
self.seq_size_per_block = config.seq_size_per_block
48+
self.seq_size_per_block = attn_configs.tokens_per_block
5349
self.support_ = self.support_ and (not attn_configs.use_mla)
5450

5551
@staticmethod

rtp_llm/models_py/modules/factory/fused_moe/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
CudaFp8PerTensorEpLowLatencyStrategy,
6262
CudaFp8PerTensorEpNormalStrategy,
6363
CudaFp8PerTensorNoDPStrategy,
64+
CudaNoQuantDpNormalStrategy,
6465
CudaNoQuantEpLowLatencyStrategy,
6566
)
6667

@@ -72,5 +73,6 @@
7273
registry.register(CudaFp8PerBlockNoDPStrategy())
7374
registry.register(CudaFp8PerTensorNoDPStrategy())
7475
registry.register(CudaNoQuantEpLowLatencyStrategy())
76+
registry.register(CudaNoQuantDpNormalStrategy())
7577
registry.register(BatchedTritonStrategy())
7678
FusedMoeFactory.set_registry(registry)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Adapt from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/ep_moe/kernels.py
2+
# but make some modifications for RTP-LLM
3+
# Licensed under the Apache License, Version 2.0
4+
from typing import Any, Dict, Optional
5+
6+
import torch
7+
8+
import rtp_llm.models_py.modules.factory.fused_moe.defs.fused_moe as mm
9+
from rtp_llm.models_py.modules.factory.fused_moe.defs.config_adapter import (
10+
MoEConfigAdapter,
11+
)
12+
from rtp_llm.models_py.modules.factory.fused_moe.defs.quant_config import (
13+
FusedMoEQuantConfig,
14+
)
15+
from rtp_llm.models_py.modules.factory.fused_moe.defs.type import ExecutorType
16+
from rtp_llm.models_py.modules.factory.fused_moe.utils.config_resolver import (
17+
MoeConfigResolver,
18+
)
19+
from rtp_llm.ops.compute_ops import FusedMoEOp
20+
from rtp_llm.utils.model_weight import W
21+
22+
23+
class CppMoeExecutor(mm.FusedMoeExpertExecutor):
24+
@classmethod
25+
def executor_type(cls):
26+
return ExecutorType.FUSED_MOE
27+
28+
@classmethod
29+
def check_conditions(cls, checker: Any, config: MoEConfigAdapter) -> None:
30+
resolver = MoeConfigResolver()
31+
checker.check(not resolver.has_quantization(config))
32+
33+
def __init__(
34+
self,
35+
config: MoEConfigAdapter,
36+
weights: Dict[str, torch.Tensor],
37+
):
38+
super().__init__(FusedMoEQuantConfig())
39+
self.config = config
40+
self.ep_size = config.ep_size
41+
self.ep_rank = config.ep_rank
42+
self.num_experts = config.expert_num
43+
assert self.num_experts % self.ep_size == 0
44+
self.num_experts_per_partition = self.num_experts // self.ep_size
45+
self.start_expert_id = self.ep_rank * self.num_experts_per_partition
46+
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
47+
self.top_k = config.moe_k
48+
self.intermediate_size = config.model_config.moe_inter_size
49+
self.activation = config.activation_type
50+
self.renormalize = True
51+
self.use_fp8_w8a8 = True
52+
self.use_block_quant = True
53+
# 权重初始化
54+
self.w13_weight = weights[W.moe_w1]
55+
self.w2_weight = weights[W.moe_w2]
56+
self.moe_op = FusedMoEOp(config.model_config, config.parallelism_config)
57+
58+
@property
59+
def topk_ids_dtype(self) -> torch.dtype:
60+
return torch.int32
61+
62+
@property
63+
def local_num_experts(self) -> int:
64+
return self.num_experts_per_partition
65+
66+
def execute(
67+
self,
68+
payload: mm.ExpertForwardPayload,
69+
activation: str,
70+
expert_map: Optional[torch.Tensor],
71+
a2_scale: Optional[torch.Tensor],
72+
apply_router_weight_on_input: bool,
73+
extra_expert_args: Optional[dict[str, Any]],
74+
) -> torch.Tensor:
75+
output = torch.zeros_like(payload.expert_x)
76+
assert payload.expert_topk_weights is not None, "expert_topk_weights is None"
77+
assert payload.expert_topk_ids is not None, "expert_topk_ids is None"
78+
payload.expert_topk_ids = payload.expert_topk_ids.to(torch.int32)
79+
self.moe_op.forward(
80+
payload.expert_x,
81+
self.w13_weight,
82+
self.w2_weight,
83+
payload.expert_topk_weights,
84+
payload.expert_topk_ids,
85+
output,
86+
)
87+
return output

rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/routers/deepep_normal_router.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
scaled_fp8_per_token_quant,
1010
sgl_per_token_group_quant_fp8,
1111
)
12+
from rtp_llm.models_py.modules.factory.fused_moe.defs.config_adapter import (
13+
MoEConfigAdapter,
14+
)
1215
from rtp_llm.models_py.modules.factory.fused_moe.defs.fused_moe import (
1316
ExpertForwardPayload,
1417
ExpertTokensMetadata,
@@ -19,7 +22,7 @@
1922
)
2023
from rtp_llm.models_py.modules.factory.fused_moe.defs.type import RouterType
2124
from rtp_llm.ops.compute_ops import trt_fp8_quantize_128
22-
from rtp_llm.models_py.modules.factory.fused_moe.defs.config_adapter import MoEConfigAdapter
25+
2326

2427
class DeepepNormalRouter(FusedMoeDataRouter):
2528
@classmethod
@@ -157,12 +160,21 @@ def prepare(
157160
num_recv_tokens_per_expert_list, device=expert_x.device, dtype=torch.int32
158161
)
159162

163+
if recv_topk_idx.numel() != 0 and (not self.use_fp8):
164+
expert_topk_ids = torch.where(
165+
recv_topk_idx == -1,
166+
self.expert_num - 1 if self.rank_expert_offset == 0 else 0,
167+
recv_topk_idx + self.rank_expert_offset,
168+
)
169+
else:
170+
expert_topk_ids = recv_topk_idx
171+
160172
return ExpertForwardPayload(
161173
expert_x,
162174
act_dtype,
163175
expert_x_scale,
164176
ExpertTokensMetadata(expert_num_tokens, num_recv_tokens_per_expert_list),
165-
recv_topk_idx,
177+
expert_topk_ids,
166178
recv_topk_weights,
167179
)
168180

rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
CudaFp8PerTensorEpNormalStrategy,
1111
CudaFp8PerTensorNoDPStrategy,
1212
)
13-
from .no_quant import CudaNoQuantEpLowLatencyStrategy
13+
from .no_quant import CudaNoQuantDpNormalStrategy, CudaNoQuantEpLowLatencyStrategy
1414

1515
__all__ = [
1616
# No quantization
1717
"CudaNoQuantEpLowLatencyStrategy",
18+
"CudaNoQuantDpNormalStrategy",
1819
# FP8 PerBlock
1920
"CudaFp8PerBlockNoDPStrategy",
2021
"CudaFp8PerBlockEpLowLatencyStrategy",

rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/no_quant.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
import torch
66

7-
from rtp_llm.models_py.modules.factory.fused_moe.defs.config_adapter import MoEConfigAdapter
7+
from rtp_llm.models_py.modules.factory.fused_moe.defs.config_adapter import (
8+
MoEConfigAdapter,
9+
)
810
from rtp_llm.models_py.modules.factory.fused_moe.defs.priority_attributes import (
911
StrategyAttributes,
1012
)
@@ -67,3 +69,39 @@ def get_attributes(self) -> StrategyAttributes:
6769
router_class=DeepEpLowLatencyRouter,
6870
executor_class=DeepGemmMaskedExecutor,
6971
)
72+
73+
74+
class CudaNoQuantDpNormalStrategy(MoeStrategy):
75+
"""CUDA CPP mode without quantization strategy and dp normal mode"""
76+
77+
def create_router(self, config: MoEConfigAdapter) -> Any:
78+
from rtp_llm.models_py.modules.factory.fused_moe.impl.cuda.routers.deepep_normal_router import (
79+
DeepepNormalRouter,
80+
)
81+
82+
return DeepepNormalRouter(config, use_fp8=False)
83+
84+
def create_executor(
85+
self, config: MoEConfigAdapter, weights: Dict[str, torch.Tensor]
86+
) -> Any:
87+
from rtp_llm.models_py.modules.factory.fused_moe.impl.cuda.executors.f16_cpp_executor import (
88+
CppMoeExecutor,
89+
)
90+
91+
return CppMoeExecutor(
92+
config,
93+
weights,
94+
)
95+
96+
def get_attributes(self) -> StrategyAttributes:
97+
from rtp_llm.models_py.modules.factory.fused_moe.impl.cuda.executors.f16_cpp_executor import (
98+
CppMoeExecutor,
99+
)
100+
from rtp_llm.models_py.modules.factory.fused_moe.impl.cuda.routers.deepep_normal_router import (
101+
DeepepNormalRouter,
102+
)
103+
104+
return StrategyAttributes(
105+
router_class=DeepepNormalRouter,
106+
executor_class=CppMoeExecutor,
107+
)

rtp_llm/models_py/modules/hybrid/test/mla_reuse_cache_test.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
from rtp_llm.models.rotary_embedding.deepseek_rotary_embedding import (
1818
DeepseekV3YarnRotaryEmbedding,
1919
)
20-
from rtp_llm.ops import ParallelismConfig
2120
from rtp_llm.models_py.modules import LinearFactory
2221
from rtp_llm.models_py.modules.factory.attention.cuda_mla_impl.flashinfer_mla_wrapper import (
2322
MlaFlashInferPrefillImpl,
2423
)
2524
from rtp_llm.models_py.modules.hybrid.test.mla_attention_ref import attention_ref
25+
from rtp_llm.ops import ParallelismConfig
2626
from rtp_llm.ops.compute_ops import KVCache, PyAttentionInputs
2727
from rtp_llm.utils.model_weight import W
2828

@@ -115,8 +115,11 @@ def _run_mla_test(
115115
self.config.attn_config.softmax_extra_scale = 1.0
116116
self.config.attn_config.use_mla = True
117117
self.config.attn_config.size_per_head = 192
118-
self.scaling = (self.config.attn_config.nope_head_dim + self.config.attn_config.rope_head_dim) ** (-0.5)
119-
118+
self.scaling = (
119+
self.config.attn_config.nope_head_dim
120+
+ self.config.attn_config.rope_head_dim
121+
) ** (-0.5)
122+
120123
self.parallelism_config = ParallelismConfig()
121124
self.parallelism_config.tp_size = 1
122125
self.parallelism_config.tp_rank = 0
@@ -146,15 +149,20 @@ def _run_mla_test(
146149
cos_sin_cache = create_cos_sin_cache()
147150

148151
fmha_impl = MlaFlashInferPrefillImpl(
149-
self.config.attn_config, attn_inputs, layer_weights, cos_sin_cache, quant_config=self.config.quant_config
152+
self.config.attn_config,
153+
attn_inputs,
154+
layer_weights,
155+
cos_sin_cache,
156+
quant_config=self.config.quant_config,
150157
)
151158
fmha_impl.prepare(attn_inputs)
152159

153160
q = torch.randn(
154161
[
155162
num_tokens,
156163
self.config.attn_config.head_num,
157-
self.config.attn_config.nope_head_dim + self.config.attn_config.rope_head_dim,
164+
self.config.attn_config.nope_head_dim
165+
+ self.config.attn_config.rope_head_dim,
158166
],
159167
dtype=torch.bfloat16,
160168
device=device,
@@ -176,7 +184,8 @@ def _run_mla_test(
176184
[
177185
mock_page_num,
178186
page_size,
179-
self.config.attn_config.kv_lora_rank + self.config.attn_config.rope_head_dim,
187+
self.config.attn_config.kv_lora_rank
188+
+ self.config.attn_config.rope_head_dim,
180189
],
181190
dtype=torch.bfloat16,
182191
device=device,
@@ -187,7 +196,10 @@ def _run_mla_test(
187196

188197
k_cache, v_cache = torch.split(
189198
kv_cache.k_cache_base,
190-
[self.config.attn_config.kv_lora_rank, self.config.attn_config.rope_head_dim],
199+
[
200+
self.config.attn_config.kv_lora_rank,
201+
self.config.attn_config.rope_head_dim,
202+
],
191203
dim=-1,
192204
)
193205
page.append_paged_mla_kv_cache(
@@ -197,7 +209,7 @@ def _run_mla_test(
197209
fmha_impl.rope_params.positions_d,
198210
k_cache,
199211
v_cache,
200-
fmha_impl.rope_kvcache_impl.cuda_graph_kv_indices,
212+
fmha_impl.rope_params.page_indice_d,
201213
fmha_impl.rope_params.decode_page_indptr_d,
202214
fmha_impl.rope_params.paged_kv_last_page_len_d,
203215
)
@@ -228,15 +240,18 @@ def _run_mla_test(
228240
k_nope = self.k_nope_proj(compressed_kv)
229241
value_states = self.v_proj(compressed_kv)
230242

231-
k_nope = k_nope.view(-1, self.config.attn_config.head_num, self.config.attn_config.nope_head_dim)
243+
k_nope = k_nope.view(
244+
-1, self.config.attn_config.head_num, self.config.attn_config.nope_head_dim
245+
)
232246
value_states = value_states.view(
233247
-1, self.config.attn_config.head_num, self.config.attn_config.v_head_dim
234248
)
235249

236250
k = k_pe.new_empty(
237251
k_pe.size(0),
238252
self.config.attn_config.head_num,
239-
self.config.attn_config.rope_head_dim + self.config.attn_config.nope_head_dim,
253+
self.config.attn_config.rope_head_dim
254+
+ self.config.attn_config.nope_head_dim,
240255
)
241256
k[..., : self.config.attn_config.nope_head_dim] = k_nope
242257
k[..., self.config.attn_config.nope_head_dim :] = k_pe
@@ -285,13 +300,21 @@ def _create_weights(self, config, hidden_size):
285300
)
286301

287302
weights[W.mla_kc] = torch.randn(
288-
[config.attn_config.head_num, config.attn_config.nope_head_dim, config.attn_config.kv_lora_rank],
303+
[
304+
config.attn_config.head_num,
305+
config.attn_config.nope_head_dim,
306+
config.attn_config.kv_lora_rank,
307+
],
289308
dtype=torch.bfloat16,
290309
device=device,
291310
)
292311

293312
weights[W.mla_vc] = torch.randn(
294-
[config.attn_config.head_num, config.attn_config.kv_lora_rank, config.attn_config.v_head_dim],
313+
[
314+
config.attn_config.head_num,
315+
config.attn_config.kv_lora_rank,
316+
config.attn_config.v_head_dim,
317+
],
295318
dtype=torch.bfloat16,
296319
device=device,
297320
)
@@ -310,13 +333,21 @@ def _create_weights(self, config, hidden_size):
310333

311334
weights[W.mla_kc] = (
312335
weights[W.mla_k_nope_w]
313-
.view(config.attn_config.kv_lora_rank, config.attn_config.head_num, config.attn_config.nope_head_dim)
336+
.view(
337+
config.attn_config.kv_lora_rank,
338+
config.attn_config.head_num,
339+
config.attn_config.nope_head_dim,
340+
)
314341
.transpose(0, 1)
315342
.transpose(1, 2)
316343
)
317344
weights[W.mla_vc] = (
318345
weights[W.mla_v_w]
319-
.view(config.attn_config.kv_lora_rank, config.attn_config.head_num, config.attn_config.v_head_dim)
346+
.view(
347+
config.attn_config.kv_lora_rank,
348+
config.attn_config.head_num,
349+
config.attn_config.v_head_dim,
350+
)
320351
.transpose(0, 1)
321352
)
322353

0 commit comments

Comments
 (0)