Skip to content

Commit 789149f

Browse files
committed
add cpu embed to llm
1 parent b298257 commit 789149f

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from lightllm.utils.custom_kernel_utis import custom_cat
1919
from lightllm.utils.envs_utils import get_env_start_args
2020
from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo
21+
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
2122

2223
logger = init_logger(__name__)
2324

@@ -30,12 +31,19 @@ class InferenceContext:
3031
requests_mapping: Dict[int, "InferReq"] = None
3132
infer_req_ids = None
3233
vocab_size = None
34+
cpu_embed_cache_client: Optional[CpuEmbedCacheClient] = None
3335

3436
overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。
3537
cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream
3638

3739
def register(
38-
self, backend, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int
40+
self,
41+
backend,
42+
req_manager: ReqManager,
43+
radix_cache: RadixCache,
44+
shm_req_manager: ShmReqManager,
45+
vocab_size: int,
46+
cpu_embed_cache_client: Optional[CpuEmbedCacheClient] = None,
3947
):
4048
self.args = get_env_start_args()
4149
from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend
@@ -50,6 +58,7 @@ def register(
5058
self.infer_req_ids = []
5159

5260
self.vocab_size = vocab_size
61+
self.cpu_embed_cache_client = cpu_embed_cache_client
5362
return
5463

5564
def get_overlap_stream(self) -> torch.cuda.Stream:

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
4040
from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet
4141
from .multi_level_kv_cache import MultiLevelKvCacheModule
42+
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
4243

4344

4445
class ModeBackend:
@@ -179,12 +180,16 @@ def init_model(self, kvargs):
179180
self.preload_prompt_cache_kv_buffer(model_cfg)
180181

181182
self.logger.info(f"loaded model class {self.model.__class__}")
183+
182184
g_infer_context.register(
183185
backend=self,
184186
req_manager=self.model.req_manager,
185187
radix_cache=self.radix_cache,
186188
shm_req_manager=self.shm_req_manager,
187189
vocab_size=self.model.vocab_size,
190+
cpu_embed_cache_client=CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False)
191+
if self.args.enable_multimodal
192+
else None,
188193
)
189194

190195
# 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到

0 commit comments

Comments
 (0)