@@ -96,7 +96,7 @@ class RadixCache:
9696 unique_name 主要用于解决单机,多实列部署时的shm冲突
9797 """
9898
99- def __init__ (self , unique_name , total_token_num , tp_id , mem_manager : MemoryManager = None ):
99+ def __init__ (self , unique_name , total_token_num , rank_in_node , mem_manager : MemoryManager = None ):
100100 self .mem_manager = mem_manager
101101 self ._key_dtype = torch .int64
102102 self ._value_dtype = torch .int64
@@ -109,9 +109,9 @@ def __init__(self, unique_name, total_token_num, tp_id, mem_manager: MemoryManag
109109 self .evict_tree_set : Set [TreeNode ] = SortedSet (key = lambda x : x .get_compare_key ()) # 自定义比较器
110110 self .evict_tree_set .add (self .root_node )
111111
112- self .refed_tokens_num = SharedArray (f"{ unique_name } _refed_tokens_num_{ tp_id } " , (1 ,), dtype = np .int64 )
112+ self .refed_tokens_num = SharedArray (f"{ unique_name } _refed_tokens_num_{ rank_in_node } " , (1 ,), dtype = np .int64 )
113113 self .refed_tokens_num .arr [0 ] = 0
114- self .tree_total_tokens_num = SharedArray (f"{ unique_name } _tree_total_tokens_num_{ tp_id } " , (1 ,), dtype = np .int64 )
114+ self .tree_total_tokens_num = SharedArray (f"{ unique_name } _tree_total_tokens_num_{ rank_in_node } " , (1 ,), dtype = np .int64 )
115115 self .tree_total_tokens_num .arr [0 ] = 0
116116
117117 def insert (self , key , value = None ):
@@ -345,9 +345,9 @@ class _RadixCacheReadOnlyClient:
345345 router 端只读用的客户端,用于从共享内存中读取树结构中的信息,用于进行prompt cache 的调度估计。
346346 """
347347
348- def __init__ (self , unique_name , total_token_num , tp_id ):
349- self .refed_tokens_num = SharedArray (f"{ unique_name } _refed_tokens_num_{ tp_id } " , (1 ,), dtype = np .int64 )
350- self .tree_total_tokens_num = SharedArray (f"{ unique_name } _tree_total_tokens_num_{ tp_id } " , (1 ,), dtype = np .int64 )
348+ def __init__ (self , unique_name , total_token_num , rank_in_node ):
349+ self .refed_tokens_num = SharedArray (f"{ unique_name } _refed_tokens_num_{ rank_in_node } " , (1 ,), dtype = np .int64 )
350+ self .tree_total_tokens_num = SharedArray (f"{ unique_name } _tree_total_tokens_num_{ rank_in_node } " , (1 ,), dtype = np .int64 )
351351
352352 def get_refed_tokens_num (self ):
353353 return self .refed_tokens_num .arr [0 ]
@@ -360,115 +360,16 @@ def get_unrefed_tokens_num(self):
360360
361361
362362class RadixCacheReadOnlyClient :
363- def __init__ (self , unique_name , total_token_num , tp_size ):
364- self .tp_clients : List [_RadixCacheReadOnlyClient ] = [
365- _RadixCacheReadOnlyClient (unique_name , total_token_num , tp_id ) for tp_id in range (tp_size )
363+ def __init__ (self , unique_name , total_token_num , node_world_size , dp_world_size ):
364+ self .dp_rank_clients : List [_RadixCacheReadOnlyClient ] = [
365+ _RadixCacheReadOnlyClient (unique_name , total_token_num , rank_in_node ) for rank_in_node in range (0 , node_world_size , dp_world_size )
366366 ]
367367
368- def get_refed_tokens_num (self , index ):
369- return self .tp_clients [ index ].get_refed_tokens_num ()
368+ def get_refed_tokens_num (self , dp_rank_in_node ):
369+ return self .dp_rank_clients [ dp_rank_in_node ].get_refed_tokens_num ()
370370
371- def get_tree_total_tokens_num (self , index ):
372- return self .tp_clients [ index ].get_tree_total_tokens_num ()
371+ def get_tree_total_tokens_num (self , dp_rank_in_node ):
372+ return self .dp_rank_clients [ dp_rank_in_node ].get_tree_total_tokens_num ()
373373
374- def get_unrefed_tokens_num (self , index ):
375- return self .tp_clients [index ].get_unrefed_tokens_num ()
376-
377-
378- # ///////////////////////////////////////////////////////////////////////////////
379-
380- if __name__ == "__main__" :
381- # test 1
382- def test1 ():
383- tree = RadixCache ("unique_name" , 100 , 0 )
384- ans = tree .insert (torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], dtype = torch .int64 , device = "cpu" ))
385- assert ans == 0
386- tree .print_self ()
387- ans = tree .insert (torch .tensor ([0 , 1 , 2 , 3 , 4 , 7 , 8 , 9 ], dtype = torch .int64 , device = "cpu" ))
388- assert ans == 5
389- tree .print_self ()
390- ans = tree .insert (torch .tensor ([0 , 1 , 2 , 3 , 4 , 7 , 8 , 9 ], dtype = torch .int64 , device = "cpu" ))
391- assert ans == 8
392- tree .print_self ()
393-
394- assert tree .get_refed_tokens_num () == 0
395- assert tree .get_tree_total_tokens_num () == 13
396-
397- # print("evict")
398- tree .evict (9 , lambda x : x )
399- tree .print_self ()
400- assert tree .get_refed_tokens_num () == 0 and tree .get_tree_total_tokens_num () == 0
401-
402- test1 ()
403-
404- # test 2
405- def test2 ():
406- tree = RadixCache ("unique_name" , 100 , 1 )
407- ans = tree .insert (torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], dtype = torch .int64 , device = "cpu" ))
408- ans = tree .insert (torch .tensor ([0 , 1 , 2 , 3 , 4 , 7 , 8 , 9 ], dtype = torch .int64 , device = "cpu" ))
409- tree .print_self ()
410-
411- tree_node , size , values = tree .match_prefix (
412- torch .tensor ([0 , 1 , 2 , 3 , 4 ], dtype = torch .int64 , device = "cpu" ), update_refs = False
413- )
414- assert tree_node .node_prefix_total_len == 5 and size == 5 and len (values ) == 5
415- tree_node , size , values = tree .match_prefix (
416- torch .tensor ([0 , 1 , 2 , 3 , 4 , 9 ], dtype = torch .int64 , device = "cpu" ), update_refs = False
417- )
418- assert tree_node .node_prefix_total_len == 5 and size == 5 and len (values ) == 5
419- tree_node , size , values = tree .match_prefix (
420- torch .tensor ([0 , 1 , 2 , 3 , 4 , 7 , 8 ], dtype = torch .int64 , device = "cpu" ), update_refs = False
421- )
422- assert tree_node .node_prefix_total_len == 7 and size == 7 and len (values ) == 7
423- tree_node , size , values = tree .match_prefix (
424- torch .tensor ([0 , 1 , 2 , 3 , 4 , 7 , 9 ], dtype = torch .int64 , device = "cpu" ), update_refs = False
425- )
426- assert tree_node .node_prefix_total_len == 6 and size == 6 and len (values ) == 6
427- print (ans )
428- return
429-
430- # test2()
431-
432- # test 3
433- def test3 ():
434- tree = RadixCache ("unique_name" , 100 , 2 )
435- ans = tree .insert (torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], dtype = torch .int64 , device = "cpu" ))
436- ans = tree .insert (torch .tensor ([0 , 1 , 2 , 3 , 4 , 7 , 8 , 9 ], dtype = torch .int64 , device = "cpu" ))
437- tree .print_self ()
438-
439- tree_node , size , values = tree .match_prefix (
440- torch .tensor ([0 , 1 , 2 , 3 , 4 ], dtype = torch .int64 , device = "cpu" ), update_refs = True
441- )
442- assert tree_node .node_prefix_total_len == 5 and size == 5 and len (values ) == 5
443- assert tree .get_refed_tokens_num () == 5 and tree .get_tree_total_tokens_num () == 13
444-
445- tree_node , size , values = tree .match_prefix (
446- torch .tensor ([0 , 1 , 2 , 3 , 4 , 7 , 9 ], dtype = torch .int64 , device = "cpu" ), update_refs = True
447- )
448- assert tree_node .node_prefix_total_len == 6 and size == 6 and len (values ) == 6
449- assert tree .get_refed_tokens_num () == 6 and tree .get_tree_total_tokens_num () == 13
450-
451- tree .print_self ()
452- tree .evict (2 , lambda x : x )
453- assert tree .get_refed_tokens_num () == 6 and tree .get_tree_total_tokens_num () == 8
454- tree .print_self ()
455-
456- tree .dec_node_ref_counter (tree_node )
457- tree .print_self ()
458- print (ans )
459- return
460-
461- test3 ()
462-
463- def test4 ():
464-
465- tree = RadixCache ("unique_name" , 100 , 2 )
466- ans = tree .insert (torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], dtype = torch .int64 , device = "cpu" ))
467- ans = tree .insert (torch .tensor ([0 , 1 , 2 , 3 , 4 , 7 , 8 , 9 ], dtype = torch .int64 , device = "cpu" ))
468- tree .print_self ()
469-
470- tree .clear_tree_nodes ()
471- print (ans )
472- return
473-
474- test4 ()
374+ def get_unrefed_tokens_num (self , dp_rank_in_node ):
375+ return self .dp_rank_clients [dp_rank_in_node ].get_unrefed_tokens_num ()
0 commit comments