Skip to content

Commit 1fe1574

Browse files
committed
[add]add disable_prompt_cache parameters
1 parent 943bba5 commit 1fe1574

File tree

5 files changed

+16
-5
lines changed

5 files changed

+16
-5
lines changed

lightllm/server/core/objs/sampling_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ class SamplingParams(ctypes.Structure):
307307
ctypes.c_bool,
308308
), # whether to add spaces between special tokens when decoding
309309
("print_eos_token", ctypes.c_bool), # eos_id will be always ignored except the value is set to True
310+
("disable_prompt_cache", ctypes.c_bool), # eos_id will be always ignored except the value is set to True
310311
]
311312

312313
_do_sample: bool = False
@@ -337,6 +338,7 @@ def init(self, tokenizer, **kwargs):
337338
self.suggested_dp_index = kwargs.get("suggested_dp_index", -1)
338339

339340
self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS)
341+
self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False)
340342

341343
self.add_special_tokens = kwargs.get("add_special_tokens", True)
342344
self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True)
@@ -477,6 +479,7 @@ def to_dict(self):
477479
"add_special_tokens": self.add_special_tokens,
478480
"add_spaces_between_special_tokens": self.add_spaces_between_special_tokens,
479481
"print_eos_token": self.print_eos_token,
482+
"disable_prompt_cache": self.disable_prompt_cache,
480483
}
481484

482485
def to_origin_dict(self):

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
8686
return req_objs
8787

8888
def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finished: bool):
89-
if self.radix_cache is None:
89+
if self.radix_cache is None or req.sampling_param.disable_prompt_cache:
9090
if is_group_finished:
9191
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len])
9292
else:
@@ -236,6 +236,7 @@ def __init__(
236236
vocab_size: int,
237237
) -> None:
238238
self.shm_param = shm_req.sample_params
239+
self.disable_prompt_cache = self.shm_param.disable_prompt_cache
239240
if self.shm_param.top_k == -1:
240241
self.shm_param.top_k = vocab_size
241242

@@ -308,7 +309,7 @@ def __init__(
308309
self.mtp_step: int = get_env_start_args().mtp_step
309310

310311
self._init_all_state()
311-
if init_prefix_cache:
312+
if init_prefix_cache and not self.sampling_param.disable_prompt_cache:
312313
self._match_radix_cache()
313314
return
314315

@@ -335,6 +336,11 @@ def _init_all_state(self):
335336
return
336337

337338
def _match_radix_cache(self):
339+
if self.sampling_param.disable_prompt_cache:
340+
self.shared_kv_node = None
341+
self.shm_req.prompt_cache_len = 0
342+
self.shm_req.shm_cur_kv_len = self.cur_kv_len
343+
return
338344
if g_infer_context.radix_cache is not None and self.get_cur_total_len() > 1 and self.cur_kv_len == 0:
339345
input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()]
340346
key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu")

lightllm/server/visualserver/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,12 @@ async def loop_for_fwd(self):
123123

124124
multimodal_params = group_req_indexes.multimodal_params
125125

126-
img_uuids = list(dict.fromkeys(img.uuid for img in multimodal_params.images))
126+
img_uuids = [img.uuid for img in multimodal_params.images]
127127
if multimodal_params.skip_image_cache:
128128
ready_image = [False] * len(img_uuids)
129129
else:
130130
ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids))
131-
for img, ready in zip(img_uuids, ready_image):
131+
for img, ready in zip(multimodal_params.images, ready_image):
132132
if not ready:
133133
images_need_infer.append(img)
134134

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,9 @@ def exposed_encode(self, images: List[ImageItem]):
100100
all_img_embeds = all_img_embeds.to(torch.device("cpu"))
101101

102102
if self.tp_rank_id == 0:
103+
ready_flags = obtain(self.cache_client.root.get_items_embed(uuids))
103104
ids_to_set = []
104-
for i, ready in uuids:
105+
for i, ready in enumerate(ready_flags):
105106
if ready:
106107
continue
107108
uid = uuids[i]

lightllm/utils/shm_size_check.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def _get_recommended_shm_size_gb(args, max_image_resolution=(3940, 2160), dtype_
117117
)
118118
fake_image_item.image_w = fake_image_item._data[0]
119119
fake_image_item.image_h = fake_image_item._data[1]
120+
fake_image_item.extra_params["image_patch_max_num"] = 12
120121
max_image_tokens = tokenizer.get_image_token_length(fake_image_item)
121122

122123
# 估算图片 token 所需的资源

0 commit comments

Comments
 (0)