Skip to content

Commit 4f744f3

Browse files
author
niushengxiao
committed
refine3
1 parent 3641b5c commit 4f744f3

File tree

5 files changed

+114
-131
lines changed

5 files changed

+114
-131
lines changed

lightllm/server/multi_level_kv_cache/cpu_cache_client.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ def get_one_empty_page(self, hash_key: int, disk_offload_enable: bool) -> Option
5757
if cur_page.self_index == tail.self_index:
5858
return None
5959

60-
assert cur_page.is_empty() or cur_page.is_ready_recycle()
61-
assert cur_page.ref_count == 0
6260
if cur_page.can_realloc(disk_offload_enable=disk_offload_enable):
6361
page_index = cur_page.self_index
6462
cur_page.del_self_from_list()
@@ -130,8 +128,11 @@ def update_pages_status_to_ready(
130128
cur_page = page_items[page_index]
131129
if cur_page.status < _CpuPageStatus.READY:
132130
cur_page.status = _CpuPageStatus.READY
131+
132+
# 全部落盘,已落盘前缀部分会在落盘中自动剔除
133133
if disk_offload_enable:
134134
offload_candidates.append(cur_page.self_index)
135+
135136
if deref:
136137
assert cur_page.ref_count > 0
137138
cur_page.ref_count -= 1
@@ -202,13 +203,13 @@ def deref_pages(self, page_list: List[int]):
202203
for page_index in page_list:
203204
if page_index != -1:
204205
page_item = page_items[page_index]
205-
assert page_item.ref_count == 1
206+
assert page_item.ref_count > 0
206207
page_item.ref_count -= 1
207208
return
208209

209210
def deref_one_page(self, page_index: int):
210211
page_item: _CpuPageStatus = self.page_items.get_item_by_index(page_index)
211-
assert page_item.ref_count == 1
212+
assert page_item.ref_count > 0
212213
page_item.ref_count -= 1
213214
return
214215

@@ -220,7 +221,6 @@ def get_pages_to_offloading(self) -> List[List[int]]:
220221
if page_list is None:
221222
return groups
222223

223-
# 缓存常量和对象引用,减少属性访问
224224
page_items = self.page_items.linked_items
225225
for value in page_list:
226226
page_index, is_group_head = self._decode_offload_value(value)
@@ -243,10 +243,9 @@ def update_pages_status_to_ready_recycle(self, page_list: List[int], deref: bool
243243
for page_index in page_list:
244244
if page_index != -1:
245245
cur_page = page_items[page_index]
246-
assert cur_page.is_offloading()
247246
cur_page.status = _CpuPageStatus.READY_RECYCLE
248247
if deref:
249-
assert cur_page.ref_count == 1
248+
assert cur_page.ref_count > 0
250249
cur_page.ref_count -= 1
251250
return
252251

lightllm/server/multi_level_kv_cache/disk_cache_worker.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@ def __init__(
4242

4343
assert disk_cache_storage_size > 0
4444
storage_size = int(disk_cache_storage_size * (1024 ** 3))
45-
num_shard = 64
45+
# num_shard与KVCACHE_MAX_BLOCK_SIZE相关,KVCACHE_MAX_BLOCK_SIZE默认64MB前提下,
46+
# num_shard设置32, 能使disk cache的容量利用率达到90%,继续增大num_shard会导致容量利用率下降
47+
num_shard = 32
4648
num_worker = 48
49+
# 读写同时进行时,分配16线程用来写,32线程用来读
4750
max_concurrent_write_tasks = 16
4851

4952
cache_dir = disk_cache_dir
@@ -134,16 +137,24 @@ def _persist_pages_to_disk(self, payloads: List[_PagePayload]) -> None:
134137
self.cpu_cache_client.update_pages_status_to_ready_recycle(page_list=page_indexes, deref=True)
135138
self.cpu_cache_client.lock.release()
136139

137-
def blocks_exist(self, tokens: List[int], start_pos: int = 0) -> bool:
140+
def query_loadable_pages(self, tokens: List[int], start_pos: int) -> int:
141+
"""
142+
查询从start_pos位置开始,可以从disk cache加载的最长前缀长度
143+
Returns:
144+
loadable_len: 从start_pos开始可以加载的长度
145+
"""
138146
if not tokens or start_pos < 0 or start_pos >= len(tokens):
139-
return False
147+
return 0
140148

141149
query_result = self.service.query(tokens)
142-
block_start = start_pos // self.service._n
143-
block_end = math.ceil(len(tokens) / self.service._n)
144-
if block_start >= block_end:
145-
return False
146-
return all(query_result[block_start:block_end])
150+
n = self.service._n
151+
start_block = start_pos // n
152+
try:
153+
first_false_idx = start_block + query_result[start_block:].index(False)
154+
except ValueError:
155+
return len(tokens) - start_pos
156+
first_missing_pos = first_false_idx * n
157+
return max(0, first_missing_pos - start_pos)
147158

148159
# 从磁盘读取数据到内存
149160
def load_pages(self, tokens: List[int], page_indexes: List[int], start_pos: int = 0) -> bool:

lightllm/server/multi_level_kv_cache/manager.py

Lines changed: 85 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
logger.info(f"send_to_router sendhwm {self.send_to_router.getsockopt(zmq.SNDHWM)}")
3535
self.cpu_cache_client = CpuKvCacheClient(only_create_meta_data=False, init_shm_data=True)
3636
self.shm_req_manager = ShmReqManager()
37+
# 磁盘io在NVMe SSD上需要大量并发才能发挥性能
3738
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=500)
3839
# 控制进行 cpu cache 页面匹配的时间,超过时间则不再匹配,直接转发。
3940
self.cpu_cache_time_out = 0.5
@@ -60,21 +61,88 @@ def cpu_cache_hanle_loop(self):
6061
try:
6162
current_group_req = self.recv_queue.get()
6263

63-
self.executor.submit(self._handle_group_req_cpu_cache_match, current_group_req, time.time())
64+
self.executor.submit(self._handle_group_req_multi_cache_match, current_group_req, time.time())
6465
except BaseException as e:
6566
logger.exception(str(e))
6667

67-
# blueswhen TODO: 考虑拆分函数,简化逻辑
68-
def _handle_group_req_cpu_cache_match(self, group_req_indexes: GroupReqIndexes, start_time: float):
68+
def _cpu_cache_match(self, token_hash_list: List[int]) -> List[int]:
6969
"""
70-
match cpu cache pages
70+
匹配CPU cache,返回命中的pages列表(最长前缀)
71+
Returns:
72+
all_pages: 命中的page索引列表,len(all_pages)即为命中长度
7173
"""
72-
# 超时时,放弃进行 cpu cache page 的匹配。
74+
all_pages = []
75+
self.cpu_cache_client.lock.acquire_sleep1ms()
76+
for token_hash in token_hash_list:
77+
page_index, _ = self.cpu_cache_client.query_one_page(token_hash)
78+
if page_index is None:
79+
break
80+
all_pages.append(page_index)
81+
self.cpu_cache_client.lock.release()
82+
return all_pages
83+
84+
def _disk_cache_match(self, token_hash_list: List[int], all_pages: List[int]) -> tuple[List[int], int]:
85+
"""
86+
匹配disk cache并加载缺失的页面,直接append到all_pages
87+
Returns:
88+
(finded_page_indexes, disk_page_num): 最终匹配到的页面索引列表(最长前缀)和从disk加载的页面数量
89+
"""
90+
cpu_hit_len = len(all_pages)
91+
loadable_len = self.disk_cache_worker.query_loadable_pages(tokens=token_hash_list, start_pos=cpu_hit_len)
92+
if loadable_len == 0:
93+
return all_pages, 0
94+
95+
missing_hash_keys = token_hash_list[cpu_hit_len : cpu_hit_len + loadable_len]
96+
self.cpu_cache_client.lock.acquire_sleep1ms()
97+
allocated_pages, _ = self.cpu_cache_client.allocate_pages(
98+
hash_keys=missing_hash_keys, disk_offload_enable=self.args.enable_disk_cache
99+
)
100+
self.cpu_cache_client.lock.release()
101+
102+
# 收集成功分配的页面,直接append到all_pages
103+
new_page_indexes = []
104+
for page_index in allocated_pages:
105+
if page_index == -1:
106+
break
107+
all_pages.append(page_index)
108+
new_page_indexes.append(page_index)
109+
110+
if not new_page_indexes:
111+
return all_pages, 0
112+
113+
# 计算需要从disk加载的范围,必须按block边界对齐
114+
block_size = self.disk_cache_worker.service._n
115+
start_block = cpu_hit_len // block_size
116+
load_start_pos = start_block * block_size
117+
118+
load_tokens = token_hash_list[: cpu_hit_len + len(new_page_indexes)]
119+
if not self.disk_cache_worker.load_pages(tokens=load_tokens, page_indexes=all_pages, start_pos=load_start_pos):
120+
self.cpu_cache_client.lock.acquire_sleep1ms()
121+
self.cpu_cache_client.recycle_pages(new_page_indexes)
122+
self.cpu_cache_client.lock.release()
123+
return all_pages[:cpu_hit_len], 0
124+
125+
self.cpu_cache_client.lock.acquire_sleep1ms()
126+
self.cpu_cache_client.update_pages_status_to_ready(
127+
page_list=all_pages,
128+
deref=False,
129+
disk_offload_enable=False,
130+
)
131+
if self.args.enable_disk_cache:
132+
self.cpu_cache_client.mark_pages_recyclable(new_page_indexes)
133+
self.cpu_cache_client.lock.release()
134+
return all_pages, len(new_page_indexes)
135+
136+
def _handle_group_req_multi_cache_match(self, group_req_indexes: GroupReqIndexes, start_time: float):
137+
"""
138+
match cpu cache and disk cache pages
139+
"""
140+
# 超时时,放弃进行 cache page 的匹配。
73141
current_time = time.time()
74142
if current_time - start_time >= self.cpu_cache_time_out:
75143
self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
76144
logger.warning(
77-
f"cpu cache match time out {current_time - start_time}s, "
145+
f"cache matching time out {current_time - start_time}s, "
78146
f"group_req_id: {group_req_indexes.group_req_id}"
79147
)
80148
return
@@ -96,119 +164,26 @@ def _handle_group_req_cpu_cache_match(self, group_req_indexes: GroupReqIndexes,
96164
if len(token_hash_list) == 0:
97165
continue
98166

99-
req.disk_prompt_cache_len = 0
100167
finded_page_indexes: List[int] = []
101168
disk_service = (
102169
self.disk_cache_worker.service
103170
if (self.disk_cache_worker is not None and self.disk_cache_worker.service is not None)
104171
else None
105172
)
106-
block_capacity = disk_service._n if disk_service is not None else 1
107-
if block_capacity <= 0:
108-
block_capacity = 1
109-
110-
disk_loaded_page_indexes: List[int] = []
111-
idx = 0
112-
while idx < len(token_hash_list):
113-
chunk_len = min(block_capacity, len(token_hash_list) - idx)
114-
chunk_tokens = token_hash_list[idx : idx + chunk_len]
115-
if not chunk_tokens:
116-
break
117-
118-
block_pages: List[int] = []
119-
missing_positions: List[int] = []
120-
121-
self.cpu_cache_client.lock.acquire_sleep1ms()
122-
for pos, token_hash_value in enumerate(chunk_tokens):
123-
page_index, ready = self.cpu_cache_client.query_one_page(token_hash_value)
124-
if page_index is not None:
125-
block_pages.append(page_index)
126-
continue
127-
128-
# -1仅用于占位
129-
block_pages.append(-1)
130-
missing_positions.append(pos)
131-
self.cpu_cache_client.lock.release()
132-
133-
if not missing_positions:
134-
finded_page_indexes.extend(block_pages)
135-
idx += chunk_len
136-
continue
137-
138-
if disk_service is None:
139-
finded_page_indexes.extend(block_pages)
140-
break
141-
142-
prefix_len = idx + chunk_len
143-
prefix_tokens = token_hash_list[:prefix_len]
144-
if not self.disk_cache_worker.blocks_exist(tokens=prefix_tokens, start_pos=idx):
145-
finded_page_indexes.extend(block_pages)
146-
break
147-
148-
self.cpu_cache_client.lock.acquire_sleep1ms()
149-
new_page_indexes: List[int] = []
150-
allocation_failed = False
151-
page_items = self.cpu_cache_client.page_items.linked_items
152-
for pos in missing_positions:
153-
token_hash_value = chunk_tokens[pos]
154-
page_index, ready = self.cpu_cache_client.allocate_one_page(
155-
page_items=page_items,
156-
hash_key=token_hash_value,
157-
disk_offload_enable=self.args.enable_disk_cache,
158-
)
159-
if page_index is None:
160-
allocation_failed = True
161-
break
162-
block_pages[pos] = page_index
163-
if not ready:
164-
new_page_indexes.append(page_index)
165-
if allocation_failed and new_page_indexes:
166-
self.cpu_cache_client.recycle_pages(new_page_indexes)
167-
self.cpu_cache_client.lock.release()
168-
169-
if allocation_failed:
170-
hit_pages = [p for p in block_pages if p not in new_page_indexes]
171-
finded_page_indexes.extend(hit_pages)
172-
break
173-
174-
pages_to_load = new_page_indexes
175-
if pages_to_load:
176-
prefix_len = idx + chunk_len
177-
prefix_tokens = token_hash_list[:prefix_len]
178-
prefix_pages = finded_page_indexes + block_pages
179-
180-
if not self.disk_cache_worker.load_pages(
181-
tokens=prefix_tokens, page_indexes=prefix_pages, start_pos=idx
182-
):
183-
self.cpu_cache_client.lock.acquire_sleep1ms()
184-
self.cpu_cache_client.recycle_pages(pages_to_load)
185-
self.cpu_cache_client.lock.release()
186-
hit_pages = [p for p in block_pages if p not in pages_to_load]
187-
finded_page_indexes.extend(hit_pages)
188-
break
189-
190-
self.cpu_cache_client.lock.acquire_sleep1ms()
191-
self.cpu_cache_client.update_pages_status_to_ready(
192-
page_list=block_pages,
193-
deref=False,
194-
disk_offload_enable=False,
195-
)
196-
if self.args.enable_disk_cache and pages_to_load:
197-
self.cpu_cache_client.mark_pages_recyclable(pages_to_load)
198-
self.cpu_cache_client.lock.release()
199-
200-
disk_loaded_page_indexes.extend(pages_to_load)
201-
202-
finded_page_indexes.extend(block_pages)
203-
idx += chunk_len
204-
205-
finded_page_indexes = [p for p in finded_page_indexes if p != -1]
173+
req.disk_prompt_cache_len = 0
174+
175+
# 匹配 CPU cache
176+
all_pages = self._cpu_cache_match(token_hash_list)
177+
if len(all_pages) == len(token_hash_list) or disk_service is None:
178+
finded_page_indexes = all_pages
179+
else:
180+
# 匹配 disk cache并load到cpu cache
181+
finded_page_indexes, disk_page_num = self._disk_cache_match(token_hash_list, all_pages)
182+
req.disk_prompt_cache_len = disk_page_num * self.args.cpu_cache_token_page_size
183+
206184
while not self.cpu_cache_client.check_allpages_ready(finded_page_indexes):
207185
time.sleep(0.01)
208186

209-
if disk_loaded_page_indexes:
210-
req.disk_prompt_cache_len = len(disk_loaded_page_indexes) * self.args.cpu_cache_token_page_size
211-
212187
req.cpu_cache_match_page_indexes.fill(finded_page_indexes)
213188

214189
for req in reqs:

lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,10 @@ def update_cpu_cache_task_states(self):
252252
page_array_list = [task.page_indexes.tolist() for task in trans_ok_tasks]
253253
if self.backend.is_master_in_dp:
254254
self.cpu_cache_client.lock.acquire_sleep1ms()
255+
# 分组update,避免不同请求的page交叉,导致disk cache hash不一致
255256
for pages in page_array_list:
256257
if not pages:
257258
continue
258-
# Keep per-req grouping so disk cache hashes stay aligned with req prefixes.
259259
self.cpu_cache_client.update_pages_status_to_ready(
260260
page_list=pages, deref=True, disk_offload_enable=self.args.enable_disk_cache
261261
)

test/benchmark/service/benchmark_qps.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def gen_random_input_text(tokenizer, input_len) -> str:
4848

4949

5050
def gen_random_input_text_with_seed(tokenizer, input_len, seed) -> str:
51-
"""Generate random input text with a specific seed"""
5251
rng = random.Random(seed)
5352
random_ids = [rng.randint(0, tokenizer.vocab_size) for _ in range(input_len)]
5453
random_text = tokenizer.decode(random_ids)
@@ -68,15 +67,12 @@ def gen_random_data(
6867
output_lens = get_random_length(reqs_num, output_len, range_ratio)
6968
input_lens = get_random_length(reqs_num, input_len, range_ratio)
7069

71-
# Generate input_len2 lengths if input_len2 > 0
7270
if input_len2 > 0:
7371
input_lens2 = get_random_length(reqs_num, input_len2, range_ratio)
7472

7573
for i in range(reqs_num):
76-
# Generate first part with main random state
7774
input_text = gen_random_input_text(tokenizer, input_lens[i])
7875

79-
# Generate second part with seed2 if specified
8076
if input_len2 > 0 and seed2 is not None:
8177
input_text2 = gen_random_input_text_with_seed(tokenizer, input_lens2[i], seed2 + i)
8278
input_text = input_text + input_text2
@@ -339,7 +335,9 @@ def main():
339335
parser.add_argument("--input_num", type=int, default=2000)
340336
parser.add_argument("--input_qps", type=float, default=30.0)
341337
parser.add_argument("--input_len", type=int, default=1024)
342-
parser.add_argument("--input_len2", type=int, default=0, help="Length of second part to append, 0 means disabled")
338+
parser.add_argument(
339+
"--input_len2", type=int, default=0, help="Length of second part to append behind input_len, 0 means disabled"
340+
)
343341
parser.add_argument("--output_len", type=int, default=128)
344342
parser.add_argument("--server_api", type=str, default="lightllm")
345343
parser.add_argument("--dump_file", type=str, default="")

0 commit comments

Comments
 (0)