1818from lightllm .utils .custom_kernel_utis import custom_cat
1919from lightllm .utils .envs_utils import get_env_start_args
2020from lightllm .server .pd_io_struct import NIXLDecodeNodeInfo
21+ from lightllm .server .embed_cache .embed_cache_client import CpuEmbedCacheClient
2122
2223logger = init_logger (__name__ )
2324
@@ -30,12 +31,19 @@ class InferenceContext:
3031 requests_mapping : Dict [int , "InferReq" ] = None
3132 infer_req_ids = None
3233 vocab_size = None
34+ cpu_embed_cache_client : Optional [CpuEmbedCacheClient ] = None
3335
3436 overlap_stream : torch .cuda .Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。
3537 cpu_kv_cache_stream : torch .cuda .Stream = None # 用 cpu kv cache 操作的 stream
3638
3739 def register (
38- self , backend , req_manager : ReqManager , radix_cache : RadixCache , shm_req_manager : ShmReqManager , vocab_size : int
40+ self ,
41+ backend ,
42+ req_manager : ReqManager ,
43+ radix_cache : RadixCache ,
44+ shm_req_manager : ShmReqManager ,
45+ vocab_size : int ,
46+ cpu_embed_cache_client : Optional [CpuEmbedCacheClient ] = None ,
3947 ):
4048 self .args = get_env_start_args ()
4149 from lightllm .server .router .model_infer .mode_backend .base_backend import ModeBackend
@@ -50,6 +58,7 @@ def register(
5058 self .infer_req_ids = []
5159
5260 self .vocab_size = vocab_size
61+ self .cpu_embed_cache_client = cpu_embed_cache_client
5362 return
5463
5564 def get_overlap_stream (self ) -> torch .cuda .Stream :
0 commit comments