Skip to content

Commit 2134768

Browse files
committed
support prefix noquant
1 parent ca7d2f9 commit 2134768

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,23 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis
106106
self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
107107
req.shared_kv_node = None
108108

109+
# # save prompt cache kv buffer
110+
# prompt_cache_token_id = list(self.radix_cache.root_node.children.values())[0].token_id_key
111+
# print(f"prompt_cache_token_id : {prompt_cache_token_id}")
112+
# if isinstance(self.radix_cache.mem_manager.kv_buffer, list):
113+
# kv_buffer_list = []
114+
# for i in range(len(self.radix_cache.mem_manager.kv_buffer)):
115+
# kv_buffer_list.append(self.radix_cache.mem_manager.kv_buffer[i][:len(prompt_cache_token_id)])
116+
# torch.save(
117+
# kv_buffer_list,
118+
# f"prompt_cache_rank_{dist.get_rank()}.pt"
119+
# )
120+
# else:
121+
# torch.save(
122+
# self.radix_cache.mem_manager.kv_buffer[:, :len(prompt_cache_token_id)],
123+
# f"prompt_cache_rank_{dist.get_rank()}.pt"
124+
# )
125+
109126
@torch.no_grad()
110127
def filter(self, finished_request_ids: List[int]):
111128
if len(finished_request_ids) == 0:

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ def init_model(self, kvargs):
212212
else None
213213
)
214214

215+
if "prompt_cache_kv_buffer" in model_cfg:
216+
self.preload_prompt_cache_kv_buffer(model_cfg)
217+
215218
self.logger.info(f"loaded model class {self.model.__class__}")
216219
self.init_custom()
217220

@@ -256,3 +259,28 @@ def _init_reqs(self, reqs: List[Tuple], init_req_obj=True):
256259
g_infer_state_lock.release()
257260
req_ids = [e[0] for e in reqs]
258261
return req_ids
262+
263+
def preload_prompt_cache_kv_buffer(self, model_cfg):
264+
self.logger.info("Preload prompt cache kv buffer.")
265+
cur_rank = dist.get_rank()
266+
prompt_cache_kv_buffer_path = os.path.join(
267+
self.weight_dir, model_cfg["prompt_cache_kv_buffer"][f"rank_{cur_rank}"]
268+
)
269+
prompt_cache_kv_buffer = torch.load(prompt_cache_kv_buffer_path, weights_only=True, map_location="cpu")
270+
if isinstance(self.radix_cache.mem_manager.kv_buffer, list):
271+
for i in range(len(self.radix_cache.mem_manager.kv_buffer)):
272+
self.radix_cache.mem_manager.kv_buffer[i][: len(model_cfg["prompt_cache_token_ids"])].copy_(
273+
prompt_cache_kv_buffer[i]
274+
)
275+
else:
276+
self.radix_cache.mem_manager.kv_buffer[:, : len(model_cfg["prompt_cache_token_ids"])].copy_(
277+
prompt_cache_kv_buffer
278+
)
279+
self.radix_cache.insert(
280+
torch.tensor(model_cfg["prompt_cache_token_ids"], dtype=torch.int64, device="cpu"),
281+
torch.tensor(range(len(model_cfg["prompt_cache_token_ids"])), dtype=torch.int32, device="cpu"),
282+
)
283+
self.radix_cache.mem_manager.mem_state[: len(model_cfg["prompt_cache_token_ids"])] = 1
284+
self.radix_cache.match_prefix(
285+
torch.tensor(model_cfg["prompt_cache_token_ids"], dtype=torch.int64, device="cpu"), update_refs=True
286+
)

0 commit comments

Comments
 (0)