Skip to content

Commit 2aae9f2

Browse files
authored
fix
1 parent 63bd6cd commit 2aae9f2

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,14 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis
106106
self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
107107
req.shared_kv_node = None
108108

109-
# save prompt cache kv buffer
110-
def save_promptcache_kvbuffer(self):
109+
def _save_promptcache_kvbuffer(self):
110+
"""
111+
save prompt cache kv buffer
112+
这个接口是用于保存非量化的缓存prompt cache资源,是定制场景使用的接口,当前代码中不会有调用。
113+
其保存的 kv 会配合量化推理模式, 加载到量化推理的prompt cache中, 提升量化推理的精度。
114+
like paper:
115+
https://arxiv.org/abs/2403.01241
116+
"""
111117
prompt_cache_token_id = list(self.radix_cache.root_node.children.values())[0].token_id_key
112118
print(f"prompt_cache_token_id : {prompt_cache_token_id}")
113119
if isinstance(self.radix_cache.mem_manager.kv_buffer, list):

lightllm/server/router/req_queue/base_queue.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ def __init__(self, args, router, dp_index, dp_size) -> None:
1414
from lightllm.server.router.manager import RouterManager
1515

1616
self.router: RouterManager = router
17+
# max_total_token_num - get_fixed_kv_len() 是为了减去被特定
18+
# 推理模式预先占用了部分token kv 资源,这会导致整体可用的kv 资源
19+
# 在极端情况下减少,在非特定模式下,get_fixed_kv_len() 返回的都是
20+
# 0, 不会有任何影响。
1721
self.max_total_tokens = args.max_total_token_num - get_fixed_kv_len()
1822
assert args.batch_max_tokens is not None
1923
self.batch_max_tokens = args.batch_max_tokens

lightllm/utils/config_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
from functools import lru_cache
44
from .envs_utils import get_env_start_args
5-
from transformers.configuration_utils import PretrainedConfig
65
from lightllm.utils.log_utils import init_logger
76

87
logger = init_logger(__name__)
@@ -54,7 +53,7 @@ def get_dtype(model_path: str):
5453
@lru_cache(maxsize=None)
5554
def get_fixed_kv_len():
5655
start_args = get_env_start_args()
57-
model_cfg, _ = PretrainedConfig.get_config_dict(start_args.model_dir)
56+
model_cfg = get_config_json(start_args.model_dir)
5857
if "prompt_cache_token_ids" in model_cfg:
5958
return len(model_cfg["prompt_cache_token_ids"])
6059
else:

0 commit comments

Comments
 (0)