Skip to content

Commit 965cdae

Browse files
committed
overlap mtp
1 parent ac47e1f commit 965cdae

File tree

6 files changed

+319
-36
lines changed

6 files changed

+319
-36
lines changed

lightllm/common/basemodel/batch_objs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@ class ModelInput:
1010
total_token_num: int
1111
max_len_in_batch: int
1212
input_ids: torch.Tensor
13-
mem_indexes: torch.Tensor
1413
b_req_idx: torch.Tensor
1514
b_mtp_index: torch.Tensor
1615
b_seq_len: torch.Tensor
16+
mem_indexes: torch.Tensor = None
1717
is_prefill: bool = False
1818
b_ready_cache_len: torch.Tensor = None
1919
multimodal_params: list = field(default_factory=list)
2020

21+
# cpu 变量
22+
mem_indexes_cpu: torch.Tensor = None
23+
2124
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
2225
# 的输入变量。只在特殊的模型模式下才会具体使用和生效。
2326

@@ -28,7 +31,8 @@ class ModelInput:
2831
def to_cuda(self):
2932
if self.input_ids is not None:
3033
self.input_ids = self.input_ids.cuda(non_blocking=True)
31-
self.mem_indexes = self.mem_indexes.cuda(non_blocking=True)
34+
if self.mem_indexes is None:
35+
self.mem_indexes = self.mem_indexes_cpu.cuda(non_blocking=True)
3236
self.b_req_idx = self.b_req_idx.cuda(non_blocking=True)
3337
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)
3438
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import triton
2+
import triton.language as tl
3+
import torch
4+
5+
6+
@triton.jit
7+
def _fwd_kernel_mtp_verify(
8+
req_to_next_token_ids,
9+
req_to_next_token_ids_stride,
10+
new_next_token_ids,
11+
mtp_accept_len,
12+
b_req_mtp_start_loc,
13+
b_req_idx,
14+
b_mtp_index,
15+
accepted_index,
16+
batch_size: tl.constexpr,
17+
BLOCK_SIZE: tl.constexpr,
18+
):
19+
cur_index = tl.program_id(0)
20+
req_start_loc = tl.load(b_req_mtp_start_loc + cur_index)
21+
cur_req_idx = tl.load(b_req_idx + req_start_loc)
22+
offset = tl.arange(0, BLOCK_SIZE)
23+
req_offset = req_start_loc + offset
24+
cur_mtp_index = tl.load(b_mtp_index + req_offset, mask=req_offset < batch_size)
25+
26+
mask = cur_mtp_index == offset
27+
28+
cur_next_token_id = tl.load(
29+
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + offset + 1, mask=mask, other=-1
30+
)
31+
cur_new_next_token_id = tl.load(new_next_token_ids + req_offset, mask=mask, other=-2)
32+
33+
match_mask = cur_next_token_id == cur_new_next_token_id
34+
35+
first_false = tl.where(~match_mask, offset, BLOCK_SIZE - 1)
36+
accept_len = tl.min(first_false)
37+
tl.store(mtp_accept_len + cur_index, accept_len)
38+
accpeted_index = tl.where((offset < accept_len + 1), 1, 0)
39+
tl.store(accepted_index + req_offset, accpeted_index, mask=mask)
40+
return
41+
42+
43+
def mtp_verify(
44+
req_to_next_token_ids: torch.Tensor,
45+
b_req_mtp_start_loc: torch.Tensor,
46+
new_next_token_ids: torch.Tensor,
47+
b_req_idx: torch.Tensor,
48+
b_mtp_index: torch.Tensor,
49+
):
50+
"""
51+
This function is used to verify the accept_len.
52+
Args:
53+
req_to_next_token_ids: (max_req_num, max_mtp_step)
54+
b_req_mtp_start_loc: (num_reqs,)
55+
new_next_token_ids: (batch_size,)
56+
b_req_idx: (batch_size,)
57+
b_mtp_index: (batch_size,)
58+
Returns:
59+
mtp_accept_len: (num_reqs,)
60+
accepted_index: (batch_size,)
61+
accepted_index: [1, 0, 1, 1, 0], 0 means the token is not accepted, 1 means the token is accepted.
62+
"""
63+
max_mtp_step = req_to_next_token_ids.shape[1]
64+
BLOCK_SIZE = 16
65+
assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than {BLOCK_SIZE}"
66+
num_reqs = b_req_mtp_start_loc.shape[0]
67+
batch_size = b_req_idx.shape[0]
68+
mtp_accept_len = torch.empty((num_reqs,), dtype=torch.int32, device=req_to_next_token_ids.device)
69+
accepted_index = torch.empty((batch_size,), dtype=torch.int32, device=req_to_next_token_ids.device)
70+
71+
grid = (num_reqs,)
72+
num_warps = 1
73+
_fwd_kernel_mtp_verify[grid](
74+
req_to_next_token_ids,
75+
req_to_next_token_ids.stride(0),
76+
new_next_token_ids,
77+
mtp_accept_len,
78+
b_req_mtp_start_loc,
79+
b_req_idx,
80+
b_mtp_index,
81+
accepted_index,
82+
batch_size,
83+
BLOCK_SIZE,
84+
num_warps=num_warps,
85+
num_stages=1,
86+
)
87+
return mtp_accept_len, accepted_index
88+
89+
90+
@triton.jit
91+
def _fwd_kernel_mtp_scatter_next_token_ids(
92+
req_to_next_token_ids,
93+
req_to_next_token_ids_stride,
94+
all_next_token_ids,
95+
all_next_token_ids_stride,
96+
mtp_accept_len,
97+
b_req_mtp_start_loc,
98+
b_req_idx,
99+
b_mtp_index,
100+
mtp_step: tl.constexpr,
101+
batch_size: tl.constexpr,
102+
BLOCK_SIZE: tl.constexpr,
103+
):
104+
105+
cur_index = tl.program_id(0)
106+
req_start_loc = tl.load(b_req_mtp_start_loc + cur_index)
107+
accept_len = tl.load(mtp_accept_len + cur_index)
108+
cur_req_idx = tl.load(b_req_idx + req_start_loc)
109+
offset = tl.arange(0, BLOCK_SIZE)
110+
req_offset = req_start_loc + offset
111+
cur_mtp_index = tl.load(b_mtp_index + req_offset, mask=req_offset < batch_size)
112+
113+
mask = cur_mtp_index == offset
114+
scatter_next_token_ids = tl.load(
115+
all_next_token_ids + (req_start_loc + accept_len) * all_next_token_ids_stride + offset,
116+
mask=offset < mtp_step,
117+
other=0,
118+
)
119+
scatter_next_token_ids = tl.where(mask, scatter_next_token_ids, -1)
120+
tl.store(
121+
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + offset,
122+
scatter_next_token_ids,
123+
mask=offset < mtp_step,
124+
)
125+
return
126+
127+
128+
def mtp_scatter_next_token_ids(
129+
req_to_next_token_ids: torch.Tensor,
130+
b_req_mtp_start_loc: torch.Tensor,
131+
all_next_token_ids: torch.Tensor,
132+
b_req_idx: torch.Tensor,
133+
b_mtp_index: torch.Tensor,
134+
mtp_accept_len: torch.Tensor,
135+
):
136+
max_mtp_step = req_to_next_token_ids.shape[1]
137+
BLOCK_SIZE = 16
138+
assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than {BLOCK_SIZE}"
139+
num_reqs = b_req_mtp_start_loc.shape[0]
140+
batch_size = b_req_idx.shape[0]
141+
mtp_step = all_next_token_ids.shape[1]
142+
grid = (num_reqs,)
143+
num_warps = 1
144+
_fwd_kernel_mtp_scatter_next_token_ids[grid](
145+
req_to_next_token_ids,
146+
req_to_next_token_ids.stride(0),
147+
all_next_token_ids,
148+
all_next_token_ids.stride(0),
149+
mtp_accept_len,
150+
b_req_mtp_start_loc,
151+
b_req_idx,
152+
b_mtp_index,
153+
mtp_step,
154+
batch_size,
155+
BLOCK_SIZE,
156+
num_warps=num_warps,
157+
num_stages=1,
158+
)
159+
160+
161+
@triton.jit
162+
def _fwd_kernel_gen_b_req_mtp_start_loc(
163+
b_mtp_index,
164+
b_req_mtp_start_loc,
165+
num_reqs: tl.constexpr,
166+
batch_size: tl.constexpr,
167+
BLOCK_SIZE: tl.constexpr,
168+
):
169+
offset = tl.arange(0, BLOCK_SIZE)
170+
cur_mtp_index = tl.load(b_mtp_index + offset, mask=offset < batch_size, other=-1)
171+
non_zero_mask = tl.where(cur_mtp_index == 0, 1, 0) # 1 0 1 0 0
172+
output_offset = tl.cumsum(non_zero_mask) - 1
173+
tl.store(b_req_mtp_start_loc + output_offset, offset, mask=non_zero_mask == 1)
174+
return
175+
176+
177+
def gen_b_req_mtp_start_loc(b_mtp_index: torch.Tensor, num_reqs: int):
178+
b_req_mtp_start_loc = torch.empty((num_reqs,), dtype=torch.int32, device=b_mtp_index.device)
179+
BLOCK_SIZE = triton.next_power_of_2(b_mtp_index.shape[0])
180+
batch_size = b_mtp_index.shape[0]
181+
grid = (1,)
182+
_fwd_kernel_gen_b_req_mtp_start_loc[grid](
183+
b_mtp_index=b_mtp_index,
184+
b_req_mtp_start_loc=b_req_mtp_start_loc,
185+
num_reqs=num_reqs,
186+
batch_size=batch_size,
187+
BLOCK_SIZE=BLOCK_SIZE,
188+
num_warps=8,
189+
)
190+
return b_req_mtp_start_loc
191+
192+
193+
def test_mtp_verify():
194+
req_to_next_token_ids = torch.tensor(
195+
[[1, 2, -2, -1, -1], [1, 2, 0, -1, -1], [1, 3, 4, 4, 5]], dtype=torch.int32, device="cuda"
196+
)
197+
b_req_idx = torch.tensor([0, 0, 2, 2, 2], dtype=torch.int32, device="cuda")
198+
b_mtp_index = torch.tensor([0, 1, 0, 1, 2], dtype=torch.int32, device="cuda")
199+
b_req_mtp_start_loc = torch.tensor([0, 2], dtype=torch.int32, device="cuda")
200+
new_next_token_ids = torch.tensor([1, 4, 3, 4, 13], dtype=torch.int32, device="cuda")
201+
all_next_token_ids = torch.tensor(
202+
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]], dtype=torch.int32, device="cuda"
203+
)
204+
mtp_accept_len, accepted_index = mtp_verify(
205+
req_to_next_token_ids, b_req_mtp_start_loc, new_next_token_ids, b_req_idx, b_mtp_index
206+
)
207+
mtp_scatter_next_token_ids(
208+
req_to_next_token_ids, b_req_mtp_start_loc, all_next_token_ids, b_req_idx, b_mtp_index, mtp_accept_len
209+
)
210+
print(mtp_accept_len)
211+
print(req_to_next_token_ids)
212+
print(accepted_index)
213+
214+
215+
def test_gen_b_req_mtp_start_loc():
216+
b_mtp_index = torch.tensor([0, 1, 0, 1, 2], dtype=torch.int32, device="cuda")
217+
gt_output = torch.where(b_mtp_index == 0)[0]
218+
b_req_mtp_start_loc = gen_b_req_mtp_start_loc(b_mtp_index, 2)
219+
print(b_req_mtp_start_loc, gt_output)
220+
221+
222+
if __name__ == "__main__":
223+
# test_mtp_verify()
224+
test_gen_b_req_mtp_start_loc()

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -298,11 +298,9 @@ def __init__(
298298
self.need_out_token_id_statistics = True
299299
self.out_token_id_count: Dict[int, int] = None
300300

301-
# mtp_gen_token_ids 用于处理一个请求可以通过mtp进行很多token的预先生成
302-
# 的技术,在没有开启 mtp 功能的时候,这个成员变量不会有任何的实际实用意义。
303-
# 当开启后,mtp_gen_token_ids 保存多生成的多余的token_id,但是在后面的
304-
# 步骤中需要重新进行校验。
305-
self.mtp_gen_token_ids: List[int] = []
301+
# mtp_step 用来记录一个请求 draft模型每步需要生成的token数量
302+
# 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量
303+
self.mtp_step: int = get_env_start_args().mtp_step
306304

307305
self._init_all_state()
308306
if init_prefix_cache:
@@ -417,7 +415,7 @@ def prefill_need_token_num(self, is_chuncked_prefill: bool):
417415
return input_token_len
418416

419417
def decode_need_token_num(self):
420-
return 1 + len(self.mtp_gen_token_ids)
418+
return 1 + self.mtp_step
421419

422420

423421
class InferReqGroup:

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from lightllm.server.router.token_load import TokenLoad
1515
from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock
1616
from lightllm.common.basemodel.basemodel import TpPartBaseModel
17-
from lightllm.common.basemodel.batch_objs import ModelOutput
17+
from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput
18+
from lightllm.common.basemodel.triton_kernel.mtp_verify import mtp_verify
1819
from lightllm.utils.dist_utils import init_distributed_env
1920
from lightllm.utils.envs_utils import get_unique_server_name
2021
from lightllm.server.core.objs import ShmReqManager, StartArgs
@@ -253,7 +254,7 @@ def init_mtp_draft_model(self, main_kvargs: dict):
253254

254255
def _save_next_token_ids_and_logprobs(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor):
255256
"""
256-
这个函数会把next token id和logprobs保存到pinned memory中,并返回一个同步事件。
257+
这个函数会把next token id和logprobs保存到pinned memory中
257258
这样可以保障post_handle 函数可以读取到正常的输出结果。
258259
"""
259260
next_token_ids_cpu = g_pin_mem_manager.alloc_pin_tensor(
@@ -521,6 +522,38 @@ def _filter_reqs(self, reqs: List[InferReq]):
521522
def _trans_req_ids_to_req_objs(self, req_ids: List[int]) -> List[InferReq]:
522523
return [g_infer_context.requests_mapping[req_id] for req_id in req_ids]
523524

525+
def _verify_mtp_v2(
526+
self, new_next_token_ids: torch.Tensor, model_input: ModelInput, b_req_mtp_start_loc: torch.Tensor
527+
):
528+
b_mtp_index = model_input.b_mtp_index
529+
mtp_accept_len, accepted_index = mtp_verify(
530+
req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
531+
b_req_mtp_start_loc=b_req_mtp_start_loc,
532+
new_next_token_ids=new_next_token_ids,
533+
b_req_idx=model_input.b_req_idx,
534+
b_mtp_index=b_mtp_index,
535+
)
536+
return mtp_accept_len, accepted_index
537+
538+
def _get_need_free_mem_indexes(
539+
self,
540+
run_reqs: List[InferReq],
541+
accepted_index_cpu: torch.Tensor,
542+
mtp_accept_len_cpu: torch.Tensor,
543+
mem_indexes_cpu: torch.Tensor,
544+
) -> Tuple[List[InferReq], torch.Tensor]:
545+
need_free_mem_indexes = []
546+
start_idx = 0
547+
for i in range(mtp_accept_len_cpu.shape[0]):
548+
req = run_reqs[start_idx]
549+
accept_len = mtp_accept_len_cpu[i]
550+
end_idx = start_idx + req.mtp_step + 1
551+
need_free_mem_indexes.extend(mem_indexes_cpu[start_idx + accept_len + 1 : end_idx])
552+
start_idx = end_idx
553+
if self.is_master_in_dp:
554+
req.update_mtp_accepted_token_num(accept_token_num=accept_len)
555+
return need_free_mem_indexes
556+
524557
# 对mtp 运行模式下的请求进行校验和过滤,保留校验成功的请求对象,并释放不再使用的kv 的 mem_index
525558
def _verify_mtp(self, run_reqs: List[InferReq], next_token_ids_cpu: np.ndarray, input_mem_indexes_cpu: np.ndarray):
526559
verify_ok_reqs = []

0 commit comments

Comments
 (0)