99 prepare_prefill_inputs ,
1010)
1111from lightllm .server .router .model_infer .mode_backend .mtp_pre_process import (
12- prepare_mtp_prefill_inputs ,
12+ prepare_mtp_chunked_prefill_inputs ,
1313 prepare_draft_main_model_decode_inputs ,
1414)
1515from lightllm .server .router .model_infer .mode_backend .generic_post_process import sample
@@ -38,7 +38,7 @@ def decode(self):
3838
3939 if prefill_reqs :
4040 model_input , run_reqs = prepare_prefill_inputs (
41- prefill_reqs , is_chuncked_mode = False , is_multimodal = self .is_multimodal
41+ prefill_reqs , is_chuncked_mode = True , is_multimodal = self .is_multimodal
4242 )
4343 model_output = self .model .forward (model_input )
4444
@@ -49,27 +49,37 @@ def decode(self):
4949 next_token_ids , next_token_probs = sample (model_output .logits , run_reqs , self .eos_id )
5050 next_token_ids = next_token_ids .detach ().cpu ().numpy ()
5151 next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
52- self . _post_handle (
53- run_reqs , next_token_ids , next_token_logprobs , is_chuncked_mode = False , do_filter_finished_reqs = False
54- )
52+ prev_step_has_output = [
53+ req_obj . get_chuncked_input_token_len () == req_obj . get_cur_total_len () for req_obj in prefill_reqs
54+ ]
5555 # spec prefill: MTP
5656 last_input_ids_cpu = None
5757 draft_model_input = model_input
5858 last_hidden_states = model_output .hidden_states
59+ draft_next_token_ids = next_token_ids
5960 for draft_model_idx in range (self .spec_step ):
60- device0_print (f"main { draft_model_input } " )
61- draft_model_input , last_input_ids_cpu = prepare_mtp_prefill_inputs (
62- prefill_reqs , model_input , last_hidden_states , next_token_ids , last_input_ids_cpu
61+
62+ draft_model_input , last_input_ids_cpu , prev_step_has_output = prepare_mtp_chunked_prefill_inputs (
63+ prefill_reqs ,
64+ model_input ,
65+ last_hidden_states ,
66+ draft_next_token_ids ,
67+ draft_model_idx + 1 ,
68+ prev_step_has_output ,
69+ last_input_ids_cpu ,
6370 )
64- device0_print ( f"draft_model_input { draft_model_input } " )
71+
6572 draft_model_output = self .draft_models [draft_model_idx ].forward (draft_model_input )
6673 draft_next_token_ids , _ = sample (draft_model_output .logits , run_reqs , self .eos_id )
6774 draft_next_token_ids = draft_next_token_ids .detach ().cpu ().numpy ()
6875
6976 last_hidden_states = draft_model_output .hidden_states
70- next_token_ids = draft_next_token_ids
7177 self ._save_draft_token_ids (draft_next_token_ids , run_reqs , draft_model_idx )
7278
79+ self ._post_handle (
80+ run_reqs , next_token_ids , next_token_logprobs , is_chuncked_mode = True , do_filter_finished_reqs = False
81+ )
82+
7383 if decode_reqs :
7484 model_input , run_reqs , mem_indexes_cpu = prepare_draft_main_model_decode_inputs (
7585 decode_reqs , self .draft_token_id_map
@@ -93,9 +103,11 @@ def decode(self):
93103 accepted_reqs ,
94104 next_token_ids [accepted_index ],
95105 next_token_logprobs [accepted_index ],
96- is_chuncked_mode = False ,
106+ is_chuncked_mode = True ,
97107 do_filter_finished_reqs = False ,
98108 )
109+ self .main_step += 1
110+
99111 # share some inference info with the main model
100112 draft_model_input = model_input
101113 draft_model_input .input_ids = next_token_ids_cuda
@@ -118,37 +130,3 @@ def decode(self):
118130
119131 self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True )
120132 return
121-
122- def verify (self , next_token_ids , run_reqs , draft_mem_indexes ):
123- accepted_reqs = []
124- accepted_index = []
125- need_free_mem_indexes = []
126- assert next_token_ids .shape [0 ] % self .spec_stride == 0
127-
128- for i , req in enumerate (run_reqs ):
129- # main model output
130- if i % self .spec_stride == 0 :
131- accepted_reqs .append (req )
132- accepted_index .append (i )
133- continue
134- draft_model_idx = i % self .spec_stride - 1
135- if (
136- self .draft_token_id_map [req .req_idx ][draft_model_idx ] == next_token_ids [i - 1 ]
137- and req .cur_accepted_len == draft_model_idx
138- ):
139- accepted_reqs .append (req )
140- accepted_index .append (i )
141- req .cur_accepted_len += 1
142- device0_print (f"req { req .req_idx } accepted, cur_accepted_len { req .cur_accepted_len } " )
143- else :
144- need_free_mem_indexes .append (draft_mem_indexes [i ])
145- return accepted_reqs , accepted_index , need_free_mem_indexes
146-
147- def _save_draft_token_ids (self , draft_next_token_ids , run_reqs , draft_model_idx ):
148- batch_size = len (run_reqs ) // self .spec_stride
149- for i in range (batch_size ):
150- req = run_reqs [self .spec_stride * i ]
151- self .draft_token_id_map [req .req_idx ][draft_model_idx ] = draft_next_token_ids [i + req .cur_accepted_len ]
152- # reset the cur_accepted_len
153- if draft_model_idx == self .spec_step - 1 :
154- req .cur_accepted_len = 0
0 commit comments