@@ -113,28 +113,16 @@ def _init_mem_manager(self):
113113 mem_fraction = self .mem_fraction ,
114114 )
115115 if self .enable_hiradix_cache :
116- from lightllm .common .radixmem_buffer import RadixMemoryBuffer , init_shared_data , get_shared_data , MemPropties
117- from lightllm .common .radixmem_manager import RadixBufferManager
116+ from lightllm .common .radixmem_buffer import get_shared_data , MemPropties
117+ from lightllm .common .radixmem_manager import build_radix_manager
118118 mem_propties = MemPropties (
119119 self .hiradix_cache_token_num ,
120120 dtype = self .data_type ,
121121 head_num = 1 ,
122122 head_dim = self .config ["kv_lora_rank" ] + self .config ["qk_rope_head_dim" ],
123123 layer_num = self .config ["num_hidden_layers" ] + added_mtp_layer_num ,
124124 )
125- init_shared_data (
126- mem_propties = mem_propties ,
127- device = "cpu" if not self .hiradix_cache_gpu else "cuda"
128- )
129- radix_mem_buffer = RadixMemoryBuffer (
130- mem_propties ,
131- shared_data = get_shared_data (),
132- lock = self .radix_lock ,
133- device = "cpu" if not self .hiradix_cache_gpu else "cuda"
134- )
135- self .radix_manager = RadixBufferManager (radix_buffer = radix_mem_buffer ,
136- radix_mem_data = get_shared_data (),
137- lock = self .radix_lock )
125+ self .radix_manager = build_radix_manager (mem_propties , self .hiradix_cache_gpu , self .radix_lock )
138126 self .mem_propties = mem_propties
139127 self .shared_mem_data = get_shared_data ()
140128 return
0 commit comments