Skip to content

Commit ef35cf6

Browse files
committed
add mtp index
1 parent a7fbb15 commit ef35cf6

File tree

6 files changed

+61
-23
lines changed

6 files changed

+61
-23
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def _decode(
353353
model_input.input_ids = gather_token(
354354
self.req_manager.req_sampling_params_manager.req_to_next_token_ids,
355355
model_input.b_req_idx,
356+
model_input.b_mtp_index,
356357
)
357358

358359
if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch):
@@ -668,6 +669,7 @@ def _check_max_len_infer(self):
668669
b_seq_len[:] = self.batch_max_tokens
669670
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
670671
total_token_num = self.batch_max_tokens
672+
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
671673
model_input = ModelInput(
672674
batch_size=1,
673675
total_token_num=total_token_num,
@@ -676,6 +678,7 @@ def _check_max_len_infer(self):
676678
mem_indexes=mem_indexes,
677679
b_req_idx=b_req_idx,
678680
b_seq_len=b_seq_len,
681+
b_mtp_index=b_mtp_index,
679682
is_prefill=True,
680683
b_ready_cache_len=b_ready_cache_len,
681684
)
@@ -723,13 +726,15 @@ def _init_padded_req(self):
723726
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
724727
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
725728
total_token_num = prefill_input_len * batch_size
729+
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
726730
model_input = ModelInput(
727731
batch_size=batch_size,
728732
total_token_num=total_token_num,
729733
max_len_in_batch=prefill_input_len,
730734
input_ids=dummy_input_ids,
731735
mem_indexes=mem_indexes,
732736
b_req_idx=b_req_idx,
737+
b_mtp_index=b_mtp_index,
733738
b_seq_len=b_seq_len,
734739
b_ready_cache_len=b_ready_cache_len,
735740
is_prefill=True,

lightllm/common/basemodel/batch_objs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class ModelInput:
1212
input_ids: torch.Tensor
1313
mem_indexes: torch.Tensor
1414
b_req_idx: torch.Tensor
15+
b_mtp_index: torch.Tensor
1516
b_seq_len: torch.Tensor
1617
is_prefill: bool = False
1718
b_ready_cache_len: torch.Tensor = None
@@ -30,6 +31,7 @@ def to_cuda(self):
3031
self.mem_indexes = self.mem_indexes.cuda(non_blocking=True)
3132
self.b_req_idx = self.b_req_idx.cuda(non_blocking=True)
3233
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)
34+
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
3335
if self.b_ready_cache_len is not None:
3436
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
3537

lightllm/common/basemodel/cuda_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def warmup(self, model):
202202
)
203203
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
204204
b_seq_len.fill_(seq_len)
205+
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
205206

206207
model_input = ModelInput(
207208
batch_size=batch_size,
@@ -211,6 +212,7 @@ def warmup(self, model):
211212
mem_indexes=mem_indexes,
212213
b_req_idx=b_req_idx,
213214
b_seq_len=b_seq_len,
215+
b_mtp_index=b_mtp_index,
214216
is_prefill=False,
215217
**model._gen_special_model_input(batch_size),
216218
)

