|
56 | 56 | UCMConnector, |
57 | 57 | UCMConnectorMetadata, |
58 | 58 | ) |
| 59 | +from ucm.logger import init_logger |
59 | 60 |
|
60 | | - |
61 | | -def log(msg: str): |
62 | | - timestamp = time.strftime("%Y-%m-%d %H:%M:%S") |
63 | | - print(f"[{timestamp}] {msg}", flush=True) |
| 61 | +logger = init_logger(__name__) |
64 | 62 |
|
65 | 63 |
|
66 | 64 | def make_aligned_tensor(shape, dtype, device, alignment=4096): |
@@ -92,7 +90,7 @@ def make_buffers( |
92 | 90 | kv: int, |
93 | 91 | is_mla: bool, |
94 | 92 | ) -> Tuple[List[str], Dict[str, torch.Tensor]]: |
95 | | - log(f"Allocating buffers: blocks={block_number}, batch_size={batch_size}") |
| 93 | + logger.info(f"Allocating buffers: blocks={block_number}, batch_size={batch_size}") |
96 | 94 | hashes = [secrets.token_hex(16) for _ in range(block_number)] |
97 | 95 | device = f"cuda:{device_id}" |
98 | 96 | kv_caches: Dict[str, torch.Tensor] = {} |
@@ -373,7 +371,7 @@ def broadcast(self, tensor, src): |
373 | 371 | r_sizes, r_times, r_bws = [], [], [] |
374 | 372 |
|
375 | 373 | for round_idx in range(repeat): |
376 | | - log(f"Round {round_idx + 1}: start write test") |
| 374 | + logger.info(f"Round {round_idx + 1}: start write test") |
377 | 375 | start_hash_idx = round_idx * batch_size |
378 | 376 | end_hash_idx = start_hash_idx + batch_size |
379 | 377 | round_hashes = hashes[start_hash_idx:end_hash_idx] |
@@ -401,7 +399,7 @@ def avg(values: List[float]) -> float: |
401 | 399 | avg_r_time = avg(r_times) |
402 | 400 | avg_r_bw = avg(r_bws) |
403 | 401 |
|
404 | | - log( |
| 402 | + logger.info( |
405 | 403 | "\n=== Summary ===\n" |
406 | 404 | f"Write : size={avg_w_size:.4f} GB | time={avg_w_time:.4f} s | bw={avg_w_bw:.4f} GB/s\n" |
407 | 405 | f"Read : size={avg_r_size:.4f} GB | time={avg_r_time:.4f} s | bw={avg_r_bw:.4f} GB/s\n" |
|
0 commit comments