Skip to content

Commit ea8673e

Browse files
committed
Support preload prompt cache kv buffer
1 parent c757cf5 commit ea8673e

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,10 @@ def free_self(self):
375375
if len(group_mapping) == 0:
376376
group_mapping.clear()
377377

378+
# # save prompt cache kv buffer
379+
# print(f"prompt_cache_token_id : {list(self.radix_cache.root_node.children.values())[0].token_id_key}")
380+
# torch.save(self.radix_cache.mem_manager.kv_buffer, f"prompt_cache_rank_{dist.get_rank()}.pt")
381+
378382
if self.radix_cache is not None:
379383
logger.debug(
380384
f"free a batch state:\n"

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ def init_model(self, kvargs):
213213
else None
214214
)
215215

216+
if "prompt_cache_kv_buffer" in model_cfg:
217+
self.preload_prompt_cache_kv_buffer(model_cfg)
218+
216219
self.logger.info(f"loaded model class {self.model.__class__}")
217220
self.init_custom()
218221

@@ -313,3 +316,21 @@ def remove_batch(self, batch_id):
313316
del batch
314317
g_infer_state_lock.release()
315318
return
319+
320+
def preload_prompt_cache_kv_buffer(self, model_cfg):
321+
self.logger.info("Preload prompt cache kv buffer.")
322+
cur_rank = dist.get_rank()
323+
prompt_cache_kv_buffer_path = os.path.join(
324+
self.weight_dir, model_cfg["prompt_cache_kv_buffer"][f"rank_{cur_rank}"]
325+
)
326+
prompt_cache_kv_buffer = torch.load(prompt_cache_kv_buffer_path, weights_only=True, map_location="cpu")
327+
if isinstance(self.radix_cache.mem_manager.kv_buffer, list):
328+
for i in range(len(self.radix_cache.mem_manager.kv_buffer)):
329+
self.radix_cache.mem_manager.kv_buffer[i].copy_(prompt_cache_kv_buffer[i])
330+
else:
331+
self.radix_cache.mem_manager.kv_buffer.copy_(prompt_cache_kv_buffer)
332+
self.radix_cache.insert(
333+
torch.tensor(model_cfg["prompt_cache_token_ids"], dtype=torch.int64, device="cpu"),
334+
torch.tensor(range(len(model_cfg["prompt_cache_token_ids"])), dtype=torch.int32, device="cpu"),
335+
)
336+
self.radix_cache.mem_manager.mem_state[: len(model_cfg["prompt_cache_token_ids"])] = 1

0 commit comments

Comments
 (0)