File tree Expand file tree Collapse file tree 6 files changed +15
-9
lines changed
Expand file tree Collapse file tree 6 files changed +15
-9
lines changed Original file line number Diff line number Diff line change @@ -86,6 +86,7 @@ async def loop_for_fwd(self):
8686 while len (self .waiting_reqs ) > 0 :
8787 group_req_indexes = self .waiting_reqs .pop (0 )
8888 shm_req = self .shm_req_manager .get_req_obj_by_index (group_req_indexes .shm_req_indexes [0 ])
89+ disable_prompt_cache = shm_req .sample_params .disable_prompt_cache
8990 is_aborted = shm_req .is_aborted
9091 self .shm_req_manager .put_back_req_obj (shm_req )
9192 if is_aborted :
@@ -98,7 +99,11 @@ async def loop_for_fwd(self):
9899 multimodal_params = group_req_indexes .multimodal_params
99100
100101 audio_uuids = [audio .uuid for audio in multimodal_params .audios ]
101- ready_audio = obtain (self .cache_client .root .get_items_embed (audio_uuids ))
102+ # disable prompt cache通常用来测试,需要也去掉audio cache的影响
103+ if disable_prompt_cache :
104+ ready_audio = [False ] * len (audio_uuids )
105+ else :
106+ ready_audio = obtain (self .cache_client .root .get_items_embed (audio_uuids ))
102107
103108 for audio , ready in zip (multimodal_params .audios , ready_audio ):
104109 if not ready :
Original file line number Diff line number Diff line change @@ -307,7 +307,7 @@ class SamplingParams(ctypes.Structure):
307307 ctypes .c_bool ,
308308 ), # whether to add spaces between special tokens when decoding
309309 ("print_eos_token" , ctypes .c_bool ), # eos_id will be always ignored except the value is set to True
310- ("disable_prompt_cache" , ctypes .c_bool ), # eos_id will be always ignored except the value is set to True
310+ ("disable_prompt_cache" , ctypes .c_bool ), # whether to disable prompt cache
311311 ]
312312
313313 _do_sample : bool = False
Original file line number Diff line number Diff line change @@ -153,6 +153,9 @@ async def fetch_stream(
153153 ):
154154 group_request_id = sampling_params .group_request_id
155155
156+ # PD分离模式下,use prompt cache必须为True
157+ sampling_params .disable_prompt_cache = False
158+
156159 req_status = ReqStatus (group_request_id , p_node , d_node )
157160 self .req_id_to_out_inf [group_request_id ] = req_status
158161
Original file line number Diff line number Diff line change @@ -137,11 +137,9 @@ def __init__(
137137 self ,
138138 images : List [dict ] = [],
139139 audios : List [dict ] = [],
140- skip_image_cache : bool = False ,
141140 ) -> None :
142141 self .images = [ImageItem (** i ) for i in images ]
143142 self .audios = [AudioItem (** a ) for a in audios ]
144- self .skip_image_cache = skip_image_cache
145143 return
146144
147145 async def verify_and_preload (self , request : Request ):
Original file line number Diff line number Diff line change @@ -309,7 +309,7 @@ def __init__(
309309 self .mtp_step : int = get_env_start_args ().mtp_step
310310
311311 self ._init_all_state ()
312- if init_prefix_cache and not self . sampling_param . disable_prompt_cache :
312+ if init_prefix_cache :
313313 self ._match_radix_cache ()
314314 return
315315
@@ -337,9 +337,6 @@ def _init_all_state(self):
337337
338338 def _match_radix_cache (self ):
339339 if self .sampling_param .disable_prompt_cache :
340- self .shared_kv_node = None
341- self .shm_req .prompt_cache_len = 0
342- self .shm_req .shm_cur_kv_len = self .cur_kv_len
343340 return
344341 if g_infer_context .radix_cache is not None and self .get_cur_total_len () > 1 and self .cur_kv_len == 0 :
345342 input_token_ids = self .shm_req .shm_prompt_ids .arr [0 : self .get_cur_total_len ()]
Original file line number Diff line number Diff line change @@ -113,6 +113,7 @@ async def loop_for_fwd(self):
113113 group_req_indexes = self .waiting_reqs .pop (0 )
114114 shm_req = self .shm_req_manager .get_req_obj_by_index (group_req_indexes .shm_req_indexes [0 ])
115115 is_aborted = shm_req .is_aborted
116+ disable_prompt_cache = shm_req .sample_params .disable_prompt_cache
116117 self .shm_req_manager .put_back_req_obj (shm_req )
117118 if is_aborted :
118119 # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理
@@ -124,10 +125,12 @@ async def loop_for_fwd(self):
124125 multimodal_params = group_req_indexes .multimodal_params
125126
126127 img_uuids = [img .uuid for img in multimodal_params .images ]
127- if multimodal_params .skip_image_cache :
128+ # disable prompt cache通常用来测试,需要也去掉image cache的影响
129+ if disable_prompt_cache :
128130 ready_image = [False ] * len (img_uuids )
129131 else :
130132 ready_image = obtain (self .cache_client .root .get_items_embed (img_uuids ))
133+
131134 for img , ready in zip (multimodal_params .images , ready_image ):
132135 if not ready :
133136 images_need_infer .append (img )
You can’t perform that action at this time.
0 commit comments