Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions examples/ucm_config_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@ ucm_connectors:

load_only_first_rank: false

metrics_config_path: "/vllm-workspace/metrics_config.yaml"
# Enable UCM metrics; metrics can be viewed online via Grafana and Prometheus
# metrics_config_path: "/workspace/unified-cache-management/examples/metrics/metrics_configs.yaml"

# UCM operation recording configuration, whether to write UCM dump/load logs to a file
record_config:
enable: false
log_path: "/workspace/ucm_ops.log"
flush_size: 10
flush_interval: 5.0

# Sparse attention configuration
# Format 1: Dictionary format (for methods like ESA, KvComp)
Expand All @@ -33,5 +41,4 @@ metrics_config_path: "/vllm-workspace/metrics_config.yaml"

# Whether to use layerwise loading/saving (optional, default: True for UnifiedCacheConnectorV1)
# use_layerwise: true
# hit_ratio: 0.9

# hit_ratio: 0.9
80 changes: 79 additions & 1 deletion ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import hashlib
import itertools
import json
import os
import pickle
import queue
import threading
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Any, Callable, List, Optional

import torch
from vllm.config import VllmConfig
Expand Down Expand Up @@ -166,6 +169,67 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
config["kv_block_size"] / 1024 / 1024,
config["io_size"] / 1024,
)
self.record_config = self.launch_config.get("record_config", {})
self.record_oper: bool = self.record_config.get("enable", False)
if self.record_oper:
self.write_thread = threading.Thread(
target=self._async_record_loop, daemon=True
)
self.write_thread.start()

def log_operation(self, operation_data: dict[str, Any]) -> None:
"""Record operation log (non-blocking)"""

default_data = {
"timestamp": time.time(),
"op_type": "None",
"block_size": self.block_size,
}
log_entry = {**default_data, **operation_data}

try:
self.log_queue.put_nowait(log_entry)
except queue.Full:
logger.error(
f"Log queue is full, dropping one log: {log_entry.get('request_id')}"
)

def _async_record_loop(self):
self.log_queue = queue.Queue(maxsize=10000) # Max cache: 10000 entries
log_path = self.record_config.get(
"log_path", "/vllm-workspace/ucm_logs/ucm_ops.log"
)
flush_size = self.record_config.get("flush_size", 100)
flush_interval = self.record_config.get("flush_interval", 5.0)
batch_buffer = []
last_flush_time = time.time()
while True:
try:
# Get log from queue (1 second timeout)
is_flush = False
current_time = time.time()
log_entry = self.log_queue.get(timeout=1.0)
batch_buffer.append(log_entry)

# Flush if conditions are met
if (
len(batch_buffer) >= flush_size
or (current_time - last_flush_time) >= flush_interval
):
is_flush = True
last_flush_time = current_time
self.log_queue.task_done()
except queue.Empty:
if (current_time - last_flush_time) >= flush_interval:
last_flush_time = current_time
except Exception as e:
logger.error(f"Log thread exception: {str(e)}")

if is_flush:
with open(log_path, "a", encoding="utf-8") as f:
for log_entry in batch_buffer:
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
batch_buffer.clear()

self.metrics_config = self.launch_config.get("metrics_config_path", "")
if self.metrics_config:
Expand Down Expand Up @@ -507,6 +571,13 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
request_to_task[request_id] = self.store.load(
ucm_total_block_ids, ucm_offsets, dst_tensor_addr
)
if self.record_oper:
self.log_operation(
{
"op_type": "load",
"blocks": ucm_block_ids,
}
)
else:
request_to_task[request_id] = None
req_broadcast_addr[request_id] = dst_tensor_addr
Expand Down Expand Up @@ -597,6 +668,13 @@ def wait_for_save(self) -> None:
request_to_task[request_id] = self.store.dump(
ucm_total_block_ids, ucm_offsets, dst_tensor_addr
)
if self.record_oper:
self.log_operation(
{
"op_type": "dump",
"blocks": ucm_block_ids,
}
)
request_to_blocks[request_id] = ucm_block_ids

for request_id, task in request_to_task.items():
Expand Down