Skip to content

Commit c8fa244

Browse files
committed
mtp
1 parent 2e559dc commit c8fa244

File tree

19 files changed

+715
-19
lines changed

19 files changed

+715
-19
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from lightllm.utils.envs_utils import get_env_start_args
2121
from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager
2222
from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch, PrefillMicroBatch
23+
from lightllm.common.spec_info import SpeculativeDecodeAlgorithm
24+
2325

2426
logger = init_logger(__name__)
2527

@@ -71,6 +73,10 @@ def __init__(self, kvargs):
7173
self.tp_world_size_ = get_dp_world_size()
7274
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode
7375

76+
# Speculative decoding
77+
self.spec_algo = SpeculativeDecodeAlgorithm.from_string(kvargs.get("spec_algo", "NONE"))
78+
self.spec_info = None
79+
7480
self._init_datatype()
7581
self._init_config()
7682
self._verify_must()
@@ -279,6 +285,8 @@ def _prefill(
279285
):
280286
infer_state = self.infer_state_class()
281287
infer_state.is_prefill = True
288+
infer_state.spec_algo = self.spec_algo
289+
infer_state.spec_info = self.spec_info
282290
infer_state.is_token_healing = self.is_token_healing
283291
infer_state.return_all_prompt_logics = self.return_all_prompt_logics
284292
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
@@ -330,6 +338,8 @@ def _decode(
330338
):
331339
infer_state = self.infer_state_class()
332340
infer_state.is_prefill = False
341+
infer_state.spec_algo = self.spec_algo
342+
infer_state.spec_info = self.spec_info
333343
infer_state.batch_size = batch_size
334344
infer_state.total_token_num = total_token_num
335345
infer_state.max_len_in_batch = max_len_in_batch
@@ -343,12 +353,13 @@ def _decode(
343353
infer_state.req_manager = self.req_manager
344354

345355
infer_state.mem_index = mem_indexes
356+
decode_len = self.spec_algo.decode_len()
346357
infer_state.kv_buffer_shapedtype = (
347-
(batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
358+
(batch_size * decode_len, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
348359
self.data_type,
349360
)
350361
infer_state.dist_group = dist_group_manager.get_default_group()
351-
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
362+
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index, decode_len)
352363

353364
infer_state.init_some_extra_state(self, input_ids)
354365
if self.graph is not None and self.graph.can_run(batch_size, max_len_in_batch):
@@ -498,6 +509,9 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo):
498509
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
499510
input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i])
500511

512+
if self.spec_algo.is_mtp():
513+
self.spec_info = input_embs
514+
501515
post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
502516
predict_logics = post_method(input_embs, infer_state, self.pre_post_weight)
503517

@@ -519,7 +533,10 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo):
519533
layer = self.layers_infer[i]
520534
layer_method = (layer.token_forward, layer.tpsp_token_forward)[run_mode_index]
521535
input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i])
522-
536+
537+
if self.spec_algo.is_mtp():
538+
self.spec_info = input_embs
539+
523540
post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
524541
predict_logics = post_method(input_embs, infer_state, self.pre_post_weight)
525542

lightllm/common/basemodel/cuda_graph.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from lightllm.utils.envs_utils import get_env_start_args
66
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
77
from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch
8+
from lightllm.common.spec_info import SpeculativeDecodeAlgorithm
9+
from lightllm.common.basemodel.basemodel import TpPartBaseModel
810

911
logger = init_logger(__name__)
1012

@@ -126,7 +128,7 @@ def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None):
126128
return self._replay(input_ids, infer_state)
127129

128130
@torch.no_grad()
129-
def warmup(self, model):
131+
def warmup(self, model: TpPartBaseModel):
130132
logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.")
131133
for batch_size in range(self.max_batch_size, 0, -1):
132134
# dummy prefill
@@ -160,6 +162,9 @@ def warmup(self, model):
160162
torch.cuda.empty_cache()
161163

162164
# dummy decoding, capture the cudagraph
165+
decode_len = model.spec_algo.decode_len()
166+
predict_ids = predict_ids.repeat(decode_len)
167+
b_start_loc = b_start_loc + torch.arange(0, batch_size*decode_len, decode_len, dtype=torch.int32, device="cuda")
163168
total_token_num += batch_size
164169
b_seq_len += 1
165170
mem_indexes = model.mem_manager.alloc(len(predict_ids)).cuda()

lightllm/common/basemodel/infer_struct.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Tuple, Any
66
from .triton_kernel.gen_prefill_params import gen_prefill_params
77
from .triton_kernel.gen_decode_params import gen_decode_params
8+
from lightllm.common.spec_info import SpeculativeDecodeAlgorithm
89

910

1011
class InferStateInfo:
@@ -53,6 +54,10 @@ def __init__(self):
5354
self.position_ids: torch.Tensor = None
5455
self.max_q_seq_len: int = None
5556
self.max_kv_seq_len: int = None
57+
58+
# Speculative decoding
59+
self.spec_algo = SpeculativeDecodeAlgorithm.NONE
60+
self.spec_info = None
5661

5762
def init_some_extra_state(self, model, input_ids: torch.Tensor):
5863
if self.is_prefill:

lightllm/common/basemodel/triton_kernel/copy_kv_index_to_req.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import triton
44
import triton.language as tl
5-
5+
import copy
66

