@@ -44,41 +44,6 @@ def compute_token_list_hash(tokens: List[int], cpu_cache_token_page_size: int) -
4444 return chunks_hash_value
4545
4646
47- class AsyncRegistrationHandle :
48- """A handle for async host memory registration.
49-
50- - wait(): blocks until registration finishes, prints tqdm progress, and returns device pointer (int).
51- """
52-
53- def __init__ (self , total_tasks : int ):
54- self .total_tasks = total_tasks
55- self .task_count = 0
56- self .thread : Optional [threading .Thread ] = None
57- self .tasks_finished = threading .Event ()
58-
59- def wait (self ):
60- """Block until the async registration completes. Only here we print tqdm progress."""
61- last_count = 0
62- desc = f"pid { os .getpid ()} Registering pinned host memory (async)"
63- with tqdm (total = self .total_tasks , desc = desc ) as pbar :
64- while not self .tasks_finished .is_set ():
65- cur = self .task_count
66- if cur > last_count :
67- pbar .update (cur - last_count )
68- last_count = cur
69- time .sleep (0.01 )
70- # final update
71- cur = self .task_count
72- if cur > last_count :
73- pbar .update (cur - last_count )
74- last_count = cur
75-
76- if self .thread is not None and self .thread .is_alive ():
77- self .thread .join ()
78-
79- return
80-
81-
8247@lru_cache (maxsize = None )
8348def calcu_cpu_cache_meta () -> "CpuKVCacheMeta" :
8449 args = get_env_start_args ()
@@ -202,7 +167,7 @@ def calcu_size(self):
202167
203168
204169@lru_cache (maxsize = None )
205- def register_shm_ptr_to_pin (shm_ptr : int , size : int ) -> AsyncRegistrationHandle :
170+ def register_shm_ptr_to_pin (shm_ptr : int , size : int ) -> " AsyncRegistrationHandle" :
206171 """Start async cudaHostRegister on the given [shm_ptr, shm_ptr+size) and return a handle."""
207172 chunk_bytes = 128 * 1024 * 1024 # 128M性能最好
208173 tasks : list [tuple [int , int ]] = []
@@ -224,6 +189,7 @@ def _worker():
224189 cudaHostRegisterDefault = 0
225190
226191 torch .cuda .set_device (get_current_device_id ())
192+ # TODO 这个地方的分块注册是否具备合法性和合理性。
227193 for offset , seg_len in tasks :
228194 ptr = ctypes .c_void_p (shm_ptr + offset )
229195 r = cuda .cudaHostRegister (ptr , ctypes .c_size_t (seg_len ), cudaHostRegisterDefault )
@@ -244,6 +210,41 @@ def _worker():
244210 return handle
245211
246212
213+ class AsyncRegistrationHandle :
214+ """A handle for async host memory registration.
215+
216+ - wait(): blocks until registration finishes, prints tqdm progress, and returns device pointer (int).
217+ """
218+
219+ def __init__ (self , total_tasks : int ):
220+ self .total_tasks = total_tasks
221+ self .task_count = 0
222+ self .thread : Optional [threading .Thread ] = None
223+ self .tasks_finished = threading .Event ()
224+
225+ def wait (self ):
226+ """Block until the async registration completes. Only here we print tqdm progress."""
227+ last_count = 0
228+ desc = f"pid { os .getpid ()} Registering pinned host memory (async)"
229+ with tqdm (total = self .total_tasks , desc = desc ) as pbar :
230+ while not self .tasks_finished .is_set ():
231+ cur = self .task_count
232+ if cur > last_count :
233+ pbar .update (cur - last_count )
234+ last_count = cur
235+ time .sleep (0.01 )
236+ # final update
237+ cur = self .task_count
238+ if cur > last_count :
239+ pbar .update (cur - last_count )
240+ last_count = cur
241+
242+ if self .thread is not None and self .thread .is_alive ():
243+ self .thread .join ()
244+
245+ return
246+
247+
247248@lru_cache (maxsize = None )
248249def attach_shm_kv_cache_ptr () -> int :
249250 libc = ctypes .CDLL ("/usr/lib/x86_64-linux-gnu/libc.so.6" , use_errno = True )
0 commit comments