@@ -213,6 +213,9 @@ def init_model(self, kvargs):
213213 else None
214214 )
215215
216+ if "prompt_cache_kv_buffer" in model_cfg :
217+ self .preload_prompt_cache_kv_buffer (model_cfg )
218+
216219 self .logger .info (f"loaded model class { self .model .__class__ } " )
217220 self .init_custom ()
218221
@@ -313,3 +316,21 @@ def remove_batch(self, batch_id):
313316 del batch
314317 g_infer_state_lock .release ()
315318 return
319+
320+ def preload_prompt_cache_kv_buffer (self , model_cfg ):
321+ self .logger .info ("Preload prompt cache kv buffer." )
322+ cur_rank = dist .get_rank ()
323+ prompt_cache_kv_buffer_path = os .path .join (
324+ self .weight_dir , model_cfg ["prompt_cache_kv_buffer" ][f"rank_{ cur_rank } " ]
325+ )
326+ prompt_cache_kv_buffer = torch .load (prompt_cache_kv_buffer_path , weights_only = True , map_location = "cpu" )
327+ if isinstance (self .radix_cache .mem_manager .kv_buffer , list ):
328+ for i in range (len (self .radix_cache .mem_manager .kv_buffer )):
329+ self .radix_cache .mem_manager .kv_buffer [i ].copy_ (prompt_cache_kv_buffer [i ])
330+ else :
331+ self .radix_cache .mem_manager .kv_buffer .copy_ (prompt_cache_kv_buffer )
332+ self .radix_cache .insert (
333+ torch .tensor (model_cfg ["prompt_cache_token_ids" ], dtype = torch .int64 , device = "cpu" ),
334+ torch .tensor (range (len (model_cfg ["prompt_cache_token_ids" ])), dtype = torch .int32 , device = "cpu" ),
335+ )
336+ self .radix_cache .mem_manager .mem_state [: len (model_cfg ["prompt_cache_token_ids" ])] = 1
0 commit comments