Skip to content

Commit 1980781

Browse files
support prefix noquant (#722)
Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
1 parent c181e7a commit 1980781

File tree

4 files changed

+65
-1
lines changed

4 files changed

+65
-1
lines changed

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,27 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis
106106
self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
107107
req.shared_kv_node = None
108108

109+
def _save_promptcache_kvbuffer(self):
110+
"""
111+
save prompt cache kv buffer
112+
这个接口是用于保存非量化的缓存prompt cache资源,是定制场景使用的接口,当前代码中不会有调用。
113+
其保存的 kv 会配合量化推理模式, 加载到量化推理的prompt cache中, 提升量化推理的精度。
114+
like paper:
115+
https://arxiv.org/abs/2403.01241
116+
"""
117+
prompt_cache_token_id = list(self.radix_cache.root_node.children.values())[0].token_id_key
118+
print(f"prompt_cache_token_id : {prompt_cache_token_id}")
119+
if isinstance(self.radix_cache.mem_manager.kv_buffer, list):
120+
kv_buffer_list = []
121+
for i in range(len(self.radix_cache.mem_manager.kv_buffer)):
122+
kv_buffer_list.append(self.radix_cache.mem_manager.kv_buffer[i][: len(prompt_cache_token_id)])
123+
torch.save(kv_buffer_list, f"prompt_cache_rank_{dist.get_rank()}.pt")
124+
else:
125+
torch.save(
126+
self.radix_cache.mem_manager.kv_buffer[:, : len(prompt_cache_token_id)],
127+
f"prompt_cache_rank_{dist.get_rank()}.pt",
128+
)
129+
109130
@torch.no_grad()
110131
def filter(self, finished_request_ids: List[int]):
111132
if len(finished_request_ids) == 0:

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ def init_model(self, kvargs):
212212
else None
213213
)
214214

215+
if "prompt_cache_kv_buffer" in model_cfg:
216+
assert self.use_dynamic_prompt_cache
217+
self.preload_prompt_cache_kv_buffer(model_cfg)
218+
215219
self.logger.info(f"loaded model class {self.model.__class__}")
216220
self.init_custom()
217221

@@ -256,3 +260,25 @@ def _init_reqs(self, reqs: List[Tuple], init_req_obj=True):
256260
g_infer_state_lock.release()
257261
req_ids = [e[0] for e in reqs]
258262
return req_ids
263+
264+
def preload_prompt_cache_kv_buffer(self, model_cfg):
265+
self.logger.info("Preload prompt cache kv buffer.")
266+
cur_rank = dist.get_rank()
267+
prompt_cache_kv_buffer_path = os.path.join(
268+
self.weight_dir, model_cfg["prompt_cache_kv_buffer"][f"rank_{cur_rank}"]
269+
)
270+
prompt_cache_kv_buffer = torch.load(prompt_cache_kv_buffer_path, weights_only=True, map_location="cpu")
271+
intact_kv_len = len(model_cfg["prompt_cache_token_ids"])
272+
intact_kv_index = self.radix_cache.mem_manager.alloc(intact_kv_len)
273+
if isinstance(self.radix_cache.mem_manager.kv_buffer, list):
274+
for i in range(len(self.radix_cache.mem_manager.kv_buffer)):
275+
self.radix_cache.mem_manager.kv_buffer[i][intact_kv_index].copy_(prompt_cache_kv_buffer[i])
276+
else:
277+
self.radix_cache.mem_manager.kv_buffer[:, intact_kv_index].copy_(prompt_cache_kv_buffer)
278+
self.radix_cache.insert(
279+
torch.tensor(model_cfg["prompt_cache_token_ids"], dtype=torch.int64, device="cpu"),
280+
intact_kv_index,
281+
)
282+
self.radix_cache.match_prefix(
283+
torch.tensor(model_cfg["prompt_cache_token_ids"], dtype=torch.int64, device="cpu"), update_refs=True
284+
)

lightllm/server/router/req_queue/base_queue.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ..batch import Batch, Req
44
from lightllm.server.core.objs import FinishStatus
55
from lightllm.common.basemodel.infer_lock import g_router_lock
6+
from lightllm.utils.config_utils import get_fixed_kv_len
67

78

89
class BaseQueue:
@@ -13,7 +14,11 @@ def __init__(self, args, router, dp_index, dp_size) -> None:
1314
from lightllm.server.router.manager import RouterManager
1415

1516
self.router: RouterManager = router
16-
self.max_total_tokens = args.max_total_token_num
17+
# max_total_token_num - get_fixed_kv_len() 是为了减去被特定
18+
# 推理模式预先占用了部分token kv 资源,这会导致整体可用的kv 资源
19+
# 在极端情况下减少,在非特定模式下,get_fixed_kv_len() 返回的都是
20+
# 0, 不会有任何影响。
21+
self.max_total_tokens = args.max_total_token_num - get_fixed_kv_len()
1722
assert args.batch_max_tokens is not None
1823
self.batch_max_tokens = args.batch_max_tokens
1924
self.running_max_req_size = args.running_max_req_size # 最大并非请求数量

lightllm/utils/config_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import json
22
import os
3+
from functools import lru_cache
4+
from .envs_utils import get_env_start_args
35
from lightllm.utils.log_utils import init_logger
46

57
logger = init_logger(__name__)
@@ -46,3 +48,13 @@ def get_dtype(model_path: str):
4648
except:
4749
logger.warning("torch_dtype not in config.json, use float16 as default")
4850
return "float16"
51+
52+
53+
@lru_cache(maxsize=None)
54+
def get_fixed_kv_len():
55+
start_args = get_env_start_args()
56+
model_cfg = get_config_json(start_args.model_dir)
57+
if "prompt_cache_token_ids" in model_cfg:
58+
return len(model_cfg["prompt_cache_token_ids"])
59+
else:
60+
return 0

0 commit comments

Comments
 (0)