Skip to content

Commit ea16d49

Browse files
author
wangzaijun
committed
reformat
1 parent 2cf5c85 commit ea16d49

File tree

1 file changed

+37
-36
lines changed

1 file changed

+37
-36
lines changed

lightllm/utils/kv_cache_utils.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -44,41 +44,6 @@ def compute_token_list_hash(tokens: List[int], cpu_cache_token_page_size: int) -
4444
return chunks_hash_value
4545

4646

47-
class AsyncRegistrationHandle:
48-
"""A handle for async host memory registration.
49-
50-
- wait(): blocks until registration finishes, prints tqdm progress, and returns device pointer (int).
51-
"""
52-
53-
def __init__(self, total_tasks: int):
54-
self.total_tasks = total_tasks
55-
self.task_count = 0
56-
self.thread: Optional[threading.Thread] = None
57-
self.tasks_finished = threading.Event()
58-
59-
def wait(self):
60-
"""Block until the async registration completes. Only here we print tqdm progress."""
61-
last_count = 0
62-
desc = f"pid {os.getpid()} Registering pinned host memory (async)"
63-
with tqdm(total=self.total_tasks, desc=desc) as pbar:
64-
while not self.tasks_finished.is_set():
65-
cur = self.task_count
66-
if cur > last_count:
67-
pbar.update(cur - last_count)
68-
last_count = cur
69-
time.sleep(0.01)
70-
# final update
71-
cur = self.task_count
72-
if cur > last_count:
73-
pbar.update(cur - last_count)
74-
last_count = cur
75-
76-
if self.thread is not None and self.thread.is_alive():
77-
self.thread.join()
78-
79-
return
80-
81-
8247
@lru_cache(maxsize=None)
8348
def calcu_cpu_cache_meta() -> "CpuKVCacheMeta":
8449
args = get_env_start_args()
@@ -202,7 +167,7 @@ def calcu_size(self):
202167

203168

204169
@lru_cache(maxsize=None)
205-
def register_shm_ptr_to_pin(shm_ptr: int, size: int) -> AsyncRegistrationHandle:
170+
def register_shm_ptr_to_pin(shm_ptr: int, size: int) -> "AsyncRegistrationHandle":
206171
"""Start async cudaHostRegister on the given [shm_ptr, shm_ptr+size) and return a handle."""
207172
chunk_bytes = 128 * 1024 * 1024 # 128M性能最好
208173
tasks: list[tuple[int, int]] = []
@@ -224,6 +189,7 @@ def _worker():
224189
cudaHostRegisterDefault = 0
225190

226191
torch.cuda.set_device(get_current_device_id())
192+
# TODO 这个地方的分块注册是否具备合法性和合理性。
227193
for offset, seg_len in tasks:
228194
ptr = ctypes.c_void_p(shm_ptr + offset)
229195
r = cuda.cudaHostRegister(ptr, ctypes.c_size_t(seg_len), cudaHostRegisterDefault)
@@ -244,6 +210,41 @@ def _worker():
244210
return handle
245211

246212

213+
class AsyncRegistrationHandle:
214+
"""A handle for async host memory registration.
215+
216+
- wait(): blocks until registration finishes, prints tqdm progress, and returns device pointer (int).
217+
"""
218+
219+
def __init__(self, total_tasks: int):
220+
self.total_tasks = total_tasks
221+
self.task_count = 0
222+
self.thread: Optional[threading.Thread] = None
223+
self.tasks_finished = threading.Event()
224+
225+
def wait(self):
226+
"""Block until the async registration completes. Only here we print tqdm progress."""
227+
last_count = 0
228+
desc = f"pid {os.getpid()} Registering pinned host memory (async)"
229+
with tqdm(total=self.total_tasks, desc=desc) as pbar:
230+
while not self.tasks_finished.is_set():
231+
cur = self.task_count
232+
if cur > last_count:
233+
pbar.update(cur - last_count)
234+
last_count = cur
235+
time.sleep(0.01)
236+
# final update
237+
cur = self.task_count
238+
if cur > last_count:
239+
pbar.update(cur - last_count)
240+
last_count = cur
241+
242+
if self.thread is not None and self.thread.is_alive():
243+
self.thread.join()
244+
245+
return
246+
247+
247248
@lru_cache(maxsize=None)
248249
def attach_shm_kv_cache_ptr() -> int:
249250
libc = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libc.so.6", use_errno=True)

0 commit comments

Comments
 (0)