Skip to content

Commit b783ebf

Browse files
committed
feat: optimize hybrid radix cache buffer insertion strategy
Add mamba_model_match_len based optimization for buffer insertion: - Only insert buffer at actual branch points instead of fixed intervals - Use threshold (chunked_prefill_size // 2) to decide strategy - Reduce buffer storage overhead while maintaining cache hit rate
1 parent fb6f960 commit b783ebf

File tree

2 files changed

+55
-6
lines changed

2 files changed

+55
-6
lines changed

lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,38 @@ def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_toke
8080
self.evict_tree_set.add(parent_node)
8181
return
8282

83+
def _should_insert_buffer(self, req) -> bool:
84+
"""决定是否需要在当前位置插入 buffer"""
85+
# 情况1:prefill 完成(即将进入 decode),必须插入
86+
if req.cur_kv_len >= req.get_cur_total_len():
87+
return True
88+
89+
# 情况2:使用优化策略时
90+
if req.use_mamba_match_len_strategy:
91+
# 只在 mamba_model_match_len 位置插入
92+
if req.cur_kv_len == req.mamba_model_match_len and not req.mamba_buffer_inserted:
93+
return True
94+
return False
95+
96+
# 情况3:原策略(每个 chunk 后都插入)
97+
return True
98+
8399
def insert_for_hybrid_radix_cache(self, reqs):
84100
from lightllm.server.router.model_infer.infer_batch import g_infer_context
85101

86-
self.free_radix_cache_to_get_enough_buffer(len(reqs))
87-
new_buffer_indexes = self.mem_manager.alloc_buffer(len(reqs))
88-
# req_ids_gpu = req_ids.cuda()
102+
# 过滤需要插入的请求
103+
reqs_to_insert = []
104+
for req in reqs:
105+
if self._should_insert_buffer(req):
106+
reqs_to_insert.append(req)
107+
108+
if len(reqs_to_insert) == 0:
109+
return
89110

90-
for i, req in enumerate(reqs):
111+
self.free_radix_cache_to_get_enough_buffer(len(reqs_to_insert))
112+
new_buffer_indexes = self.mem_manager.alloc_buffer(len(reqs_to_insert))
113+
114+
for i, req in enumerate(reqs_to_insert):
91115
input_token_ids = req.get_input_token_ids()
92116
key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu")
93117
value = g_infer_context.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu()
@@ -105,6 +129,10 @@ def insert_for_hybrid_radix_cache(self, reqs):
105129
# free_a_req_mem 中会释放 [prompt_cache_len:prefix_len],更新后这个范围为空
106130
req.shm_req.prompt_cache_len = req.cur_kv_len
107131

132+
# 标记已在 mamba_model_match_len 位置插入
133+
if req.cur_kv_len == req.mamba_model_match_len:
134+
req.mamba_buffer_inserted = True
135+
108136
def match_prefix(self, key, update_refs=False):
109137
assert len(key) != 0
110138
self.match_count = (self.match_count + 1) % self.log_interval

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,10 @@ def __init__(
409409

410410
# 在开启radix cache的情况下,用于标记命中情况,用于插入算法
411411
self.mamba_model_match_len = 0
412+
# 是否使用基于 mamba_model_match_len 的优化策略
413+
self.use_mamba_match_len_strategy = False
414+
# 是否已在 mamba_model_match_len 位置插入 buffer
415+
self.mamba_buffer_inserted = False
412416

413417
# 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache
414418
# 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态
@@ -471,6 +475,12 @@ def _match_radix_cache(self):
471475
self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换
472476
self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度
473477

478+
# 判断是否使用基于 mamba_model_match_len 的优化策略
479+
# 当需要重新计算的增量足够大时,值得单独在分支点保存 buffer
480+
increment = self.mamba_model_match_len - ready_cache_len
481+
threshold = self.shm_req.chunked_prefill_size // 2
482+
self.use_mamba_match_len_strategy = increment >= threshold
483+
474484
self.shm_req.shm_cur_kv_len = self.cur_kv_len
475485
return
476486

@@ -518,13 +528,24 @@ def get_input_token_ids(self):
518528
return self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()]
519529

520530
def get_chuncked_input_token_ids(self):
521-
chunked_start = self.cur_kv_len
522-
chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size)
531+
# 复用 get_chuncked_input_token_len 的逻辑,保持一致性
532+
chunked_end = self.get_chuncked_input_token_len()
523533
return self.shm_req.shm_prompt_ids.arr[0:chunked_end]
524534

525535
def get_chuncked_input_token_len(self):
526536
chunked_start = self.cur_kv_len
527537
chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size)
538+
539+
# 优化策略:第一个 chunk 直接到 mamba_model_match_len(分支点)
540+
# 这样可以在分支点位置保存 buffer,提升后续请求的缓存命中率
541+
if (
542+
self.use_mamba_match_len_strategy
543+
and not self.mamba_buffer_inserted
544+
and self.mamba_model_match_len > chunked_start
545+
and self.mamba_model_match_len <= self.get_cur_total_len()
546+
):
547+
chunked_end = self.mamba_model_match_len
548+
528549
return chunked_end
529550

530551
def set_next_gen_token_id(self, next_token_id: int, logprob: float, output_len: int):

0 commit comments

Comments
 (0)