lightllm/common/basemodel/triton_kernel/gather_token_id.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -72,37 +72,48 @@ def gather_and_scatter_token_to_cpu(
7272

7373
@triton.jit
7474
def _fwd_kernel_scatter(
75-
token_info,
76-
req_to_token_info,
75+
next_token_ids,
76+
req_to_next_token_ids,
7777
b_req_idx,
78-
req_to_token_info_stride,
78+
b_mtp_index,
79+
req_to_next_token_ids_stride,
80+
req_to_next_token_ids_stride_1,
7981
):
8082
cur_index = tl.program_id(0)
8183
cur_req_idx = tl.load(b_req_idx + cur_index)
82-
cur_token_info = tl.load(token_info + cur_index)
83-
tl.store(req_to_token_info + cur_req_idx * req_to_token_info_stride, cur_token_info)
84+
cur_mtp_index = tl.load(b_mtp_index + cur_index)
85+
cur_next_token_id = tl.load(next_token_ids + cur_index)
86+
tl.store(req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index, cur_next_token_id)
8487
return
8588

8689

8790
@torch.no_grad()
88-
def scatter_token(token_info: torch.Tensor, req_to_token_info: torch.Tensor, b_req_idx: torch.Tensor):
91+
def scatter_token(
92+
next_token_ids: torch.Tensor,
93+
req_to_next_token_ids: torch.Tensor,
94+
b_req_idx: torch.Tensor,
95+
b_mtp_index: torch.Tensor,
96+
):
8997
"""
9098
This function is used to scatter the token_info(GPU tensor) to the req_to_token_info(CPU tensor).
9199
Args:
92-
token_info: (batch_size, vocab_size)
93-
req_to_token_info: (max_req_num,)
100+
next_token_ids: (batch_size,)
101+
req_to_next_token_ids: (max_req_num, max_mtp_step)
94102
b_req_idx: (batch_size,)
103+
b_mtp_index: (batch_size,)
95104
"""
96-
assert token_info.shape[0] == b_req_idx.shape[0]
105+
assert next_token_ids.shape[0] == b_req_idx.shape[0]
97106
batch_size = b_req_idx.shape[0]
98107
grid = (batch_size,)
99108
num_warps = 1
100109

101110
_fwd_kernel_scatter[grid](
102-
token_info,
103-
req_to_token_info,
111+
next_token_ids,
112+
req_to_next_token_ids,
104113
b_req_idx,
105-
req_to_token_info.stride(0),
114+
b_mtp_index,
115+
req_to_next_token_ids.stride(0),
116+
req_to_next_token_ids.stride(1),
106117
num_warps=num_warps,
107118
num_stages=1,
108119
)
@@ -111,24 +122,28 @@ def scatter_token(token_info: torch.Tensor, req_to_token_info: torch.Tensor, b_r
111122

112123
@triton.jit
113124
def _fwd_kernel_gather(
114-
req_to_token_info,
115-
req_to_token_info_stride,
125+
req_to_next_token_ids,
126+
req_to_next_token_ids_stride,
127+
req_to_next_token_ids_stride_1,
116128
output,
117129
b_req_idx,
130+
b_mtp_index,
118131
):
119132
cur_index = tl.program_id(0)
120133
cur_req_idx = tl.load(b_req_idx + cur_index)
121-
cur_token_info = tl.load(req_to_token_info + cur_req_idx * req_to_token_info_stride)
122-
tl.store(output + cur_index, cur_token_info)
134+
cur_mtp_index = tl.load(b_mtp_index + cur_index)
135+
cur_next_token_id = tl.load(req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index)
136+
tl.store(output + cur_index, cur_next_token_id)
123137
return
124138

125139

126-
def gather_token(req_to_token_info: torch.Tensor, b_req_idx: torch.Tensor):
140+
def gather_token(req_to_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, b_mtp_index: torch.Tensor):
127141
"""
128142
This function is used to gather the token_info(CPU tensor) to the token_info(GPU tensor).
129143
Args:
130144
req_to_token_info: (max_req_num, max_mtp_step)
131145
b_req_idx: (batch_size,)
146+
b_mtp_index: (batch_size,)
132147
Returns:
133148
output: (batch_size,)
134149
"""
@@ -137,10 +152,12 @@ def gather_token(req_to_token_info: torch.Tensor, b_req_idx: torch.Tensor):
137152
grid = (batch_size,)
138153
num_warps = 1
139154
_fwd_kernel_gather[grid](
140-
req_to_token_info,
141-
req_to_token_info.stride(0),
155+
req_to_next_token_ids,
156+
req_to_next_token_ids.stride(0),
157+
req_to_next_token_ids.stride(1),
142158
output,
143159
b_req_idx,
160+
b_mtp_index,
144161
num_warps=num_warps,
145162
num_stages=1,
146163
)
@@ -187,7 +204,8 @@ def test_scatter_token_to_cpu():
187204
req_to_token_info = torch.zeros((1000,), dtype=torch.float32, pin_memory=True)
188205
token_info = torch.randn((batch_size,)).cuda()
189206
req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda()
190-
scatter_token(token_info, req_to_token_info, req_ids)
207+
mtp_index = torch.zeros((batch_size,), dtype=torch.int32).cuda()
208+
scatter_token(token_info, req_to_token_info, req_ids, mtp_index)
191209
diff = (req_to_token_info[20 : 20 + batch_size].cuda() - token_info).abs().max()
192210
assert diff < 1e-6
193211
print("test_scatter_token_to_cpu passed")
@@ -198,8 +216,9 @@ def test_gather_token():
198216
req_to_token_info = torch.zeros((1000,), dtype=torch.int32, pin_memory=True)
199217
token_info = torch.randn((batch_size,)).cuda()
200218
req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda()
201-
scatter_token(token_info, req_to_token_info, req_ids)
202-
output = gather_token(req_to_token_info, req_ids)
219+
mtp_index = torch.zeros((batch_size,), dtype=torch.int32).cuda()
220+
scatter_token(token_info, req_to_token_info, req_ids, mtp_index)
221+
output = gather_token(req_to_token_info, req_ids, mtp_index)
203222
diff = (token_info - output).abs().max()
204223
assert diff < 1e-6
205224
print("test_gather_token passed")

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def prefill_normal(
102102
next_token_ids,
103103
self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
104104
model_input.b_req_idx,
105+
model_input.b_mtp_index,
105106
)
106107
next_token_ids_cpu = g_pin_mem_manager.alloc_pin_tensor(
107108
"next_token_ids", next_token_ids.shape[0], next_token_ids.dtype
@@ -149,6 +150,7 @@ def decode_normal(
149150
next_token_ids,
150151
self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
151152
model_input.b_req_idx,
153+
model_input.b_mtp_index,
152154
)
153155
next_token_ids_cpu = g_pin_mem_manager.alloc_pin_tensor(
154156
"next_token_ids", next_token_ids.shape[0], next_token_ids.dtype

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def prepare_prefill_inputs(
1717
b_seq_len = []
1818
batch_multimodal_params = []
1919
b_ready_cache_len = []
20+
b_mtp_index = []
2021
for req in req_objs:
2122
run_reqs.append(req)
2223
batch_multimodal_params.append(req.multimodal_params)
@@ -37,12 +38,14 @@ def prepare_prefill_inputs(
3738
total_token_num += seq_len
3839
max_len_in_batch = max(max_len_in_batch, input_token_len)
3940
b_ready_cache_len.append(req.cur_kv_len)
41+
b_mtp_index.append(0)
4042

4143
input_ids = np.concatenate(input_ids, dtype=np.int64)
4244

4345
input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cpu")
4446
b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu")
4547
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
48+
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
4649
b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu")
4750

4851
# dynamic prompt cache 准备 token
@@ -59,6 +62,7 @@ def prepare_prefill_inputs(
5962
input_ids=input_ids,
6063
mem_indexes=mem_indexes,
6164
b_req_idx=b_req_idx,
65+
b_mtp_index=b_mtp_index,
6266
b_seq_len=b_seq_len,
6367
b_ready_cache_len=b_ready_cache_len,
6468
is_prefill=True,
@@ -74,6 +78,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
7478
total_token_num = 0
7579
max_len_in_batch = 0
7680
b_req_idx = []
81+
b_mtp_index = []
7782
b_seq_len = []
7883
for req in req_objs:
7984
run_reqs.append(req)
@@ -83,7 +88,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
8388
b_seq_len.append(seq_len)
8489
total_token_num += seq_len
8590
max_len_in_batch = max(max_len_in_batch, seq_len)
86-
91+
b_mtp_index.append(0)
8792
# process the draft tokens.
8893
for step in range(len(req.mtp_gen_token_ids)):
8994
run_reqs.append(req)
@@ -92,9 +97,11 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
9297
b_seq_len.append(seq_len)
9398
total_token_num += seq_len
9499
max_len_in_batch = max(max_len_in_batch, seq_len)
100+
b_mtp_index.append(step + 1)
95101

96102
b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu")
97103
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
104+
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
98105

99106
# dynamic prompt cache 准备 token
100107
g_infer_state_lock.acquire()
@@ -110,6 +117,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
110117
input_ids=None,
111118
mem_indexes=mem_indexes,
112119
b_req_idx=b_req_idx,
120+
b_mtp_index=b_mtp_index,
113121
b_seq_len=b_seq_len,
114122
is_prefill=False,
115123
)

0 commit comments

Comments
 (0)