@@ -28,7 +28,6 @@ class InferenceContext:
2828 radix_cache : RadixCache = None
2929 shm_req_manager : ShmReqManager = None # 共享内存请求对象管理
3030 requests_mapping : Dict [int , "InferReq" ] = None
31- group_mapping = None # 只有进行多输出模式下才有真的使用
3231 infer_req_ids = None
3332 vocab_size = None
3433
@@ -48,7 +47,6 @@ def register(
4847 self .shm_req_manager = shm_req_manager
4948
5049 self .requests_mapping = {}
51- self .group_mapping : Dict [int , InferReqGroup ] = {}
5250 self .infer_req_ids = []
5351
5452 self .vocab_size = vocab_size
@@ -84,46 +82,42 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
8482
8583 self .infer_req_ids .extend (request_ids )
8684
87- # 多输出模式下需要将请求添加到各自的组对象 InferReqGroup 中
85+ # diverse mode 下,建立一组请求间的主从关系
8886 if get_env_start_args ().diverse_mode :
87+ group_reqs : Dict [int , InferReq ] = collections .defaultdict (lambda : [None , list ()])
8988 for r_id in request_ids :
9089 req : InferReq = g_infer_context .requests_mapping [r_id ]
9190 group_req_id = req .shm_req .group_req_id
92- if group_req_id not in g_infer_context .group_mapping :
93- g_infer_context .group_mapping [group_req_id ] = InferReqGroup (group_req_id = group_req_id )
94- g_infer_context .group_mapping [group_req_id ].add_req (r_id )
91+ if req .req_id == group_req_id :
92+ group_reqs [group_req_id ][0 ] = req
93+ else :
94+ group_reqs [group_req_id ][1 ].append (req )
95+
96+ for group_req_id , (master_req , slave_reqs ) in group_reqs .items ():
97+ master_req : InferReq = master_req
98+ master_req .slave_reqs .extend (slave_reqs )
99+ for slave_req in slave_reqs :
100+ slave_req : InferReq = slave_req
101+ slave_req .related_master_req = master_req
95102
96103 return req_objs
97104
98- def free_a_req_mem (self , free_token_index : List , req : "InferReq" , is_group_finished : bool ):
105+ def free_a_req_mem (self , free_token_index : List , req : "InferReq" ):
99106 if self .radix_cache is None :
100- if is_group_finished :
101- free_token_index .append (self .req_manager .req_to_token_indexs [req .req_idx ][0 : req .cur_kv_len ])
102- else :
103- free_token_index .append (
104- self .req_manager .req_to_token_indexs [req .req_idx ][req .shm_req .input_len : req .cur_kv_len ]
105- )
107+ free_token_index .append (self .req_manager .req_to_token_indexs [req .req_idx ][0 : req .cur_kv_len ])
106108 else :
107109 input_token_ids = req .get_input_token_ids ()
108110 key = torch .tensor (input_token_ids [0 : req .cur_kv_len ], dtype = torch .int64 , device = "cpu" )
109111 # .cpu() 是 流内阻塞操作
110112 value = self .req_manager .req_to_token_indexs [req .req_idx ][: req .cur_kv_len ].detach ().cpu ()
111113
112- if is_group_finished :
113- prefix_len , _ = self .radix_cache .insert (key , value )
114- old_prefix_len = 0 if req .shared_kv_node is None else req .shared_kv_node .node_prefix_total_len
115- free_token_index .append (self .req_manager .req_to_token_indexs [req .req_idx ][old_prefix_len :prefix_len ])
116- if req .shared_kv_node is not None :
117- assert req .shared_kv_node .node_prefix_total_len <= prefix_len
118- self .radix_cache .dec_node_ref_counter (req .shared_kv_node )
119- req .shared_kv_node = None
120- else :
121- free_token_index .append (
122- self .req_manager .req_to_token_indexs [req .req_idx ][req .shm_req .input_len : req .cur_kv_len ]
123- )
124- if req .shared_kv_node is not None :
125- self .radix_cache .dec_node_ref_counter (req .shared_kv_node )
126- req .shared_kv_node = None
114+ prefix_len , _ = self .radix_cache .insert (key , value )
115+ old_prefix_len = 0 if req .shared_kv_node is None else req .shared_kv_node .node_prefix_total_len
116+ free_token_index .append (self .req_manager .req_to_token_indexs [req .req_idx ][old_prefix_len :prefix_len ])
117+ if req .shared_kv_node is not None :
118+ assert req .shared_kv_node .node_prefix_total_len <= prefix_len
119+ self .radix_cache .dec_node_ref_counter (req .shared_kv_node )
120+ req .shared_kv_node = None
127121
128122 def _save_promptcache_kvbuffer (self ):
129123 """
@@ -148,14 +142,10 @@ def _filter(self, finished_request_ids: List[int]):
148142 free_token_index = []
149143 for request_id in finished_request_ids :
150144 req : InferReq = self .requests_mapping .pop (request_id )
151- group_req_id = convert_sub_id_to_group_id (req .shm_req .request_id )
152- if group_req_id in self .group_mapping :
153- is_group_finished = self .group_mapping [group_req_id ].remove_req (req .shm_req .request_id )
154- if is_group_finished :
155- del self .group_mapping [group_req_id ]
156- self .free_a_req_mem (free_token_index , req , is_group_finished )
157- else :
158- self .free_a_req_mem (free_token_index , req , True )
145+ if self .args .diverse_mode :
146+ req .clear_master_slave_state ()
147+ self .free_a_req_mem (free_token_index , req )
148+
159149 free_req_index .append (req .req_idx )
160150 # logger.info(f"infer release req id {req.shm_req.request_id}")
161151 req .shm_req .shm_infer_released = True
@@ -192,8 +182,10 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
192182
193183 free_token_index = []
194184 for req in pause_reqs :
195- # 不支持多输出的情况的暂停, 不能支持 diverse 输出模式。
196- self .free_a_req_mem (free_token_index , req , is_group_finished = True )
185+ if self .args .diverse_mode :
186+ # 发生暂停的时候,需要清除 diverse 模式下的主从关系
187+ req .clear_master_slave_state ()
188+ self .free_a_req_mem (free_token_index , req )
197189 req .cur_kv_len = 0
198190 req .shm_req .shm_cur_kv_len = req .cur_kv_len
199191 assert req .wait_pause is True
@@ -337,6 +329,10 @@ def __init__(
337329 self .need_out_token_id_statistics = True
338330 self .out_token_id_count : Dict [int , int ] = None
339331
332+ # diverse mode 下,用于标记请求组之间的依赖关系
333+ self .slave_reqs : List [InferReq ] = []
334+ self .related_master_req : InferReq = None
335+
340336 # nixl pd 分离模式使用的变量, 普通模式下这些变量没有具体用途
341337 self .nixl_trans_kv_start_index : int = 0
342338 self .nixl_pd_task_num : int = 0
@@ -407,6 +403,37 @@ def _match_radix_cache(self):
407403 self .shm_req .shm_cur_kv_len = self .cur_kv_len
408404 return
409405
406+ def is_master_req (self ):
407+ """
408+ diverse 模式下,判断当前请求是否为独立主请求,其进行prefill后,将
409+ kv 通过 radix cache 共享给其他 slave 请求, 共享后 slave 请求也
410+ 会升级为 master 请求,具有独立推理,暂停的特性。
411+ """
412+ return self .related_master_req is None
413+
414+ def is_slave_req (self ):
415+ return self .related_master_req is not None
416+
417+ def clear_master_slave_state (self ):
418+ if self .is_slave_req ():
419+ self .remove_master_req ()
420+ elif self .is_master_req ():
421+ # 数组需要 copy 后遍历。
422+ for slave_req in self .slave_reqs .copy ():
423+ slave_req .remove_master_req ()
424+
425+ def remove_master_req (self ):
426+ """
427+ 一个处于 slave 状态的请求,解除与 master 请求的依赖关系后,自己会升级为
428+ master_req 的状态,具有独立推理,暂停的特性。
429+ """
430+ master_req = self .related_master_req
431+ if master_req is not None :
432+ master_req .slave_reqs .remove (self )
433+ self .related_master_req = None
434+ else :
435+ logger .warning (f"try to remove master req, but related_master_req is None, req id { self .req_id } " )
436+
410437 def get_output_len (self ):
411438 return self .cur_output_len
412439
@@ -482,49 +509,6 @@ def _mtp_decode_need_token_num(self) -> int:
482509 return (1 + self .mtp_step ) * 2
483510
484511
485- class InferReqGroup :
486- def __init__ (
487- self ,
488- group_req_id : int ,
489- ) -> None :
490- self .group_req_id = group_req_id
491- self .req_ids_group = []
492-
493- def get_req (self , index ):
494- return g_infer_context .requests_mapping [self .req_ids_group [index ]]
495-
496- def get_all_reqs (self ):
497- return [g_infer_context .requests_mapping [self .req_ids_group [i ]] for i in range (len (self .req_ids_group ))]
498-
499- def add_req (self , req_id ):
500- self .req_ids_group .append (req_id )
501-
502- def remove_req (self , req_id ):
503- assert req_id in self .req_ids_group
504- self .req_ids_group .remove (req_id )
505- return len (self .req_ids_group ) == 0
506-
507- def best_of (self ):
508- return len (self .req_ids_group )
509-
510- def diverse_copy (self , req_manager , is_prefill ):
511- # record previous status
512- master_req = g_infer_context .requests_mapping [convert_sub_id_to_group_id (self .req_ids_group [0 ])]
513- new_kv_len = master_req .get_chuncked_input_token_len ()
514-
515- # update the InferReq status and mem_manager status for cache sharing
516- for req_id in self .req_ids_group [:]:
517- if req_id == convert_sub_id_to_group_id (req_id ):
518- continue
519- req = g_infer_context .requests_mapping [req_id ]
520- req .finish_status .set_status (FinishStatus .NO_FINISH )
521- assert req .cur_kv_len <= master_req .cur_kv_len
522- copy_token_index = req_manager .req_to_token_indexs [master_req .req_idx ][req .cur_kv_len : new_kv_len ]
523-
524- req_manager .req_to_token_indexs [req .req_idx ][req .cur_kv_len : new_kv_len ] = copy_token_index
525- req .cur_kv_len = master_req .cur_kv_len
526-
527-
528512class InferReqUpdatePack :
529513 """
530514 用于延迟InferReq的请求更新,主要是为了方便更高效的overlap机制实现。解耦
0 commit comments