5757 UCMConnectorMetadata ,
5858)
5959from ucm .logger import init_logger
60+ from ucm .store .factory_v1 import UcmConnectorFactoryV1
61+ from ucm .store .ucmstore_v1 import UcmKVStoreBaseV1
6062
6163logger = init_logger (__name__ )
6264
@@ -91,7 +93,7 @@ def make_buffers(
9193 is_mla : bool ,
9294) -> Tuple [List [str ], Dict [str , torch .Tensor ]]:
9395 logger .info (f"Allocating buffers: blocks={ block_number } , batch_size={ batch_size } " )
94- hashes = [secrets .token_hex (16 ) for _ in range (block_number )]
96+ hashes = [secrets .token_bytes (16 ) for _ in range (block_number )]
9597 device = f"cuda:{ device_id } "
9698 kv_caches : Dict [str , torch .Tensor ] = {}
9799
@@ -123,8 +125,8 @@ def build_vllm_config(
123125 tp_size : int ,
124126 connector_name : str ,
125127 storage_backends : str ,
126- transfer_stream_number : int ,
127- use_direct : bool ,
128+ stream_number : int ,
129+ io_direct : bool ,
128130) -> VllmConfig :
129131 cache_config = CacheConfig (
130132 block_size = block_size ,
@@ -189,8 +191,8 @@ def build_vllm_config(
189191 "ucm_connector_name" : connector_name ,
190192 "ucm_connector_config" : {
191193 "storage_backends" : storage_backends ,
192- "use_direct " : use_direct ,
193- "stream_number" : transfer_stream_number ,
194+ "io_direct " : io_direct ,
195+ "stream_number" : stream_number ,
194196 "local_rank_size" : 1 ,
195197 },
196198 }
@@ -241,6 +243,7 @@ def compute_total_bytes(
241243
242244def run_once (
243245 connector : UCMConnector ,
246+ scheduler : UcmKVStoreBaseV1 ,
244247 kv_caches : Dict [str , torch .Tensor ],
245248 hashes : List [str ],
246249 batch_size : int ,
@@ -254,7 +257,9 @@ def run_once(
254257 load_block_ids = ([], []),
255258 dump_block_ids = (dump_hashes , dump_vllm_block_ids ),
256259 )
257- connector .connector .kv_caches = kv_caches
260+
261+ if not hasattr (connector .connector , "store" ) or connector .connector .store is None :
262+ connector .connector .register_kv_caches (kv_caches )
258263 connector .bind_connector_metadata (metadata )
259264
260265 total_bytes = compute_total_bytes (kv_caches , batch_size , is_mla )
@@ -267,7 +272,7 @@ def run_once(
267272
268273 write_bw = (total_bytes / (1024 ** 3 )) / write_time if write_time > 0 else 0.0
269274
270- lookup = connector . connector . store .lookup (dump_hashes )
275+ lookup = scheduler .lookup (dump_hashes )
271276 if not all (lookup ):
272277 raise RuntimeError ("Found missing cache blocks before load test." )
273278
@@ -277,7 +282,7 @@ def run_once(
277282 load_block_ids = (dump_hashes , load_vllm_block_ids ),
278283 dump_block_ids = ([], []),
279284 )
280- connector . connector . kv_caches = kv_caches
285+
281286 connector .bind_connector_metadata (load_metadata )
282287
283288 forward_context = build_forward_context (kv_caches , is_mla )
@@ -316,8 +321,8 @@ def run_test(
316321 ucm_connector_name : str ,
317322 total_tp_size : int ,
318323 model_path : str ,
319- transfer_stream_number : int ,
320- use_direct : bool ,
324+ stream_number : int ,
325+ io_direct : bool ,
321326) -> Tuple [float , float , float , float , float , float ]:
322327 block_dim = head_size * num_head
323328 io_size = block_dim * block_len * block_elem_size
@@ -335,8 +340,8 @@ def run_test(
335340 tp_size = total_tp_size ,
336341 connector_name = ucm_connector_name ,
337342 storage_backends = storage_backends ,
338- transfer_stream_number = transfer_stream_number ,
339- use_direct = use_direct ,
343+ stream_number = stream_number ,
344+ io_direct = io_direct ,
340345 )
341346
342347 dummy_world_group = type ("DummyWorldGroup" , (), {"local_rank" : 0 })()
@@ -375,6 +380,25 @@ def broadcast(self, tensor, src):
375380 mla ,
376381 )
377382
383+ connector .connector .register_kv_caches (kv_caches )
384+
385+ storage_backends_list = [
386+ os .path .join (path , "kv" ) for path in storage_backends .split (":" ) if path
387+ ]
388+
389+ scheduler_config = {
390+ "storage_backends" : storage_backends_list ,
391+ "block_size" : block_size ,
392+ "device_id" : - 1 , # device_id=-1 means transferEnable=false
393+ "tensor_size" : io_size ,
394+ "stream_number" : stream_number ,
395+ "io_direct" : io_direct ,
396+ "unique_id" : secrets .token_hex (8 ),
397+ }
398+ scheduler = UcmConnectorFactoryV1 .create_connector (
399+ ucm_connector_name , scheduler_config
400+ )
401+
378402 w_sizes , w_times , w_bws = [], [], []
379403 r_sizes , r_times , r_bws = [], [], []
380404
@@ -385,10 +409,10 @@ def broadcast(self, tensor, src):
385409 round_hashes = hashes [start_hash_idx :end_hash_idx ]
386410
387411 if len (round_hashes ) < batch_size :
388- round_hashes = [secrets .token_hex (16 ) for _ in range (batch_size )]
412+ round_hashes = [secrets .token_bytes (16 ) for _ in range (batch_size )]
389413
390414 (w_size , w_time , w_bw ), (r_size , r_time , r_bw ) = run_once (
391- connector , kv_caches , round_hashes , batch_size , mla
415+ connector , scheduler , kv_caches , round_hashes , batch_size , mla
392416 )
393417
394418 if round_idx != 0 :
@@ -451,7 +475,7 @@ def main():
451475 num_tokens_list = [2048 , 4096 , 8192 , 16384 , 32768 ]
452476 ucm_connector_name = "UcmNfsStore"
453477 model_path = "/home/models/QwQ-32B"
454- transfer_stream_numbers = [32 , 64 , 128 ]
478+ stream_numbers = [32 , 64 , 128 ]
455479 os .environ ["UC_LOGGER_LEVEL" ] = "debug"
456480
457481 print ("1. Model Selection:" )
@@ -462,8 +486,8 @@ def main():
462486 print ("\n 2. IoDirect Transfer:" )
463487 print (" 1 - Disable IoDirect (default)" )
464488 print (" 2 - Enable IoDirect" )
465- use_direct = get_user_input ("Please select Direct IO mode" , "1" )
466- use_direct = False if use_direct == "1" else True
489+ io_direct = get_user_input ("Please select Direct IO mode" , "1" )
490+ io_direct = False if io_direct == "1" else True
467491
468492 if mla :
469493 block_lens = [64 ]
@@ -515,7 +539,7 @@ def main():
515539
516540 for num_head in num_head_list :
517541 for block_len in block_lens :
518- for transfer_stream_number in transfer_stream_numbers :
542+ for stream_number in stream_numbers :
519543 block_dim = head_size * num_head
520544 io_size = block_dim * block_len * block_elem_size
521545
@@ -548,8 +572,8 @@ def main():
548572 ucm_connector_name ,
549573 total_tp_size ,
550574 model_path ,
551- transfer_stream_number ,
552- use_direct ,
575+ stream_number ,
576+ io_direct ,
553577 ),
554578 )
555579
@@ -579,7 +603,7 @@ def main():
579603 kv ,
580604 num_head ,
581605 block_len ,
582- transfer_stream_number ,
606+ stream_number ,
583607 io_count ,
584608 io_size ,
585609 f"{ avg_w_size :.4f} " ,
0 commit comments