Skip to content

Commit 64fda2b

Browse files
committed
clean code
1 parent 846ae0c commit 64fda2b

File tree

7 files changed

+28
-41
lines changed

7 files changed

+28
-41
lines changed

lightllm/common/basemodel/batch_objs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ModelInput:
3131
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
3232
# 的一些变量
3333
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出
34-
b_chunked_prefill_next_token_ids_cpu: List[int] = None # for chunked prefill mtp
34+
b_next_chunck_first_token_ids_cpu: List[int] = None # for chuncked prefill mtp
3535

3636
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
3737
# 的输入变量。只在特殊的模型模式下才会具体使用和生效。

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,8 @@ def get_chuncked_input_token_ids(self):
397397
if chunked_end < self.get_cur_total_len():
398398
next_token_id = self.shm_req.shm_prompt_ids.arr[chunked_end]
399399
else:
400-
next_token_id = -1 # last chunk
400+
# padding id for last chunck, will be discarded.
401+
next_token_id = self.shm_req.shm_prompt_ids.arr[0]
401402

402403
return self.shm_req.shm_prompt_ids.arr[0:chunked_end], next_token_id
403404

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ def _sample_and_scatter_token(
658658
is_prefill: bool,
659659
b_prefill_has_output_cpu: torch.Tensor = None,
660660
mask_func: Optional[Callable] = None,
661+
b_next_chunck_first_token_ids_cpu: torch.Tensor = None,
661662
):
662663

