@@ -53,31 +53,22 @@ def get_overlap_stream(self) -> torch.cuda.Stream:
5353 self .overlap_stream = torch .cuda .Stream ()
5454 return self .overlap_stream
5555
56- def add_reqs (self , requests : List [Tuple [int , int , Any , int ]], init_req_obj = True ):
56+ def add_reqs (self , requests : List [Tuple [int , int , Any , int ]]):
5757 request_ids = []
5858 for r in requests :
59-
6059 r_id , r_index , multimodal_params , _ = r
61- if r_id not in self .requests_mapping .keys ():
62- r_obj = InferReq (
63- req_id = r_id ,
64- req_idx = self .req_manager .alloc (),
65- shm_index = r_index ,
66- multimodal_params = multimodal_params ,
67- vocab_size = self .vocab_size ,
68- )
69- self .requests_mapping [r_id ] = r_obj
70- else :
71- r_obj : InferReq = self .requests_mapping [r_id ]
72- assert r_obj .paused is True
73-
60+ assert r_id not in self .requests_mapping .keys ()
61+ r_obj = InferReq (
62+ req_id = r_id ,
63+ req_idx = self .req_manager .alloc (),
64+ shm_index = r_index ,
65+ multimodal_params = multimodal_params ,
66+ vocab_size = self .vocab_size ,
67+ )
68+ self .requests_mapping [r_id ] = r_obj
7469 request_ids .append (r_id )
7570
76- if init_req_obj :
77- r_obj .init_all ()
78-
7971 self .infer_req_ids .extend (request_ids )
80-
8172 return
8273
8374 def free_a_req_mem (self , free_token_index : List , req : "InferReq" , is_group_finished : bool ):
@@ -169,27 +160,47 @@ def filter_reqs(self, finished_reqs: List["InferReq"]):
169160 return
170161
171162 @torch .no_grad ()
172- def pause_reqs (self , pause_req_ids : List [int ]):
173- free_token_index = []
174- for request_id in pause_req_ids :
175- req : InferReq = self .requests_mapping [request_id ]
176- self .infer_req_ids .remove (request_id )
163+ def pause_reqs (self , pause_reqs : List ["InferReq" ]):
164+ if pause_reqs :
165+ g_infer_state_lock .acquire ()
177166
178- if req .initialized :
179- # 不支持多输出的情况的暂停
167+ free_token_index = []
168+ for req in pause_reqs :
169+ # 不支持多输出的情况的暂停, 不能支持 diverse 输出模式。
180170 self .free_a_req_mem (free_token_index , req , is_group_finished = True )
181171 req .cur_kv_len = 0
182172 req .shm_req .shm_cur_kv_len = req .cur_kv_len
183- req .paused = True # 暂停信息标记。
184- else :
173+ assert req .wait_pause is True
174+ req . wait_pause = False
185175 req .paused = True
186176
187- if len (free_token_index ) != 0 :
188- free_token_index = custom_cat (free_token_index )
189- self .req_manager .free_token (free_token_index )
177+ if len (free_token_index ) != 0 :
178+ free_token_index = custom_cat (free_token_index )
179+ self .req_manager .free_token (free_token_index )
190180
181+ g_infer_state_lock .release ()
191182 return self
192183
184+ def recover_paused_reqs (self , paused_reqs : List ["InferReq" ]):
185+ if paused_reqs :
186+ g_infer_state_lock .acquire ()
187+
188+ for req in paused_reqs :
189+ req ._match_radix_cache ()
190+ assert req .paused is True
191+ req .paused = False
192+
193+ g_infer_state_lock .release ()
194+ return
195+
196+ def get_can_alloc_token_num (self ):
197+ radix_cache_unref_token_num = 0
198+ if self .radix_cache is not None :
199+ radix_cache_unref_token_num = (
200+ self .radix_cache .get_tree_total_tokens_num () - self .radix_cache .get_refed_tokens_num ()
201+ )
202+ return self .req_manager .mem_manager .can_use_mem_size + radix_cache_unref_token_num
203+
193204
194205g_infer_context = InferenceContext ()
195206
@@ -256,9 +267,14 @@ def __init__(
256267 self .shm_index = shm_index
257268 self .multimodal_params = multimodal_params
258269 self .vocab_size = vocab_size
259- self .initialized = False
270+
271+ # 请求需要被暂停
272+ self .wait_pause = False
273+ # 请求已经被暂停
260274 self .paused = False
275+
261276 self .infer_aborted = False
277+ self .filter_mark = False
262278 self .need_out_token_id_statistics = True
263279 self .out_token_id_count : Dict [int , int ] = None
264280
@@ -268,51 +284,48 @@ def __init__(
268284 # 步骤中需要重新进行校验。
269285 self .mtp_gen_token_ids : List [int ] = []
270286
271- def init_all (self ):
272- if self .initialized is False :
273- self .shm_req = g_infer_context .shm_req_manager .get_req_obj_by_index (self .shm_index )
274- self .shm_req .link_prompt_ids_shm_array ()
275- self .shm_req .link_logprobs_shm_array ()
276- self .sampling_param : InferSamplingParams = InferSamplingParams (self .shm_req , self .vocab_size )
277- self .cur_kv_len = 0
278- self .cur_output_len = 0
279-
280- g_infer_context .req_manager .req_sampling_params_manager .init_req_sampling_params (self )
281-
282- self .stop_sequences = self .sampling_param .shm_param .stop_sequences .to_list ()
283- # token healing mode 才被使用的管理对象
284- if self .shm_req .prefix_token_ids .size != 0 :
285- self .prefix_token_ids = self .shm_req .prefix_token_ids .get_token_ids ()
286- else :
287- self .prefix_token_ids = []
288- self .multimodal_params = self .multimodal_params .to_dict ()
289- self .shared_kv_node : TreeNode = None
290-
291- self .finish_status = FinishStatus ()
292-
293- if self .paused or not self .initialized :
294- # 如果是具有 prompt_cache 的使用特性则需要进行提前的填充和恢复操作。
295- if g_infer_context .radix_cache is not None and self .get_cur_total_len () > 1 :
296- input_token_ids = self .shm_req .shm_prompt_ids .arr [0 : self .get_cur_total_len ()]
297- key = torch .tensor (input_token_ids , dtype = torch .int64 , device = "cpu" )
298- key = key [0 : len (key ) - 1 ] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值
299- share_node , kv_len , value_tensor = g_infer_context .radix_cache .match_prefix (key , update_refs = True )
300- if share_node is not None :
301- self .shared_kv_node = share_node
302- ready_cache_len = share_node .node_prefix_total_len
303- # 从 cpu 到 gpu 是流内阻塞操作
304- g_infer_context .req_manager .req_to_token_indexs [self .req_idx , 0 :ready_cache_len ] = value_tensor
305- self .cur_kv_len = int (ready_cache_len ) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换
306- self .shm_req .prompt_cache_len = self .cur_kv_len # 记录 prompt cache 的命中长度
307-
308- self .shm_req .shm_cur_kv_len = self .cur_kv_len
309-
310- self .initialized = True
311- self .paused = False
287+ self ._init_all_state ()
288+ self ._match_radix_cache ()
312289 return
313290
314- def is_uninitialized (self ):
315- return not self .initialized or self .paused
291+ def _init_all_state (self ):
292+ self .shm_req = g_infer_context .shm_req_manager .get_req_obj_by_index (self .shm_index )
293+ self .shm_req .link_prompt_ids_shm_array ()
294+ self .shm_req .link_logprobs_shm_array ()
295+ self .sampling_param : InferSamplingParams = InferSamplingParams (self .shm_req , self .vocab_size )
296+ self .cur_kv_len = 0
297+ self .cur_output_len = 0
298+
299+ g_infer_context .req_manager .req_sampling_params_manager .init_req_sampling_params (self )
300+
301+ self .stop_sequences = self .sampling_param .shm_param .stop_sequences .to_list ()
302+ # token healing mode 才被使用的管理对象
303+ if self .shm_req .prefix_token_ids .size != 0 :
304+ self .prefix_token_ids = self .shm_req .prefix_token_ids .get_token_ids ()
305+ else :
306+ self .prefix_token_ids = []
307+ self .multimodal_params = self .multimodal_params .to_dict ()
308+ self .shared_kv_node : TreeNode = None
309+
310+ self .finish_status = FinishStatus ()
311+ return
312+
313+ def _match_radix_cache (self ):
314+ if g_infer_context .radix_cache is not None and self .get_cur_total_len () > 1 and self .cur_kv_len == 0 :
315+ input_token_ids = self .shm_req .shm_prompt_ids .arr [0 : self .get_cur_total_len ()]
316+ key = torch .tensor (input_token_ids , dtype = torch .int64 , device = "cpu" )
317+ key = key [0 : len (key ) - 1 ] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值
318+ share_node , kv_len , value_tensor = g_infer_context .radix_cache .match_prefix (key , update_refs = True )
319+ if share_node is not None :
320+ self .shared_kv_node = share_node
321+ ready_cache_len = share_node .node_prefix_total_len
322+ # 从 cpu 到 gpu 是流内阻塞操作
323+ g_infer_context .req_manager .req_to_token_indexs [self .req_idx , 0 :ready_cache_len ] = value_tensor
324+ self .cur_kv_len = int (ready_cache_len ) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换
325+ self .shm_req .prompt_cache_len = self .cur_kv_len # 记录 prompt cache 的命中长度
326+
327+ self .shm_req .shm_cur_kv_len = self .cur_kv_len
328+ return
316329
317330 def get_output_len (self ):
318331 return self .cur_output_len
@@ -372,6 +385,19 @@ def _stop_sequences_matched(self, output_len: int):
372385 return True
373386 return False
374387
388+ def prefill_need_token_num (self , is_chuncked_prefill : bool ):
389+ if is_chuncked_prefill :
390+ input_token_ids = self .get_chuncked_input_token_ids ()
391+ else :
392+ input_token_ids = self .get_input_token_ids ()
393+
394+ seq_len = len (input_token_ids )
395+ input_token_len = seq_len - self .cur_kv_len
396+ return input_token_len
397+
398+ def decode_need_token_num (self ):
399+ return 1 + len (self .mtp_gen_token_ids )
400+
375401
376402class InferReqGroup :
377403 def __init__ (
0 commit comments