Skip to content

Commit 4b70b7a

Browse files
hiworldwzjshihaobaiwangzaijunroot
authored
Qwen3 mtp dense (#1159)
Co-authored-by: shihaobai <[email protected]> Co-authored-by: wangzaijun <[email protected]> Co-authored-by: root <[email protected]>
1 parent b686595 commit 4b70b7a

File tree

132 files changed

+1609
-1501
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

132 files changed

+1609
-1501
lines changed

docs/CN/source/tutorial/api_server_args_zh.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,10 +447,12 @@ MTP 多预测参数
447447

448448
.. option:: --mtp_mode
449449

450-
支持的 mtp 模式,建议使用 deepseekv3_eagle获得更好的性能体验,可选值:
450+
支持的 mtp 模式,建议使用 eagle_with_att获得更好的性能体验,可选值:
451451

452-
* ``deepseekv3_vanilla``
453-
* ``deepseekv3_eagle``
452+
* ``vanilla_with_att``
453+
* ``eagle_with_att``
454+
* ``vanilla_no_att``
455+
* ``eagle_no_att``
454456
* ``None``: 不启用 mtp(默认)
455457

456458
.. option:: --mtp_draft_model_dir

docs/EN/source/tutorial/api_server_args_zh.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,10 +444,12 @@ MTP Multi-Prediction Parameters
444444

445445
.. option:: --mtp_mode
446446

447-
Supported mtp modes, it is recommended to use deepseekv3_eagle for better performance, optional values:
447+
Supported mtp modes, it is recommended to use eagle_with_att for better performance, optional values:
448448

449-
* ``deepseekv3_vanilla``
450-
* ``deepseekv3_eagle``
449+
* ``vanilla_with_att``
450+
* ``eagle_with_att``
451+
* ``vanilla_no_att``
452+
* ``eagle_no_att``
451453
* ``None``: Do not enable mtp (default)
452454

453455
.. option:: --mtp_draft_model_dir

lightllm/common/basemodel/basemodel.py

Lines changed: 68 additions & 68 deletions
Large diffs are not rendered by default.

lightllm/common/basemodel/batch_objs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ class ModelInput:
4646
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
4747
# 的输入变量。只在特殊的模型模式下才会具体使用和生效。
4848

49-
# deepseekv3_mtp_draft_input_hiddens 用于 deepseekv3 模型 mtp 模式下
49+
# mtp_draft_input_hiddens 用于模型 mtp 模式下
5050
# 的 draft 模型的输入
51-
deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
51+
mtp_draft_input_hiddens: Optional[torch.Tensor] = None
5252

5353
def to_cuda(self):
5454
if self.input_ids is not None:
@@ -90,12 +90,12 @@ class ModelOutput:
9090
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
9191
# 的输出变量。只在特殊的模型模式下才会具体使用和生效。
9292

93-
# deepseekv3_mtp_main_output_hiddens 用于在mtp模式下,llm main model
94-
# 输出最后一层的hidden state 状态用于 draft 模型的 deepseekv3_mtp_draft_input_hiddens
93+
# mtp_main_output_hiddens 用于在mtp模式下,llm main model
94+
# 输出最后一层的hidden state 状态用于 draft 模型的 mtp_draft_input_hiddens
9595
# 输入
96-
deepseekv3_mtp_main_output_hiddens: Optional[torch.Tensor] = None
96+
mtp_main_output_hiddens: Optional[torch.Tensor] = None
9797

9898
def to_no_ref_tensor(self):
9999
self.logits = tensor_to_no_ref_tensor(self.logits)
100-
if self.deepseekv3_mtp_main_output_hiddens is not None:
101-
self.deepseekv3_mtp_main_output_hiddens = tensor_to_no_ref_tensor(self.deepseekv3_mtp_main_output_hiddens)
100+
if self.mtp_main_output_hiddens is not None:
101+
self.mtp_main_output_hiddens = tensor_to_no_ref_tensor(self.mtp_main_output_hiddens)

lightllm/common/basemodel/cuda_graph.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ def find_closest_graph_batch_size(self, batch_size):
6262
else:
6363
return None
6464

65-
def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: InferStateInfo):
65+
def _capture_decode(self, decode_func, infer_state: InferStateInfo):
6666
dist_group: CustomProcessGroup = infer_state.dist_group
6767
graph_obj = torch.cuda.CUDAGraph()
68+
input_ids = infer_state.input_ids
6869
batch_size = input_ids.shape[0]
6970
infer_state.max_len_in_batch = self.graph_max_len_in_batch
7071
infer_state.total_token_num = self.graph_max_len_in_batch * batch_size
@@ -78,27 +79,26 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf
7879
# 中的 tensor。
7980
for _ in range(1):
8081
torch.cuda.synchronize()
81-
decode_func(input_ids, copy.copy(infer_state))
82+
decode_func(copy.copy(infer_state))
8283
torch.cuda.synchronize()
8384

