Skip to content

Commit e12f8d6

Browse files
author
wangzaijun
committed
remove dep code
1 parent 8b528a9 commit e12f8d6

File tree

4 files changed

+56
-252
lines changed

4 files changed

+56
-252
lines changed

lightllm/server/embed_cache/embed_cache_client.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,8 @@
55
from typing import Optional, List
66
from lightllm.utils.envs_utils import get_env_start_args
77
from lightllm.utils.log_utils import init_logger
8-
from lightllm.utils.embed_utils import (
9-
calcu_embed_cache_meta,
10-
create_shm_embed_cache_ptr,
11-
attach_shm_kv_cache_ptr,
12-
register_shm_ptr_to_pin,
13-
)
8+
from lightllm.utils.embed_utils import calcu_embed_cache_meta
9+
from lightllm.utils.kv_cache_utils import create_shm_kv_cache_ptr, attach_shm_kv_cache_ptr, register_shm_ptr_to_pin
1410

1511
logger = init_logger(__name__)
1612

@@ -52,7 +48,11 @@ def copy_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int):
5248
)
5349

5450
def _create_shm_embed_kv_cache(self):
55-
shm_ptr = create_shm_embed_cache_ptr()
51+
shm_ptr = create_shm_kv_cache_ptr(
52+
key=self.args.multi_modal_cache_shm_id, size=self.embed_cache_tensor_meta.calcu_size()
53+
)
54+
handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.embed_cache_tensor_meta.calcu_size())
55+
handle.wait()
5656
numpy_array = np.frombuffer(
5757
memoryview((ctypes.c_uint8 * self.embed_cache_tensor_meta.calcu_size()).from_address(shm_ptr)),
5858
dtype=np.uint8,
@@ -69,7 +69,9 @@ def _create_shm_embed_kv_cache(self):
6969
return
7070

7171
def _attach_shm_cpu_embed_cache(self):
72-
shm_ptr = attach_shm_kv_cache_ptr()
72+
shm_ptr = attach_shm_kv_cache_ptr(
73+
key=self.args.multi_modal_cache_shm_id, size=self.embed_cache_tensor_meta.calcu_size()
74+
)
7375
handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.embed_cache_tensor_meta.calcu_size())
7476
handle.wait()
7577
numpy_array = np.frombuffer(

lightllm/server/multi_level_kv_cache/cpu_cache_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,9 @@ def _create_cpu_status_list(self, init_shm_data: bool):
275275
return
276276

277277
def _create_shm_cpu_kv_cache(self):
278-
shm_ptr = create_shm_kv_cache_ptr()
278+
shm_ptr = create_shm_kv_cache_ptr(
279+
key=self.args.cpu_kv_cache_shm_id, size=self.kv_cache_tensor_meta.calcu_size()
280+
)
279281
numpy_array = np.frombuffer(
280282
memoryview((ctypes.c_uint8 * self.kv_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), dtype=np.uint8
281283
)
@@ -293,7 +295,9 @@ def _create_shm_cpu_kv_cache(self):
293295
return
294296

295297
def _attach_shm_cpu_kv_cache(self):
296-
shm_ptr = attach_shm_kv_cache_ptr()
298+
shm_ptr = attach_shm_kv_cache_ptr(
299+
key=self.args.cpu_kv_cache_shm_id, size=self.kv_cache_tensor_meta.calcu_size()
300+
)
297301
handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.kv_cache_tensor_meta.calcu_size())
298302
numpy_array = np.frombuffer(
299303
memoryview((ctypes.c_uint8 * self.kv_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), dtype=np.uint8

lightllm/utils/embed_utils.py

Lines changed: 1 addition & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,9 @@
11
import torch
2-
import ctypes
32
import dataclasses
4-
import os
5-
import threading
6-
import time
7-
import numpy as np
8-
import triton
93
from functools import lru_cache
10-
from lightllm.utils.envs_utils import get_env_start_args, enable_huge_page, get_llm_data_type
4+
from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type
115
from lightllm.utils.log_utils import init_logger
126
from lightllm.utils.config_utils import get_hidden_size
13-
from typing import List, Tuple, Optional
14-
from tqdm import tqdm
15-
from lightllm.utils.auto_shm_cleanup import register_sysv_shm_for_cleanup
16-
from lightllm.utils.dist_utils import get_current_device_id
177

188
logger = init_logger(__name__)
199

@@ -64,190 +54,3 @@ def calcu_embed_cache_meta() -> "EmbedCacheMeta":
6454
logger.info(f"embed cache token num: {embed_cache_meta_data.token_num}")
6555

6656
return embed_cache_meta_data
67-
68-
69-
@lru_cache(maxsize=None)
70-
def create_shm_embed_cache_ptr() -> int:
71-
libc = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libc.so.6", use_errno=True)
72-
libc.shmget.argtypes = (ctypes.c_long, ctypes.c_size_t, ctypes.c_int)
73-
libc.shmget.restype = ctypes.c_int
74-
libc.shmat.argtypes = (ctypes.c_int, ctypes.c_void_p, ctypes.c_int)
75-
libc.shmat.restype = ctypes.c_void_p
76-
77-
args = get_env_start_args()
78-
key = args.multi_modal_cache_shm_id
79-
requested_size = calcu_embed_cache_meta().calcu_size()
80-
use_hugetlb = enable_huge_page()
81-
82-
# 计算大页大小(默认从 /proc/meminfo 读取 Hugepagesize)
83-
def _get_default_hugepage_size() -> int:
84-
try:
85-
with open("/proc/meminfo", "r") as f:
86-
for line in f:
87-
if line.startswith("Hugepagesize:"):
88-
parts = line.split()
89-
if len(parts) >= 2:
90-
kb = int(parts[1])
91-
return kb * 1024
92-
except Exception:
93-
pass
94-
return 2 * 1024 * 1024 # fallback 2MB
95-
96-
shmflg = 0o666 | 0o1000 # 权限和 IPC_CREAT 标志
97-
if use_hugetlb:
98-
# 向上对齐到大页大小
99-
huge_sz = _get_default_hugepage_size()
100-
size_to_alloc = triton.cdiv(requested_size, huge_sz) * huge_sz
101-
SHM_HUGETLB = 0o4000
102-
shmflg |= SHM_HUGETLB
103-
logger.info(
104-
f"Using SHM_HUGETLB, hugepage_size={huge_sz} bytes, requested={requested_size}, alloc={size_to_alloc}"
105-
)
106-
else:
107-
size_to_alloc = requested_size
108-
logger.info(f"Using regular pages, requested={requested_size}, alloc={size_to_alloc}")
109-
110-
shmid = libc.shmget(key, size_to_alloc, shmflg)
111-
hugepages_num = (size_to_alloc + 1024 * 1024 * 1024 - 1) // (1024 * 1024 * 1024)
112-
if shmid < 0:
113-
err = ctypes.get_errno()
114-
if use_hugetlb:
115-
raise Exception(
116-
f"shmget with SHM_HUGETLB failed (errno={err}). Falling back to regular pages."
117-
f"You may need to configure hugepages manually, e.g.,"
118-
f"sudo sed -i 's/^GRUB_CMDLINE_LINUX=\"/& default_hugepagesz=1G \
119-
hugepagesz=1G hugepages={hugepages_num}/' /etc/default/grub"
120-
f"sudo update-grub"
121-
f"sudo reboot"
122-
)
123-
else:
124-
raise Exception(f"Error creating regular shared memory (errno={err})")
125-
126-
register_sysv_shm_for_cleanup(key, shmid)
127-
logger.info(f"Shared memory ID: {shmid}")
128-
129-
# 附加共享内存
130-
shm_addr = libc.shmat(shmid, ctypes.c_void_p(0), 0)
131-
if shm_addr == ctypes.c_void_p(-1).value:
132-
raise Exception("Error attaching shared memory")
133-
logger.info(f"Shared cpu kv cache tensor memory at address: {shm_addr}")
134-
135-
# Best-effort memory prefaulting in background to speed up subsequent cudaHostRegister
136-
def _pre_warm_memory():
137-
page_size = _get_default_hugepage_size() if use_hugetlb else 4096
138-
arr = np.ctypeslib.as_array(ctypes.cast(shm_addr, ctypes.POINTER(ctypes.c_uint8)), shape=(size_to_alloc,))
139-
volatile_sum = int(arr[::page_size].sum())
140-
logger.info(f"pre warmed shared memory pages successfully, checksum={volatile_sum})")
141-
142-
th = threading.Thread(target=_pre_warm_memory, name="cpu_cache_pre_warm", daemon=True)
143-
th.start()
144-
145-
return shm_addr
146-
147-
148-
@lru_cache(maxsize=None)
149-
def register_shm_ptr_to_pin(shm_ptr: int, size: int) -> "AsyncRegistrationHandle":
150-
"""Start async cudaHostRegister on the given [shm_ptr, shm_ptr+size) and return a handle."""
151-
chunk_bytes = 128 * 1024 * 1024 # 128M性能最好
152-
tasks: list[tuple[int, int]] = []
153-
offset = 0
154-
while offset < size:
155-
seg_len = min(chunk_bytes, size - offset)
156-
tasks.append((offset, seg_len))
157-
offset += seg_len
158-
159-
handle = AsyncRegistrationHandle(total_tasks=len(tasks))
160-
161-
def _worker():
162-
cuda = ctypes.CDLL("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so")
163-
cuda.cudaHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint]
164-
cuda.cudaHostRegister.restype = ctypes.c_int
165-
cuda.cudaHostGetDevicePointer.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.c_int]
166-
cuda.cudaHostGetDevicePointer.restype = ctypes.c_int
167-
168-
cudaHostRegisterFlag = 3
169-
170-
torch.cuda.set_device(get_current_device_id())
171-
# TODO 这个地方的分块注册是否具备合法性和合理性。
172-
for offset, seg_len in tasks:
173-
ptr = ctypes.c_void_p(shm_ptr + offset)
174-
r = cuda.cudaHostRegister(ptr, ctypes.c_size_t(seg_len), cudaHostRegisterFlag)
175-
if r != 0:
176-
raise Exception(f"cudaHostRegister failed with error code {r}, prefer to use hugetlb")
177-
handle.task_count += 1
178-
179-
device_ptr = ctypes.c_void_p()
180-
host_ptr = ctypes.c_void_p(shm_ptr)
181-
res = cuda.cudaHostGetDevicePointer(ctypes.byref(device_ptr), host_ptr, 0)
182-
if res != 0:
183-
raise Exception(f"cudaHostGetDevicePointer failed with error code {res}")
184-
assert host_ptr.value == device_ptr.value
185-
handle.tasks_finished.set()
186-
187-
th = threading.Thread(target=_worker, name="cpu_cache_register", daemon=True)
188-
handle.thread = th
189-
th.start()
190-
return handle
191-
192-
193-
class AsyncRegistrationHandle:
194-
"""A handle for async host memory registration.
195-
196-
- wait(): blocks until registration finishes, prints tqdm progress, and returns device pointer (int).
197-
"""
198-
199-
def __init__(self, total_tasks: int):
200-
self.total_tasks = total_tasks
201-
self.task_count = 0
202-
self.thread: Optional[threading.Thread] = None
203-
self.tasks_finished = threading.Event()
204-
205-
def wait(self):
206-
"""Block until the async registration completes. Only here we print tqdm progress."""
207-
last_count = 0
208-
desc = f"pid {os.getpid()} Registering pinned host memory (async)"
209-
with tqdm(total=self.total_tasks, desc=desc) as pbar:
210-
while not self.tasks_finished.is_set():
211-
cur = self.task_count
212-
if cur > last_count:
213-
pbar.update(cur - last_count)
214-
last_count = cur
215-
time.sleep(0.01)
216-
# final update
217-
cur = self.task_count
218-
if cur > last_count:
219-
pbar.update(cur - last_count)
220-
last_count = cur
221-
222-
if self.thread is not None and self.thread.is_alive():
223-
self.thread.join()
224-
225-
return
226-
227-
228-
@lru_cache(maxsize=None)
229-
def attach_shm_kv_cache_ptr() -> int:
230-
libc = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libc.so.6", use_errno=True)
231-
libc.shmget.argtypes = (ctypes.c_long, ctypes.c_size_t, ctypes.c_int)
232-
libc.shmget.restype = ctypes.c_int
233-
libc.shmat.argtypes = (ctypes.c_int, ctypes.c_void_p, ctypes.c_int)
234-
libc.shmat.restype = ctypes.c_void_p
235-
236-
# Try to locate an existing SHM without creating a new one
237-
args = get_env_start_args()
238-
key = args.multi_modal_cache_shm_id
239-
shmid = libc.shmget(key, 0, 0)
240-
if shmid < 0:
241-
size = calcu_embed_cache_meta().calcu_size()
242-
shmid = libc.shmget(key, size, 0)
243-
if shmid < 0:
244-
err = ctypes.get_errno()
245-
raise Exception(f"Error locating existing shared memory (errno={err})")
246-
247-
shm_addr = libc.shmat(shmid, ctypes.c_void_p(0), 0)
248-
if shm_addr == ctypes.c_void_p(-1).value:
249-
err = ctypes.get_errno()
250-
raise Exception(f"Error attaching shared memory (errno={err})")
251-
252-
logger.info(f"Attached to SHM key={key}, shmid={shmid}, addr={shm_addr}")
253-
return shm_addr

0 commit comments

Comments
 (0)