Skip to content

Commit 13b0c83

Browse files
committed
[add]add disable_prompt_cache
1 parent b852070 commit 13b0c83

File tree

6 files changed

+15
-9
lines changed

6 files changed

+15
-9
lines changed

lightllm/server/audioserver/manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ async def loop_for_fwd(self):
8686
while len(self.waiting_reqs) > 0:
8787
group_req_indexes = self.waiting_reqs.pop(0)
8888
shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0])
89+
disable_prompt_cache = shm_req.sample_params.disable_prompt_cache
8990
is_aborted = shm_req.is_aborted
9091
self.shm_req_manager.put_back_req_obj(shm_req)
9192
if is_aborted:
@@ -98,7 +99,11 @@ async def loop_for_fwd(self):
9899
multimodal_params = group_req_indexes.multimodal_params
99100

100101
audio_uuids = [audio.uuid for audio in multimodal_params.audios]
101-
ready_audio = obtain(self.cache_client.root.get_items_embed(audio_uuids))
102+
# disable prompt cache通常用来测试,需要也去掉audio cache的影响
103+
if disable_prompt_cache:
104+
ready_audio = [False] * len(audio_uuids)
105+
else:
106+
ready_audio = obtain(self.cache_client.root.get_items_embed(audio_uuids))
102107

103108
for audio, ready in zip(multimodal_params.audios, ready_audio):
104109
if not ready:

lightllm/server/core/objs/sampling_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ class SamplingParams(ctypes.Structure):
307307
ctypes.c_bool,
308308
), # whether to add spaces between special tokens when decoding
309309
("print_eos_token", ctypes.c_bool), # eos_id will be always ignored except the value is set to True
310-
("disable_prompt_cache", ctypes.c_bool), # eos_id will be always ignored except the value is set to True
310+
("disable_prompt_cache", ctypes.c_bool), # whether to disable prompt cache
311311
]
312312

313313
_do_sample: bool = False

lightllm/server/httpserver_for_pd_master/manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ async def fetch_stream(
153153
):
154154
group_request_id = sampling_params.group_request_id
155155

156+
# PD分离模式下,use prompt cache必须为True
157+
sampling_params.disable_prompt_cache = False
158+
156159
req_status = ReqStatus(group_request_id, p_node, d_node)
157160
self.req_id_to_out_inf[group_request_id] = req_status
158161

lightllm/server/multimodal_params.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,9 @@ def __init__(
137137
self,
138138
images: List[dict] = [],
139139
audios: List[dict] = [],
140-
skip_image_cache: bool = False,
141140
) -> None:
142141
self.images = [ImageItem(**i) for i in images]
143142
self.audios = [AudioItem(**a) for a in audios]
144-
self.skip_image_cache = skip_image_cache
145143
return
146144

147145
async def verify_and_preload(self, request: Request):

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def __init__(
309309
self.mtp_step: int = get_env_start_args().mtp_step
310310

311311
self._init_all_state()
312-
if init_prefix_cache and not self.sampling_param.disable_prompt_cache:
312+
if init_prefix_cache:
313313
self._match_radix_cache()
314314
return
315315

@@ -337,9 +337,6 @@ def _init_all_state(self):
337337

338338
def _match_radix_cache(self):
339339
if self.sampling_param.disable_prompt_cache:
340-
self.shared_kv_node = None
341-
self.shm_req.prompt_cache_len = 0
342-
self.shm_req.shm_cur_kv_len = self.cur_kv_len
343340
return
344341
if g_infer_context.radix_cache is not None and self.get_cur_total_len() > 1 and self.cur_kv_len == 0:
345342
input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()]

lightllm/server/visualserver/manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ async def loop_for_fwd(self):
113113
group_req_indexes = self.waiting_reqs.pop(0)
114114
shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0])
115115
is_aborted = shm_req.is_aborted
116+
disable_prompt_cache = shm_req.sample_params.disable_prompt_cache
116117
self.shm_req_manager.put_back_req_obj(shm_req)
117118
if is_aborted:
118119
# 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理
@@ -124,10 +125,12 @@ async def loop_for_fwd(self):
124125
multimodal_params = group_req_indexes.multimodal_params
125126

126127
img_uuids = [img.uuid for img in multimodal_params.images]
127-
if multimodal_params.skip_image_cache:
128+
# disable prompt cache通常用来测试,需要也去掉image cache的影响
129+
if disable_prompt_cache:
128130
ready_image = [False] * len(img_uuids)
129131
else:
130132
ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids))
133+
131134
for img, ready in zip(multimodal_params.images, ready_image):
132135
if not ready:
133136
images_need_infer.append(img)

0 commit comments

Comments
 (0)