8485
with lightllm_capture_graph(dist_group):
8586
with torch.cuda.graph(graph_obj, pool=self.mempool):
86-
model_output = decode_func(input_ids, infer_state)
87-
self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output)
87+
model_output = decode_func(infer_state)
88+
self.graph[batch_size] = (graph_obj, infer_state, model_output)
8889
graph_obj.replay()
8990
return model_output
9091

9192
def _capture_decode_overlap(
9293
self,
9394
decode_func,
94-
input_ids: torch.Tensor,
9595
infer_state: InferStateInfo,
96-
input_ids1: torch.Tensor,
9796
infer_state1: InferStateInfo,
9897
):
9998
dist_group: CustomProcessGroup = infer_state.dist_group
10099
dist_group1 = infer_state1.dist_group
101100
graph_obj = torch.cuda.CUDAGraph()
101+
input_ids = infer_state.input_ids
102102
batch_size = input_ids.shape[0]
103103
infer_state.max_len_in_batch = self.graph_max_len_in_batch
104104
infer_state.total_token_num = self.graph_max_len_in_batch * batch_size
@@ -107,17 +107,15 @@ def _capture_decode_overlap(
107107
# warmup
108108
for _ in range(1):
109109
torch.cuda.synchronize()
110-
decode_func(input_ids, copy.copy(infer_state), input_ids1, copy.copy(infer_state1))
110+
decode_func(copy.copy(infer_state), copy.copy(infer_state1))
111111
torch.cuda.synchronize()
112112
with lightllm_capture_graph(dist_group1):
113113
with lightllm_capture_graph(dist_group):
114114
with torch.cuda.graph(graph_obj, pool=self.mempool):
115-
model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1)
115+
model_output, model_output1 = decode_func(infer_state, infer_state1)
116116
self.graph[batch_size] = (
117117
graph_obj,
118-
input_ids,
119118
infer_state,
120-
input_ids1,
121119
infer_state1,
122120
model_output,
123121
model_output1,
@@ -128,59 +126,50 @@ def _capture_decode_overlap(
128126
def capture_decode(
129127
self,
130128
decode_func,
131-
input_ids: torch.Tensor,
132129
infer_state: InferStateInfo,
133-
input_ids1: Optional[torch.Tensor] = None,
134-
infer_state1: Optional[torch.Tensor] = None,
130+
infer_state1: Optional[InferStateInfo] = None,
135131
):
136132
"""
137133
Capture the cuda graph for the decoding stage.
138134
input_ids1 and infer_state1 is used for the overlap.
139135
"""
140136
if self.enable_decode_microbatch_overlap:
141-
return self._capture_decode_overlap(decode_func, input_ids, infer_state, input_ids1, infer_state1)
137+
return self._capture_decode_overlap(decode_func, infer_state, infer_state1)
142138
else:
143-
assert input_ids1 is None and infer_state1 is None
144-
return self._capture_decode(decode_func, input_ids, infer_state)
139+
assert infer_state1 is None
140+
return self._capture_decode(decode_func, infer_state)
145141

146-
def _replay(self, input_ids: torch.Tensor, infer_state: InferStateInfo):
147-
batch_size = input_ids.shape[0]
148-
graph_obj, graph_input_ids, graph_infer_state, graph_output = self.graph[batch_size]
149-
graph_input_ids.copy_(input_ids)
142+
def _replay(self, infer_state: InferStateInfo):
143+
batch_size = infer_state.input_ids.shape[0]
144+
graph_obj, graph_infer_state, graph_output = self.graph[batch_size]
150145
graph_infer_state.copy_for_cuda_graph(infer_state)
151146
graph_obj.replay()
152147
return graph_output
153148

154149
def _replay_overlap(
155150
self,
156-
input_ids: torch.Tensor,
157151
infer_state: InferStateInfo,
158-
input_ids1: torch.Tensor,
159152
infer_state1: InferStateInfo,
160153
):
161-
batch_size = input_ids.shape[0]
154+
batch_size = infer_state.input_ids.shape[0]
162155
(
163156
graph_obj,
164-
graph_input_ids,
165157
graph_infer_state,
166-
graph_input_ids1,
167158
graph_infer_state1,
168159
graph_model_output,
169160
graph_model_output1,
170161
) = self.graph[batch_size]
171-
graph_input_ids.copy_(input_ids)
172162
graph_infer_state.copy_for_cuda_graph(infer_state)
173-
graph_input_ids1.copy_(input_ids1)
174163
graph_infer_state1.copy_for_cuda_graph(infer_state1)
175164
graph_obj.replay()
176165
return graph_model_output, graph_model_output1
177166

178-
def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None):
167+
def replay(self, infer_state, infer_state1=None):
179168
if self.enable_decode_microbatch_overlap:
180-
return self._replay_overlap(input_ids, infer_state, input_ids1, infer_state1)
169+
return self._replay_overlap(infer_state, infer_state1)
181170
else:
182-
assert input_ids1 is None and infer_state1 is None
183-
return self._replay(input_ids, infer_state)
171+
assert infer_state1 is None
172+
return self._replay(infer_state)
184173

185174
@torch.no_grad()
186175
def warmup(self, model):

lightllm/common/basemodel/infer_struct.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class InferStateInfo:
1919
"""
2020

2121
def __init__(self):
22+
self.input_ids: torch.Tensor = None
2223
self.batch_size: int = None
2324
self.total_token_num: int = None
2425
self.b_req_idx: torch.Tensor = None
@@ -71,10 +72,10 @@ def __init__(self):
7172
# inferstate的基类中,但是为了代码的简洁和方便,都放在基类中
7273
# 进行管理。注意这些成员变量只会在特定的模型和模式下才会生效。
7374

74-
# deepseekv3 mtp draft model 使用的额外输入参数,
75-
# 在开启 mtp_mode == deepseekv3 时,mtp draft model
75+
# mtp draft model 使用的额外输入参数,
76+
# 在开启 mtp_mode 时,mtp draft model
7677
# 的输入会用到,其他模型和场景都不会用到
77-
self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
78+
self.mtp_draft_input_hiddens: Optional[torch.Tensor] = None
7879

7980
# 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象,
8081
# 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的
@@ -88,7 +89,8 @@ def __init__(self):
8889
self.dp_output_split_sizes: List[List[int]] = None
8990
self.dp_input_split_sizes: List[List[int]] = None
9091

91-
def init_some_extra_state(self, model, input_ids: torch.Tensor):
92+
def init_some_extra_state(self, model):
93+
9294
if self.is_prefill:
9395
(
9496
self.b_q_seq_len,
@@ -97,7 +99,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
9799
self.b1_cu_kv_seq_len,
98100
self.position_ids,
99101
) = gen_prefill_params(
100-
input_token_num=input_ids.shape[0],
102+
input_token_num=self.input_ids.shape[0],
101103
b_ready_cache_len=self.b_ready_cache_len,
102104
b_seq_len=self.b_seq_len,
103105
)
@@ -211,6 +213,9 @@ def prefill_dp_balance(self, input_ids: torch.Tensor):
211213

212214
self.position_sin = self._all_to_all_balance_get(self.position_sin)
213215

216+
self._unbalance_input_ids = self.input_ids
217+
self.input_ids = new_input_ids
218+
214219
return new_input_ids
215220

216221
def _all_to_all_balance_get(self, data: torch.Tensor):

lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ class PreLayerInferTpl(PreLayerInfer):
88
def __init__(self, network_config, mode):
99
super().__init__(network_config, mode)
1010
self.eps_ = 1e-5
11-
self.vob_start_id_ = -1
12-
self.vob_end_id_ = -1
1311
return
1412

1513
def _norm(self, input, infer_state, layer_weight) -> torch.Tensor:

lightllm/common/basemodel/layer_weights/meta_weights/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
COLMMWeight,
77
ROWBMMWeight,
88
)
9-
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
9+
from .norm_weight import NoTpGEMMANormWeight, TpVitPadNormWeight, NoTpNormWeight, TpHeadNormWeight
1010
from .fused_moe_weight_tp import create_tp_moe_wegiht_obj
1111
from .fused_moe_weight_ep import FusedMoeWeightEP
12+
from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight
13+
from .att_sink_weight import TpAttSinkWeight
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
from typing import Dict
3+
from .base_weight import BaseWeightTpl
4+
from lightllm.utils.dist_utils import get_current_device_id
5+
6+
7+
class TpAttSinkWeight(BaseWeightTpl):
8+
def __init__(self, weight_name: str, data_type):
9+
super().__init__()
10+
self.weight_name = weight_name
11+
self.data_type_ = data_type
12+
self.weight: torch.Tensor = None
13+
14+
def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
15+
if self.weight_name not in weights or self.weight is not None:
16+
return
17+
18+
t_weight = weights[self.weight_name]
19+
start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_weight)
20+
self.weight = t_weight[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id())
21+
22+
def verify_load(self):
23+
return self.weight is not None

lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from abc import ABC, abstractmethod
3-
from typing import Dict
3+
from typing import Dict, Tuple
44
from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp, get_current_device_id
55

66

@@ -29,3 +29,30 @@ def load_hf_weights(self, weights):
2929

3030
def verify_load(self) -> bool:
3131
raise NotImplementedError("verify_load must implement this method")
32+
33+
def _get_head_tp_split_params(self, weight: torch.Tensor) -> Tuple[int, int]:
34+
"""
35+
Docstring for _get_head_tp_split_params,
36+
一个常用的tp 划分head获取head_index 范围的功能函数, 一些继承类可能会使用。
37+
:param self: Description
38+
:param weight: Description
39+
:type weight: torch.Tensor
40+
:return: Description
41+
:rtype: Tuple[int, int]
42+
"""
43+
assert weight.ndim == 2
44+
45+
all_head_num = weight.shape[0]
46+
tp_head_num = all_head_num // self.tp_world_size_
47+
48+
if tp_head_num > 0:
49+
start_head_index = self.tp_rank_ * tp_head_num
50+
end_head_index = (self.tp_rank_ + 1) * tp_head_num
51+
else:
52+
# 当 tp_world_size 大于 all_head_num 时的特殊处理
53+
scale_size = self.tp_world_size_ // all_head_num
54+
assert self.tp_world_size_ % all_head_num == 0
55+
start_head_index = self.tp_rank_ // scale_size
56+
end_head_index = start_head_index + 1
57+
58+
return start_head_index, end_head_index

0 commit comments

Comments
 (0)