Skip to content

Commit a5f188f

Browse files
SangChengCwangzaijun
andauthored
[add] add skip image cache and disable_prompt_cache para (#1061)
Co-authored-by: wangzaijun <[email protected]>
1 parent 26ea376 commit a5f188f

File tree

6 files changed

+23
-12
lines changed

6 files changed

+23
-12
lines changed

lightllm/server/audioserver/manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ async def loop_for_fwd(self):
8686
while len(self.waiting_reqs) > 0:
8787
group_req_indexes = self.waiting_reqs.pop(0)
8888
shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0])
89+
disable_prompt_cache = shm_req.sample_params.disable_prompt_cache
8990
is_aborted = shm_req.is_aborted
9091
self.shm_req_manager.put_back_req_obj(shm_req)
9192
if is_aborted:
@@ -98,7 +99,11 @@ async def loop_for_fwd(self):
9899
multimodal_params = group_req_indexes.multimodal_params
99100

100101
audio_uuids = [audio.uuid for audio in multimodal_params.audios]
101-
ready_audio = obtain(self.cache_client.root.get_items_embed(audio_uuids))
102+
# disable prompt cache通常用来测试,需要也去掉audio cache的影响
103+
if disable_prompt_cache:
104+
ready_audio = [False] * len(audio_uuids)
105+
else:
106+
ready_audio = obtain(self.cache_client.root.get_items_embed(audio_uuids))
102107

103108
for audio, ready in zip(multimodal_params.audios, ready_audio):
104109
if not ready:

lightllm/server/core/objs/sampling_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ class SamplingParams(ctypes.Structure):
320320
ctypes.c_bool,
321321
), # whether to add spaces between special tokens when decoding
322322
("print_eos_token", ctypes.c_bool), # eos_id will be always ignored except the value is set to True
323+
("disable_prompt_cache", ctypes.c_bool), # whether to disable prompt cache
323324
]
324325

325326
_do_sample: bool = False
@@ -350,6 +351,7 @@ def init(self, tokenizer, **kwargs):
350351
self.suggested_dp_index = kwargs.get("suggested_dp_index", -1)
351352

352353
self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS)
354+
self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False)
353355

354356
self.add_special_tokens = kwargs.get("add_special_tokens", True)
355357
self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True)
@@ -494,6 +496,7 @@ def to_dict(self):
494496
"add_special_tokens": self.add_special_tokens,
495497
"add_spaces_between_special_tokens": self.add_spaces_between_special_tokens,
496498
"print_eos_token": self.print_eos_token,
499+
"disable_prompt_cache": self.disable_prompt_cache,
497500
}
498501

499502
def to_origin_dict(self):

lightllm/server/httpserver_for_pd_master/manager.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,6 @@ async def _log_req_header(self, request: Request, group_request_id: int):
144144
)
145145
return
146146

147-
async def _to_req_info(
148-
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
149-
):
150-
req = {
151-
"inputs": prompt,
152-
"parameters": sampling_params.to_origin_dict(),
153-
"multimodal_params": multimodal_params.to_origin_dict(),
154-
}
155-
return req
156-
157147
async def fetch_stream(
158148
self,
159149
p_node: PD_Client_Obj,
@@ -323,6 +313,9 @@ async def _wait_to_token_package(
323313
multimodal_params: MultimodalParams,
324314
request: Request,
325315
):
316+
if sampling_params.disable_prompt_cache:
317+
assert False, "pd mode dont support set disable_prompt_cache to True"
318+
326319
out_token_counter = 0
327320
first_token_cost_ms = float("inf")
328321
group_request_id = sampling_params.group_request_id

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def __init__(
238238
vocab_size: int,
239239
) -> None:
240240
self.shm_param = shm_req.sample_params
241+
self.disable_prompt_cache = self.shm_param.disable_prompt_cache
241242
if self.shm_param.top_k == -1:
242243
self.shm_param.top_k = vocab_size
243244

@@ -358,6 +359,8 @@ def _init_all_state(self):
358359
return
359360

360361
def _match_radix_cache(self):
362+
if self.sampling_param.disable_prompt_cache:
363+
return
361364
if g_infer_context.radix_cache is not None and self.get_cur_total_len() > 1 and self.cur_kv_len == 0:
362365
input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()]
363366
key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu")

lightllm/server/visualserver/manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ async def loop_for_fwd(self):
113113
group_req_indexes = self.waiting_reqs.pop(0)
114114
shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0])
115115
is_aborted = shm_req.is_aborted
116+
disable_prompt_cache = shm_req.sample_params.disable_prompt_cache
116117
self.shm_req_manager.put_back_req_obj(shm_req)
117118
if is_aborted:
118119
# 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理
@@ -124,7 +125,11 @@ async def loop_for_fwd(self):
124125
multimodal_params = group_req_indexes.multimodal_params
125126

126127
img_uuids = [img.uuid for img in multimodal_params.images]
127-
ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids))
128+
# disable prompt cache通常用来测试,需要也去掉image cache的影响
129+
if disable_prompt_cache:
130+
ready_image = [False] * len(img_uuids)
131+
else:
132+
ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids))
128133

129134
for img, ready in zip(multimodal_params.images, ready_image):
130135
if not ready:

lightllm/utils/shm_size_check.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def _get_recommended_shm_size_gb(args, max_image_resolution=(3940, 2160), dtype_
117117
)
118118
fake_image_item.image_w = fake_image_item._data[0]
119119
fake_image_item.image_h = fake_image_item._data[1]
120+
# for internvl model shm check
121+
fake_image_item.extra_params["image_patch_max_num"] = 12
120122
max_image_tokens = tokenizer.get_image_token_length(fake_image_item)
121123

122124
# 估算图片 token 所需的资源

0 commit comments

Comments
 (0)