Skip to content

Commit 2af966c

Browse files
committed
Modify the test tool code to adapt to the latest code.
1 parent 4907197 commit 2af966c

File tree

2 files changed

+125
-92
lines changed

2 files changed

+125
-92
lines changed

test/test_ucm_connector_save_load.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
UCMConnectorMetadata,
5858
)
5959
from ucm.logger import init_logger
60+
from ucm.store.factory_v1 import UcmConnectorFactoryV1
61+
from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1
6062

6163
logger = init_logger(__name__)
6264

@@ -91,7 +93,7 @@ def make_buffers(
9193
is_mla: bool,
9294
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
9395
logger.info(f"Allocating buffers: blocks={block_number}, batch_size={batch_size}")
94-
hashes = [secrets.token_hex(16) for _ in range(block_number)]
96+
hashes = [secrets.token_bytes(16) for _ in range(block_number)]
9597
device = f"cuda:{device_id}"
9698
kv_caches: Dict[str, torch.Tensor] = {}
9799

@@ -123,8 +125,8 @@ def build_vllm_config(
123125
tp_size: int,
124126
connector_name: str,
125127
storage_backends: str,
126-
transfer_stream_number: int,
127-
use_direct: bool,
128+
stream_number: int,
129+
io_direct: bool,
128130
) -> VllmConfig:
129131
cache_config = CacheConfig(
130132
block_size=block_size,
@@ -189,8 +191,8 @@ def build_vllm_config(
189191
"ucm_connector_name": connector_name,
190192
"ucm_connector_config": {
191193
"storage_backends": storage_backends,
192-
"use_direct": use_direct,
193-
"stream_number": transfer_stream_number,
194+
"io_direct": io_direct,
195+
"stream_number": stream_number,
194196
"local_rank_size": 1,
195197
},
196198
}
@@ -241,6 +243,7 @@ def compute_total_bytes(
241243

242244
def run_once(
243245
connector: UCMConnector,
246+
scheduler: UcmKVStoreBaseV1,
244247
kv_caches: Dict[str, torch.Tensor],
245248
hashes: List[str],
246249
batch_size: int,
@@ -254,7 +257,9 @@ def run_once(
254257
load_block_ids=([], []),
255258
dump_block_ids=(dump_hashes, dump_vllm_block_ids),
256259
)
257-
connector.connector.kv_caches = kv_caches
260+
261+
if not hasattr(connector.connector, "store") or connector.connector.store is None:
262+
connector.connector.register_kv_caches(kv_caches)
258263
connector.bind_connector_metadata(metadata)
259264

260265
total_bytes = compute_total_bytes(kv_caches, batch_size, is_mla)
@@ -267,7 +272,7 @@ def run_once(
267272

268273
write_bw = (total_bytes / (1024**3)) / write_time if write_time > 0 else 0.0
269274

270-
lookup = connector.connector.store.lookup(dump_hashes)
275+
lookup = scheduler.lookup(dump_hashes)
271276
if not all(lookup):
272277
raise RuntimeError("Found missing cache blocks before load test.")
273278

@@ -277,7 +282,7 @@ def run_once(
277282
load_block_ids=(dump_hashes, load_vllm_block_ids),
278283
dump_block_ids=([], []),
279284
)
280-
connector.connector.kv_caches = kv_caches
285+
281286
connector.bind_connector_metadata(load_metadata)
282287

283288
forward_context = build_forward_context(kv_caches, is_mla)
@@ -316,8 +321,8 @@ def run_test(
316321
ucm_connector_name: str,
317322
total_tp_size: int,
318323
model_path: str,
319-
transfer_stream_number: int,
320-
use_direct: bool,
324+
stream_number: int,
325+
io_direct: bool,
321326
) -> Tuple[float, float, float, float, float, float]:
322327
block_dim = head_size * num_head
323328
io_size = block_dim * block_len * block_elem_size
@@ -335,8 +340,8 @@ def run_test(
335340
tp_size=total_tp_size,
336341
connector_name=ucm_connector_name,
337342
storage_backends=storage_backends,
338-
transfer_stream_number=transfer_stream_number,
339-
use_direct=use_direct,
343+
stream_number=stream_number,
344+
io_direct=io_direct,
340345
)
341346

342347
dummy_world_group = type("DummyWorldGroup", (), {"local_rank": 0})()
@@ -375,6 +380,25 @@ def broadcast(self, tensor, src):
375380
mla,
376381
)
377382

383+
connector.connector.register_kv_caches(kv_caches)
384+
385+
storage_backends_list = [
386+
os.path.join(path, "kv") for path in storage_backends.split(":") if path
387+
]
388+
389+
scheduler_config = {
390+
"storage_backends": storage_backends_list,
391+
"block_size": block_size,
392+
"device_id": -1, # device_id=-1 means transferEnable=false
393+
"tensor_size": io_size,
394+
"stream_number": stream_number,
395+
"io_direct": io_direct,
396+
"unique_id": secrets.token_hex(8),
397+
}
398+
scheduler = UcmConnectorFactoryV1.create_connector(
399+
ucm_connector_name, scheduler_config
400+
)
401+
378402
w_sizes, w_times, w_bws = [], [], []
379403
r_sizes, r_times, r_bws = [], [], []
380404

@@ -385,10 +409,10 @@ def broadcast(self, tensor, src):
385409
round_hashes = hashes[start_hash_idx:end_hash_idx]
386410

387411
if len(round_hashes) < batch_size:
388-
round_hashes = [secrets.token_hex(16) for _ in range(batch_size)]
412+
round_hashes = [secrets.token_bytes(16) for _ in range(batch_size)]
389413

390414
(w_size, w_time, w_bw), (r_size, r_time, r_bw) = run_once(
391-
connector, kv_caches, round_hashes, batch_size, mla
415+
connector, scheduler, kv_caches, round_hashes, batch_size, mla
392416
)
393417

394418
if round_idx != 0:
@@ -451,7 +475,7 @@ def main():
451475
num_tokens_list = [2048, 4096, 8192, 16384, 32768]
452476
ucm_connector_name = "UcmNfsStore"
453477
model_path = "/home/models/QwQ-32B"
454-
transfer_stream_numbers = [32, 64, 128]
478+
stream_numbers = [32, 64, 128]
455479
os.environ["UC_LOGGER_LEVEL"] = "debug"
456480

457481
print("1. Model Selection:")
@@ -462,8 +486,8 @@ def main():
462486
print("\n2. IoDirect Transfer:")
463487
print(" 1 - Disable IoDirect (default)")
464488
print(" 2 - Enable IoDirect")
465-
use_direct = get_user_input("Please select Direct IO mode", "1")
466-
use_direct = False if use_direct == "1" else True
489+
io_direct = get_user_input("Please select Direct IO mode", "1")
490+
io_direct = False if io_direct == "1" else True
467491

468492
if mla:
469493
block_lens = [64]
@@ -515,7 +539,7 @@ def main():
515539

516540
for num_head in num_head_list:
517541
for block_len in block_lens:
518-
for transfer_stream_number in transfer_stream_numbers:
542+
for stream_number in stream_numbers:
519543
block_dim = head_size * num_head
520544
io_size = block_dim * block_len * block_elem_size
521545

@@ -548,8 +572,8 @@ def main():
548572
ucm_connector_name,
549573
total_tp_size,
550574
model_path,
551-
transfer_stream_number,
552-
use_direct,
575+
stream_number,
576+
io_direct,
553577
),
554578
)
555579

@@ -579,7 +603,7 @@ def main():
579603
kv,
580604
num_head,
581605
block_len,
582-
transfer_stream_number,
606+
stream_number,
583607
io_count,
584608
io_size,
585609
f"{avg_w_size:.4f}",

0 commit comments

Comments
 (0)