@@ -212,6 +212,9 @@ def init_model(self, kvargs):
212212 else None
213213 )
214214
215+ if "prompt_cache_kv_buffer" in model_cfg :
216+ self .preload_prompt_cache_kv_buffer (model_cfg )
217+
215218 self .logger .info (f"loaded model class { self .model .__class__ } " )
216219 self .init_custom ()
217220
@@ -256,3 +259,28 @@ def _init_reqs(self, reqs: List[Tuple], init_req_obj=True):
256259 g_infer_state_lock .release ()
257260 req_ids = [e [0 ] for e in reqs ]
258261 return req_ids
262+
263+ def preload_prompt_cache_kv_buffer (self , model_cfg ):
264+ self .logger .info ("Preload prompt cache kv buffer." )
265+ cur_rank = dist .get_rank ()
266+ prompt_cache_kv_buffer_path = os .path .join (
267+ self .weight_dir , model_cfg ["prompt_cache_kv_buffer" ][f"rank_{ cur_rank } " ]
268+ )
269+ prompt_cache_kv_buffer = torch .load (prompt_cache_kv_buffer_path , weights_only = True , map_location = "cpu" )
270+ if isinstance (self .radix_cache .mem_manager .kv_buffer , list ):
271+ for i in range (len (self .radix_cache .mem_manager .kv_buffer )):
272+ self .radix_cache .mem_manager .kv_buffer [i ][: len (model_cfg ["prompt_cache_token_ids" ])].copy_ (
273+ prompt_cache_kv_buffer [i ]
274+ )
275+ else :
276+ self .radix_cache .mem_manager .kv_buffer [:, : len (model_cfg ["prompt_cache_token_ids" ])].copy_ (
277+ prompt_cache_kv_buffer
278+ )
279+ self .radix_cache .insert (
280+ torch .tensor (model_cfg ["prompt_cache_token_ids" ], dtype = torch .int64 , device = "cpu" ),
281+ torch .tensor (range (len (model_cfg ["prompt_cache_token_ids" ])), dtype = torch .int32 , device = "cpu" ),
282+ )
283+ self .radix_cache .mem_manager .mem_state [: len (model_cfg ["prompt_cache_token_ids" ])] = 1
284+ self .radix_cache .match_prefix (
285+ torch .tensor (model_cfg ["prompt_cache_token_ids" ], dtype = torch .int64 , device = "cpu" ), update_refs = True
286+ )
0 commit comments