2424
2525logger = init_logger (__name__ )
2626
27- # TODO: optim
28- def update_draft_token_mem_indexes (draft_token_memindex_map , run_reqs , mem_indexes ):
29- for i , req in enumerate (run_reqs ):
30- draft_token_memindex_map [req .req_idx ] = mem_indexes [i ]
31-
3227
3328class ContinuesBatchWithMTPBackend (ModeBackend ):
3429 def __init__ (self ) -> None :
@@ -81,6 +76,7 @@ def init_model(self, kvargs):
8176 self .mtp_draft_token_memindex_map = torch .full (
8277 (max_req_num ,), fill_value = IS_NONE , dtype = torch .int32 , device = "cpu"
8378 )
79+ self .accept_len = 0
8480
8581 def prefill (self , reqs : List [Tuple ]):
8682 self ._init_reqs (reqs , init_req_obj = False )
@@ -107,10 +103,8 @@ def decode(self):
107103 next_token_ids , next_token_probs = sample (model_output .logits , run_reqs , self .eos_id )
108104 next_token_ids = next_token_ids .detach ().cpu ().numpy ()
109105 next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
110- # spec decode: MTP
111- draft_model_input = prepare_mtp_prefill_inputs (prefill_reqs , next_token_ids , self .draft_model .mem_manager )
112- # mtp embedding
113- draft_model_input .hidden_states = model_output .hidden_states
106+ # spec prefill: MTP
107+ draft_model_input = prepare_mtp_prefill_inputs (prefill_reqs , model_input , model_output , next_token_ids )
114108 draft_model_output = self .draft_model .forward (draft_model_input )
115109 draft_next_token_ids , _ = sample (draft_model_output .logits , run_reqs , self .eos_id )
116110 draft_next_token_ids = draft_next_token_ids .detach ().cpu ().numpy ()
@@ -121,9 +115,8 @@ def decode(self):
121115 )
122116
123117 if decode_reqs :
124- model_input , run_reqs = prepare_draft_main_model_decode_inputs (decode_reqs , self .draft_token_id_map )
125- update_draft_token_mem_indexes (
126- self .main_draft_token_memindex_map , run_reqs [1 ::2 ], model_input .mem_indexes [1 ::2 ]
118+ model_input , run_reqs , mem_indexes_cpu = prepare_draft_main_model_decode_inputs (
119+ decode_reqs , self .draft_token_id_map
127120 )
128121 model_output = self .model .forward (model_input )
129122 assert model_output .logits .shape [0 ] % 2 == 0
@@ -148,7 +141,9 @@ def decode(self):
148141 next_token_ids1 = next_token_ids [1 ::2 ]
149142 next_token_logprobs1 = next_token_logprobs [1 ::2 ]
150143
151- accepted_reqs , accepted_index = self .verify (next_token_ids0 , run_reqs [::2 ])
144+ accepted_reqs , accepted_index , need_free_mem_indexes = self .verify (
145+ next_token_ids0 , run_reqs [::2 ], mem_indexes_cpu [1 ::2 ]
146+ )
152147 self ._post_handle (
153148 accepted_reqs ,
154149 next_token_ids1 [accepted_index ],
@@ -157,22 +152,23 @@ def decode(self):
157152 do_filter_finished_reqs = False ,
158153 )
159154 # spec decode: MTP
160- draft_model_input = copy . deepcopy ( model_input )
155+ draft_model_input = model_input
161156 draft_model_input .input_ids = torch .tensor (next_token_ids , dtype = torch .int64 , device = "cuda" )
162- mtp_mem_indexes = self .draft_model .mem_manager .alloc (next_token_ids .shape [0 ]).cuda ()
163- draft_model_input .mem_indexes = mtp_mem_indexes
164157 draft_model_input .hidden_states = model_output .hidden_states
165- update_draft_token_mem_indexes (self .mtp_draft_token_memindex_map , run_reqs [1 ::2 ], mtp_mem_indexes [1 ::2 ])
166158 draft_model_output = self .draft_model .forward (draft_model_input )
167159 draft_next_token_ids , _ = sample (draft_model_output .logits , run_reqs , self .eos_id )
168160
169161 accepted_req_idxs = [req .req_idx for req in accepted_reqs ]
170162 self ._save_draft_token_ids (draft_next_token_ids , run_reqs [::2 ], accepted_req_idxs )
163+ if need_free_mem_indexes :
164+ g_infer_state_lock .acquire ()
165+ g_infer_context .req_manager .mem_manager .free (need_free_mem_indexes )
166+ g_infer_state_lock .release ()
171167
172168 self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True )
173169 return
174170
175- def verify (self , next_token_ids0 , run_reqs ):
171+ def verify (self , next_token_ids0 , run_reqs , draft_mem_indexes ):
176172 accepted_reqs = []
177173 accepted_index = []
178174 need_free_mem_indexes = []
@@ -181,66 +177,20 @@ def verify(self, next_token_ids0, run_reqs):
181177 if self .draft_token_id_map [req .req_idx ] == next_token_ids0 [i ]:
182178 accepted_reqs .append (req )
183179 accepted_index .append (i )
184- self .main_draft_token_memindex_map [req .req_idx ] = IS_NONE
180+ self .accept_len += 1
181+ device0_print (f"self.accept_len: { self .accept_len } " )
185182 else :
186- need_free_mem_indexes .append (self .main_draft_token_memindex_map [req .req_idx ])
187- if need_free_mem_indexes :
188- g_infer_state_lock .acquire ()
189- g_infer_context .req_manager .mem_manager .free (need_free_mem_indexes )
190- g_infer_state_lock .release ()
191- return accepted_reqs , accepted_index
183+ need_free_mem_indexes .append (draft_mem_indexes [i ])
184+ return accepted_reqs , accepted_index , need_free_mem_indexes
192185
193186 def _save_draft_token_ids (self , draft_next_token_ids , run_reqs , accepted_reqs = None ):
194187 assert accepted_reqs is None or draft_next_token_ids .shape [0 ] == 2 * len (run_reqs )
195- need_free_mem_indexes = []
196188 for i , req in enumerate (run_reqs ):
197189 if accepted_reqs is None :
198190 self .draft_token_id_map [req .req_idx ] = draft_next_token_ids [i ]
199191 else :
200192 if req .req_idx in accepted_reqs :
201193 self .draft_token_id_map [req .req_idx ] = draft_next_token_ids [2 * i + 1 ]
202- self .mtp_draft_token_memindex_map [req .req_idx ] = IS_NONE
203194 else :
204195 self .draft_token_id_map [req .req_idx ] = draft_next_token_ids [2 * i ]
205- need_free_mem_indexes .append (self .mtp_draft_token_memindex_map [req .req_idx ])
206-
207- req = run_reqs [0 ]
208- if need_free_mem_indexes :
209- g_infer_state_lock .acquire ()
210- self .draft_model .mem_manager .free (need_free_mem_indexes )
211- g_infer_state_lock .release ()
212- return
213-
214- def _overlap_req_init_and_filter (
215- self , uninit_reqs : List [InferReq ], ok_finished_reqs : List [InferReq ], clear_list = False
216- ):
217- if uninit_reqs or ok_finished_reqs :
218- with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
219- if ok_finished_reqs :
220- g_infer_state_lock .acquire ()
221- self ._free_mtp_model_memindex (ok_finished_reqs )
222- g_infer_context .filter_reqs (ok_finished_reqs )
223- g_infer_state_lock .release ()
224-
225- if uninit_reqs :
226- g_infer_state_lock .acquire ()
227- self ._post_init_reqs (uninit_reqs )
228- g_infer_state_lock .release ()
229-
230- torch .cuda .current_stream ().wait_stream (g_infer_context .get_overlap_stream ())
231-
232- if clear_list :
233- uninit_reqs .clear ()
234- ok_finished_reqs .clear ()
235196 return
236-
237- def _free_mtp_model_memindex (self , ok_finished_reqs ):
238- mtp_free_mem_indexes = []
239- for req in ok_finished_reqs :
240- mtp_free_mem_indexes .append (
241- self .draft_model .req_manager .req_to_token_indexs [req .req_idx ][0 : req .cur_kv_len ]
242- )
243- free_memindexes = torch .cat (mtp_free_mem_indexes , dim = 0 )
244- g_infer_state_lock .acquire ()
245- self .draft_model .req_manager .mem_manager .free (free_memindexes )
246- g_infer_state_lock .release ()
0 commit comments