663664
if mask_func is not None:
@@ -670,6 +671,11 @@ def _sample_and_scatter_token(
670671
b_has_out = g_pin_mem_manager.gen_from_list(
671672
key="b_has_out", data=b_prefill_has_output_cpu, dtype=torch.bool
672673
).cuda(non_blocking=True)
674+
if b_next_chunck_first_token_ids_cpu is not None:
675+
b_next_chunck_first_token_ids = g_pin_mem_manager.gen_from_list(
676+
key="b_next_chunck_first_token_ids", data=b_next_chunck_first_token_ids_cpu, dtype=torch.int64
677+
).cuda(non_blocking=True)
678+
next_token_ids = torch.where(b_has_out, next_token_ids, b_next_chunck_first_token_ids)
673679

674680
scatter_token(
675681
next_token_ids=next_token_ids,

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,10 @@ def prefill_mtp(
190190
is_prefill=True,
191191
b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu,
192192
mask_func=self.prefill_mask_func,
193+
b_next_chunck_first_token_ids_cpu=model_input.b_next_chunck_first_token_ids_cpu,
193194
)
194-
# mtp kv fill
195-
b_has_out = torch.tensor(model_input.b_prefill_has_output_cpu, dtype=torch.bool, device="cuda")
196-
b_chunked_next_token_ids = torch.tensor(
197-
model_input.b_chunked_prefill_next_token_ids_cpu, dtype=torch.int64, device="cuda"
198-
)
199-
mtp_next_token_ids = torch.where(b_has_out, next_token_ids, b_chunked_next_token_ids)
200-
201195
self._draft_prefill_forward(
202-
model_input=model_input, model_output=model_output, next_token_ids=mtp_next_token_ids
196+
model_input=model_input, model_output=model_output, next_token_ids=next_token_ids
203197
)
204198
sync_event = torch.cuda.Event()
205199
sync_event.record()

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -354,13 +354,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]
354354
# mtp kv fill
355355
draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda")
356356
if req_num > 0:
357-
b_has_out = torch.tensor(b_has_out_cpu, dtype=torch.bool, device="cuda")
358-
b_chunked_next_token_ids = torch.tensor(
359-
model_input.b_chunked_prefill_next_token_ids_cpu[0:req_num], dtype=torch.int64, device="cuda"
360-
)
361-
mtp_next_token_ids = torch.where(b_has_out, next_token_ids, b_chunked_next_token_ids)
362-
draft_next_token_ids_gpu[0:req_num].copy_(mtp_next_token_ids)
363-
357+
draft_next_token_ids_gpu[0:req_num].copy_(next_token_ids)
364358
self._draft_prefill_forward(
365359
model_input=model_input,
366360
model_output=model_output,
@@ -622,6 +616,10 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I
622616
b_has_out_cpu = (
623617
model_input0.b_prefill_has_output_cpu[0:req_num0] + model_input1.b_prefill_has_output_cpu[0:req_num1]
624618
)
619+
b_next_chunck_first_token_ids_cpu = (
620+
model_input0.b_next_chunck_first_token_ids_cpu[0:req_num0]
621+
+ model_input1.b_next_chunck_first_token_ids_cpu[0:req_num1]
622+
)
625623
b_mtp_index = torch.cat((model_input0.b_mtp_index[0:req_num0], model_input1.b_mtp_index[0:req_num1]), dim=0)
626624
b_req_idx = torch.cat((model_input0.b_req_idx[0:req_num0], model_input1.b_req_idx[0:req_num1]), dim=0)
627625

@@ -633,33 +631,20 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I
633631
b_mtp_index=b_mtp_index,
634632
is_prefill=True,
635633
b_prefill_has_output_cpu=b_has_out_cpu,
634+
b_next_chunck_first_token_ids_cpu=b_next_chunck_first_token_ids_cpu,
636635
)
637636

638637
# spec prefill: MTP
639638
draft_model_input0, draft_model_input1 = model_input0, model_input1
640639
draft_next_token_ids_gpu0 = torch.zeros((model_input0.batch_size), dtype=torch.int64, device="cuda")
641640
if req_num0 > 0:
642-
b_has_out0 = torch.tensor(
643-
model_input0.b_prefill_has_output_cpu[0:req_num0], dtype=torch.bool, device="cuda"
644-
)
645-
b_chunked_next_token_ids0 = torch.tensor(
646-
model_input0.b_chunked_prefill_next_token_ids_cpu[0:req_num0], dtype=torch.int64, device="cuda"
647-
)
648-
mtp_next_token_ids0 = torch.where(b_has_out0, next_token_ids[0:req_num0], b_chunked_next_token_ids0)
649-
draft_next_token_ids_gpu0[0:req_num0].copy_(mtp_next_token_ids0, non_blocking=True)
641+
draft_next_token_ids_gpu0[0:req_num0].copy_(next_token_ids[0:req_num0], non_blocking=True)
650642

651643
draft_next_token_ids_gpu1 = torch.zeros((model_input1.batch_size), dtype=torch.int64, device="cuda")
652644
if req_num1 > 0:
653-
b_has_out1 = torch.tensor(
654-
model_input1.b_prefill_has_output_cpu[0:req_num1], dtype=torch.bool, device="cuda"
655-
)
656-
b_chunked_next_token_ids1 = torch.tensor(
657-
model_input1.b_chunked_prefill_next_token_ids_cpu[0:req_num1], dtype=torch.int64, device="cuda"
658-
)
659-
mtp_next_token_ids1 = torch.where(
660-
b_has_out1, next_token_ids[req_num0 : (req_num0 + req_num1)], b_chunked_next_token_ids1
645+
draft_next_token_ids_gpu1[0:req_num1].copy_(
646+
next_token_ids[req_num0 : (req_num0 + req_num1)], non_blocking=True
661647
)
662-
draft_next_token_ids_gpu1[0:req_num1].copy_(mtp_next_token_ids1, non_blocking=True)
663648

664649
draft_model_output0, draft_model_output1 = model_output0, model_output1
665650

lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def padded_prepare_prefill_inputs(
3636
b_ready_cache_len = []
3737
b_mtp_index = []
3838
b_prefill_has_output = []
39-
b_chunked_prefill_next_token_ids = []
39+
b_next_chunck_first_token_ids = []
4040

4141
for req in req_objs:
4242

@@ -45,7 +45,7 @@ def padded_prepare_prefill_inputs(
4545
b_req_idx.append(req.req_idx)
4646

4747
input_token_ids, next_token_id = req.get_chuncked_input_token_ids()
48-
b_chunked_prefill_next_token_ids.append(next_token_id)
48+
b_next_chunck_first_token_ids.append(next_token_id)
4949
b_prefill_has_output.append(False if len(input_token_ids) < req.get_cur_total_len() else True)
5050
seq_len = len(input_token_ids)
5151
input_token_len = seq_len - req.cur_kv_len
@@ -67,7 +67,7 @@ def padded_prepare_prefill_inputs(
6767
b_q_seq_len.append(1)
6868
b_mtp_index.append(0)
6969
b_prefill_has_output.append(False)
70-
b_chunked_prefill_next_token_ids.append(-1)
70+
b_next_chunck_first_token_ids.append(0)
7171
b_ready_cache_len.append(0)
7272
total_token_num += 1
7373
prefix_total_token_num += 0
@@ -115,7 +115,7 @@ def padded_prepare_prefill_inputs(
115115
b_ready_cache_len=b_ready_cache_len,
116116
is_prefill=True,
117117
b_prefill_has_output_cpu=b_prefill_has_output,
118-
b_chunked_prefill_next_token_ids_cpu=b_chunked_prefill_next_token_ids,
118+
b_next_chunck_first_token_ids_cpu=b_next_chunck_first_token_ids,
119119
)
120120
if is_multimodal:
121121
model_input.multimodal_params = batch_multimodal_params

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def prepare_prefill_inputs(
2020
b_ready_cache_len = []
2121
b_mtp_index = []
2222
b_prefill_has_output = []
23-
b_chunked_prefill_next_token_ids = []
23+
b_next_chunck_first_token_ids = []
2424

2525
for req in req_objs:
2626
run_reqs.append(req)
@@ -29,7 +29,7 @@ def prepare_prefill_inputs(
2929

3030
if is_chuncked_mode:
3131
input_token_ids, next_token_id = req.get_chuncked_input_token_ids()
32-
b_chunked_prefill_next_token_ids.append(next_token_id)
32+
b_next_chunck_first_token_ids.append(next_token_id)
3333
else:
3434
input_token_ids = req.get_input_token_ids()
3535

@@ -59,6 +59,7 @@ def prepare_prefill_inputs(
5959
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
6060
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
6161
b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu")
62+
b_next_chunck_first_token_ids = torch.tensor(b_next_chunck_first_token_ids, dtype=torch.int64, device="cpu")
6263

6364
# dynamic prompt cache 准备 token
6465
g_infer_state_lock.acquire()
@@ -82,7 +83,7 @@ def prepare_prefill_inputs(
8283
b_ready_cache_len=b_ready_cache_len,
8384
is_prefill=True,
8485
b_prefill_has_output_cpu=b_prefill_has_output,
85-
b_chunked_prefill_next_token_ids_cpu=b_chunked_prefill_next_token_ids,
86+
b_next_chunck_first_token_ids_cpu=b_next_chunck_first_token_ids,
8687
prefix_total_token_num=prefix_total_token_num,
8788
)
8889
if is_multimodal:

0 commit comments

Comments
 (0)