77
@triton.jit
88
def _fwd_kernel_copy_kv_index_to_req(
@@ -19,16 +19,40 @@ def _fwd_kernel_copy_kv_index_to_req(
1919

2020

2121
@torch.no_grad()
22-
def copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_seq_len, memindex):
23-
seq_len = b_seq_len.shape[0]
24-
assert b_seq_len.shape[0] == memindex.shape[0] and b_req_idx.shape[0] == b_seq_len.shape[0]
25-
grid = (seq_len,)
22+
def copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_seq_len, memindex, decode_len=1):
23+
batch_size = b_seq_len.shape[0]
24+
assert b_seq_len.shape[0] * decode_len == memindex.shape[0] and b_req_idx.shape[0] == b_seq_len.shape[0]
25+
grid = (batch_size, )
2626
num_warps = 1
27-
28-
_fwd_kernel_copy_kv_index_to_req[grid](
29-
req_to_token_indexs, b_req_idx, b_seq_len, memindex,
30-
req_to_token_indexs.stride(0), req_to_token_indexs.stride(1),
31-
num_warps=num_warps,
32-
num_stages=1,
33-
)
27+
_b_seq_len = copy.deepcopy(b_seq_len)
28+
for i in range(decode_len):
29+
_fwd_kernel_copy_kv_index_to_req[grid](
30+
req_to_token_indexs, b_req_idx, _b_seq_len, memindex[batch_size * i: batch_size * (i + 1)],
31+
req_to_token_indexs.stride(0), req_to_token_indexs.stride(1),
32+
num_warps=num_warps,
33+
num_stages=1,
34+
)
35+
_b_seq_len = _b_seq_len + 1
3436
return
37+
38+
39+
if __name__ == '__main__':
40+
for decode_len in [1,2]:
41+
max_request_num = 100
42+
max_sequence_length = 1000
43+
req_to_token_indexs = torch.zeros(
44+
(max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda"
45+
)
46+
bs = 8
47+
b_req_idx = torch.randint(low=0, high=max_request_num-1, size=(bs,)).cuda()
48+
b_seq_len = torch.randint(low=1, high=max_sequence_length, size=(bs,)).cuda()
49+
memindex = torch.randint(low=0, high=10000, size=(bs*decode_len,)).cuda()
50+
copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_seq_len,memindex,decode_len)
51+
52+
for i in range(bs):
53+
for j in range(decode_len):
54+
if req_to_token_indexs[b_req_idx[i]][b_seq_len[i]+j-1] != memindex[j*bs+i]:
55+
print("ERROR")
56+
exit(1)
57+
58+
print("PASS")

lightllm/common/spec_info.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from enum import IntEnum, auto
2+
3+
4+
class SpeculativeDecodeAlgorithm(IntEnum):
5+
NONE = auto()
6+
MTP = auto()
7+
MTP_MOUDLE = auto()
8+
9+
def is_none(self):
10+
return self == SpeculativeDecodeAlgorithm.NONE
11+
12+
def is_mtp(self):
13+
return self == SpeculativeDecodeAlgorithm.MTP
14+
15+
@staticmethod
16+
def from_string(name: str):
17+
name_map = {
18+
"MTP": SpeculativeDecodeAlgorithm.MTP,
19+
"MTP_MOUDLE": SpeculativeDecodeAlgorithm.MTP_MOUDLE,
20+
"NONE": SpeculativeDecodeAlgorithm.NONE,
21+
}
22+
if name is not None:
23+
name = name.upper()
24+
return name_map[name]
25+
26+
def decode_len(self):
27+
if self == SpeculativeDecodeAlgorithm.NONE:
28+
return 1
29+
if self == SpeculativeDecodeAlgorithm.MTP:
30+
return 2
31+
if self == SpeculativeDecodeAlgorithm.MTP_MOUDLE:
32+
return 2

lightllm/models/deepseek2/infer_struct.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import torch.distributed as dist
55
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6-
6+
from lightllm.common.spec_info import SpeculativeDecodeAlgorithm
77

88
class Deepseek2InferStateInfo(LlamaInferStateInfo):
99
def __init__(self):
@@ -18,4 +18,15 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
1818
if self.is_prefill:
1919
self.b1_kv_start_loc = self.b1_cu_kv_seq_len
2020
self.max_value_in_b_seq_len = self.b_seq_len.max().item()
21-
return
21+
22+
if not self.is_prefill and not self.spec_algo.is_none():
23+
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
24+
position_ids = torch.from_numpy(
25+
np.concatenate(
26+
[np.arange(b_seq_len_numpy[i] - 2, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))],
27+
axis=0,
28+
)
29+
).cuda()
30+
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
31+
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
32+
return

lightllm/models/deepseek_mtp/__init__.py

Whitespace-only changes.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
3+
from lightllm.utils.log_utils import init_logger
4+
from lightllm.utils.dist_utils import get_current_rank_in_node
5+
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
6+
7+
logger = init_logger(__name__)
8+
9+
10+
class Deepseek3MTPMemoryManager(Deepseek2MemoryManager):
11+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
12+
self.size = size
13+
self.head_num = head_num
14+
self.head_dim = head_dim
15+
self.layer_num = layer_num
16+
self.always_copy = always_copy
17+
self.dtype = dtype
18+
# profile the max total token num if the size is None
19+
self.profile_size(mem_fraction)
20+
21+
self.mem_state = torch.arange(
22+
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
23+
)
24+
self.mark_start = 0
25+
self.mark_end = self.size
26+
27+
self.can_use_mem_size = self.size
28+
29+
rank_in_node = get_current_rank_in_node()
30+
self.shared_can_use_token_num = SharedInt(
31+
f"MTP_mem_manger_can_use_token_num_{rank_in_node}"
32+
)
33+
34+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
35+
36+
self._init_buffers(
37+
self.size,
38+
dtype,
39+
head_num,
40+
head_dim,
41+
layer_num,
42+
)
43+
self.HOLD_TOKEN_MEMINDEX = self.size
44+

lightllm/models/deepseek_mtp/layer_infer/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)