44import random
55from threading import Thread , Event
66from queue import Queue
7- from lightllm .server .router .dynamic_prompt .cache_controller import HiCacheController , CacheNode , BLOCK_SIZE , HiHostService , HiHostTask
7+ from lightllm .server .router .dynamic_prompt .cache_controller import (
8+ HiCacheController ,
9+ CacheNode ,
10+ BLOCK_SIZE ,
11+ HiHostService ,
12+ HiHostTask ,
13+ )
14+
815
916class MockMemoryManager :
1017 """模拟内存管理器,仅返回连续的索引值"""
18+
1119 def __init__ (self ):
1220 self .current_idx = 0
1321 self .kvcache_store = {}
1422
1523 def alloc (self , size ):
1624 indices = list (range (self .current_idx , self .current_idx + size ))
1725 self .current_idx += size
18- self .store (indices , torch .tensor ([[random .randint (0 , 0xffff ) for __ in range (512 )] for _ in range (size )]))
26+ self .store (indices , torch .tensor ([[random .randint (0 , 0xFFFF ) for __ in range (512 )] for _ in range (size )]))
1927 return indices
20-
28+
2129 def load_index_kv_buffer (self , index , load_tensor_dict ):
2230 self .kvcache_store [index ] = load_tensor_dict ["kv_buffer" ]
23-
31+
2432 def get_index_kv_buffer (self , index ):
2533 return {"kv_buffer" : self .kvcache_store [index ]}
26-
34+
2735 def to_kvcache (self , indices ):
28- assert all ([idx in self .kvcache_store for idx in indices ]), f"Not all of { indices } are not found in kvcache_store"
36+ assert all (
37+ [idx in self .kvcache_store for idx in indices ]
38+ ), f"Not all of { indices } are not found in kvcache_store"
2939 return torch .tensor ([self .kvcache_store [idx ].tolist () for idx in indices ])
30-
40+
3141 def store (self , indices , value ):
3242 print (f"[TEST:MemManager] Storing { value .shape } at { indices } " )
3343 for idx , value_dim in zip (indices , range (value .shape [0 ])):
3444 self .kvcache_store [idx ] = value [value_dim ]
3545 print (f"[TEST:MemManager] Stored { value [value_dim ].shape } at { idx } " )
3646 return indices
37-
47+
3848 def free (self , indices ):
3949 print (f"[TEST:MemManager] Freeing { indices } " )
4050 for idx in indices :
@@ -46,87 +56,91 @@ def setup():
4656 service = HiHostService ()
4757 hicache = HiCacheController (mem_manager )
4858 hicache .service = service # 注入模拟服务
49-
59+
5060 indices = mem_manager .alloc (5 )
5161 print (mem_manager .to_kvcache (indices ))
52-
62+
5363 # 预先计算单token大小
5464 dummy_indices = mem_manager .alloc (1 )
5565 kvcache = mem_manager .to_kvcache (dummy_indices [:1 ])
5666 token_size = kvcache .nelement () * kvcache .element_size ()
5767 print (f"[TEST] Single token KV cache size: { token_size } bytes, Block size: { BLOCK_SIZE } " )
58-
68+
5969 return mem_manager , service , hicache , token_size
6070
71+
6172def test_basic_write_read (mem_manager , hicache , token_size ):
6273 # 计算每个块可容纳的token数量
6374 tokens_per_block = BLOCK_SIZE // token_size
6475 print (f"[TEST] Each block can hold { tokens_per_block } tokens" )
65-
76+
6677 # 生成测试数据:刚好占满一个块
6778 token_ids = list (range (tokens_per_block ))
6879 indices = mem_manager .alloc (len (token_ids ))
6980 kvcache = mem_manager .to_kvcache (indices )
7081 print (f"[TEST] Generated KV cache with shape: { kvcache .shape } , type: { kvcache .dtype } " )
71-
82+
7283 # 写入缓存
7384 hicache .write (torch .tensor (token_ids ), torch .tensor (indices ))
7485 time .sleep (2 )
75-
86+
7687 # 等待任务完成
7788 hicache .service .wait_till_all_finished ()
78-
89+
7990 mem_manager .free (indices )
80-
91+
8192 # 读取验证
8293 result = hicache .read (torch .tensor (token_ids ))
8394 result = mem_manager .to_kvcache (result .tolist ())
8495 assert result .eq (kvcache ).all (), f"Retrieved kvcache: { result } , Expected kvcache: { kvcache } "
85- print (f"[TEST] Basic test passed. Retrieved kvcache\n \n " )
96+ print ("[TEST] Basic test passed. Retrieved kvcache\n \n " )
97+
8698
8799def test_node_splitting (mem_manager , hicache , token_size ):
88100 tokens_per_block = BLOCK_SIZE // token_size
89101 # 生成超过一个块的数据
90102 token_ids = list (range (12 , 12 + tokens_per_block * 3 + 1 ))
91103 indices = mem_manager .alloc (len (token_ids ))
92104 kvcache = mem_manager .to_kvcache (indices )
93-
105+
94106 hicache .write (torch .tensor (token_ids ), torch .tensor (indices ))
95107 time .sleep (2 )
96108 hicache .service .wait_till_all_finished ()
97-
109+
98110 # 验证根节点应该有子节点
99111 root = hicache .root
100112 assert len (root .children ) > 0
101113 print (f"\n Root node has { len (root .children )} children" )
102-
114+
103115 # 读取完整序列
104116 result = hicache .read (torch .tensor (token_ids ))
105117 result = mem_manager .to_kvcache (result .tolist ())
106118 assert result .eq (kvcache ).all (), f"Retrieved kvcache: { result } , Expected kvcache: { kvcache } "
107119 print (f"[TEST] Node splitting test passed. Retrieved kvcache: { result .shape } \n \n " )
108120
121+
109122def test_partial_read (mem_manager , hicache ):
110123 token_ids = [97 , 98 , 99 , 100 , 101 , 102 ]
111124 indices = mem_manager .alloc (len (token_ids ))
112125 kvcache = mem_manager .to_kvcache (indices )
113126 hicache .write (torch .tensor (token_ids ), torch .tensor (indices ))
114127 time .sleep (2 )
115128 hicache .service .wait_till_all_finished ()
116-
129+
117130 # 查询存在的部分前缀
118131 result = hicache .read (torch .tensor ([97 , 98 , 99 ]))
119132 result = mem_manager .to_kvcache (result .tolist ())
120133 assert result .eq (kvcache [:3 ]).all ()
121- print (f "[TEST] Partial read passed" )
122-
134+ print ("[TEST] Partial read passed" )
135+
123136 # 查询不存在的前缀
124137 result = hicache .read (torch .tensor ([97 , 98 , 100 ]))
125138 assert len (result ) == 2
126139 result = mem_manager .to_kvcache (result .tolist ())
127140 assert result .eq (kvcache [:2 ]).all ()
128141 print (f"[TEST] Non-existent prefix returned: { result .tolist ()} " )
129142
143+
130144def main ():
131145 mem_manager , service , hicache , token_size = setup ()
132146 try :
@@ -136,5 +150,6 @@ def main():
136150 finally :
137151 service .shutdown ()
138152
153+
139154if __name__ == "__main__" :
140- main ()
155+ main ()
0 commit comments