|
1 | 1 | import torch |
2 | | -import ctypes |
3 | 2 | import dataclasses |
4 | | -import os |
5 | | -import threading |
6 | | -import time |
7 | | -import numpy as np |
8 | | -import triton |
9 | 3 | from functools import lru_cache |
10 | | -from lightllm.utils.envs_utils import get_env_start_args, enable_huge_page, get_llm_data_type |
| 4 | +from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type |
11 | 5 | from lightllm.utils.log_utils import init_logger |
12 | 6 | from lightllm.utils.config_utils import get_hidden_size |
13 | | -from typing import List, Tuple, Optional |
14 | | -from tqdm import tqdm |
15 | | -from lightllm.utils.auto_shm_cleanup import register_sysv_shm_for_cleanup |
16 | | -from lightllm.utils.dist_utils import get_current_device_id |
17 | 7 |
|
18 | 8 | logger = init_logger(__name__) |
19 | 9 |
|
@@ -64,190 +54,3 @@ def calcu_embed_cache_meta() -> "EmbedCacheMeta": |
64 | 54 | logger.info(f"embed cache token num: {embed_cache_meta_data.token_num}") |
65 | 55 |
|
66 | 56 | return embed_cache_meta_data |
67 | | - |
68 | | - |
69 | | -@lru_cache(maxsize=None) |
70 | | -def create_shm_embed_cache_ptr() -> int: |
71 | | - libc = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libc.so.6", use_errno=True) |
72 | | - libc.shmget.argtypes = (ctypes.c_long, ctypes.c_size_t, ctypes.c_int) |
73 | | - libc.shmget.restype = ctypes.c_int |
74 | | - libc.shmat.argtypes = (ctypes.c_int, ctypes.c_void_p, ctypes.c_int) |
75 | | - libc.shmat.restype = ctypes.c_void_p |
76 | | - |
77 | | - args = get_env_start_args() |
78 | | - key = args.multi_modal_cache_shm_id |
79 | | - requested_size = calcu_embed_cache_meta().calcu_size() |
80 | | - use_hugetlb = enable_huge_page() |
81 | | - |
82 | | - # 计算大页大小(默认从 /proc/meminfo 读取 Hugepagesize) |
83 | | - def _get_default_hugepage_size() -> int: |
84 | | - try: |
85 | | - with open("/proc/meminfo", "r") as f: |
86 | | - for line in f: |
87 | | - if line.startswith("Hugepagesize:"): |
88 | | - parts = line.split() |
89 | | - if len(parts) >= 2: |
90 | | - kb = int(parts[1]) |
91 | | - return kb * 1024 |
92 | | - except Exception: |
93 | | - pass |
94 | | - return 2 * 1024 * 1024 # fallback 2MB |
95 | | - |
96 | | - shmflg = 0o666 | 0o1000 # 权限和 IPC_CREAT 标志 |
97 | | - if use_hugetlb: |
98 | | - # 向上对齐到大页大小 |
99 | | - huge_sz = _get_default_hugepage_size() |
100 | | - size_to_alloc = triton.cdiv(requested_size, huge_sz) * huge_sz |
101 | | - SHM_HUGETLB = 0o4000 |
102 | | - shmflg |= SHM_HUGETLB |
103 | | - logger.info( |
104 | | - f"Using SHM_HUGETLB, hugepage_size={huge_sz} bytes, requested={requested_size}, alloc={size_to_alloc}" |
105 | | - ) |
106 | | - else: |
107 | | - size_to_alloc = requested_size |
108 | | - logger.info(f"Using regular pages, requested={requested_size}, alloc={size_to_alloc}") |
109 | | - |
110 | | - shmid = libc.shmget(key, size_to_alloc, shmflg) |
111 | | - hugepages_num = (size_to_alloc + 1024 * 1024 * 1024 - 1) // (1024 * 1024 * 1024) |
112 | | - if shmid < 0: |
113 | | - err = ctypes.get_errno() |
114 | | - if use_hugetlb: |
115 | | - raise Exception( |
116 | | - f"shmget with SHM_HUGETLB failed (errno={err}). Falling back to regular pages." |
117 | | - f"You may need to configure hugepages manually, e.g.," |
118 | | - f"sudo sed -i 's/^GRUB_CMDLINE_LINUX=\"/& default_hugepagesz=1G \ |
119 | | - hugepagesz=1G hugepages={hugepages_num}/' /etc/default/grub" |
120 | | - f"sudo update-grub" |
121 | | - f"sudo reboot" |
122 | | - ) |
123 | | - else: |
124 | | - raise Exception(f"Error creating regular shared memory (errno={err})") |
125 | | - |
126 | | - register_sysv_shm_for_cleanup(key, shmid) |
127 | | - logger.info(f"Shared memory ID: {shmid}") |
128 | | - |
129 | | - # 附加共享内存 |
130 | | - shm_addr = libc.shmat(shmid, ctypes.c_void_p(0), 0) |
131 | | - if shm_addr == ctypes.c_void_p(-1).value: |
132 | | - raise Exception("Error attaching shared memory") |
133 | | - logger.info(f"Shared cpu kv cache tensor memory at address: {shm_addr}") |
134 | | - |
135 | | - # Best-effort memory prefaulting in background to speed up subsequent cudaHostRegister |
136 | | - def _pre_warm_memory(): |
137 | | - page_size = _get_default_hugepage_size() if use_hugetlb else 4096 |
138 | | - arr = np.ctypeslib.as_array(ctypes.cast(shm_addr, ctypes.POINTER(ctypes.c_uint8)), shape=(size_to_alloc,)) |
139 | | - volatile_sum = int(arr[::page_size].sum()) |
140 | | - logger.info(f"pre warmed shared memory pages successfully, checksum={volatile_sum})") |
141 | | - |
142 | | - th = threading.Thread(target=_pre_warm_memory, name="cpu_cache_pre_warm", daemon=True) |
143 | | - th.start() |
144 | | - |
145 | | - return shm_addr |
146 | | - |
147 | | - |
148 | | -@lru_cache(maxsize=None) |
149 | | -def register_shm_ptr_to_pin(shm_ptr: int, size: int) -> "AsyncRegistrationHandle": |
150 | | - """Start async cudaHostRegister on the given [shm_ptr, shm_ptr+size) and return a handle.""" |
151 | | - chunk_bytes = 128 * 1024 * 1024 # 128M性能最好 |
152 | | - tasks: list[tuple[int, int]] = [] |
153 | | - offset = 0 |
154 | | - while offset < size: |
155 | | - seg_len = min(chunk_bytes, size - offset) |
156 | | - tasks.append((offset, seg_len)) |
157 | | - offset += seg_len |
158 | | - |
159 | | - handle = AsyncRegistrationHandle(total_tasks=len(tasks)) |
160 | | - |
161 | | - def _worker(): |
162 | | - cuda = ctypes.CDLL("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so") |
163 | | - cuda.cudaHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint] |
164 | | - cuda.cudaHostRegister.restype = ctypes.c_int |
165 | | - cuda.cudaHostGetDevicePointer.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.c_int] |
166 | | - cuda.cudaHostGetDevicePointer.restype = ctypes.c_int |
167 | | - |
168 | | - cudaHostRegisterFlag = 3 |
169 | | - |
170 | | - torch.cuda.set_device(get_current_device_id()) |
171 | | - # TODO 这个地方的分块注册是否具备合法性和合理性。 |
172 | | - for offset, seg_len in tasks: |
173 | | - ptr = ctypes.c_void_p(shm_ptr + offset) |
174 | | - r = cuda.cudaHostRegister(ptr, ctypes.c_size_t(seg_len), cudaHostRegisterFlag) |
175 | | - if r != 0: |
176 | | - raise Exception(f"cudaHostRegister failed with error code {r}, prefer to use hugetlb") |
177 | | - handle.task_count += 1 |
178 | | - |
179 | | - device_ptr = ctypes.c_void_p() |
180 | | - host_ptr = ctypes.c_void_p(shm_ptr) |
181 | | - res = cuda.cudaHostGetDevicePointer(ctypes.byref(device_ptr), host_ptr, 0) |
182 | | - if res != 0: |
183 | | - raise Exception(f"cudaHostGetDevicePointer failed with error code {res}") |
184 | | - assert host_ptr.value == device_ptr.value |
185 | | - handle.tasks_finished.set() |
186 | | - |
187 | | - th = threading.Thread(target=_worker, name="cpu_cache_register", daemon=True) |
188 | | - handle.thread = th |
189 | | - th.start() |
190 | | - return handle |
191 | | - |
192 | | - |
193 | | -class AsyncRegistrationHandle: |
194 | | - """A handle for async host memory registration. |
195 | | -
|
196 | | - - wait(): blocks until registration finishes, prints tqdm progress, and returns device pointer (int). |
197 | | - """ |
198 | | - |
199 | | - def __init__(self, total_tasks: int): |
200 | | - self.total_tasks = total_tasks |
201 | | - self.task_count = 0 |
202 | | - self.thread: Optional[threading.Thread] = None |
203 | | - self.tasks_finished = threading.Event() |
204 | | - |
205 | | - def wait(self): |
206 | | - """Block until the async registration completes. Only here we print tqdm progress.""" |
207 | | - last_count = 0 |
208 | | - desc = f"pid {os.getpid()} Registering pinned host memory (async)" |
209 | | - with tqdm(total=self.total_tasks, desc=desc) as pbar: |
210 | | - while not self.tasks_finished.is_set(): |
211 | | - cur = self.task_count |
212 | | - if cur > last_count: |
213 | | - pbar.update(cur - last_count) |
214 | | - last_count = cur |
215 | | - time.sleep(0.01) |
216 | | - # final update |
217 | | - cur = self.task_count |
218 | | - if cur > last_count: |
219 | | - pbar.update(cur - last_count) |
220 | | - last_count = cur |
221 | | - |
222 | | - if self.thread is not None and self.thread.is_alive(): |
223 | | - self.thread.join() |
224 | | - |
225 | | - return |
226 | | - |
227 | | - |
228 | | -@lru_cache(maxsize=None) |
229 | | -def attach_shm_kv_cache_ptr() -> int: |
230 | | - libc = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libc.so.6", use_errno=True) |
231 | | - libc.shmget.argtypes = (ctypes.c_long, ctypes.c_size_t, ctypes.c_int) |
232 | | - libc.shmget.restype = ctypes.c_int |
233 | | - libc.shmat.argtypes = (ctypes.c_int, ctypes.c_void_p, ctypes.c_int) |
234 | | - libc.shmat.restype = ctypes.c_void_p |
235 | | - |
236 | | - # Try to locate an existing SHM without creating a new one |
237 | | - args = get_env_start_args() |
238 | | - key = args.multi_modal_cache_shm_id |
239 | | - shmid = libc.shmget(key, 0, 0) |
240 | | - if shmid < 0: |
241 | | - size = calcu_embed_cache_meta().calcu_size() |
242 | | - shmid = libc.shmget(key, size, 0) |
243 | | - if shmid < 0: |
244 | | - err = ctypes.get_errno() |
245 | | - raise Exception(f"Error locating existing shared memory (errno={err})") |
246 | | - |
247 | | - shm_addr = libc.shmat(shmid, ctypes.c_void_p(0), 0) |
248 | | - if shm_addr == ctypes.c_void_p(-1).value: |
249 | | - err = ctypes.get_errno() |
250 | | - raise Exception(f"Error attaching shared memory (errno={err})") |
251 | | - |
252 | | - logger.info(f"Attached to SHM key={key}, shmid={shmid}, addr={shm_addr}") |
253 | | - return shm_addr |
0 commit comments