Skip to content

Commit 053c922

Browse files
author
none
committed
Fix
1 parent f4cdbed commit 053c922

File tree

4 files changed

+25
-7
lines changed

4 files changed

+25
-7
lines changed

lightllm/server/multi_level_kv_cache/cpu_cache_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def _create_shm_cpu_kv_cache(self):
213213

214214
def _attach_shm_cpu_kv_cache(self):
215215
shm_ptr = attach_shm_kv_cache_ptr()
216-
register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.kv_cache_tensor_meta.calcu_size())
216+
device_ptr = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.kv_cache_tensor_meta.calcu_size())
217217
shape = (
218218
self.kv_cache_tensor_meta.page_num,
219219
self.kv_cache_tensor_meta.layer_num,
@@ -223,7 +223,7 @@ def _attach_shm_cpu_kv_cache(self):
223223
)
224224
self.cpu_kv_cache_tensor = torch.empty(size=shape, dtype=torch.bfloat16, device="meta")
225225
# 将指针绑定到 tensor上,方便triton获取真实的地址。
226-
self.cpu_kv_cache_tensor.data_ptr = lambda: shm_ptr
226+
self.cpu_kv_cache_tensor.data_ptr = lambda: device_ptr
227227
return
228228

229229

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,16 +192,16 @@ def init_model(self, kvargs):
192192
# 开启 mtp 模式,需要完成mtp model的初始化
193193
if self.args.mtp_mode:
194194
self.init_mtp_draft_model(kvargs)
195+
196+
if self.args.enable_cpu_cache:
197+
self.multi_level_cache_module = MultiLevelKvCacheModule(self)
195198

196199
# 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景
197200
# 可以降低 cpu overhead,大幅提升gpu得使用率。
198201
self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True)
199202
self.infer_loop_thread.start()
200203
self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True)
201204
self.infer_loop_thread1.start()
202-
203-
if self.args.enable_cpu_cache:
204-
self.multi_level_cache_module = MultiLevelKvCacheModule(self)
205205
return
206206

207207
def init_custom(self):

lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ def __init__(self, backend):
2121
self.gloo_group = create_new_group_for_current_dp("gloo")
2222
self.filter_group = create_new_group_for_current_dp("gloo")
2323
self.sync_group = create_new_group_for_current_dp("nccl")
24+
dist.barrier(group=self.sync_group)
2425
self.init_sync_group = create_new_group_for_current_dp("nccl")
26+
dist.barrier(group=self.init_sync_group)
27+
2528

2629
self.cpu_cache_handle_queue: Deque[TransTask] = deque()
2730
self.cpu_cache_client = CpuKvCacheClient(init_shm_data=False)

lightllm/utils/kv_cache_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,18 @@ def calcu_size(self):
148148
return self.page_num * self.layer_num * self.token_page_size * self.num_heads * self.head_dim * self.item_size
149149

150150

151-
def register_shm_ptr_to_pin(shm_ptr: int, size: int):
151+
def register_shm_ptr_to_pin(shm_ptr: int, size: int) -> int:
152152
# 加载 CUDA 库
153153
cuda = ctypes.CDLL("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so") # Linux 下的 CUDA 库路径
154154

155155
# 定义 cudaHostRegister 函数的参数和返回类型
156156
cuda.cudaHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint]
157157
cuda.cudaHostRegister.restype = ctypes.c_int
158158

159+
# 定义 cudaHostGetDevicePointer 函数原型
160+
cuda.cudaHostGetDevicePointer.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.c_int]
161+
cuda.cudaHostGetDevicePointer.restype = ctypes.c_int
162+
159163
# 定义常量
160164
cudaHostRegisterDefault = 0 # 默认注册标志
161165

@@ -166,4 +170,15 @@ def register_shm_ptr_to_pin(shm_ptr: int, size: int):
166170
raise Exception(f"Error registering host memory: {result}")
167171
else:
168172
logger.info("Host memory registered successfully.")
169-
return
173+
174+
device_ptr = ctypes.c_void_p() # 输出设备指针
175+
host_ptr = ctypes.c_void_p(shm_ptr) # 输入主机指针
176+
177+
result = cuda.cudaHostGetDevicePointer(ctypes.byref(device_ptr), host_ptr, 0)
178+
179+
if result != 0:
180+
raise RuntimeError(f"cudaHostGetDevicePointer failed with error code {result}")
181+
182+
logger.info(f"get Host memory registered Device ptr {device_ptr.value}")
183+
184+
return device_ptr.value

0 commit comments

Comments
 (0)