diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
new file mode 100644
index 000000000..0236df923
--- /dev/null
+++ b/.github/CODEOWNERS
@@ -0,0 +1,38 @@
+# See https://help.github.com/articles/about-codeowners/
+# for more info about CODEOWNERS file
+
+* @mag1c-h @ygwpz @FangRun2 @Tarrei
+/.github @Wwwzff @hek14 @ygwpz @mag1c-h @FangRun2 @Tarrei
+
+/ucm/sparse @wuhuxiao @wangwenxin0312 @hek14 @ygwpz @mag1c-h
+/ucm/sparse/cache_blend @wuhuxiao @hek14 @ygwpz @mag1c-h
+/ucm/sparse/esa @wangwenxin0312 @hek14 @ygwpz @mag1c-h
+/ucm/sparse/gsa @Zbm1996 @zbb200819 @yxkyong @HaoLi980405 @wuhuxiao @hek14 @ygwpz @mag1c-h
+/ucm/sparse/kvcomp @leideng @pengwwang @wuhuxiao @hek14 @ygwpz @mag1c-h
+/ucm/sparse/kvstar @saki-daisuki @summer-ai007 @xwLearnsLLM @wuhuxiao @hek14 @ygwpz @mag1c-h
+
+/ucm/store @mag1c-h @ygwpz
+/ucm/store/dramstore @harrisonyhq @mag1c-h @ygwpz
+/ucm/store/localstore @mag1c-h @ygwpz
+/ucm/store/mooncakestore @chinesezyc @mag1c-h @ygwpz
+/ucm/store/nfsstore @mag1c-h @ygwpz
+
+/ucm/integration @qyh111 @harrisonyhq @ygwpz @mag1c-h @hek14
+
+/ucm/pd @flesher0813 @ygwpz @mag1c-h
+
+/ucm/sandbox @Wwwzff @hek14 @ygwpz @mag1c-h @FangRun2 @Tarrei
+
+/benchmarks @flesher0813 @ygwpz @mag1c-h
+
+/docker @harrisonyhq @ygwpz @mag1c-h
+
+/docs @flesher0813 @ygwpz @mag1c-h @FangRun2 @Tarrei @hek14
+/docs/source/user-guide/sparse-attention/esa.md @wangwenxin0312 @hek14 @flesher0813 @ygwpz @mag1c-h @FangRun2 @Tarrei
+/docs/source/user-guide/sparse-attention/gsa.md @Zbm1996 @zbb200819 @yxkyong @HaoLi980405 @flesher0813 @ygwpz @mag1c-h @FangRun2 @Tarrei
+/docs/source/user-guide/sparse-attention/kvcomp.md @leideng @pengwwang @flesher0813 @ygwpz @mag1c-h @FangRun2 @Tarrei
+/docs/source/user-guide/sparse-attention/kvstar.md @saki-daisuki @summer-ai007 @flesher0813 @ygwpz @mag1c-h @FangRun2 @Tarrei
+
+/examples @harrisonyhq @ygwpz @mag1c-h @hek14
+
+/test @Wwwzff @ygwpz @mag1c-h
diff --git a/.github/workflows/cpp-linter.yml b/.github/workflows/cpp-linter.yml
new file mode 100644
index 000000000..e013a62c7
--- /dev/null
+++ b/.github/workflows/cpp-linter.yml
@@ -0,0 +1,34 @@
+name: cpp-linter
+
+on:
+ push:
+ branches: [ "*" ]
+ pull_request:
+ branches: [ "dev*", "main", "*release" ]
+
+
+jobs:
+ cpp-linter:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
+ with:
+ persist-credentials: false
+ - uses: cpp-linter/cpp-linter-action@main
+ id: linter
+ continue-on-error: true
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ style: file
+ tidy-checks: '-*'
+ files-changed-only: true
+ lines-changed-only: diff
+ format-review: true
+ thread-comments: ${{ github.event_name == 'pull_request' && 'update' }}
+
+ - name: Fail fast?!
+ if: steps.linter.outputs.checks-failed != 0
+ run: |
+ echo "some linter checks failed. ${{ steps.linter.outputs.checks-failed }}"
+ exit 1
diff --git a/.github/workflows/unifiedcache_test.yml b/.github/workflows/unifiedcache_test.yml
index 91bb0db15..aa0dee2ad 100644
--- a/.github/workflows/unifiedcache_test.yml
+++ b/.github/workflows/unifiedcache_test.yml
@@ -49,7 +49,9 @@ jobs:
set -euo pipefail
pip install -v -e . --no-build-isolation
cd \$(pip show vllm | grep Location | awk '{print \$2}') &&
- git apply /workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch
+ git apply /workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-pc.patch
+ git apply /workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-aggre.patch
+ git apply /workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch
cd /workspace/unified-cache-management
python3 -m unittest discover -s test
"
diff --git a/.gitignore b/.gitignore
index 8ff7d5c29..734cf4bf1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -49,4 +49,15 @@
**/output/**
.venv/**
**/__pycache__/**
-*.egg-info/**
\ No newline at end of file
+*.egg-info/**
+reports/
+dataset/
+logs/
+.*
+*.log
+result_outputs/
+results/
+.cache/
+backup/
+$null
+*__pycache__/
\ No newline at end of file
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 788dc15ce..b72407387 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -9,7 +9,10 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
option(BUILD_UCM_STORE "build ucm store module." ON)
option(BUILD_UCM_SPARSE "build ucm sparse module." ON)
option(BUILD_UNIT_TESTS "build all unit test suits." OFF)
-set(RUNTIME_ENVIRONMENT "simu" CACHE STRING "runtime: simu, ascend or cuda.")
+option(BUILD_NUMA "build numactl library." OFF)
+option(DOWNLOAD_DEPENDENCE "download dependence by cmake." ON)
+set(RUNTIME_ENVIRONMENT "simu" CACHE STRING "runtime: simu, ascend, musa or cuda.")
+set(LOGGER_BACKEND "spdlog" CACHE STRING "backend: spdlog or flux.")
execute_process(COMMAND git rev-parse HEAD OUTPUT_VARIABLE UCM_COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE)
add_compile_definitions(UCM_PROJECT_NAME="${PROJECT_NAME}")
diff --git a/README.md b/README.md
index 421964e65..42ab3a257 100644
--- a/README.md
+++ b/README.md
@@ -6,7 +6,7 @@
-| Documentation | Website | RoadMap | 中文 |
+| Documentation | Website | RoadMap | 中文 |
---
@@ -82,9 +82,10 @@ please refer to [Quick Start](./docs/source/getting-started/quick_start.md).
---
## Contact Us
+1. For technical questions and feature requests, please use GitHub [Issues](https://github.com/ModelEngine-Group/unified-cache-management/issues).
+2. WeChat technical discussion group: Scan the QR code below.
-For technical questions and feature requests, please use
-GitHub [Issues](https://github.com/ModelEngine-Group/unified-cache-management/issues).
+
## License
diff --git a/README_zh.md b/README_zh.md
index 9eba04e34..6a5bee04d 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -6,7 +6,7 @@
-| 文档 | 网站 | 发展路线图 | EN |
+| 文档 | 网站 | 发展路线图 | EN |
---
@@ -62,7 +62,10 @@ KVStoreBase有助于实现稀疏算法与外部存储的解耦。它定义了与
---
## 联系我们
-如需技术咨询或功能请求,请提交 GitHub [Issues](https://github.com/ModelEngine-Group/unified-cache-management/issues).
+1. 如需技术咨询或功能请求,请提交 GitHub [Issues](https://github.com/ModelEngine-Group/unified-cache-management/issues)。
+2. 微信技术交流群:扫描下方二维码。
+
+
## 许可协议
diff --git a/benchmarks/README.md b/benchmarks/README.md
new file mode 100644
index 000000000..8c5f5f70b
--- /dev/null
+++ b/benchmarks/README.md
@@ -0,0 +1,128 @@
+# TraceReplay Benchmark Tool
+
+It accurately replays real-world request traces with original timing or dynamically generates requests using popular datasets. The tool delivers comprehensive performance metrics—including Time to First Token (TTFT), Time Per Output Token (TPOT), Inter-Token Latency (ITL), End-to-End Latency, Goodput, etc.
+
+---
+
+## 1. Overview
+
+The Trace Replay feature mainly includes request generation, request sending and response receiving, as well as result calculation and saving. It can reproduce historical requests based on MoonCake's trace file and strictly send the requests according to the timestamps recorded in the trace. After execution, Trace Replay calculates key performance metrics such as Time to First Token (TTFT) and Time Per Output Token (TPOT), then outputs the results to the terminal and saves them to an Excel file.
+
+[Mooncake traces](https://github.com/kvcache-ai/Mooncake/tree/main/FAST25-release/traces) consist of two types of trace data:
+* Conversation and Tool&Agent trace: Sampled from one hour of online request data from different clusters.
+* Synthetic trace: Generated synthetically from other publicly available datasets.
+
+For more information, please refer to the Mooncake paper: [Mooncake-FAST25.pdf](https://github.com/kvcache-ai/Mooncake/blob/main/FAST25-release/Mooncake-FAST25.pdf) .
+
+Trace Replay supports two request-generation methods:
+* Hash ID-based: Input tokens are generated based on the input_length and hash_ids recorded in the trace file. Each hash_id corresponds to a block, with each block containing 512 tokens. The same hash_id always maps to the identical token sequence.
+* Dataset-based: Prompts are generated by invoking vLLM's benchmark module using the input_length from the trace file and the user-specified dataset name. This approach does not rely on the hash_ids present in the trace file.
+
+Depending on the request generation method, Trace Replay offers two modes: Trace Mode and Benchmark Mode, which can be configured by the user via the --trace-mode parameter.
+
+---
+
+## 2. Parameter
+
+| Argument | Default | Help |
+|-----------|---------|---------|
+| --backend | None | backend framework type |
+| --model | None | model path |
+| --host| localhost | IP address of the inference server |
+| --port | None | Port number of the inference server |
+| --trace-path | None | trace jsonl file path |
+| --trace-mode | trace | 'trace' to replay requests from cached trace files, 'benchmark' to generate requests dynamically using the benchmark module |
+| --dataset-name | sharegpt | if enable benchmark mode, you must specify a dataset, refer to the [vLLM benchmark documentation](https://github.com/vllm-project/vllm/blob/releases/v0.9.1/benchmarks/README.md )|
+| --save-prompts | False | save generated prompts with timestamp for reuse |
+| --save-result | False | save the benchmark metrics to excel file |
+| --result-dir | None | path to save results |
+
+---
+
+## 3. Example
+
+### 1. Download example trace
+
+You need to download the trace jsonl file from [Mooncake traces](https://github.com/kvcache-ai/Mooncake/tree/main/FAST25-release/traces ). In the trace, each line is a JSON object representing a single request:
+
+```
+{
+ "timestamp": 1696000000123, // ms since epoch
+ "input_length": 512, // number of input tokens
+ "output_length": 128, // expected output tokens
+ "hash_ids": [123, 456, 789] // seed list for deterministic prompt generation
+}
+```
+
+### 2. Set environment variable
+
+Trace Replay depends on [vLLM's benchmark](https://github.com/vllm-project/vllm/tree/main/benchmarks) module, which you need to download separately. Before running Trace Replay, you must set the path to the benchmark module via an environment variable.:
+
+```bash
+export BENCHMARK_PATH="/vllm-workspace/benchmarks"
+```
+
+### 3.Basic Usage
+
+Execute the Python script to replay a trace against a local vLLM instance:
+
+```bash
+python3 /trace_replay.py \
+ --model /home/models/dsv2-lite \
+ --backend vllm \
+ --trace-path /conversation_trace.jsonl \
+ --trace-mode trace \
+ --host 127.0.0.1 \
+ --port 8000 \
+ --save-result \
+ --save-prompts
+```
+
+### 4.Results
+
+Successful execution results in output similar to the following:
+
+```
+============ Serving Benchmark Result ============
+Successful requests: 510
+Benchmark duration (s): 301.46
+Total input tokens: 7201515
+Total generated tokens: 185502
+Request throughput (req/s): 1.69
+Output token throughput (tok/s): 615.34
+Total Token throughput (tok/s): 24504.02
+---------------Time to First Token----------------
+Mean TTFT (ms): 20931.33
+Median TTFT (ms): 19119.63
+Std TTFT (ms): 17324.45
+P25 TTFT (ms): 4057.98
+P50 TTFT (ms): 19119.63
+P75 TTFT (ms): 33284.55
+P99 TTFT (ms): 64592.68
+-----Time per Output Token (excl. 1st token)------
+Mean TPOT (ms): 187.71
+Median TPOT (ms): 200.69
+Std TPOT (ms): 63.08
+P25 TPOT (ms): 144.17
+P50 TPOT (ms): 200.69
+P75 TPOT (ms): 234.55
+P99 TPOT (ms): 312.87
+---------------Inter-token Latency----------------
+Mean ITL (ms): 181.20
+Median ITL (ms): 169.18
+Std ITL (ms): 133.70
+P25 ITL (ms): 86.63
+P50 ITL (ms): 169.18
+P75 ITL (ms): 230.91
+P99 ITL (ms): 647.04
+----------------End-to-end Latency----------------
+Mean E2EL (ms): 86656.79
+Median E2EL (ms): 89218.82
+Std E2EL (ms): 43454.94
+P25 E2EL (ms): 53935.13
+P50 E2EL (ms): 89218.82
+P75 E2EL (ms): 120761.34
+P99 E2EL (ms): 171262.27
+==================================================
+```
+
diff --git a/benchmarks/trace_replay.py b/benchmarks/trace_replay.py
index 21567d454..820216e3b 100644
--- a/benchmarks/trace_replay.py
+++ b/benchmarks/trace_replay.py
@@ -9,26 +9,14 @@
import time
from collections import defaultdict
from datetime import datetime
+from typing import Optional
+import aiohttp
import pandas
+import vllm
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
-
-benchmark_path = os.environ.get("BENCHMARK_PATH")
-if benchmark_path:
- sys.path.append(benchmark_path)
- print(f"Added benchmark path: {benchmark_path}")
-else:
- raise EnvironmentError(
- "Missing BENCHMARK_PATH environment variable.\n"
- "Usage: export BENCHMARK_PATH="
- )
-
-from backend_request_func import (
- ASYNC_REQUEST_FUNCS,
- RequestFuncInput,
-)
-from benchmark_dataset import (
+from vllm.benchmarks.datasets import (
AIMODataset,
ASRDataset,
BenchmarkDataset,
@@ -45,13 +33,19 @@
SonnetDataset,
VisionArenaDataset,
)
-from benchmark_serving import (
+from vllm.benchmarks.endpoint_request_func import (
+ ASYNC_REQUEST_FUNCS,
+ RequestFuncInput,
+)
+from vllm.benchmarks.serve import (
+ BenchmarkMetrics,
+ add_cli_args,
calculate_metrics,
- create_argument_parser,
get_tokenizer,
)
logger = logging.getLogger(__name__)
+REQUEST_FUC = None
class TraceReplayDataset(BenchmarkDataset):
@@ -89,19 +83,20 @@ def __init__(
def generate_prompt(
self, hash_ids: list[int], target_length: int, tokenizer
) -> str:
-
+ DEFAULT_BLOCK_SIZE = 512
vocab_size = tokenizer.vocab_size
# Use hash_ids to influence token generation
base_offset = hash_ids[0] if hash_ids else 0
+
token_ids = []
for i, value in enumerate(hash_ids):
if value in self.hash_to_tokens:
token_ids.extend(self.hash_to_tokens[value])
- elif (i + 1) * 512 <= target_length:
- for j in range(512):
- token_idx = i * 512 + j
+ elif (i + 1) * DEFAULT_BLOCK_SIZE <= target_length:
+ for j in range(DEFAULT_BLOCK_SIZE):
+ token_idx = i * DEFAULT_BLOCK_SIZE + j
token_id = (
base_offset
+ token_idx
@@ -112,10 +107,11 @@ def generate_prompt(
]
)
) % vocab_size
+
self.hash_to_tokens.setdefault(value, []).append(token_id)
token_ids.extend(self.hash_to_tokens[value])
else:
- needed = target_length - i * 512
+ needed = target_length - i * DEFAULT_BLOCK_SIZE
padding = [
(base_offset + len(token_ids) + j) % vocab_size
for j in range(needed)
@@ -357,7 +353,9 @@ def gene_prompts_by_dataset_name(
return input_requests
-def save_metrics_to_file(metrics, output_dir="./"):
+def save_metrics_to_file(
+ metrics: BenchmarkMetrics, metric_percentiles: list[float], output_dir: str = "./"
+):
output_path = output_dir
if not os.path.exists(output_path):
os.makedirs(output_path, exist_ok=True)
@@ -365,33 +363,32 @@ def save_metrics_to_file(metrics, output_dir="./"):
outputs = {}
outputs["time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- outputs["mean_ttft(ms)"] = round(metrics.mean_ttft_ms, 2)
- outputs["p99_ttft(ms)"] = round(metrics.percentiles_ttft_ms[3][1], 2)
- outputs["mean_tpot(ms)"] = round(metrics.mean_tpot_ms, 2)
- outputs["p99_tpot(ms)"] = round(metrics.percentiles_tpot_ms[3][1], 2)
- outputs["total_input_tokens"] = round(metrics.total_input, 2)
- outputs["total_output_tokens"] = round(metrics.total_output, 2)
- outputs["total_token_throughput(tok/s)"] = round(metrics.total_token_throughput, 2)
- outputs["output_throughput(tok/s)"] = round(metrics.output_throughput, 2)
- outputs["request_throughput(req/s)"] = round(metrics.request_throughput, 2)
- outputs["request_goodput(req/s)"] = metrics.request_goodput
+ outputs["total_input_tokens"] = metrics.total_input
+ outputs["total_output_tokens"] = metrics.total_output
outputs["completed requests"] = metrics.completed
+ outputs["request_throughput"] = metrics.request_throughput
+ outputs["request_goodput"] = metrics.request_goodput
+ outputs["output_throughput"] = metrics.output_throughput
+ outputs["total_token_throughput"] = metrics.total_token_throughput
+ outputs["mean_ttft_ms"] = metrics.mean_ttft_ms
+ for p, value in metrics.percentiles_ttft_ms:
+ outputs[f"p{int(p)}_ttft_ms"] = value
+ outputs["mean_tpot_ms"] = metrics.mean_tpot_ms
+ for p, value in metrics.percentiles_tpot_ms:
+ outputs[f"p{int(p)}_tpot_ms"] = value
+ outputs["mean_itl_ms"] = metrics.mean_itl_ms
+ for p, value in metrics.percentiles_itl_ms:
+ outputs[f"p{int(p)}_itl_ms"] = value
+ outputs["mean_e2el_ms"] = metrics.mean_e2el_ms
+ for p, value in metrics.percentiles_e2el_ms:
+ outputs[f"p{int(p)}_e2el_ms"] = value
df = pandas.DataFrame([outputs])
- if os.path.isfile(excel_file):
- try:
- existing_df = pandas.read_excel(excel_file)
- updated_df = pandas.concat([existing_df, df], ignore_index=True)
- except Exception as e:
- print(
- f"Warning: Failed to read {excel_file}, it will be overwritten. Error: {e}"
- )
- updated_df = df
- else:
- updated_df = df
- # Save back to Excel (automatically create or overwrite)
- with pandas.ExcelWriter(excel_file, engine="openpyxl", mode="w") as writer:
- updated_df.to_excel(writer, index=False, sheet_name="Performance Metrics")
+ os.makedirs(os.path.dirname(excel_file), exist_ok=True)
+ with pandas.ExcelWriter(
+ excel_file, engine="openpyxl", mode="a", if_sheet_exists="replace"
+ ) as writer:
+ df.to_excel(writer, index=False, sheet_name="Metrics")
print(f"Successfully saved performance metrics to {excel_file}")
@@ -399,8 +396,9 @@ def save_req_results_to_file(outputs, output_dir="./"):
output_path = output_dir
if not os.path.exists(output_path):
os.makedirs(output_path, exist_ok=True)
- excel_file = os.path.join(output_path, "req_results.xlsx")
+ excel_file = os.path.join(output_path, "metrics.xlsx")
rows = []
+ # print(f"outputs: {outputs}")
for output in outputs:
ttft = output.ttft * 1000 if output.ttft is not None else None
output_len = output.output_tokens if output.output_tokens is not None else 0
@@ -410,29 +408,54 @@ def save_req_results_to_file(outputs, output_dir="./"):
if output_len > 1 and output.ttft is not None and output.latency is not None:
tpot = (output.latency - output.ttft) / (output_len - 1) * 1000
row = {
- "ttft(ms)": ttft,
- "tpot(ms)": tpot,
- "e2el(ms)": latency,
- "input_tokens": input_len,
- "output_tokens": output_len,
- "success": output.success,
+ "input_lens": input_len,
+ "output_lens": output_len,
+ "ttfts_ms": ttft,
+ "tpot_ms": tpot,
}
+ if output.send_time and output.running_time:
+ row["send_to_funning"] = output.running_time - output.send_time
+ if output.running_time and output.worker_time:
+ row["running_to_worker"] = output.worker_time - output.running_time
+ if output.worker_time and output.start_loadkv_time:
+ row["worker_to_loadkv"] = output.start_loadkv_time - output.worker_time
+ if output.start_loadkv_time and output.start_forward_time:
+ row["loadkv_duration"] = (
+ output.start_forward_time - output.start_loadkv_time
+ )
+ if output.start_forward_time and output.finish_forward_time:
+ row["forward_duration"] = (
+ output.finish_forward_time - output.start_forward_time
+ )
+ if output.finish_forward_time and output.finish_savekv_time:
+ row["savekv_duration"] = (
+ output.finish_savekv_time - output.finish_forward_time
+ )
+ if output.first_token_time and output.running_time:
+ row["running_to_first_token"] = (
+ output.first_token_time - output.running_time
+ )
+ row["success"] = output.success
rows.append(row)
+
df = pandas.DataFrame(rows)
- file_exists = os.path.isfile(excel_file)
- if file_exists:
- try:
- existing_df = pandas.read_excel(excel_file)
- updated_df = pandas.concat([existing_df, df], ignore_index=True)
- except Exception as e:
- print(
- f"Warning: Failed to read {excel_file}, it will be overwritten. Error: {e}"
- )
- updated_df = df
- else:
- updated_df = df
- with pandas.ExcelWriter(excel_file, engine="openpyxl", mode="w") as writer:
- updated_df.to_excel(writer, index=False, sheet_name="Performance Metrics")
+ os.makedirs(os.path.dirname(excel_file), exist_ok=True)
+ with pandas.ExcelWriter(
+ excel_file, engine="openpyxl", mode="a", if_sheet_exists="replace"
+ ) as writer:
+ df.to_excel(writer, index=False, sheet_name="details")
+
+
+async def request_func(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+ session: aiohttp.ClientSession = None,
+):
+ if session:
+ return await REQUEST_FUC(
+ request_func_input=request_func_input, pbar=pbar, session=session
+ )
+ return await REQUEST_FUC(request_func_input=request_func_input, pbar=pbar)
# Send requests by timestamp
@@ -443,6 +466,8 @@ async def replay_trace_by_time(
model_id = args.model
model_name = args.served_model_name
disable_tqdm = args.disable_tqdm
+ if backend not in ASYNC_REQUEST_FUNCS:
+ raise ValueError(f"Unknown backend: {backend}")
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
@@ -451,10 +476,6 @@ async def replay_trace_by_time(
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}"
- if backend not in ASYNC_REQUEST_FUNCS:
- raise ValueError(f"Unknown backend: {backend}")
- request_func = ASYNC_REQUEST_FUNCS[backend]
-
print("Starting initial single prompt test run...")
test_request = None
for _, reqs in sorted(req_groups.items()):
@@ -463,6 +484,18 @@ async def replay_trace_by_time(
break
if test_request is None:
raise ValueError("No request found for initial test run.")
+
+ use_session = tuple(map(int, vllm.__version__.split(".")[:3])) >= (0, 10, 1)
+ session = None
+ if use_session:
+ session = vllm.aiohttp.ClientSession(
+ base_url=base_url,
+ trust_env=True,
+ timeout=aiohttp.ClientTimeout(total=6 * 60 * 60),
+ )
+ global REQUEST_FUC
+ REQUEST_FUC = ASYNC_REQUEST_FUNCS[backend]
+
test_input = RequestFuncInput(
model=model_id,
model_name=model_name,
@@ -476,7 +509,7 @@ async def replay_trace_by_time(
extra_body={"temperature": 0.9},
)
- test_output = await request_func(request_func_input=test_input)
+ test_output = await request_func(request_func_input=test_input, session=session)
if not getattr(test_output, "success", False):
raise ValueError(
@@ -513,9 +546,13 @@ async def _run_one_request(sample_req):
)
if semaphore is not None:
async with semaphore:
- return await request_func(request_func_input=req_input, pbar=pbar)
+ return await request_func(
+ request_func_input=req_input, session=session, pbar=pbar
+ )
else:
- return await request_func(request_func_input=req_input, pbar=pbar)
+ return await request_func(
+ request_func_input=req_input, session=session, pbar=pbar
+ )
for sec, reqs in sorted(req_groups.items()):
delay = sec - (time.perf_counter() - start_time)
@@ -540,13 +577,28 @@ async def send_group(r=reqs, d=delay):
flat_requests.extend(reqs)
group_results = await asyncio.gather(*tasks)
+
+ if pbar is not None:
+ pbar.close()
+ if use_session:
+ await session.close()
+
outputs = []
for res in group_results:
if isinstance(res, list):
outputs.extend(res)
- if pbar is not None:
- pbar.close()
+ percentile_metrics = (
+ args.metric_percentiles.split(",")
+ if args.metric_percentiles
+ else ["ttft", "tpot", "itl", "e2el"]
+ )
+ metric_percentiles = (
+ [int(x) for x in args.metric_percentiles.split(",")]
+ if args.metric_percentiles
+ else [50, 90]
+ )
+ goodput = {k: float(v) for item in args.goodput for k, v in [item.split(":", 1)]}
benchmark_duration = time.perf_counter() - start_time
metrics, actual_output_lens = calculate_metrics(
@@ -554,9 +606,8 @@ async def send_group(r=reqs, d=delay):
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
- selected_percentile_metrics=["ttft", "tpot", "itl", "e2el"],
- selected_percentiles=[25.0, 50.0, 75.0, 99.0],
- goodput_config_dict={"ttft": 2000, "tpot": 50},
+ selected_percentiles=metric_percentiles,
+ goodput_config_dict=goodput,
)
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
@@ -586,7 +637,7 @@ def process_one_metric(
metric_name: str,
metric_header: str,
):
- selected_percentile_metrics = ["ttft", "tpot", "itl", "e2el"]
+ selected_percentile_metrics = percentile_metrics
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
@@ -621,13 +672,19 @@ def process_one_metric(
output_dir = args.result_dir if args.result_dir is not None else "./"
if args.save_result:
- save_metrics_to_file(metrics=metrics, output_dir=output_dir)
+ save_metrics_to_file(
+ metrics=metrics,
+ metric_percentiles=metric_percentiles,
+ output_dir=output_dir,
+ )
save_req_results_to_file(outputs=outputs, output_dir=output_dir)
return
def create_argument_trace():
- parser = create_argument_parser()
+ parser = argparse.ArgumentParser(description="Benchmark LLM serving performance")
+ add_cli_args(parser)
+
trace_group = parser.add_argument_group("tracing parameters")
trace_group.add_argument(
"--trace-path",
@@ -677,6 +734,14 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
+ # Check openpyxl for Excel export
+ try:
+ import openpyxl
+ except ImportError:
+ print("\nMissing package: openpyxl")
+ print("Please install openpyxl via pip install.\n")
+ sys.exit(1)
+
parser = create_argument_trace()
args = parser.parse_args()
main(args)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 622d8fc1b..fdebcb997 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -6,15 +6,12 @@ ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"
WORKDIR /workspace
# Install unified-cache-management
-COPY . /vllm-workspace/unified-cache-management
+COPY . /workspace/unified-cache-management
RUN pip config set global.index-url ${PIP_INDEX_URL}
RUN export PLATFORM="cuda" && \
- pip install -v -e /vllm-workspace/unified-cache-management --no-build-isolation
+ pip install -v -e /workspace/unified-cache-management --no-build-isolation
-# Apply patch for vLLM
-RUN cd $(pip show vllm | grep Location | awk '{print $2}') \
- && git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch
ENTRYPOINT ["/bin/bash"]
\ No newline at end of file
diff --git a/docker/Dockerfile-NPU b/docker/Dockerfile-NPU
index e03c45dc3..270a63c95 100644
--- a/docker/Dockerfile-NPU
+++ b/docker/Dockerfile-NPU
@@ -6,22 +6,12 @@ ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"
WORKDIR /workspace
# Install unified-cache-management
-COPY . /vllm-workspace/unified-cache-management
+COPY . /workspace/unified-cache-management
RUN pip config set global.index-url ${PIP_INDEX_URL}
RUN export PLATFORM="ascend" && \
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \
- pip install -v -e /vllm-workspace/unified-cache-management --no-build-isolation
-
-# Apply patch for vLLM
-RUN cd /vllm-workspace/vllm \
- && git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch \
- && git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch
-
-# Apply patch for vLLM-Ascend
-RUN cd /vllm-workspace/vllm-ascend \
- && git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch \
- && git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt-sparse.patch
+ pip install -v -e /workspace/unified-cache-management --no-build-isolation
CMD ["/bin/bash"]
\ No newline at end of file
diff --git a/docs/source/_static/images/qrcode_for_wechat.png b/docs/source/_static/images/qrcode_for_wechat.png
new file mode 100644
index 000000000..b456bdabd
Binary files /dev/null and b/docs/source/_static/images/qrcode_for_wechat.png differ
diff --git a/docs/source/getting-started/installation_gpu.md b/docs/source/getting-started/installation_gpu.md
index 2ab51d705..eaf1d3d05 100644
--- a/docs/source/getting-started/installation_gpu.md
+++ b/docs/source/getting-started/installation_gpu.md
@@ -33,6 +33,13 @@ docker run \
```
Refer to [Set up using docker](https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html#set-up-using-docker) for more information to run your own vLLM container.
+### Install by pip
+Install by pip or find the pre-build wheels on [Pypi](https://pypi.org/project/uc-manager/).
+```
+pip install uc-manager
+```
+
+
### Build from source code
Follow commands below to install unified-cache-management:
@@ -44,17 +51,12 @@ export PLATFORM=cuda
pip install -v -e . --no-build-isolation
```
-After installation, please apply patch to ensure uc_connector can be used:
+**Note:** Patches are now applied automatically via dynamic patching when you import the unified-cache-management package. You no longer need to manually apply patches using `git apply`. The patches are automatically applied when you use the `UnifiedCacheConnectorV1` connector.
-```bash
-cd $(pip show vllm | grep Location | awk '{print $2}')
-git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch
-git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch
-```
-
-Refer to this [issue](https://github.com/vllm-project/vllm/issues/21702) to see details of this patch's changes.
## Setup from docker
+
+### Build image from source
Download the pre-built `vllm/vllm-openai:v0.9.2` docker image and build unified-cache-management docker image by commands below:
```bash
# Build docker image using source code, replace with the branch or tag name needed
@@ -62,6 +64,14 @@ Download the pre-built `vllm/vllm-openai:v0.9.2` docker image and build unified-
cd unified-cache-management
docker build -t ucm-vllm:latest -f ./docker/Dockerfile ./
```
+
+
+### Pre-built images
+
+```bash
+docker pull unifiedcachemanager/ucm:latest
+```
+
Then run your container using following command. You can add or remove Docker parameters as needed.
```bash
# Use `--ipc=host` to make sure the shared memory is large enough.
diff --git a/docs/source/getting-started/installation_npu.md b/docs/source/getting-started/installation_npu.md
index 0824de0c3..571e96e15 100644
--- a/docs/source/getting-started/installation_npu.md
+++ b/docs/source/getting-started/installation_npu.md
@@ -39,13 +39,10 @@ docker run --rm \
-v /root/.cache:/root/.cache \
-it $IMAGE bash
```
-Codes of vLLM and vLLM Ascend are placed in /vllm-workspace, you can refer to [vLLM-Ascend Installation](https://vllm-ascend.readthedocs.io/en/latest/installation.html) for more information. After installation, please apply patches to ensure uc_connector can be used:
-```bash
-cd /vllm-workspace/vllm
-git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch
-cd /vllm-workspace/vllm-ascend
-git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch
-```
+Codes of vLLM and vLLM Ascend are placed in /vllm-workspace, you can refer to [vLLM-Ascend Installation](https://vllm-ascend.readthedocs.io/en/latest/installation.html) for more information.
+
+**Note:** For vLLM and vLLM Ascend patches, they are now applied automatically via dynamic patching when you import the unified-cache-management package.
+
Refer to these issues [vllm-issue](https://github.com/vllm-project/vllm/issues/21702) and [vllm-ascend-issue](https://github.com/vllm-project/vllm-ascend/issues/2057) to see details of patches' changes.
### Build from source code
@@ -60,7 +57,7 @@ cd ..
```
## Setup from docker
-Download the pre-built docker image provided or build unified-cache-management docker image by commands below:
+Download the pre-built `vllm-ascend` docker image or build unified-cache-management docker image by commands below:
```bash
# Build docker image using source code, replace with the branch or tag name needed
git clone --depth 1 --branch https://github.com/ModelEngine-Group/unified-cache-management.git
diff --git a/docs/source/getting-started/quick_start.md b/docs/source/getting-started/quick_start.md
index 5ad56b289..098c2eeb7 100644
--- a/docs/source/getting-started/quick_start.md
+++ b/docs/source/getting-started/quick_start.md
@@ -21,14 +21,16 @@ Before you start with UCM, please make sure that you have installed UCM correctl
## Features Overview
-UCM supports two key features: **Prefix Cache** and **GSA Sparsity**.
+UCM supports two key features: **Prefix Cache** and **Sparse attention**.
Each feature supports both **Offline Inference** and **Online API** modes.
For quick start, just follow the [usage](./quick_start.md) guide below to launch your own inference experience;
-For further research, click on the links blow to see more details of each feature:
+For further research on Prefix Cache, more details are available via the link below:
- [Prefix Cache](../user-guide/prefix-cache/index.md)
+
+Various Sparse Attention features are now available, try GSA Sparsity via the link below:
- [GSA Sparsity](../user-guide/sparse-attention/gsa.md)
## Usage
@@ -40,12 +42,14 @@ You can use our official offline example script to run offline inference as foll
```bash
cd examples/
+# Change the model path to your own model path
+export MODEL_PATH=/home/models/Qwen2.5-14B-Instruct
python offline_inference.py
```
-
+
OpenAI-Compatible Online API
For online inference , vLLM with our connector can also be deployed as a server that implements the OpenAI API protocol.
@@ -55,10 +59,23 @@ First, specify the python hash seed by:
export PYTHONHASHSEED=123456
```
-Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model:
+Create a config yaml like following and save it to your own directory:
+```yaml
+# UCM Configuration File Example
+# Refer to file unified-cache-management/examples/ucm_config_example.yaml for more details
+ucm_connector_name: "UcmNfsStore"
+
+ucm_connector_config:
+ storage_backends: "/mnt/test"
+```
+
+Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model and your config file path:
```bash
-vllm serve /home/models/Qwen2.5-14B-Instruct \
+# Change the model path to your own model path
+export MODEL_PATH=/home/models/Qwen2.5-14B-Instruct
+vllm serve ${MODEL_PATH} \
+--served-model-name vllm_cpu_offload \
--max-model-len 20000 \
--tensor-parallel-size 2 \
--gpu_memory_utilization 0.87 \
@@ -66,15 +83,11 @@ vllm serve /home/models/Qwen2.5-14B-Instruct \
--port 7800 \
--kv-transfer-config \
'{
- "kv_connector": "UnifiedCacheConnectorV1",
- "kv_connector_module_path": "ucm.integration.vllm.uc_connector",
+ "kv_connector": "UCMConnector",
+ "kv_connector_module_path": "ucm.integration.vllm.ucm_connector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
- "ucm_connector_name": "UcmDramStore",
- "ucm_connector_config": {
- "max_cache_size": 5368709120,
- "kv_block_size": 262144
- }
+ "UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"
}
}'
```
@@ -95,7 +108,7 @@ After successfully started the vLLM server,You can interact with the API as fo
curl http://localhost:7800/v1/completions \
-H "Content-Type: application/json" \
-d '{
- "model": "/home/models/Qwen2.5-14B-Instruct",
+ "model": "vllm_cpu_offload",
"prompt": "Shanghai is a",
"max_tokens": 7,
"temperature": 0
@@ -103,3 +116,4 @@ curl http://localhost:7800/v1/completions \
```
+Note: If you want to disable vLLM prefix cache to test the cache ability of UCM, you can add `--no-enable-prefix-caching` to the command line.
\ No newline at end of file
diff --git a/docs/source/index.md b/docs/source/index.md
index 2352d3996..69be815e6 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -57,6 +57,7 @@ getting-started/installation_npu
user-guide/prefix-cache/index
user-guide/sparse-attention/index
user-guide/pd-disaggregation/index
+user-guide/metrics/metrics
:::
:::{toctree}
diff --git a/docs/source/user-guide/metrics/metrics.md b/docs/source/user-guide/metrics/metrics.md
new file mode 100644
index 000000000..22b532681
--- /dev/null
+++ b/docs/source/user-guide/metrics/metrics.md
@@ -0,0 +1,193 @@
+# Observability
+
+UCM (Unified Cache Management) provides detailed metrics monitoring through Prometheus endpoints, allowing in-depth monitoring of cache performance and behavior. This document describes how to enable and configure observability from the embedded vLLM `/metrics` API endpoint.
+
+---
+
+## Quick Start Guide
+
+### 1) On UCM Side
+
+First, set the `PROMETHEUS_MULTIPROC_DIR` environment variable.
+
+```bash
+export PROMETHEUS_MULTIPROC_DIR=/vllm-workspace
+```
+
+Then, start the UCM service.
+
+```bash
+export CUDA_VISIBLE_DEVICES=0
+vllm serve /home/models/Qwen2.5-14B-Instruct \
+ --max-model-len 5000 \
+ --tensor-parallel-size 1 \
+ --gpu_memory_utilization 0.87 \
+ --trust-remote-code \
+ --disable-log-requests \
+ --no-enable-prefix-caching \
+ --enforce-eager \
+ --max-num-batched-tokens 40000 \
+ --max-num-seqs 10 \
+ --host 0.0.0.0 \
+ --port 8000 \
+ --kv-transfer-config \
+ '{
+ "kv_connector": "UCMConnector",
+ "kv_connector_module_path": "ucm.integration.vllm.ucm_connector",
+ "kv_role": "kv_both",
+ "kv_connector_extra_config": {
+ "UCM_CONFIG_FILE": "/vllm-workspace/unified-cache-management/examples/ucm_config.yaml"
+ }
+ }'
+```
+**Note**: You can refer to the `ucm_config.yaml` file at https://github.com/ModelEngine-Group/unified-cache-management/tree/develop/examples to configure the `metrics_config_path` parameter.
+
+You can use the `vllm bench serve` command to run benchmarks:
+
+```bash
+vllm bench serve \
+ --backend vllm \
+ --model /home/models/Qwen2.5-14B-Instruct \
+ --host 127.0.0.1 \
+ --port 8000 \
+ --dataset-name random \
+ --num-prompts 20 \
+ --random-input-len 200 \
+ --random-output-len 10 \
+ --request-rate 1 \
+ --ignore-eos
+```
+
+Once the HTTP server is running, you can access the UCM metrics at the `/metrics` endpoint.
+
+```bash
+curl http://$:8000/metrics | grep ucm:
+```
+
+You will also find some `.db` files in the `$PROMETHEUS_MULTIPROC_DIR` directory, which are temporary files used by Prometheus.
+
+### 2) Start Prometheus and Grafana with Docker Compose
+
+#### Create Docker Compose Configuration Files
+
+First, create the `docker-compose.yaml` file:
+
+```yaml
+# docker-compose.yaml
+version: "3"
+
+services:
+ prometheus:
+ image: prom/prometheus:latest
+ extra_hosts:
+ - "host.docker.internal:host-gateway"
+ ports:
+ - "9090:9090"
+ volumes:
+ - ${PWD}/prometheus.yaml:/etc/prometheus/prometheus.yml
+
+ grafana:
+ image: grafana/grafana:latest
+ depends_on:
+ - prometheus
+ ports:
+ - "3000:3000"
+```
+
+Then, create the `prometheus.yaml` configuration file:
+
+```yaml
+# prometheus.yaml
+global:
+ scrape_interval: 5s
+ evaluation_interval: 30s
+
+scrape_configs:
+ - job_name: vllm
+ static_configs:
+ - targets:
+ - 'host.docker.internal:8000'
+```
+
+**Note**: Make sure the port number in `prometheus.yaml` matches the port number used when starting the vLLM service.
+
+#### Start Services
+
+Run the following command in the directory containing `docker-compose.yaml` and `prometheus.yaml`:
+
+```bash
+docker compose up
+```
+
+This will start Prometheus and Grafana services.
+
+### 3) Configure Grafana Dashboard
+
+#### Access Grafana
+
+Navigate to `http://:3000`. Log in with the default username (`admin`) and password (`admin`). You will be prompted to change the password on first login.
+
+#### Add Prometheus Data Source
+
+1. Navigate to `http://:3000/connections/datasources/new` and select **Prometheus**.
+
+2. On the Prometheus configuration page, add the Prometheus server URL in the **Connection** section. For this Docker Compose setup, Grafana and Prometheus run in separate containers, but Docker creates DNS names for each container. You can directly use `http://prometheus:9090`.
+
+3. Click **Save & Test**. You should see a green checkmark showing "Successfully queried the Prometheus API."
+
+#### Import Dashboard
+
+1. Navigate to `http://:3000/dashboard/import`.
+
+2. Click **Upload JSON file**, then upload the `unified-cache-management/examples/metrics/grafana.json` file.
+
+3. Select the Prometheus data source configured earlier.
+
+4. Click **Import** to complete the import.
+
+You should now be able to see the UCM monitoring dashboard with real-time visualization of all 9 metrics.
+
+## Available Metrics
+
+UCM exposes various metrics to monitor its performance. The following table lists all available metrics organized by category:
+
+| Metric Name | Type | Description |
+|------------|------|-------------|
+| **Load Operation Metrics** | | |
+| `ucm:load_requests_num` | Histogram | Number of requests loaded per `start_load_kv` call |
+| `ucm:load_blocks_num` | Histogram | Number of blocks loaded per `start_load_kv` call |
+| `ucm:load_duration` | Histogram | Time to load KV cache from UCM (milliseconds) |
+| `ucm:load_speed` | Histogram | Speed of loading from UCM (GB/s) |
+| **Save Operation Metrics** | | |
+| `ucm:save_requests_num` | Histogram | Number of requests saved per `wait_for_save` call |
+| `ucm:save_blocks_num` | Histogram | Number of blocks saved per `wait_for_save` call |
+| `ucm:save_duration` | Histogram | Time to save to UCM (milliseconds) |
+| `ucm:save_speed` | Histogram | Speed of saving to UCM (GB/s) |
+| **Lookup Hit Rate Metrics** | | |
+| `ucm:interval_lookup_hit_rates` | Histogram | Hit rate of UCM lookup requests |
+
+## Prometheus Configuration
+
+Metrics configuration is defined in the `unified-cache-management/examples/metrics/metrics_configs.yaml` file:
+
+```yaml
+log_interval: 5 # Interval in seconds for logging metrics
+
+prometheus:
+ multiproc_dir: "/vllm-workspace" # Prometheus directory
+ metric_prefix: "ucm:" # Metric name prefix
+
+ enabled_metrics:
+ counters: true
+ gauges: true
+ histograms: true
+
+ histograms:
+ - name: "load_requests_num"
+ documentation: "Number of requests loaded from ucm"
+ buckets: [1, 5, 10, 20, 50, 100, 200, 500, 1000]
+ # ... other metric configurations
+```
+
+---
+
diff --git a/docs/source/user-guide/pd-disaggregation/1p1d.md b/docs/source/user-guide/pd-disaggregation/1p1d.md
index a6e2547f4..fb3f4d056 100644
--- a/docs/source/user-guide/pd-disaggregation/1p1d.md
+++ b/docs/source/user-guide/pd-disaggregation/1p1d.md
@@ -5,16 +5,17 @@ This example demonstrates how to run unified-cache-management with disaggregated
## Prerequisites
- UCM: Installed with reference to the Installation documentation.
-- Hardware: At least 2 GPUs
+- Hardware: At least 2 GPUs or 2 NPUs
## Start disaggregated service
-For illustration purposes, let us assume that the model used is Qwen2.5-7B-Instruct.
+For illustration purposes, let us take GPU as an example and assume the model used is Qwen2.5-7B-Instruct.Using ASCEND_RT_VISIBLE_DEVICES instead of CUDA_VISIBLE_DEVICES to specify visible devices when starting service on Ascend platform.
### Run prefill server
Prefiller Launch Command:
```bash
export PYTHONHASHSEED=123456
-CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \
+export CUDA_VISIBLE_DEVICES=0
+vllm serve /home/models/Qwen2.5-7B-Instruct \
--max-model-len 20000 \
--tensor-parallel-size 1 \
--gpu_memory_utilization 0.87 \
@@ -41,8 +42,9 @@ CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \
### Run decode server
Decoder Launch Command:
```bash
-export PYTHONHASHSEED=123456
-CUDA_VISIBLE_DEVICES=1 vllm serve /home/models/Qwen2.5-7B-Instruct \
+export PYTHONHASHSEED=123456
+export CUDA_VISIBLE_DEVICES=0
+vllm serve /home/models/Qwen2.5-7B-Instruct \
--max-model-len 20000 \
--tensor-parallel-size 1 \
--gpu_memory_utilization 0.87 \
@@ -68,8 +70,8 @@ CUDA_VISIBLE_DEVICES=1 vllm serve /home/models/Qwen2.5-7B-Instruct \
### Run proxy server
Make sure prefill nodes and decode nodes can connect to each other.
```bash
-cd vllm-workspace/unified-cache-management/test/
-python3 toy_proxy_server.py --host localhost --port 7802 --prefiller-host --prefiller-port 7800 --decoder-host --decoder-port 7801
+cd /vllm-workspace/unified-cache-management/ucm/pd
+python3 toy_proxy_server.py --pd-disaggregation --host localhost --port 7802 --prefiller-host --prefiller-port 7800 --decoder-host --decoder-port 7801
```
## Testing and Benchmarking
@@ -88,8 +90,7 @@ curl http://localhost:7802/v1/completions \
### Benchmark Test
Use the benchmark scripts provided by vLLM.
```bash
-cd /vllm-workspace/vllm/benchmarks
-python3 benchmark_serving.py \
+vllm bench serve \
--backend vllm \
--dataset-name random \
--random-input-len 4096 \
diff --git a/docs/source/user-guide/pd-disaggregation/npgd.md b/docs/source/user-guide/pd-disaggregation/npgd.md
index e7dd10bf7..c4919779a 100644
--- a/docs/source/user-guide/pd-disaggregation/npgd.md
+++ b/docs/source/user-guide/pd-disaggregation/npgd.md
@@ -50,7 +50,8 @@ vllm serve /home/models/Qwen2.5-7B-Instruct \
Decoder Launch Command:
```bash
export PYTHONHASHSEED=123456
-CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \
+export CUDA_VISIBLE_DEVICES=0
+vllm serve /home/models/Qwen2.5-7B-Instruct \
--max-model-len 20000 \
--tensor-parallel-size 1 \
--gpu_memory_utilization 0.87 \
@@ -77,7 +78,7 @@ CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \
### Run proxy server
Make sure prefill nodes and decode nodes can connect to each other.
```bash
-cd vllm-workspace/unified-cache-management/test/
+cd /vllm-workspace/unified-cache-management/ucm/pd
python3 toy_proxy_server.py --host localhost --port 7802 --prefiller-host --prefiller-port 7800 --decoder-host --decoder-port 7801
```
@@ -97,8 +98,7 @@ curl http://localhost:7802/v1/completions \
### Benchmark Test
Use the benchmark scripts provided by vLLM.
```bash
-cd /vllm-workspace/vllm/benchmarks
-python3 benchmark_serving.py \
+vllm bench serve \
--backend vllm \
--dataset-name random \
--random-input-len 4096 \
diff --git a/docs/source/user-guide/pd-disaggregation/xpyd.md b/docs/source/user-guide/pd-disaggregation/xpyd.md
index c7fde1e9f..a57ab5d2f 100644
--- a/docs/source/user-guide/pd-disaggregation/xpyd.md
+++ b/docs/source/user-guide/pd-disaggregation/xpyd.md
@@ -5,15 +5,17 @@ This example demonstrates how to run unified-cache-management with disaggregated
## Prerequisites
- UCM: Installed with reference to the Installation documentation.
-- Hardware: At least 4 GPUs (At least 2 GPUs for prefiller + 2 for decoder in 2d2p setup)
+- Hardware: At least 4 GPUs (At least 2 GPUs for prefiller + 2 for decoder in 2d2p setup or 2 NPUs for prefiller + 2 for decoder in 2d2p setup)
## Start disaggregated service
-For illustration purposes, let us assume that the model used is Qwen2.5-7B-Instruct.
+For illustration purposes, let us take GPU as an example and assume the model used is Qwen2.5-7B-Instruct.Using ASCEND_RT_VISIBLE_DEVICES instead of CUDA_VISIBLE_DEVICES to specify visible devices when starting service on Ascend platform.
+
### Run prefill servers
Prefiller1 Launch Command:
```bash
export PYTHONHASHSEED=123456
-CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \
+export CUDA_VISIBLE_DEVICES=0
+vllm serve /home/models/Qwen2.5-7B-Instruct \
--max-model-len 20000 \
--tensor-parallel-size 1 \
--gpu_memory_utilization 0.87 \
@@ -40,7 +42,8 @@ CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \
Prefiller2 Launch Command:
```bash
export PYTHONHASHSEED=123456
-CUDA_VISIBLE_DEVICES=1 vllm serve /home/models/Qwen2.5-7B-Instruct \
+export CUDA_VISIBLE_DEVICES=1
+vllm serve /home/models/Qwen2.5-7B-Instruct \
--max-model-len 20000 \
--tensor-parallel-size 1 \
--gpu_memory_utilization 0.87 \
@@ -68,7 +71,8 @@ CUDA_VISIBLE_DEVICES=1 vllm serve /home/models/Qwen2.5-7B-Instruct \
Decoder1 Launch Command:
```bash
export PYTHONHASHSEED=123456
-CUDA_VISIBLE_DEVICES=2 vllm serve /home/models/Qwen2.5-7B-Instruct \
+export CUDA_VISIBLE_DEVICES=2
+vllm serve /home/models/Qwen2.5-7B-Instruct \
--max-model-len 20000 \
--tensor-parallel-size 1 \
--gpu_memory_utilization 0.87 \
@@ -94,7 +98,8 @@ CUDA_VISIBLE_DEVICES=2 vllm serve /home/models/Qwen2.5-7B-Instruct \
Decoder2 Launch Command:
```bash
export PYTHONHASHSEED=123456
-CUDA_VISIBLE_DEVICES=3 vllm serve /home/models/Qwen2.5-7B-Instruct \
+export CUDA_VISIBLE_DEVICES=3
+vllm serve /home/models/Qwen2.5-7B-Instruct \
--max-model-len 20000 \
--tensor-parallel-size 1 \
--gpu_memory_utilization 0.87 \
@@ -121,8 +126,8 @@ CUDA_VISIBLE_DEVICES=3 vllm serve /home/models/Qwen2.5-7B-Instruct \
### Run proxy server
Make sure prefill nodes and decode nodes can connect to each other. the number of prefill/decode hosts should be equal to the number of prefill/decode ports.
```bash
-cd vllm-workspace/unified-cache-management/test/
-python3 toy_proxy_server.py --host localhost --port 7805 --prefiller-hosts --prefiller-port 7800 7801 --decoder-hosts --decoder-ports 7802 7803
+cd /vllm-workspace/unified-cache-management/ucm/pd
+python3 toy_proxy_server.py --pd-disaggregation --host localhost --port 7805 --prefiller-hosts --prefiller-port 7800 7801 --decoder-hosts --decoder-ports 7802 7803
```
## Testing and Benchmarking
@@ -141,8 +146,7 @@ curl http://localhost:7805/v1/completions \
### Benchmark Test
Use the benchmark scripts provided by vLLM.
```bash
-cd /vllm-workspace/vllm/benchmarks
-python3 benchmark_serving.py \
+vllm bench serve \
--backend vllm \
--dataset-name random \
--random-input-len 4096 \
diff --git a/docs/source/user-guide/prefix-cache/dram_store.md b/docs/source/user-guide/prefix-cache/dram_store.md
deleted file mode 100644
index 1be2f30a2..000000000
--- a/docs/source/user-guide/prefix-cache/dram_store.md
+++ /dev/null
@@ -1,133 +0,0 @@
-# DRAM Store
-
-This document provides a usage example and configuration guide for the **DRAM Connector**. This connector enables offloading of KV cache from GPU HBM to CPU DRAM, helping reduce memory pressure and supporting larger models or batch sizes.
-
-## Performance
-
-### Overview
-The following are the multi-concurrency performance test results of UCM in the Prefix Cache scenario under a CUDA environment, showing the performance improvements of UCM on two different models.
-During the tests, HBM cache was disabled, and KV Cache was retrieved and matched only from DRAM.
-
-In the QwQ-32B model, the test used one H20 server with 2 GPUs.
-
-Here, Full Compute refers to pure VLLM inference, while DRAM80% indicates that after UCM pooling, the DRAM hit rate of the KV cache is 80%.
-
-The following table shows the results on the QwQ-32B model:
-| **QwQ-32B** | | | | |
-| ---------------: | -------------: | ------------------: | -------------: | :----------- |
-| **Input length** | **Concurrent** | **Full Compute(s)** | **DRAM80%(s)** | **Speedup** |
-| 4 000 | 1 | 1.0269 | 0.3102 | **+230.9 %** |
-| 8 000 | 1 | 2.0902 | 0.5718 | **+265.5 %** |
-| 16 000 | 1 | 4.4852 | 1.1914 | **+276.4 %** |
-| 4 000 | 2 | 1.5383 | 0.4209 | **+265.4 %** |
-| 8 000 | 2 | 3.1323 | 0.8231 | **+280.5 %** |
-| 16 000 | 2 | 6.7984 | 1.7420 | **+290.2 %** |
-| 4 000 | 4 | 2.8173 | 0.9444 | **+198.2 %** |
-| 8 000 | 4 | 5.2643 | 1.8290 | **+187.8 %** |
-| 16 000 | 4 | 11.3651 | 3.6706 | **+209.6 %** |
-## Features
-
-The DRAM connector supports the following functionalities:
-
-- `dump`: Offload KV cache blocks from HBM to DRAM.
-- `load`: Load KV cache blocks from DRAM back to HBM.
-- `lookup`: Look up KV blocks stored in DRAM by block hash.
-- `wait`: Ensure that all copy streams between CPU and GPU have completed.
-- `commit`: Mark cache operations as complete and ready for reuse.
-
-## Configuration
-
-To use the DRAM connector, you need to configure the `connector_config` dictionary in your model's launch configuration.
-
-### Required Parameters
-
-- `max_cache_size` *(optional)*:
- Specifies the maximum allowed DRAM memory usage (in **bytes**) for caching in `kv_connector_extra_config["ucm_connector_config"]`.
- If not provided, it defaults to **5 GB**.
-- `kv_block_size` *(optional)*:
- Specifies the memory size (in **bytes**) of a single key or value cache block used in vLLM’s paged attention mechanism, which is calculated as : `block_size * head_size * total_num_kv_heads * element_size`.
-
-### Example:
-
-```python
-# Allocate up to 8GB DRAM for KV cache
-# KV Block size (in byte) is 262144
-kv_connector_extra_config={"ucm_connector_name": "UcmDramStore", "ucm_connector_config":{"max_cache_size": 5368709120, "kv_block_size": 262144}}
-```
-
-## Launching Inference
-
-### Offline Inference
-
-To start **offline inference** with the DRAM connector,modify the script `examples/offline_inference.py` to include the `kv_connector_extra_config` for DRAM connector usage:
-
-```python
-# In examples/offline_inference.py
-ktc = KVTransferConfig(
- ...
- kv_connector_extra_config={"ucm_connector_name": "UcmDramStore", "ucm_connector_config":{"max_cache_size": 5368709120, "kv_block_size": 262144}}
-)
-```
-
-Then run the script as follows:
-
-```bash
-cd examples/
-python offline_inference.py
-```
-
-### Online Inference
-
-For **online inference** , vLLM with our connector can also be deployed as a server that implements the OpenAI API protocol.
-
-First, specify the python hash seed by:
-```bash
-export PYTHONHASHSEED=123456
-```
-
-Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model:
-
-```bash
-vllm serve /home/models/Qwen2.5-14B-Instruct \
---max-model-len 20000 \
---tensor-parallel-size 2 \
---gpu_memory_utilization 0.87 \
---trust-remote-code \
---port 7800 \
---kv-transfer-config \
-'{
- "kv_connector": "UnifiedCacheConnectorV1",
- "kv_connector_module_path": "ucm.integration.vllm.uc_connector",
- "kv_role": "kv_both",
- "kv_connector_extra_config": {
- "ucm_connector_name": "UcmDramStore",
- "ucm_connector_config": {
- "max_cache_size": 5368709120,
- "kv_block_size": 262144
- }
- }
-}'
-```
-
-If you see log as below:
-
-```bash
-INFO: Started server process [32890]
-INFO: Waiting for application startup.
-INFO: Application startup complete.
-```
-
-Congratulations, you have successfully started the vLLM server with DRAM Connector!
-
-After successfully started the vLLM server,You can interact with the API as following:
-
-```bash
-curl http://localhost:7800/v1/completions \
- -H "Content-Type: application/json" \
- -d '{
- "model": "/home/models/Qwen2.5-14B-Instruct",
- "prompt": "Shanghai is a",
- "max_tokens": 7,
- "temperature": 0
- }'
-```
diff --git a/docs/source/user-guide/prefix-cache/index.md b/docs/source/user-guide/prefix-cache/index.md
index defe27d38..ba3d16bef 100644
--- a/docs/source/user-guide/prefix-cache/index.md
+++ b/docs/source/user-guide/prefix-cache/index.md
@@ -79,6 +79,5 @@ performance.
:::{toctree}
:maxdepth: 1
-dram_store
nfs_store
:::
\ No newline at end of file
diff --git a/docs/source/user-guide/prefix-cache/nfs_store.md b/docs/source/user-guide/prefix-cache/nfs_store.md
index b581acf56..741fcedf7 100644
--- a/docs/source/user-guide/prefix-cache/nfs_store.md
+++ b/docs/source/user-guide/prefix-cache/nfs_store.md
@@ -87,8 +87,15 @@ To use the NFS connector, you need to configure the `connector_config` dictionar
### Example:
-```python
-kv_connector_extra_config={"ucm_connector_name": "UcmNfsStore", "ucm_connector_config":{"storage_backends": "/mnt/test1", "transferStreamNumber": 32}}
+Create a config yaml like following and save it to your own directory:
+```yaml
+# UCM Configuration File Example
+# Refer to file unified-cache-management/examples/ucm_config_example.yaml for more details
+ucm_connector_name: "UcmNfsStore"
+
+ucm_connector_config:
+ storage_backends: "/mnt/test"
+ transferStreamNumber: 32
```
## Launching Inference
@@ -101,7 +108,7 @@ To start **offline inference** with the NFS connector,modify the script `examp
# In examples/offline_inference.py
ktc = KVTransferConfig(
...
- kv_connector_extra_config={"ucm_connector_name": "UcmNfsStore", "ucm_connector_config":{"storage_backends": "/mnt/test1", "transferStreamNumber": 32}}
+ kv_connector_extra_config={"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"}
)
```
@@ -131,13 +138,7 @@ vllm serve /home/models/Qwen2.5-14B-Instruct \
"kv_connector": "UnifiedCacheConnectorV1",
"kv_connector_module_path": "ucm.integration.vllm.uc_connector",
"kv_role": "kv_both",
- "kv_connector_extra_config": {
- "ucm_connector_name": "UcmNfsStore",
- "ucm_connector_config": {
- "storage_backends": "/mnt/test",
- "transferStreamNumber":32
- }
- }
+ "kv_connector_extra_config": {"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"}
}'
```
diff --git a/docs/source/user-guide/sparse-attention/esa.md b/docs/source/user-guide/sparse-attention/esa.md
index 3d42ba52e..53beadf10 100644
--- a/docs/source/user-guide/sparse-attention/esa.md
+++ b/docs/source/user-guide/sparse-attention/esa.md
@@ -9,6 +9,7 @@ ESA provides developers with an intuitive example of how to implement their own
### Basic Usage
ESA can be launched using the following command:
```shell
+export ENABLE_SPARSE=TRUE
export MODEL_PATH="/path/to/model" # For example: /home/models/Qwen2.5-14B-Instruct
export DATASET_PATH="/path/to/longbench/multifieldqa_zh.jsonl" # For example: /home/data/Longbench/data/multifieldqa_zh.jsonl
python examples/offline_inference_esa.py
@@ -31,8 +32,8 @@ ktc = KVTransferConfig(
"init_window_sz": 1,
"local_window_sz": 2,
"min_blocks": 4,
- "sparse_ratio": 0.3,
- "retrieval_stride": 5,
+ "sparse_ratio": 0.2,
+ "retrieval_stride": 10,
}
},
},
@@ -80,8 +81,8 @@ The following results were obtained using `Qwen2.5-14B-Instruct` under the speci
"init_window_sz": 1,
"local_window_sz": 2,
"min_blocks": 4,
- "sparse_ratio": 0.3,
- "retrieval_stride": 5
+ "sparse_ratio": 0.2,
+ "retrieval_stride": 10
}
},
```
@@ -92,5 +93,5 @@ The following results were obtained using `Qwen2.5-14B-Instruct` under the speci
We use [LongBench](https://huggingface.co/datasets/zai-org/LongBench) to evaluate the accuracy of the ESA algorithm.
| Dataset | F1-Score |
|-------|-----------|
-| multifieldqa_zh | 59.4 |
-| dureader | 26.4 |
\ No newline at end of file
+| multifieldqa_zh | 64.28 |
+| dureader | 28.73 |
diff --git a/docs/source/user-guide/sparse-attention/gsa.md b/docs/source/user-guide/sparse-attention/gsa.md
index 327fe0769..5a96287a3 100644
--- a/docs/source/user-guide/sparse-attention/gsa.md
+++ b/docs/source/user-guide/sparse-attention/gsa.md
@@ -107,6 +107,8 @@ ktc = KVTransferConfig(
Thus, an example command for launching the online LLM service is as follows:
```shell
+export ENABLE_SPARSE=TRUE
+
vllm serve /home/models/DeepSeek-R1-Distill-Qwen-32B \
--served-model-name DeepSeek-R1-Distill-Qwen-32B \
--max-model-len 131000 \
diff --git a/docs/source/user-guide/sparse-attention/kvcomp.md b/docs/source/user-guide/sparse-attention/kvcomp.md
index 3c1d0b238..4e0cbc715 100644
--- a/docs/source/user-guide/sparse-attention/kvcomp.md
+++ b/docs/source/user-guide/sparse-attention/kvcomp.md
@@ -97,6 +97,7 @@ This design ensures both **efficiency** and **accuracy** by preserving essential
KVComp is part of the UCM Sparse Attention module. For installation instructions, please refer to the [UCM's top-level README](https://github.com/ModelEngine-Group/unified-cache-management). Once UCM is installed, KVComp is naturally supported by running the following example python scripts.
```bash
+export ENABLE_SPARSE=TRUE
python ucm/sandbox/sparse/kvcomp/offline_inference_kvcomp.py
```
diff --git a/docs/source/user-guide/sparse-attention/kvstar.md b/docs/source/user-guide/sparse-attention/kvstar.md
index cf6222158..41d13358b 100644
--- a/docs/source/user-guide/sparse-attention/kvstar.md
+++ b/docs/source/user-guide/sparse-attention/kvstar.md
@@ -32,6 +32,7 @@ For long-sequence inference, KVstar achieves the following with minimal accuracy
### Basic Usage
KVstar can be launched using the following command:
```shell
+export ENABLE_SPARSE=TRUE
export MODEL_PATH="/path/to/model" # For example: /home/models/Qwen2.5-14B-Instruct
export DATASET_PATH="/path/to/longbench/multifieldqa_zh.jsonl" # For example: /home/data/Longbench/data/multifieldqa_zh.jsonl
export DATA_DIR="/path/to/data"
diff --git a/examples/metrics/grafana.json b/examples/metrics/grafana.json
new file mode 100644
index 000000000..72d175971
--- /dev/null
+++ b/examples/metrics/grafana.json
@@ -0,0 +1,1025 @@
+{
+ "annotations": {
+ "list": [
+ {
+ "builtIn": 1,
+ "datasource": {
+ "type": "grafana",
+ "uid": "-- Grafana --"
+ },
+ "enable": true,
+ "hide": true,
+ "iconColor": "rgba(0, 211, 255, 1)",
+ "name": "Annotations & Alerts",
+ "target": {
+ "limit": 100,
+ "matchAny": false,
+ "tags": [],
+ "type": "dashboard"
+ },
+ "type": "dashboard"
+ }
+ ]
+ },
+ "description": "Monitoring UnifiedCache Connector Service (load/save)",
+ "editable": true,
+ "fiscalYearStartMonth": 0,
+ "graphTooltip": 0,
+ "id": 1,
+ "links": [],
+ "liveNow": false,
+ "panels": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "description": "Hit rates of ucm lookup requests",
+ "fieldConfig": {
+ "defaults": {
+ "color": {
+ "mode": "palette-classic"
+ },
+ "custom": {
+ "axisBorderShow": false,
+ "axisCenteredZero": false,
+ "axisColorMode": "text",
+ "axisLabel": "",
+ "axisPlacement": "auto",
+ "barAlignment": 0,
+ "barWidthFactor": 0.6,
+ "drawStyle": "line",
+ "fillOpacity": 0,
+ "gradientMode": "none",
+ "hideFrom": {
+ "legend": false,
+ "tooltip": false,
+ "viz": false
+ },
+ "insertNulls": false,
+ "lineInterpolation": "linear",
+ "lineWidth": 1,
+ "pointSize": 5,
+ "scaleDistribution": {
+ "type": "linear"
+ },
+ "showPoints": "auto",
+ "spanNulls": false,
+ "stacking": {
+ "group": "A",
+ "mode": "none"
+ },
+ "thresholdsStyle": {
+ "mode": "off"
+ }
+ },
+ "mappings": [],
+ "thresholds": {
+ "mode": "absolute",
+ "steps": [
+ {
+ "color": "green",
+ "value": null
+ },
+ {
+ "color": "yellow",
+ "value": 5000
+ },
+ {
+ "color": "red",
+ "value": 10000
+ }
+ ]
+ },
+ "unit": ""
+ },
+ "overrides": []
+ },
+ "gridPos": {
+ "h": 8,
+ "w": 24,
+ "x": 0,
+ "y": 0
+ },
+ "id": 14,
+ "options": {
+ "legend": {
+ "calcs": [],
+ "displayMode": "list",
+ "placement": "bottom",
+ "showLegend": true
+ },
+ "tooltip": {
+ "mode": "single",
+ "sort": "none"
+ }
+ },
+ "targets": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "editorMode": "code",
+ "expr": "rate(ucm:interval_lookup_hit_rates_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:interval_lookup_hit_rates_count{model_name=\"$model_name\"}[$__rate_interval])",
+ "hide": false,
+ "instant": false,
+ "legendFormat": "Average",
+ "range": true,
+ "refId": "A"
+ }
+ ],
+ "title": "Connector Interval Lookup Hit Rates",
+ "type": "timeseries"
+ },
+
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "description": "Number of load requests each start_load_kv.",
+ "fieldConfig": {
+ "defaults": {
+ "color": {
+ "mode": "palette-classic"
+ },
+ "custom": {
+ "axisBorderShow": false,
+ "axisCenteredZero": false,
+ "axisColorMode": "text",
+ "axisLabel": "",
+ "axisPlacement": "auto",
+ "barAlignment": 0,
+ "barWidthFactor": 0.6,
+ "drawStyle": "line",
+ "fillOpacity": 0,
+ "gradientMode": "none",
+ "hideFrom": {
+ "legend": false,
+ "tooltip": false,
+ "viz": false
+ },
+ "insertNulls": false,
+ "lineInterpolation": "linear",
+ "lineWidth": 1,
+ "pointSize": 5,
+ "scaleDistribution": {
+ "type": "linear"
+ },
+ "showPoints": "auto",
+ "spanNulls": false,
+ "stacking": {
+ "group": "A",
+ "mode": "none"
+ },
+ "thresholdsStyle": {
+ "mode": "off"
+ }
+ },
+ "mappings": [],
+ "thresholds": {
+ "mode": "absolute",
+ "steps": [
+ {
+ "color": "green",
+ "value": null
+ },
+ {
+ "color": "yellow",
+ "value": 5000
+ },
+ {
+ "color": "red",
+ "value": 10000
+ }
+ ]
+ },
+ "unit": ""
+ },
+ "overrides": []
+ },
+ "gridPos": {
+ "h": 8,
+ "w": 12,
+ "x": 0,
+ "y": 0
+ },
+ "id": 15,
+ "options": {
+ "legend": {
+ "calcs": [],
+ "displayMode": "list",
+ "placement": "bottom",
+ "showLegend": true
+ },
+ "tooltip": {
+ "mode": "single",
+ "sort": "none"
+ }
+ },
+ "targets": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "editorMode": "code",
+ "expr": "rate(ucm:load_requests_num_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:load_requests_num_count{model_name=\"$model_name\"}[$__rate_interval])",
+ "hide": false,
+ "instant": false,
+ "legendFormat": "worker-{{worker_id}}",
+ "range": true,
+ "refId": "A"
+ }
+ ],
+ "title": "Connector Load Requests Num",
+ "type": "timeseries"
+ },
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "description": "Number of load blocks each start_load_kv.",
+ "fieldConfig": {
+ "defaults": {
+ "color": {
+ "mode": "palette-classic"
+ },
+ "custom": {
+ "axisBorderShow": false,
+ "axisCenteredZero": false,
+ "axisColorMode": "text",
+ "axisLabel": "",
+ "axisPlacement": "auto",
+ "barAlignment": 0,
+ "barWidthFactor": 0.6,
+ "drawStyle": "line",
+ "fillOpacity": 0,
+ "gradientMode": "none",
+ "hideFrom": {
+ "legend": false,
+ "tooltip": false,
+ "viz": false
+ },
+ "insertNulls": false,
+ "lineInterpolation": "linear",
+ "lineWidth": 1,
+ "pointSize": 5,
+ "scaleDistribution": {
+ "type": "linear"
+ },
+ "showPoints": "auto",
+ "spanNulls": false,
+ "stacking": {
+ "group": "A",
+ "mode": "none"
+ },
+ "thresholdsStyle": {
+ "mode": "off"
+ }
+ },
+ "mappings": [],
+ "thresholds": {
+ "mode": "absolute",
+ "steps": [
+ {
+ "color": "green",
+ "value": null
+ },
+ {
+ "color": "yellow",
+ "value": 5000
+ },
+ {
+ "color": "red",
+ "value": 10000
+ }
+ ]
+ },
+ "unit": ""
+ },
+ "overrides": []
+ },
+ "gridPos": {
+ "h": 8,
+ "w": 12,
+ "x": 12,
+ "y": 0
+ },
+ "id": 16,
+ "options": {
+ "legend": {
+ "calcs": [],
+ "displayMode": "list",
+ "placement": "bottom",
+ "showLegend": true
+ },
+ "tooltip": {
+ "mode": "single",
+ "sort": "none"
+ }
+ },
+ "targets": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "editorMode": "code",
+ "expr": "rate(ucm:load_blocks_num_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:load_blocks_num_count{model_name=\"$model_name\"}[$__rate_interval])",
+ "hide": false,
+ "instant": false,
+ "legendFormat": "worker-{{worker_id}}",
+ "range": true,
+ "refId": "A"
+ }
+ ],
+ "title": "Connector Load Blocks Num",
+ "type": "timeseries"
+ },
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "description": "P50, P90, P95, P99 and Average load duration in milliseconds for each start_load_kv.",
+ "fieldConfig": {
+ "defaults": {
+ "color": {
+ "mode": "palette-classic"
+ },
+ "custom": {
+ "axisBorderShow": false,
+ "axisCenteredZero": false,
+ "axisColorMode": "text",
+ "axisLabel": "",
+ "axisPlacement": "auto",
+ "barAlignment": 0,
+ "barWidthFactor": 0.6,
+ "drawStyle": "line",
+ "fillOpacity": 0,
+ "gradientMode": "none",
+ "hideFrom": {
+ "legend": false,
+ "tooltip": false,
+ "viz": false
+ },
+ "insertNulls": false,
+ "lineInterpolation": "linear",
+ "lineWidth": 1,
+ "pointSize": 5,
+ "scaleDistribution": {
+ "type": "linear"
+ },
+ "showPoints": "auto",
+ "spanNulls": false,
+ "stacking": {
+ "group": "A",
+ "mode": "none"
+ },
+ "thresholdsStyle": {
+ "mode": "off"
+ }
+ },
+ "mappings": [],
+ "thresholds": {
+ "mode": "absolute",
+ "steps": [
+ {
+ "color": "green",
+ "value": null
+ },
+ {
+ "color": "yellow",
+ "value": 5000
+ },
+ {
+ "color": "red",
+ "value": 10000
+ }
+ ]
+ },
+ "unit": "ms"
+ },
+ "overrides": []
+ },
+ "gridPos": {
+ "h": 8,
+ "w": 12,
+ "x": 0,
+ "y": 8
+ },
+ "id": 17,
+ "options": {
+ "legend": {
+ "calcs": [],
+ "displayMode": "list",
+ "placement": "bottom",
+ "showLegend": true
+ },
+ "tooltip": {
+ "mode": "single",
+ "sort": "none"
+ }
+ },
+ "targets": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "editorMode": "code",
+ "expr": "rate(ucm:load_duration_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:load_duration_count{model_name=\"$model_name\"}[$__rate_interval])",
+ "hide": false,
+ "instant": false,
+ "legendFormat": "worker-{{worker_id}}",
+ "range": true,
+ "refId": "A"
+ }
+ ],
+ "title": "Connector Load Duration",
+ "type": "timeseries"
+ },
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "description": "P50, P90, P95, P99 and Average load speed in GB/s for each start_load_kv.",
+ "fieldConfig": {
+ "defaults": {
+ "color": {
+ "mode": "palette-classic"
+ },
+ "custom": {
+ "axisBorderShow": false,
+ "axisCenteredZero": false,
+ "axisColorMode": "text",
+ "axisLabel": "",
+ "axisPlacement": "auto",
+ "barAlignment": 0,
+ "barWidthFactor": 0.6,
+ "drawStyle": "line",
+ "fillOpacity": 0,
+ "gradientMode": "none",
+ "hideFrom": {
+ "legend": false,
+ "tooltip": false,
+ "viz": false
+ },
+ "insertNulls": false,
+ "lineInterpolation": "linear",
+ "lineWidth": 1,
+ "pointSize": 5,
+ "scaleDistribution": {
+ "type": "linear"
+ },
+ "showPoints": "auto",
+ "spanNulls": false,
+ "stacking": {
+ "group": "A",
+ "mode": "none"
+ },
+ "thresholdsStyle": {
+ "mode": "off"
+ }
+ },
+ "mappings": [],
+ "thresholds": {
+ "mode": "absolute",
+ "steps": [
+ {
+ "color": "red",
+ "value": null
+ },
+ {
+ "color": "yellow",
+ "value": 0.005
+ },
+ {
+ "color": "green",
+ "value": 0.01
+ }
+ ]
+ },
+ "unit": "gb/s"
+ },
+ "overrides": []
+ },
+ "gridPos": {
+ "h": 8,
+ "w": 12,
+ "x": 12,
+ "y": 8
+ },
+ "id": 21,
+ "options": {
+ "legend": {
+ "calcs": [],
+ "displayMode": "list",
+ "placement": "bottom",
+ "showLegend": true
+ },
+ "tooltip": {
+ "mode": "single",
+ "sort": "none"
+ }
+ },
+ "targets": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "editorMode": "code",
+ "expr": "rate(ucm:load_speed_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:load_speed_count{model_name=\"$model_name\"}[$__rate_interval])",
+ "hide": false,
+ "instant": false,
+ "legendFormat": "worker-{{worker_id}}",
+ "range": true,
+ "refId": "A"
+ }
+ ],
+ "title": "Connector Load Speed",
+ "type": "timeseries"
+ },
+
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "description": "Number of save requests each wait_for_save.",
+ "fieldConfig": {
+ "defaults": {
+ "color": {
+ "mode": "palette-classic"
+ },
+ "custom": {
+ "axisBorderShow": false,
+ "axisCenteredZero": false,
+ "axisColorMode": "text",
+ "axisLabel": "",
+ "axisPlacement": "auto",
+ "barAlignment": 0,
+ "barWidthFactor": 0.6,
+ "drawStyle": "line",
+ "fillOpacity": 0,
+ "gradientMode": "none",
+ "hideFrom": {
+ "legend": false,
+ "tooltip": false,
+ "viz": false
+ },
+ "insertNulls": false,
+ "lineInterpolation": "linear",
+ "lineWidth": 1,
+ "pointSize": 5,
+ "scaleDistribution": {
+ "type": "linear"
+ },
+ "showPoints": "auto",
+ "spanNulls": false,
+ "stacking": {
+ "group": "A",
+ "mode": "none"
+ },
+ "thresholdsStyle": {
+ "mode": "off"
+ }
+ },
+ "mappings": [],
+ "thresholds": {
+ "mode": "absolute",
+ "steps": [
+ {
+ "color": "green",
+ "value": null
+ },
+ {
+ "color": "yellow",
+ "value": 5000
+ },
+ {
+ "color": "red",
+ "value": 10000
+ }
+ ]
+ },
+ "unit": ""
+ },
+ "overrides": []
+ },
+ "gridPos": {
+ "h": 8,
+ "w": 12,
+ "x": 0,
+ "y": 16
+ },
+ "id": 19,
+ "options": {
+ "legend": {
+ "calcs": [],
+ "displayMode": "list",
+ "placement": "bottom",
+ "showLegend": true
+ },
+ "tooltip": {
+ "mode": "single",
+ "sort": "none"
+ }
+ },
+ "targets": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "editorMode": "code",
+ "expr": "rate(ucm:save_requests_num_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:save_requests_num_count{model_name=\"$model_name\"}[$__rate_interval])",
+ "hide": false,
+ "instant": false,
+ "legendFormat": "worker-{{worker_id}}",
+ "range": true,
+ "refId": "A"
+ }
+ ],
+ "title": "Connector Save Requests Num",
+ "type": "timeseries"
+ },
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "description": "Number of save blocks each wait_for_save.",
+ "fieldConfig": {
+ "defaults": {
+ "color": {
+ "mode": "palette-classic"
+ },
+ "custom": {
+ "axisBorderShow": false,
+ "axisCenteredZero": false,
+ "axisColorMode": "text",
+ "axisLabel": "",
+ "axisPlacement": "auto",
+ "barAlignment": 0,
+ "barWidthFactor": 0.6,
+ "drawStyle": "line",
+ "fillOpacity": 0,
+ "gradientMode": "none",
+ "hideFrom": {
+ "legend": false,
+ "tooltip": false,
+ "viz": false
+ },
+ "insertNulls": false,
+ "lineInterpolation": "linear",
+ "lineWidth": 1,
+ "pointSize": 5,
+ "scaleDistribution": {
+ "type": "linear"
+ },
+ "showPoints": "auto",
+ "spanNulls": false,
+ "stacking": {
+ "group": "A",
+ "mode": "none"
+ },
+ "thresholdsStyle": {
+ "mode": "off"
+ }
+ },
+ "mappings": [],
+ "thresholds": {
+ "mode": "absolute",
+ "steps": [
+ {
+ "color": "green",
+ "value": null
+ },
+ {
+ "color": "yellow",
+ "value": 5000
+ },
+ {
+ "color": "red",
+ "value": 10000
+ }
+ ]
+ },
+ "unit": ""
+ },
+ "overrides": []
+ },
+ "gridPos": {
+ "h": 8,
+ "w": 12,
+ "x": 12,
+ "y": 16
+ },
+ "id": 20,
+ "options": {
+ "legend": {
+ "calcs": [],
+ "displayMode": "list",
+ "placement": "bottom",
+ "showLegend": true
+ },
+ "tooltip": {
+ "mode": "single",
+ "sort": "none"
+ }
+ },
+ "targets": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "editorMode": "code",
+ "expr": "rate(ucm:save_blocks_num_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:save_blocks_num_count{model_name=\"$model_name\"}[$__rate_interval])",
+ "hide": false,
+ "instant": false,
+ "legendFormat": "worker-{{worker_id}}",
+ "range": true,
+ "refId": "A"
+ }
+ ],
+ "title": "Connector Save Blocks Num",
+ "type": "timeseries"
+ },
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "description": "P50, P90, P95, P99 and Average save duration in milliseconds for each save_kv.",
+ "fieldConfig": {
+ "defaults": {
+ "color": {
+ "mode": "palette-classic"
+ },
+ "custom": {
+ "axisBorderShow": false,
+ "axisCenteredZero": false,
+ "axisColorMode": "text",
+ "axisLabel": "",
+ "axisPlacement": "auto",
+ "barAlignment": 0,
+ "barWidthFactor": 0.6,
+ "drawStyle": "line",
+ "fillOpacity": 0,
+ "gradientMode": "none",
+ "hideFrom": {
+ "legend": false,
+ "tooltip": false,
+ "viz": false
+ },
+ "insertNulls": false,
+ "lineInterpolation": "linear",
+ "lineWidth": 1,
+ "pointSize": 5,
+ "scaleDistribution": {
+ "type": "linear"
+ },
+ "showPoints": "auto",
+ "spanNulls": false,
+ "stacking": {
+ "group": "A",
+ "mode": "none"
+ },
+ "thresholdsStyle": {
+ "mode": "off"
+ }
+ },
+ "mappings": [],
+ "thresholds": {
+ "mode": "absolute",
+ "steps": [
+ {
+ "color": "green",
+ "value": null
+ },
+ {
+ "color": "yellow",
+ "value": 8000
+ },
+ {
+ "color": "red",
+ "value": 15000
+ }
+ ]
+ },
+ "unit": "ms"
+ },
+ "overrides": []
+ },
+ "gridPos": {
+ "h": 8,
+ "w": 12,
+ "x": 0,
+ "y": 24
+ },
+ "id": 18,
+ "options": {
+ "legend": {
+ "calcs": [],
+ "displayMode": "list",
+ "placement": "bottom",
+ "showLegend": true
+ },
+ "tooltip": {
+ "mode": "single",
+ "sort": "none"
+ }
+ },
+ "targets": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "editorMode": "code",
+ "expr": "rate(ucm:save_duration_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:save_duration_count{model_name=\"$model_name\"}[$__rate_interval])",
+ "hide": false,
+ "instant": false,
+ "legendFormat": "worker-{{worker_id}}",
+ "range": true,
+ "refId": "A"
+ }
+ ],
+ "title": "Connector Save Duration",
+ "type": "timeseries"
+ },
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "description": "P50, P90, P95, P99 and Average save speed in GB/s for each save_kv.",
+ "fieldConfig": {
+ "defaults": {
+ "color": {
+ "mode": "palette-classic"
+ },
+ "custom": {
+ "axisBorderShow": false,
+ "axisCenteredZero": false,
+ "axisColorMode": "text",
+ "axisLabel": "",
+ "axisPlacement": "auto",
+ "barAlignment": 0,
+ "barWidthFactor": 0.6,
+ "drawStyle": "line",
+ "fillOpacity": 0,
+ "gradientMode": "none",
+ "hideFrom": {
+ "legend": false,
+ "tooltip": false,
+ "viz": false
+ },
+ "insertNulls": false,
+ "lineInterpolation": "linear",
+ "lineWidth": 1,
+ "pointSize": 5,
+ "scaleDistribution": {
+ "type": "linear"
+ },
+ "showPoints": "auto",
+ "spanNulls": false,
+ "stacking": {
+ "group": "A",
+ "mode": "none"
+ },
+ "thresholdsStyle": {
+ "mode": "off"
+ }
+ },
+ "mappings": [],
+ "thresholds": {
+ "mode": "absolute",
+ "steps": [
+ {
+ "color": "red",
+ "value": null
+ },
+ {
+ "color": "yellow",
+ "value": 0.004
+ },
+ {
+ "color": "green",
+ "value": 0.008
+ }
+ ]
+ },
+ "unit": "gb/s"
+ },
+ "overrides": []
+ },
+ "gridPos": {
+ "h": 8,
+ "w": 12,
+ "x": 12,
+ "y": 24
+ },
+ "id": 22,
+ "options": {
+ "legend": {
+ "calcs": [],
+ "displayMode": "list",
+ "placement": "bottom",
+ "showLegend": true
+ },
+ "tooltip": {
+ "mode": "single",
+ "sort": "none"
+ }
+ },
+ "targets": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "editorMode": "code",
+ "expr": "rate(ucm:save_speed_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:save_speed_count{model_name=\"$model_name\"}[$__rate_interval])",
+ "hide": false,
+ "instant": false,
+ "legendFormat": "worker-{{worker_id}}",
+ "range": true,
+ "refId": "A"
+ }
+ ],
+ "title": "Connector Save Speed",
+ "type": "timeseries"
+ }
+ ],
+ "refresh": "",
+ "schemaVersion": 39,
+ "tags": [],
+ "templating": {
+ "list": [
+ {
+ "current": {
+ "selected": false,
+ "text": "prometheus",
+ "value": "edx8memhpd9tsa"
+ },
+ "hide": 0,
+ "includeAll": false,
+ "label": "datasource",
+ "multi": false,
+ "name": "DS_PROMETHEUS",
+ "options": [],
+ "query": "prometheus",
+ "queryValue": "",
+ "refresh": 1,
+ "regex": "",
+ "skipUrlSync": false,
+ "type": "datasource"
+ },
+ {
+ "current": {
+ "selected": false,
+ "text": "/share/datasets/public_models/Meta-Llama-3-8B-Instruct",
+ "value": "/share/datasets/public_models/Meta-Llama-3-8B-Instruct"
+ },
+ "datasource": {
+ "type": "prometheus",
+ "uid": "${DS_PROMETHEUS}"
+ },
+ "definition": "label_values(model_name)",
+ "hide": 0,
+ "includeAll": false,
+ "label": "model_name",
+ "multi": false,
+ "name": "model_name",
+ "options": [],
+ "query": {
+ "query": "label_values(model_name)",
+ "refId": "StandardVariableQuery"
+ },
+ "refresh": 1,
+ "regex": "",
+ "skipUrlSync": false,
+ "sort": 0,
+ "type": "query"
+ }
+ ]
+ },
+ "time": {
+ "from": "now-5m",
+ "to": "now"
+ },
+ "timepicker": {},
+ "timezone": "",
+ "title": "vLLM - UnifiedCache Connector Monitoring",
+ "uid": "b281712d-8bff-41ef-9f3f-71ad43c05e9b",
+ "version": 9,
+ "weekStart": ""
+}
\ No newline at end of file
diff --git a/examples/metrics/metrics_configs.yaml b/examples/metrics/metrics_configs.yaml
new file mode 100644
index 000000000..5ed07baa9
--- /dev/null
+++ b/examples/metrics/metrics_configs.yaml
@@ -0,0 +1,56 @@
+# Prometheus Metrics Configuration
+# This file defines which metrics should be enabled and their configurations
+log_interval: 5 # Interval in seconds for logging metrics
+
+prometheus:
+ multiproc_dir: "/vllm-workspace" # Directory for Prometheus multiprocess mode
+
+ metric_prefix: "ucm:"
+
+ # Enable/disable metrics by category
+ enabled_metrics:
+ counters: true
+ gauges: true
+ histograms: true
+
+ # Counter metrics configuration
+ # counters:
+ # - name: "received_requests"
+ # documentation: "Total number of requests sent to ucm"
+
+ # Gauge metrics configuration
+ # gauges:
+ # - name: "lookup_hit_rate"
+ # documentation: "Hit rate of ucm lookup requests since last log"
+ # multiprocess_mode: "livemostrecent"
+
+ # Histogram metrics configuration
+ histograms:
+ - name: "load_requests_num"
+ documentation: "Number of requests loaded from ucm"
+ buckets: [1, 5, 10, 20, 50, 100, 200, 500, 1000]
+ - name: "load_blocks_num"
+ documentation: "Number of blocks loaded from ucm"
+ buckets: [0, 50, 100, 150, 200, 250, 300, 350, 400, 550, 600, 750, 800, 850, 900, 950, 1000]
+ - name: "load_duration"
+ documentation: "Time to load from ucm (ms)"
+ buckets: [0, 50, 100, 150, 200, 250, 300, 350, 400, 550, 600, 750, 800, 850, 900, 950, 1000]
+ - name: "load_speed"
+ documentation: "Speed of loading from ucm (GB/s)"
+ buckets: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 50, 60, 70, 80, 90, 100]
+ - name: "save_requests_num"
+ documentation: "Number of requests saved to ucm"
+ buckets: [1, 5, 10, 20, 50, 100, 200, 500, 1000]
+ - name: "save_blocks_num"
+ documentation: "Number of blocks saved to ucm"
+ buckets: [0, 50, 100, 150, 200, 250, 300, 350, 400, 550, 600, 750, 800, 850, 900, 950, 1000]
+ - name: "save_duration"
+ documentation: "Time to save to ucm (ms)"
+ buckets: [0, 50, 100, 150, 200, 250, 300, 350, 400, 550, 600, 750, 800, 850, 900, 950, 1000]
+ - name: "save_speed"
+ documentation: "Speed of saving to ucm (GB/s)"
+ buckets: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 50, 60, 70, 80, 90, 100]
+ - name: "interval_lookup_hit_rates"
+ documentation: "Hit rates of ucm lookup requests"
+ buckets: [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
+
diff --git a/examples/offline_inference.py b/examples/offline_inference.py
index f50682464..5a2fea372 100644
--- a/examples/offline_inference.py
+++ b/examples/offline_inference.py
@@ -1,5 +1,4 @@
import contextlib
-import json
import os
import time
from dataclasses import asdict
@@ -16,11 +15,6 @@
logger = init_logger(__name__)
-def setup_environment_variables():
- os.environ["VLLM_USE_V1"] = "1"
- os.environ["PYTHONHASHSEED"] = "123456"
-
-
@contextlib.contextmanager
def build_llm_with_uc(module_path: str, name: str, model: str):
ktc = KVTransferConfig(
@@ -28,20 +22,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
kv_connector_module_path=module_path,
kv_role="kv_both",
kv_connector_extra_config={
- "ucm_connector_name": "UcmDramStore",
- "ucm_connector_config": {
- "max_cache_size": 5368709120,
- "kv_block_size": 262144,
- },
- "ucm_sparse_config": {
- "ESA": {
- "init_window_sz": 1,
- "local_window_sz": 2,
- "min_blocks": 4,
- "sparse_ratio": 0.3,
- "retrieval_stride": 5,
- }
- },
+ "UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"
},
)
@@ -53,6 +34,8 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
max_num_batched_tokens=30000,
block_size=128,
enforce_eager=True,
+ trust_remote_code=True,
+ enable_prefix_caching=False,
)
llm = LLM(**asdict(llm_args))
@@ -79,22 +62,41 @@ def print_output(
def main():
- module_path = "ucm.integration.vllm.uc_connector"
- name = "UnifiedCacheConnectorV1"
- model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct")
+ module_path = "ucm.integration.vllm.ucm_connector"
+ name = "UCMConnector"
+ model = os.getenv("MODEL_PATH", "/home/models/DeepSeek-V2-Lite")
tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True)
- setup_environment_variables()
with build_llm_with_uc(module_path, name, model) as llm:
messages = [
{
"role": "system",
- "content": "You are a highly specialized assistant whose mission is to faithfully reproduce English literary texts verbatim, without any deviation, paraphrasing, or omission. Your primary responsibility is accuracy: every word, every punctuation mark, and every line must appear exactly as in the original source. Core Principles: Verbatim Reproduction: If the user asks for a passage, you must output the text word-for-word. Do not alter spelling, punctuation, capitalization, or line breaks. Do not paraphrase, summarize, modernize, or “improve” the language. Consistency: The same input must always yield the same output. Do not generate alternative versions or interpretations. Clarity of Scope: Your role is not to explain, interpret, or critique. You are not a storyteller or commentator, but a faithful copyist of English literary and cultural texts. Recognizability: Because texts must be reproduced exactly, they will carry their own cultural recognition. You should not add labels, introductions, or explanations before or after the text. Coverage: You must handle passages from classic literature, poetry, speeches, or cultural texts. Regardless of tone—solemn, visionary, poetic, persuasive—you must preserve the original form, structure, and rhythm by reproducing it precisely. Success Criteria: A human reader should be able to compare your output directly with the original and find zero differences. The measure of success is absolute textual fidelity. Your function can be summarized as follows: verbatim reproduction only, no paraphrase, no commentary, no embellishment, no omission.",
+ "content": "You are a highly specialized assistant whose mission is to faithfully reproduce English "
+ "literary texts verbatim, without any deviation, paraphrasing, or omission. Your primary "
+ "responsibility is accuracy: every word, every punctuation mark, and every line must "
+ "appear exactly as in the original source. Core Principles: Verbatim Reproduction: If the "
+ "user asks for a passage, you must output the text word-for-word. Do not alter spelling, "
+ "punctuation, capitalization, or line breaks. Do not paraphrase, summarize, modernize, "
+ "or “improve” the language. Consistency: The same input must always yield the same output. "
+ "Do not generate alternative versions or interpretations. Clarity of Scope: Your role is "
+ "not to explain, interpret, or critique. You are not a storyteller or commentator, "
+ "but a faithful copyist of English literary and cultural texts. Recognizability: Because "
+ "texts must be reproduced exactly, they will carry their own cultural recognition. You "
+ "should not add labels, introductions, or explanations before or after the text. Coverage: "
+ "You must handle passages from classic literature, poetry, speeches, or cultural texts. "
+ "Regardless of tone—solemn, visionary, poetic, persuasive—you must preserve the original "
+ "form, structure, and rhythm by reproducing it precisely. Success Criteria: A human reader "
+ "should be able to compare your output directly with the original and find zero "
+ "differences. The measure of success is absolute textual fidelity. Your function can be "
+ "summarized as follows: verbatim reproduction only, no paraphrase, no commentary, "
+ "no embellishment, no omission.",
},
{
"role": "user",
- "content": "Please reproduce verbatim the opening sentence of the United States Declaration of Independence (1776), starting with 'When in the Course of human events' and continuing word-for-word without paraphrasing.",
+ "content": "Please reproduce verbatim the opening sentence of the United States Declaration of "
+ "Independence (1776), starting with 'When in the Course of human events' and continuing "
+ "word-for-word without paraphrasing.",
},
]
diff --git a/examples/offline_inference_esa.py b/examples/offline_inference_esa.py
index caae0970d..852a8ca02 100644
--- a/examples/offline_inference_esa.py
+++ b/examples/offline_inference_esa.py
@@ -24,6 +24,7 @@
def setup_environment_variables():
os.environ["VLLM_USE_V1"] = "1"
os.environ["PYTHONHASHSEED"] = "123456"
+ os.environ["ENABLE_SPARSE"] = "true"
global model, path_to_dataset, data_dir, tokenizer
model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct")
@@ -45,10 +46,10 @@ def setup_environment_variables():
sys.exit(1)
data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache")
- data_dir = input(
- "Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
- )
if not os.path.isdir(data_dir):
+ data_dir = input(
+ "Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
+ )
create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ")
if create.lower() == "y":
os.makedirs(data_dir, exist_ok=True)
@@ -66,11 +67,15 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
kv_connector_module_path=module_path,
kv_role="kv_both",
kv_connector_extra_config={
- "ucm_connector_name": "UcmNfsStore",
- "ucm_connector_config": {
- "storage_backends": data_dir,
- "kv_block_size": 33554432,
- },
+ "ucm_connectors": [
+ {
+ "ucm_connector_name": "UcmNfsStore",
+ "ucm_connector_config": {
+ "storage_backends": data_dir,
+ "use_direct": False,
+ },
+ }
+ ],
"ucm_sparse_config": {
"ESA": {
"init_window_sz": 1,
@@ -87,12 +92,13 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
model=model,
kv_transfer_config=ktc,
max_model_len=32768,
- gpu_memory_utilization=0.6,
+ gpu_memory_utilization=0.8,
max_num_batched_tokens=30000,
block_size=128,
enforce_eager=True,
distributed_executor_backend="mp",
tensor_parallel_size=1,
+ trust_remote_code=True,
)
llm = LLM(**asdict(llm_args))
@@ -111,16 +117,20 @@ def print_output(
start = time.time()
outputs = llm.generate(prompt, sampling_params)
print("-" * 50)
+ lines = []
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
+ lines.append(generated_text + "\n")
print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
+ with open("./newest_out.txt", "w") as f:
+ f.writelines(lines)
print("-" * 50)
def main():
- module_path = "ucm.integration.vllm.uc_connector"
- name = "UnifiedCacheConnectorV1"
+ module_path = "ucm.integration.vllm.ucm_connector"
+ name = "UCMConnector"
setup_environment_variables()
def get_prompt(prompt):
@@ -140,24 +150,23 @@ def get_prompt(prompt):
with build_llm_with_uc(module_path, name, model) as llm:
prompts = []
- batch_size = 5
+ batch_size = 20
assert os.path.isfile(
path_to_dataset
), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`"
with open(path_to_dataset, "r") as f:
- for _ in range(batch_size):
- line = f.readline()
- if not line:
- break
- data = json.loads(line)
- context = data["context"]
- question = data["input"]
- prompts.append(get_prompt(f"{context}\n\n{question}"))
-
- sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=100)
+ lines = f.readlines()
+ for i in range(batch_size):
+ line = lines[i]
+ data = json.loads(line)
+ prompt = f"""阅读以下文字并用中文简短回答:\n\n{data["context"]}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{data["input"]}\n回答:"""
+ prompts.append(get_prompt(prompt))
+
+ sampling_params = SamplingParams(
+ temperature=0, top_p=0.95, max_tokens=256, ignore_eos=False
+ )
print_output(llm, prompts, sampling_params, "first")
- print_output(llm, prompts, sampling_params, "second")
if __name__ == "__main__":
diff --git a/examples/offline_inference_kvcomp.py b/examples/offline_inference_kvcomp.py
new file mode 100644
index 000000000..f1cd6c196
--- /dev/null
+++ b/examples/offline_inference_kvcomp.py
@@ -0,0 +1,172 @@
+import contextlib
+import json
+import os
+import sys
+import time
+from dataclasses import asdict
+
+from transformers import AutoTokenizer
+
+# Third Party
+from vllm import LLM, SamplingParams
+from vllm.config import KVTransferConfig
+from vllm.engine.arg_utils import EngineArgs
+
+from ucm.logger import init_logger
+
+logger = init_logger(__name__)
+model = ""
+path_to_dataset = ""
+data_dir = ""
+tokenizer = None
+
+
+def setup_environment_variables():
+ os.environ["VLLM_USE_V1"] = "1"
+ os.environ["PYTHONHASHSEED"] = "123456"
+ os.environ["ENABLE_SPARSE"] = "true"
+
+ global model, path_to_dataset, data_dir, tokenizer
+ model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct")
+ if not os.path.isdir(model):
+ model = input("Enter path to model, e.g. /home/models/Qwen2.5-14B-Instruct: ")
+ if not os.path.isdir(model):
+ print("Exiting. Incorrect model_path")
+ sys.exit(1)
+
+ path_to_dataset = os.getenv(
+ "DATASET_PATH", "/home/data/Longbench/data/multifieldqa_zh.jsonl"
+ )
+ if not os.path.isfile(path_to_dataset):
+ path_to_dataset = input(
+ "Enter path to one of the longbench dataset, e.g. /home/data/Longbench/data/multifieldqa_zh.jsonl: "
+ )
+ if not os.path.isfile(path_to_dataset):
+ print("Exiting. Incorrect dataset path")
+ sys.exit(1)
+
+ data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache")
+ data_dir = input(
+ "Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
+ )
+ if not os.path.isdir(data_dir):
+ create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ")
+ if create.lower() == "y":
+ os.makedirs(data_dir, exist_ok=True)
+ else:
+ print("Exiting. Directory not created.")
+ sys.exit(1)
+
+ tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True)
+
+
+@contextlib.contextmanager
+def build_llm_with_uc(module_path: str, name: str, model: str):
+ ktc = KVTransferConfig(
+ kv_connector=name,
+ kv_connector_module_path=module_path,
+ kv_role="kv_both",
+ kv_connector_extra_config={
+ "ucm_connectors": [
+ {
+ "ucm_connector_name": "UcmNfsStore",
+ "ucm_connector_config": {
+ "storage_backends": data_dir,
+ "use_direct": False,
+ },
+ }
+ ],
+ "ucm_sparse_config": {
+ "KvComp": {
+ "init_window_sz": 1,
+ "local_window_sz": 2,
+ "min_blocks": 4,
+ "sparse_ratio": 0.3,
+ "retrieval_stride": 5,
+ }
+ },
+ # "kvcomp_config_path": "unified-cache-management/ucm/sparse/kvcomp/configs/kvcomp_deepseek_v2_lite_config.json",
+ "kvcomp_config_path": "unified-cache-management/ucm/sparse/kvcomp/configs/kvcomp_qwq_32B_config.json",
+ },
+ )
+
+ llm_args = EngineArgs(
+ model=model,
+ kv_transfer_config=ktc,
+ max_model_len=32768,
+ gpu_memory_utilization=0.8,
+ max_num_batched_tokens=30000,
+ block_size=128,
+ enforce_eager=True,
+ distributed_executor_backend="mp",
+ tensor_parallel_size=2,
+ trust_remote_code=True,
+ )
+
+ llm = LLM(**asdict(llm_args))
+ try:
+ yield llm
+ finally:
+ logger.info("LLM engine is exiting.")
+
+
+def print_output(
+ llm: LLM,
+ prompt: list[str],
+ sampling_params: SamplingParams,
+ req_str: str,
+):
+ start = time.time()
+ outputs = llm.generate(prompt, sampling_params)
+ print("-" * 50)
+ for output in outputs:
+ generated_text = output.outputs[0].text
+ print(f"Generated text: {generated_text!r}")
+ print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
+ print("-" * 50)
+
+
+def main():
+ module_path = "ucm.integration.vllm.ucm_connector"
+ name = "UCMConnector"
+ setup_environment_variables()
+
+ def get_prompt(prompt):
+ messages = [
+ {
+ "role": "system",
+ "content": "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:“全国美国文学研究会的第十八届年会在哪所大学举办的?”\n回答应该为:“xx大学”。\n\n",
+ },
+ {"role": "user", "content": prompt},
+ ]
+ return tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ add_special_tokens=True,
+ )
+
+ with build_llm_with_uc(module_path, name, model) as llm:
+ prompts = []
+ batch_size = 10
+ assert os.path.isfile(
+ path_to_dataset
+ ), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`"
+ with open(path_to_dataset, "r") as f:
+ lines = f.readlines()
+ for i in range(batch_size):
+ line = lines[i]
+ data = json.loads(line)
+ prompt = f"""阅读以下文字并用中文简短回答:\n\n{data["context"]}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{data["input"]}\n回答:"""
+ prompts.append(get_prompt(prompt))
+
+ sampling_params = SamplingParams(
+ temperature=0, top_p=0.95, max_tokens=256, ignore_eos=False
+ )
+
+ print_output(llm, prompts, sampling_params, "first")
+ print_output(llm, prompts, sampling_params, "second")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/offline_inference_kvstar.py b/examples/offline_inference_kvstar.py
index e26113993..c8dfd1eec 100644
--- a/examples/offline_inference_kvstar.py
+++ b/examples/offline_inference_kvstar.py
@@ -24,6 +24,7 @@
def setup_environment_variables():
os.environ["VLLM_USE_V1"] = "1"
os.environ["PYTHONHASHSEED"] = "123456"
+ os.environ["ENABLE_SPARSE"] = "true"
global model, path_to_dataset, data_dir, tokenizer
model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct")
@@ -67,11 +68,15 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
kv_connector_module_path=module_path,
kv_role="kv_both",
kv_connector_extra_config={
- "ucm_connector_name": "UcmNfsStore",
- "ucm_connector_config": {
- "storage_backends": data_dir,
- "kv_block_size": 33554432,
- },
+ "ucm_connectors": [
+ {
+ "ucm_connector_name": "UcmNfsStore",
+ "ucm_connector_config": {
+ "storage_backends": data_dir,
+ "use_direct": False,
+ },
+ }
+ ],
"ucm_sparse_config": {
"KVStarMultiStep": {
"init_window_sz": 1,
@@ -122,8 +127,8 @@ def print_output(
def main():
- module_path = "ucm.integration.vllm.uc_connector"
- name = "UnifiedCacheConnectorV1"
+ module_path = "ucm.integration.vllm.ucm_connector"
+ name = "UCMConnector"
setup_environment_variables()
def get_prompt(prompt):
diff --git a/examples/ucm_config_example.yaml b/examples/ucm_config_example.yaml
new file mode 100644
index 000000000..32c357e1b
--- /dev/null
+++ b/examples/ucm_config_example.yaml
@@ -0,0 +1,38 @@
+# UCM Configuration File Example
+#
+# This file demonstrates how to configure UCM using YAML.
+# You can use this config file by setting the path to this file in kv_connector_extra_config in launch script or command line like this:
+# kv_connector_extra_config={"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"}
+#
+# Alternatively, you can still use kv_connector_extra_config in KVTransferConfig
+# for backward compatibility.
+
+# Connector name (e.g., "UcmNfsStore", "UcmDramStore")
+ucm_connectors:
+ - ucm_connector_name: "UcmNfsStore"
+ ucm_connector_config:
+ storage_backends: "/mnt/test"
+ use_direct: false
+
+load_only_first_rank: false
+
+# Enable UCM metrics so they can be monitored online via Grafana and Prometheus.
+# metrics_config_path: "/workspace/unified-cache-management/examples/metrics/metrics_configs.yaml"
+
+# Sparse attention configuration
+# Format 1: Dictionary format (for methods like ESA, KvComp)
+# ucm_sparse_config:
+# ESA:
+# init_window_sz: 1
+# local_window_sz: 2
+# min_blocks: 4
+# sparse_ratio: 0.3
+# retrieval_stride: 5
+ # Or for GSA:
+ # GSA: {}
+
+
+# Whether to use layerwise loading/saving (optional, default: True for UnifiedCacheConnectorV1)
+# use_layerwise: true
+# hit_ratio: 0.9
+
diff --git a/setup.py b/setup.py
index a86d6d513..8b4c8f660 100644
--- a/setup.py
+++ b/setup.py
@@ -26,6 +26,7 @@
import subprocess
import sys
import sysconfig
+from glob import glob
import pybind11
import torch
@@ -45,6 +46,10 @@ def _is_npu() -> bool:
return PLATFORM == "ascend"
+def _is_musa() -> bool:
+ return PLATFORM == "musa"
+
+
class CMakeExtension(Extension):
def __init__(self, name: str, sourcedir: str = ""):
super().__init__(name, sources=[])
@@ -84,10 +89,12 @@ def build_cmake(self, ext: CMakeExtension):
cmake_args.append("-DRUNTIME_ENVIRONMENT=cuda")
elif _is_npu():
cmake_args.append("-DRUNTIME_ENVIRONMENT=ascend")
+ elif _is_musa():
+ cmake_args.append("-DRUNTIME_ENVIRONMENT=musa")
else:
raise RuntimeError(
"No supported accelerator found. "
- "Please ensure either CUDA or NPU is available."
+ "Please ensure either CUDA/MUSA or NPU is available."
)
cmake_args.append(ext.sourcedir)
@@ -104,17 +111,37 @@ def build_cmake(self, ext: CMakeExtension):
)
+def _get_package_data_with_so():
+ """Automatically discover all packages and include .so files."""
+ packages = find_packages()
+ package_data = {}
+
+ for package in packages:
+ # Convert package name to directory path
+ package_dir = os.path.join(ROOT_DIR, package.replace(".", os.sep))
+
+ # Check if this package directory contains .so files
+ so_files = glob(os.path.join(package_dir, "*.so"))
+ if so_files:
+ package_data[package] = ["*.so"]
+ print(f"[INFO] Including .so files for package: {package}")
+
+ print(f"[INFO] Package data: {package_data}")
+ return package_data
+
+
ext_modules = []
ext_modules.append(CMakeExtension(name="ucm", sourcedir=ROOT_DIR))
setup(
- name="ucm",
- version="0.0.2",
+ name="uc-manager",
+ version="0.1.0",
description="Unified Cache Management",
author="Unified Cache Team",
packages=find_packages(),
python_requires=">=3.10",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuild},
+ package_data=_get_package_data_with_so(),
zip_safe=False,
)
diff --git a/test/.gitignore b/test/.gitignore
new file mode 100644
index 000000000..220d21ac1
--- /dev/null
+++ b/test/.gitignore
@@ -0,0 +1,13 @@
+reports/
+dataset/
+logs/
+result_outputs/
+results/
+.cache/
+backup/
+$null
+*__pycache__/
+.*
+*.log
+start.bat
+!.gitignore
\ No newline at end of file
diff --git a/test/README.md b/test/README.md
new file mode 100644
index 000000000..1e11da7e7
--- /dev/null
+++ b/test/README.md
@@ -0,0 +1,179 @@
+# Pytest
+[简体中文](README_zh.md)
+A comprehensive Pytest testing framework featuring configuration management, database integration, performance testing, and HTML report generation.
+
+## 📋 Features
+
+- **Modern Testing Framework**: Complete test solution built on Pytest 7.0+
+- **Configuration Management**: YAML-based config with thread-safe singleton pattern
+- **Database Integration**: Built-in MySQL support with automatic result storage
+- **HTML Reports**: Auto-generated pytest HTML test reports
+- **Tagging System**: Multi-dimensional test tags (stage, feature, platform, etc.)
+
+## 🗂️ Project Structure
+
+```
+pytest_demo/
+├── common/ # Common modules
+│ ├── __init__.py
+│ ├── config_utils.py # Configuration utilities
+│ ├── db_utils.py # Database utilities
+│ └── capture_utils # Return-value capture utilities
+├── results/ # Result storage folder
+├── suites/ # Test suites
+│ ├── UnitTest # Unit tests
+│ ├── Feature # Feature tests
+│ └── E2E/ # End-to-end tests
+│ └── test_demo_performance.py # Sample test file
+├── config.yaml # Main config file
+├── conftest.py # Pytest config
+├── pytest.ini # Pytest settings
+├── requirements.txt # Dependencies
+└── README.md # This doc (CN)
+```
+
+## 🚀 Quick Start
+
+### Prerequisites
+
+- Python 3.8+
+- MySQL 5.7+ (optional, for DB features)
+- Git
+
+### Installation
+
+1. **Install dependencies**
+ ```bash
+ pip install -r requirements.txt
+ ```
+
+2. **Configure database** (optional)
+
+ Edit `config.yaml`:
+ ```yaml
+ database:
+ backup: "results/"
+ host: "127.0.0.1"
+ port: 3306
+ name: "ucm_pytest"
+ user: "root"
+ password: "123456"
+ charset: "utf8mb4"
+ ```
+
+3. **Run tests**
+ ```bash
+ # Run all tests
+ pytest
+
+ # Run tests by tag
+ pytest --stage=1
+ pytest --feature=performance
+ ```
+
+## ⚙️ Configuration
+
+### config.yaml
+
+Full YAML-based config. Key sections:
+
+- **reports**: Report settings (HTML, timestamp, etc.)
+- **database**: MySQL connection details
+
+## 🧪 Test Examples
+
+### Basic functional test
+
+```python
+# suites/E2E/test_demo_performance.py
+import pytest
+
+@pytest.fixture(scope="module", name="calc")
+def calculator():
+ return Calculator()
+
+@pytest.mark.feature("mark")
+class TestCalculator:
+ def test_add(self, calc):
+ assert calc.add(1, 2) == 3
+
+ def test_divide_by_zero(self, calc):
+ with pytest.raises(ZeroDivisionError):
+ calc.divide(6, 0)
+```
+
+## 🏷️ Tagging System
+
+Multi-dimensional tags supported:
+
+### Stage tags
+- `stage(0)`: Unit tests
+- `stage(1)`: Smoke tests
+- `stage(2)`: Regression tests
+- `stage(3)`: Release tests
+
+### Functional tags
+- `feature`: Module tag
+- `platform`: Platform tag (GPU/NPU)
+
+### Usage
+
+```bash
+# Run smoke tests and above
+pytest --stage=1+
+
+# Run by feature
+pytest --feature=performance
+pytest --feature=performance,reliability
+
+# Run by platform
+pytest --platform=gpu
+```
+
+### HTML Reports
+
+Auto-generated timestamped HTML reports:
+- Location: `reports/pytest_YYYYMMDD_HHMMSS/report.html`
+- Detailed results, errors, timing
+- Customizable title & style
+
+### Database Storage
+
+If enabled, results are auto-saved to MySQL.
+To add new record types, ask DB admin to create tables; otherwise only local files are used.
+
+Example:
+```python
+@pytest.mark.feature("capture") # Must be top decorator
+@export_vars
+def test_capture_mix():
+ assert 1 == 1
+ return {
+ '_name': 'demo',
+ '_data': {
+ 'length': 10086, # single value
+ 'accuracy': [0.1, 0.2, 0.3], # list
+ 'loss': [0.1, 0.2, 0.3], # list
+ }
+ }
+```
+
+### Config Access
+
+Read settings easily:
+```python
+from common.config_utils import config_utils
+# Get config
+db_config = config_utils.get_config("database")
+api_config = config_utils.get_nested_config("easyPerf.api")
+```
+
+## 🛠️ Development Guide
+
+### Adding New Tests
+
+1. Create test files under `suites/` categories
+2. Apply appropriate tags
+3. Naming: `test_*.py`
+4. Use fixtures & marks for data management
+5. Keep custom marks concise and aligned with overall goals
\ No newline at end of file
diff --git a/test/README_zh.md b/test/README_zh.md
new file mode 100644
index 000000000..26b0f393a
--- /dev/null
+++ b/test/README_zh.md
@@ -0,0 +1,182 @@
+# Pytest 项目
+ Pytest 测试框架,包括配置管理、数据库集成、性能测试和 HTML 报告生成。
+
+## 📋 项目特性
+
+- **现代化测试框架**: 基于 Pytest 7.0+ 的完整测试解决方案
+- **配置管理**: 支持 YAML 配置文件,线程安全的单例模式配置管理
+- **数据库集成**: 内置 MySQL 数据库支持,自动结果存储
+- **HTML 报告**: 自动生成pytest HTML 测试报告
+- **标记系统**: 支持多维度测试标记(阶段、功能、平台等)
+
+## 🗂️ 项目结构
+
+```
+pytest_demo/
+├── common/ # 公共模块
+│ ├── __init__.py
+│ ├── config_utils.py # 配置管理工具
+│ ├── db_utils.py # 数据库工具
+│ └── capture_utils # 返回值捕获工具
+├── results/ # 结果存储目录
+├── suites/ # 测试套件
+│ ├── UnitTest # 单元测试
+│ ├── Feature # 功能测试
+│ └── E2E/ # 端到端测试
+│ └── test_demo_performance.py # 示例测试文件
+├── config.yaml # 主配置文件
+├── conftest.py # Pytest 配置文件
+├── pytest.ini # Pytest 配置
+├── requirements.txt # 项目依赖
+└── README.md # 本文档
+```
+
+## 🚀 快速开始
+
+### 环境要求
+
+- Python 3.8+
+- MySQL 5.7+ (可选,用于数据库功能)
+- Git
+
+### 安装步骤
+
+1. **安装依赖**
+ ```bash
+ pip install -r requirements.txt
+ ```
+
+2. **配置数据库**(可选)
+
+ 编辑 `config.yaml` 文件中的数据库配置:
+ ```yaml
+ database:
+ backup: "results/"
+ host: "127.0.0.1"
+ port: 3306
+ name: "ucm_pytest"
+ user: "root"
+ password: "123456"
+ charset: "utf8mb4"
+ ```
+
+3. **运行测试**
+ ```bash
+ # 运行所有测试
+ pytest
+
+ # 运行特定标记的测试
+ pytest --stage=1
+ pytest --feature=performance
+ ```
+
+## ⚙️ 配置说明
+
+
+### config.yaml 配置
+
+项目支持完整的 YAML 配置管理,主要配置项包括:
+
+- **reports**: 报告配置(HTML 报告、时间戳等)
+- **database**: 数据库连接配置
+
+## 🧪 测试示例
+
+### 基础功能测试
+
+```python
+# suites/E2E/test_demo_performance.py
+import pytest
+
+@pytest.fixture(scope="module", name="calc")
+def calculator():
+ return Calculator()
+
+@pytest.mark.feature("mark")
+class TestCalculator:
+ def test_add(self, calc):
+ assert calc.add(1, 2) == 3
+
+ def test_divide_by_zero(self, calc):
+ with pytest.raises(ZeroDivisionError):
+ calc.divide(6, 0)
+```
+
+## 🏷️ 测试标记系统
+
+项目支持多维度的测试标记:
+
+### 测试阶段标记
+- `stage(0)`: 单元测试
+- `stage(1)`: 冒烟测试
+- `stage(2)`: 回归测试
+- `stage(3)`: 发布测试
+
+### 功能标记
+- `feature`: 功能模块标记
+- `platform`: 平台标记(GPU/NPU)
+
+### 使用示例
+
+```bash
+# 运行冒烟测试及以上的所有测试
+pytest --stage=1+
+
+# 运行特定功能的测试
+pytest --feature=performance
+pytest --feature=performance, reliability
+# 运行特定平台的测试
+pytest --platform=gpu
+```
+
+
+### HTML 报告
+
+项目自动生成带时间戳的 HTML 测试报告:
+- 报告位置:`reports/pytest_YYYYMMDD_HHMMSS/report.html`
+- 包含详细的测试结果、错误信息和执行时间
+- 支持自定义报告标题和样式
+
+### 数据库存储
+
+如果启用数据库功能,测试结果会自动存储到 MySQL 数据库。
+若需要新增记录,请联系管理人员在数据库新增对应表;否则只能保存至本地文件。
+使用方式示例:
+```python
+@pytest.mark.feature("capture") # pytest 的标签必须在上面,否则无法正常使用标记功能
+@export_vars
+def test_capture_mix():
+ assert 1 == 1
+ return {
+ '_name': 'demo',
+ '_data': {
+ 'length': 10086, # single value
+ 'accuracy': [0.1, 0.2, 0.3], # list
+ 'loss': [0.1, 0.2, 0.3], # list
+ }
+ }
+
+```
+
+
+### 配置管理
+
+可以通过配置工具便捷读取参数:
+```python
+from common.config_utils import config_utils
+# 获取配置
+db_config = config_utils.get_config("database")
+api_config = config_utils.get_nested_config("easyPerf.api")
+```
+
+
+
+## 🛠️ 开发指南
+
+### 添加新测试
+
+1. 在 `suites/` 目录下的各个分类下创建新的测试文件
+2. 使用适当的测试标记
+3. 遵循命名规范:`test_*.py`
+4. 使用 fixture 及mark 进行测试数据管理
+5. 自定义 mark 标签不易过细,应当与整体功能目标相符合
\ No newline at end of file
diff --git a/test/common/__init__.py b/test/common/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/test/common/capture_utils.py b/test/common/capture_utils.py
new file mode 100644
index 000000000..b12b76637
--- /dev/null
+++ b/test/common/capture_utils.py
@@ -0,0 +1,97 @@
+import functools
+from typing import Any, Dict, List
+
+from common.db_utils import write_to_db
+
+
+def _align_and_split(name: str, data: Dict[str, Any]) -> List[Dict[str, Any]]:
+ """
+ Align a mixed data package (single values and/or lists) and split it into
+ """
+ if not data:
+ return []
+
+ aligned: Dict[str, List[Any]] = {}
+ lengths: Dict[str, int] = {}
+ for k, v in data.items():
+ if isinstance(v, (list, tuple)):
+ aligned[k] = list(v)
+ else:
+ aligned[k] = [v]
+ lengths[k] = len(aligned[k])
+
+ max_len = max(lengths.values())
+
+ for k, lst in aligned.items():
+ if len(lst) < max_len:
+ lst.extend([lst[-1]] * (max_len - len(lst)))
+
+ return [{k: aligned[k][i] for k in aligned} for i in range(max_len)]
+
+
+def post_process(table_name: str, **kwargs) -> List[Dict[str, Any]]:
+ """
+ Unified post-processing entry point. Supports two calling styles:
+ """
+ results = []
+ if "_data" in kwargs:
+ name = kwargs.get("_name", table_name)
+ results = _align_and_split(name, kwargs["_data"])
+ for result in results:
+ write_to_db(name, result)
+ return results
+ return []
+
+
+# ---------------- decorator ----------------
+def export_vars(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ result = func(*args, **kwargs)
+ # If the function returns a dict containing '_data' or 'data', post-process it
+ if isinstance(result, dict):
+ if "_data" in result or "data" in result:
+ return post_process(func.__name__, **result)
+ # Otherwise return unchanged
+ return result
+
+ return wrapper
+
+
+# ---------------- usage examples ----------------
+@export_vars
+def capture():
+ """All single values via 'name' + 'data'"""
+ return {"name": "demo", "_data": {"accuracy": 0.1, "loss": 0.3}}
+
+
+@export_vars
+def capture_list():
+ """All lists via '_name' + '_data'"""
+ return {
+ "_name": "demo",
+ "_data": {
+ "accuracy": [0.1, 0.2, 0.3],
+ "loss": [0.1, 0.2, 0.3],
+ },
+ }
+
+
+@export_vars
+def capture_mix():
+ """Mixed single + lists via '_name' + '_data'"""
+ return {
+ "_name": "demo",
+ "_data": {
+ "length": 10086, # single value
+ "accuracy": [0.1, 0.2, 0.3], # list
+ "loss": [0.1, 0.2, 0.3], # list
+ },
+ }
+
+
+# quick test
+if __name__ == "__main__":
+ print("capture(): ", capture())
+ print("capture_list(): ", capture_list())
+ print("capture_mix(): ", capture_mix())
diff --git a/test/common/config_utils.py b/test/common/config_utils.py
new file mode 100644
index 000000000..106f783ee
--- /dev/null
+++ b/test/common/config_utils.py
@@ -0,0 +1,86 @@
+import os
+import threading
+from typing import Any, Dict
+
+import yaml
+
+
+class ConfigUtils:
+ """
+ Singleton Configuration Utility
+ Provides methods to read and access YAML configuration files.
+ """
+
+ _instance = None
+ _lock = threading.Lock() # Ensure thread-safe singleton creation
+
+ def __init__(self):
+ self._config = None
+
+ def __new__(cls, config_file: str = None):
+ # Double-checked locking
+ if cls._instance is None:
+ with cls._lock:
+ if cls._instance is None:
+ instance = super().__new__(cls)
+ instance._init_config(config_file)
+ cls._instance = instance
+ return cls._instance
+
+ def _init_config(self, config_file: str = None):
+ """Initialize configuration file path and load config"""
+ if config_file is None:
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ config_file = os.path.join(current_dir, "..", "config.yaml")
+
+ self.config_file = os.path.abspath(config_file)
+ self._config = None # Lazy load
+
+ def _load_config(self) -> Dict[str, Any]:
+ """Internal method to read configuration from file"""
+ try:
+ with open(self.config_file, "r", encoding="utf-8") as f:
+ return yaml.safe_load(f) or {}
+ except FileNotFoundError:
+ print(f"[WARN] Config file not found: {self.config_file}")
+ return {}
+ except yaml.YAMLError as e:
+ print(f"[ERROR] Failed to parse YAML config: {e}")
+ return {}
+
+ def read_config(self) -> Dict[str, Any]:
+ """Read configuration file (lazy load)"""
+ if self._config is None:
+ self._config = self._load_config()
+ return self._config
+
+ def reload_config(self):
+ """Force reload configuration file"""
+ self._config = self._load_config()
+
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get top-level configuration item"""
+ config = self.read_config()
+ return config.get(key, default)
+
+ def get_nested_config(self, key_path: str, default: Any = None) -> Any:
+ """Get nested configuration, e.g., 'influxdb.host'"""
+ config = self.read_config()
+ keys = key_path.split(".")
+ value = config
+ try:
+ for k in keys:
+ value = value[k]
+ return value
+ except (KeyError, TypeError):
+ return default
+
+
+# Global instance
+config_utils = ConfigUtils()
+
+if __name__ == "__main__":
+ print("DataBase config:", config_utils.get_config("database"))
+ print(
+ "DataBase host:", config_utils.get_nested_config("database.host", "localhost")
+ )
diff --git a/test/common/db_utils.py b/test/common/db_utils.py
new file mode 100644
index 000000000..089af43b2
--- /dev/null
+++ b/test/common/db_utils.py
@@ -0,0 +1,183 @@
+import json
+import logging
+import threading
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import peewee
+from common.config_utils import config_utils as config_instance
+from peewee import AutoField, Model, MySQLDatabase, TextField
+
+logger = logging.getLogger("db_handler")
+logger.setLevel(logging.DEBUG)
+
+# Avoid adding handlers multiple times
+if not logger.handlers:
+ logger.setLevel(logging.DEBUG)
+
+# Global DB instance and lock for thread-safe singleton
+_db_instance: Optional[MySQLDatabase] = None
+_db_lock = threading.Lock()
+_test_build_id: Optional[str] = None
+_backup_path: Optional[Path] = None
+_db_enabled: bool = False # from config
+
+
+def _get_db() -> Optional[MySQLDatabase]:
+ """Return a singleton MySQLDatabase instance based on YAML configuration."""
+ global _db_instance, _backup_path, _db_enabled
+
+ if _db_instance is None:
+ with _db_lock:
+ if _db_instance is None:
+ db_config = config_instance.get_config("database", {})
+ _db_enabled = db_config.get("enabled", False)
+
+ backup_str = db_config.get("backup", "results/")
+ _backup_path = Path(backup_str).resolve()
+ _backup_path.mkdir(parents=True, exist_ok=True)
+ logger.info(f"Backup directory set to: {_backup_path}")
+
+ if not _db_enabled:
+ return None
+
+ try:
+ _db_instance = MySQLDatabase(
+ db_config.get("name", "test_db"),
+ user=db_config.get("user", "root"),
+ password=db_config.get("password", ""),
+ host=db_config.get("host", "localhost"),
+ port=db_config.get("port", 3306),
+ charset=db_config.get("charset", "utf8mb4"),
+ )
+ logger.info(
+ f"Database instance created for: {_db_instance.database}"
+ )
+ except Exception as e:
+ logger.error(f"Failed to create database instance: {e}")
+ _db_instance = None
+
+ return _db_instance
+
+
+def _set_test_build_id(build_id: Optional[str] = None) -> None:
+ """Set or generate a unique test build ID."""
+ global _test_build_id
+ _test_build_id = build_id or "default_build_id"
+ logger.debug(f"Test build ID set to: {_test_build_id}")
+
+
+def _get_test_build_id() -> str:
+ """Return the current test build ID, generating one if necessary."""
+ global _test_build_id
+ if _test_build_id is None:
+ _set_test_build_id()
+ return _test_build_id
+
+
+class BaseEntity(Model):
+ """Base PeeWee model class using the singleton database."""
+
+ class Meta:
+ database = _get_db()
+
+
+def _backup_to_file(table_name: str, data: Dict[str, Any]) -> None:
+ """Write data to a JSON Lines (.jsonl) file in the backup directory."""
+ if not _backup_path:
+ logger.warning("Backup path is not set. Skipping backup.")
+ return
+
+ file_path = _backup_path / f"{table_name}.jsonl"
+ try:
+ file_path.parent.mkdir(parents=True, exist_ok=True)
+ with file_path.open("a", encoding="utf-8") as f:
+ json.dump(data, f, ensure_ascii=False)
+ f.write("\n")
+ logger.info(f"Data backed up to {file_path}")
+ except Exception as e:
+ logger.error(f"Failed to write backup file {file_path}: {e}")
+
+
+def write_to_db(table_name: str, data: Dict[str, Any]) -> bool:
+ """
+ Attempt to insert data into the specified database table.
+ If the table doesn't exist or an error occurs, back up to a JSONL file.
+ """
+ db = _get_db()
+ data["test_build_id"] = _get_test_build_id()
+
+ # Skip DB entirely if disabled
+ if not _db_enabled or db is None:
+ _backup_to_file(table_name, data)
+ return False
+
+ try:
+ if not db.table_exists(table_name):
+ logger.warning(f"Table '{table_name}' does not exist. Writing to backup.")
+ _backup_to_file(table_name, data)
+ return False
+
+ # Get existing columns and filter data
+ columns = db.get_columns(table_name)
+ col_names = {col.name for col in columns}
+ filtered_data = {k: v for k, v in data.items() if k in col_names}
+
+ # Build dynamic model for insertion
+ fields = {"id": AutoField()}
+ for col in columns:
+ if col.name != "id":
+ fields[col.name] = TextField(null=True)
+
+ DynamicEntity = type(
+ f"{table_name.capitalize()}DynamicModel",
+ (BaseEntity,),
+ {
+ "Meta": type("Meta", (), {"database": db, "table_name": table_name}),
+ **fields,
+ },
+ )
+
+ with db.atomic():
+ DynamicEntity.insert(filtered_data).execute()
+ logger.info(f"Successfully inserted data into table '{table_name}'.")
+ return True
+
+ except peewee.PeeweeException as e:
+ logger.error(
+ f"Database write error for table '{table_name}': {e}", exc_info=True
+ )
+ except Exception as e:
+ logger.critical(
+ f"Unexpected error during DB write for '{table_name}': {e}", exc_info=True
+ )
+
+ # Fallback to backup on any failure
+ _backup_to_file(table_name, data)
+ return False
+
+
+def database_connection(build_id: str) -> None:
+ """Test database connection and set the build ID."""
+ logger.info(f"Setting test build ID: {build_id}")
+ _set_test_build_id(build_id)
+
+ db = _get_db()
+ if not _db_enabled:
+ logger.info("Database connection skipped because enabled=false.")
+ return
+
+ if db is None:
+ logger.error("No database instance available.")
+ return
+
+ logger.info(f"Attempting connection to database: {db.database}")
+ try:
+ db.connect(reuse_if_open=True)
+ logger.info("Database connection successful.")
+ except Exception as e:
+ logger.error(f"Database connection failed: {e}", exc_info=True)
+ finally:
+ if not db.is_closed():
+ db.close()
+ logger.debug("Database connection closed.")
diff --git a/test/common/llmperf/__init__.py b/test/common/llmperf/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/test/common/llmperf/run_inference.py b/test/common/llmperf/run_inference.py
new file mode 100644
index 000000000..b04deb1ea
--- /dev/null
+++ b/test/common/llmperf/run_inference.py
@@ -0,0 +1,185 @@
+import json
+import os
+import random
+from pathlib import Path
+from typing import Any, Dict, List
+
+import yaml
+from common.llmperf.utils.token_benchmark import run_token_benchmark
+from common.llmperf.utils.utils import reset_prefill_cache
+
+
+def run_test_cases(
+ llm_api,
+ model,
+ timeout,
+ max_num_completed_requests,
+ concurrent_requests,
+ mean_input_tokens,
+ stddev_input,
+ mean_output_tokens,
+ stddev_output,
+ additional_sampling_params,
+ timestamp_dir,
+ server_url,
+ tokenizer_path,
+ hit_rate,
+):
+ print(f"[INFO] Total {len(mean_input_tokens)} test cases to be executed")
+ all_summaries = []
+ failed_case = []
+
+ # Clear proxy environment variables
+ env = os.environ.copy()
+ env.pop("http_proxy", None)
+ env.pop("https_proxy", None)
+
+ for i, (
+ mean_input,
+ mean_output,
+ max_completed,
+ concurrent,
+ additional_sampling_params,
+ hit_rate_val,
+ ) in enumerate(
+ zip(
+ mean_input_tokens,
+ mean_output_tokens,
+ max_num_completed_requests,
+ concurrent_requests,
+ additional_sampling_params,
+ hit_rate,
+ ),
+ start=1,
+ ):
+ # for i, case in enumerate(mean_input_tokens):
+ print(f"\n>>> Executing test case {i} <<<")
+ reset_prefill_cache(env, server_url)
+ # Use a fixed random_seed for each test to control PC hit_rate
+ random_seed = random.randint(1, 100000)
+
+ try:
+ # Determine if two runs are needed (PC hit_rate test)
+ if hit_rate_val == 0:
+ summary = run_token_benchmark(
+ llm_api=llm_api,
+ model=model,
+ test_timeout_s=timeout,
+ max_num_completed_requests=max_completed,
+ concurrent_requests=concurrent,
+ mean_input_tokens=mean_input,
+ stddev_input_tokens=stddev_input,
+ mean_output_tokens=mean_output,
+ stddev_output_tokens=stddev_output,
+ additional_sampling_params=additional_sampling_params,
+ results_dir=str(timestamp_dir),
+ random_seed=random_seed,
+ openai_api_base=server_url + "/v1",
+ tokenizer_path=tokenizer_path,
+ user_metadata={"case_idx": i, "phase": "normal"},
+ )
+ else:
+ print(
+ f"[INFO] hit_rate > 0 detected, entering prefill mode, PC hit rate: {hit_rate_val} %"
+ )
+ # hit_rate > 0: first prefill mode
+ prefill_mean_input = int(mean_input * hit_rate_val / 100)
+ print(
+ f"[INFO] Prefill execution: mean_input_tokens={prefill_mean_input}"
+ )
+ run_token_benchmark(
+ llm_api=llm_api,
+ model=model,
+ test_timeout_s=timeout,
+ max_num_completed_requests=max_completed,
+ concurrent_requests=concurrent,
+ mean_input_tokens=prefill_mean_input,
+ stddev_input_tokens=stddev_input,
+ mean_output_tokens=2,
+ stddev_output_tokens=stddev_output,
+ additional_sampling_params=additional_sampling_params,
+ results_dir=str(timestamp_dir),
+ random_seed=random_seed,
+ openai_api_base=server_url + "/v1",
+ tokenizer_path=tokenizer_path,
+ user_metadata={"case_idx": i, "phase": "prefill"},
+ )
+ reset_prefill_cache(env, server_url)
+ # Then run normal mode
+ print("[INFO] Prefill completed, switching to normal mode execution")
+ summary = run_token_benchmark(
+ llm_api=llm_api,
+ model=model,
+ test_timeout_s=timeout,
+ max_num_completed_requests=max_completed,
+ concurrent_requests=concurrent,
+ mean_input_tokens=mean_input,
+ stddev_input_tokens=stddev_input,
+ mean_output_tokens=mean_output,
+ stddev_output_tokens=stddev_output,
+ additional_sampling_params=additional_sampling_params,
+ results_dir=str(timestamp_dir),
+ random_seed=random_seed,
+ openai_api_base=server_url + "/v1",
+ tokenizer_path=tokenizer_path,
+ user_metadata={"case_idx": i, "phase": "normal"},
+ )
+ all_summaries.append(summary)
+ except Exception as e:
+ print(f"[Warning] {e}")
+ failed_case.append(i)
+
+ return all_summaries, failed_case
+
+
+def inference_results(
+ mean_input_tokens,
+ mean_output_tokens,
+ max_num_completed_requests,
+ concurrent_requests,
+ additional_sampling_params,
+ hit_rate,
+):
+ config_file = Path(__file__).parent.parent.parent / "config.yaml"
+ print("[INFO] Initialization complete, starting main process")
+ print(f"[INFO] Reading configuration file: {config_file}")
+ with open(config_file, "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+ llm_api = config.get("llm_connection", {}).get("llm_api", "openai")
+ model = config.get("llm_connection", {}).get("model", "")
+ test_timeout_s = config.get("llm_connection", {}).get("test_timeout_s", 60000)
+ stddev_input_tokens = config.get("llm_connection", {}).get(
+ "stddev_input_tokens", 0
+ )
+ stddev_output_tokens = config.get("llm_connection", {}).get(
+ "stddev_output_tokens", 0
+ )
+ timestamp_dir = Path("results")
+ timestamp_dir.mkdir(parents=True, exist_ok=True)
+ server_url = config.get("llm_connection", {}).get("server_url", "")
+ tokenizer_path = config.get("llm_connection", {}).get("tokenizer_path", "")
+ print(f"[INFO] Created results directory: {timestamp_dir}")
+
+ all_summaries, failed_cases = run_test_cases(
+ llm_api,
+ model,
+ test_timeout_s,
+ max_num_completed_requests,
+ concurrent_requests,
+ mean_input_tokens,
+ stddev_input_tokens,
+ mean_output_tokens,
+ stddev_output_tokens,
+ additional_sampling_params,
+ timestamp_dir,
+ server_url,
+ tokenizer_path,
+ hit_rate,
+ )
+ total = len(mean_input_tokens)
+ print(
+ f"\n[INFO] All tests completed! Success: {total - len(failed_cases)}/{total}"
+ )
+ if failed_cases:
+ print(f"[WARN] Failed case indices: {failed_cases}")
+ return all_summaries
diff --git a/test/common/llmperf/utils/__init__.py b/test/common/llmperf/utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/test/common/llmperf/utils/common_metrics.py b/test/common/llmperf/utils/common_metrics.py
new file mode 100644
index 000000000..40e21124e
--- /dev/null
+++ b/test/common/llmperf/utils/common_metrics.py
@@ -0,0 +1,17 @@
+# TODO (Avnishn): compute metrics in class
+INTER_TOKEN_LAT = "inter_token_latency_s"
+TTFT = "ttft_s"
+E2E_LAT = "end_to_end_latency_s"
+NUM_INPUT_TOKENS = "number_input_tokens"
+NUM_OUTPUT_TOKENS = "number_output_tokens"
+NUM_TOTAL_TOKENS = "number_total_tokens"
+REQ_OUTPUT_THROUGHPUT = "request_output_throughput_token_per_s"
+ERROR_MSG = "error_msg"
+ERROR_CODE = "error_code"
+ERROR_CODE_FREQ = "error_code_frequency"
+NUM_ERRORS = "number_errors"
+OUTPUT_THROUGHPUT = "mean_output_throughput_token_per_s"
+NUM_COMPLETED_REQUESTS = "num_completed_requests"
+COMPLETED_REQUESTS_PER_MIN = "num_completed_requests_per_min"
+ERROR_RATE = "error_rate"
+NUM_REQ_STARTED = "num_requests_started"
diff --git a/test/common/llmperf/utils/models.py b/test/common/llmperf/utils/models.py
new file mode 100644
index 000000000..1cbab6281
--- /dev/null
+++ b/test/common/llmperf/utils/models.py
@@ -0,0 +1,23 @@
+from typing import Any, Dict, Optional, Tuple
+
+from pydantic import BaseModel
+
+
+class RequestConfig(BaseModel):
+ """The configuration for a request to the LLM API.
+
+ Args:
+ model: The model to use.
+ prompt: The prompt to provide to the LLM API.
+ sampling_params: Additional sampling parameters to send with the request.
+ For more information see the Router app's documentation for the completions
+ llm_api: The name of the LLM API to send the request to.
+ metadata: Additional metadata to attach to the request for logging or validation purposes.
+ """
+
+ model: str
+ prompt: Tuple[str, int]
+ sampling_params: Optional[Dict[str, Any]] = None
+ llm_api: Optional[str] = None
+ metadata: Optional[Dict[str, Any]] = None
+ openai_api_base: Optional[str] = ""
diff --git a/test/common/llmperf/utils/openai_chat_completions_client.py b/test/common/llmperf/utils/openai_chat_completions_client.py
new file mode 100644
index 000000000..5023bfa1f
--- /dev/null
+++ b/test/common/llmperf/utils/openai_chat_completions_client.py
@@ -0,0 +1,136 @@
+import json
+import os
+import time
+from asyncio import timeout
+from pathlib import Path
+from typing import Any, Dict, Tuple
+
+import requests
+import yaml
+from common.llmperf.utils import common_metrics
+from common.llmperf.utils.models import RequestConfig
+
+config_file = Path(__file__).parent.parent.parent.parent / "config.yaml"
+with open(config_file, "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+stream = config.get("llm_connection", {}).get("stream", True)
+ignore_eos = config.get("llm_connection", {}).get("ignore_eos", True)
+timeout = config.get("llm_connection", {}).get("timeout", 180)
+
+
+class OpenAIChatCompletionsClient:
+ """
+ used for sending HTTP requests, receiving token streams, measuring latency, etc.
+ """
+
+ def llm_request(
+ self, request_config: RequestConfig
+ ) -> Tuple[Dict[str, Any], str, RequestConfig]:
+ prompt, prompt_len = request_config.prompt
+
+ message = [
+ {"role": "user", "content": prompt},
+ ]
+ model = request_config.model
+ body = {
+ "model": model,
+ "messages": message,
+ "stream": stream,
+ "ignore_eos": ignore_eos,
+ }
+ sampling_params = request_config.sampling_params
+ body.update(sampling_params or {})
+
+ time_to_next_token = []
+ tokens_received = 0
+ ttft = 0.0
+ error_response_code = None
+ generated_text = ""
+ error_msg = ""
+ output_throughput = 0.0
+ total_request_time = 0.0
+ flag = False
+
+ metrics: Dict[str, Any] = {}
+
+ metrics[common_metrics.ERROR_CODE] = None
+ metrics[common_metrics.ERROR_MSG] = ""
+
+ start_time = time.monotonic()
+ most_recent_received_token_time = start_time
+
+ address = request_config.openai_api_base
+
+ if not address:
+ raise ValueError("the environment variable OPENAI_API_BASE must be set.")
+ key = os.environ.get("OPENAI_API_KEY", "secret_abcdefg")
+ if not key:
+ raise ValueError("the environment variable OPENAI_API_KEY must be set.")
+ headers = {"Authorization": f"Bearer {key}"}
+ if not address.endswith("/"):
+ address = address + "/"
+ address += "chat/completions"
+ try:
+ with requests.post(
+ address,
+ json=body,
+ stream=stream,
+ timeout=timeout,
+ headers=headers,
+ ) as response:
+ if response.status_code != 200:
+ error_msg = response.text
+ error_response_code = response.status_code
+ response.raise_for_status()
+
+ for chunk in response.iter_lines(chunk_size=None):
+ if not chunk:
+ continue
+ stem = b"data: "
+ if chunk.startswith(stem):
+ chunk = chunk[len(stem) :]
+ # Data might already be bytes or str
+ if isinstance(chunk, bytes):
+ chunk = chunk.decode("utf-8", errors="ignore")
+ if chunk.strip() == "[DONE]":
+ continue
+ tokens_received += 1
+ data = json.loads(chunk)
+ if "error" in data:
+ error_msg = data["error"]["message"]
+ error_response_code = data["error"]["code"]
+ raise RuntimeError(error_msg)
+ delta = data["choices"][0]["delta"]
+ content = delta.get("content", None) or delta.get(
+ "reasoning_content", ""
+ )
+ if content:
+ if tokens_received != 0 and flag == False:
+ ttft = time.monotonic() - start_time
+ flag = True
+ else:
+ time_to_next_token.append(
+ time.monotonic() - most_recent_received_token_time
+ )
+ most_recent_received_token_time = time.monotonic()
+ generated_text += content
+
+ total_request_time = time.monotonic() - start_time
+ if total_request_time > 0:
+ output_throughput = tokens_received / total_request_time
+
+ except Exception as e:
+ metrics[common_metrics.ERROR_MSG] = error_msg
+ metrics[common_metrics.ERROR_CODE] = error_response_code
+ print(f"Warning Or Error: {e}")
+ print(error_response_code)
+
+ metrics[common_metrics.INTER_TOKEN_LAT] = sum(time_to_next_token)
+ metrics[common_metrics.TTFT] = ttft
+ metrics[common_metrics.E2E_LAT] = total_request_time
+ metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = output_throughput
+ metrics[common_metrics.NUM_TOTAL_TOKENS] = tokens_received + prompt_len
+ metrics[common_metrics.NUM_OUTPUT_TOKENS] = tokens_received
+ metrics[common_metrics.NUM_INPUT_TOKENS] = prompt_len
+
+ return metrics, generated_text, request_config
diff --git a/test/common/llmperf/utils/sonnet.txt b/test/common/llmperf/utils/sonnet.txt
new file mode 100644
index 000000000..9f13ead47
--- /dev/null
+++ b/test/common/llmperf/utils/sonnet.txt
@@ -0,0 +1,84 @@
+Shall I compare thee to a summer's day?
+Thou art more lovely and more temperate:
+Rough winds do shake the darling buds of May,
+And summer's lease hath all too short a date:
+Sometime too hot the eye of heaven shines,
+And often is his gold complexion dimm'd;
+And every fair from fair sometime declines,
+By chance or nature's changing course untrimm'd;
+But thy eternal summer shall not fade
+Nor lose possession of that fair thou owest;
+Nor shall Death brag thou wander'st in his shade,
+When in eternal lines to time thou growest:
+So long as men can breathe or eyes can see,
+So long lives this and this gives life to thee.
+Then let not winter's ragged hand deface
+In thee thy summer, ere thou be distill'd:
+Make sweet some vial; treasure thou some place
+With beauty's treasure, ere it be self-kill'd.
+That use is not forbidden usury,
+Which happies those that pay the willing loan;
+That's for thyself to breed another thee,
+Or ten times happier, be it ten for one;
+Ten times thyself were happier than thou art,
+If ten of thine ten times refigured thee:
+Then what could death do, if thou shouldst depart,
+Leaving thee living in posterity?
+Be not self-will'd, for thou art much too fair
+To be death's conquest and make worms thine heir.
+Where art thou, Muse, that thou forget'st so long
+To speak of that which gives thee all thy might?
+Spend'st thou thy fury on some worthless song,
+Darkening thy power to lend base subjects light?
+Return, forgetful Muse, and straight redeem
+In gentle numbers time so idly spent;
+Sing to the ear that doth thy lays esteem
+And gives thy pen both skill and argument.
+Rise, resty Muse, my love's sweet face survey,
+If Time have any wrinkle graven there;
+If any, be a satire to decay,
+And make Time's spoils despised every where.
+Give my love fame faster than Time wastes life;
+So thou prevent'st his scythe and crooked knife.
+My glass shall not persuade me I am old,
+So long as youth and thou are of one date;
+But when in thee time's furrows I behold,
+Then look I death my days should expiate.
+For all that beauty that doth cover thee
+Is but the seemly raiment of my heart,
+Which in thy breast doth live, as thine in me:
+How can I then be elder than thou art?
+O, therefore, love, be of thyself so wary
+As I, not for myself, but for thee will;
+Bearing thy heart, which I will keep so chary
+As tender nurse her babe from faring ill.
+Presume not on thy heart when mine is slain;
+Thou gavest me thine, not to give back again.
+So am I as the rich, whose blessed key
+Can bring him to his sweet up-locked treasure,
+The which he will not every hour survey,
+For blunting the fine point of seldom pleasure.
+Therefore are feasts so solemn and so rare,
+Since, seldom coming, in the long year set,
+Like stones of worth they thinly placed are,
+Or captain jewels in the carcanet.
+So is the time that keeps you as my chest,
+Or as the wardrobe which the robe doth hide,
+To make some special instant special blest,
+By new unfolding his imprison'd pride.
+Blessed are you, whose worthiness gives scope,
+Being had, to triumph, being lack'd, to hope.
+If there be nothing new, but that which is
+Hath been before, how are our brains beguiled,
+Which, labouring for invention, bear amiss
+The second burden of a former child!
+O, that record could with a backward look,
+Even of five hundred courses of the sun,
+Show me your image in some antique book,
+Since mind at first in character was done!
+That I might see what the old world could say
+To this composed wonder of your frame;
+Whether we are mended, or whether better they,
+Or whether revolution be the same.
+O, sure I am, the wits of former days
+To subjects worse have given admiring praise.
\ No newline at end of file
diff --git a/test/common/llmperf/utils/token_benchmark.py b/test/common/llmperf/utils/token_benchmark.py
new file mode 100644
index 000000000..67553cf1b
--- /dev/null
+++ b/test/common/llmperf/utils/token_benchmark.py
@@ -0,0 +1,386 @@
+import json
+import logging
+import random
+import re
+import time
+from collections.abc import Iterable
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import pandas as pd
+from common.llmperf.utils import common_metrics
+from common.llmperf.utils.models import RequestConfig
+from common.llmperf.utils.openai_chat_completions_client import (
+ OpenAIChatCompletionsClient,
+)
+from common.llmperf.utils.utils import (
+ LLMPerfResults,
+ randomly_sample_sonnet_lines_prompt,
+ sample_random_positive_int,
+)
+from transformers import AutoTokenizer
+
+
+def get_token_throughput_latencies(
+ model: str,
+ mean_input_tokens: int,
+ stddev_input_tokens: int,
+ mean_output_tokens: int,
+ stddev_output_tokens: int,
+ additional_sampling_params: Optional[Dict[str, Any]] = None,
+ concurrent_requests: int = 1,
+ max_num_completed_requests: int = 500,
+ test_timeout_s=90,
+ llm_api="openai",
+ random_seed: int = None,
+ openai_api_base: str = "",
+ tokenizer_path: str = None,
+) -> Tuple[Dict[str, Any], List[Dict[str, Any]], float, float]:
+ """Get the token throughput and latencies for the given model.
+
+ Args:
+ model: The name of the model to query.
+ mean_input_tokens: The mean number of tokens to send in the prompt for the request.
+ stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request.
+ mean_output_tokens: The mean number of tokens to generate per request.
+ stddev_output_tokens: The standard deviation of the number of tokens to generate per request.
+ additional_sampling_params: Additional sampling parameters to send with the request.
+ For more information see the LLM APIs documentation for the completions
+ concurrent_requests: The number of concurrent requests to make. Increase
+ this to increase the amount of load and vice versa.
+ test_timeout_s: The amount of time to run the test for before reporting results.
+ llm_api: The name of the llm api to use. Either "openai" or "litellm".
+
+ Returns:
+ A summary of the performance metrics collected across all completed requests
+ (e.g. throughput, latencies, etc.)
+ The individual metrics for each request.
+ """
+ random.seed(random_seed)
+
+ print(f"Using tokenizer:{tokenizer_path}")
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ get_token_length = lambda text: len(tokenizer.encode(text))
+
+ if not additional_sampling_params:
+ additional_sampling_params = {}
+
+ # 1. create prompts
+ prompts: List[Tuple[str, int]] = []
+ num_output_tokens_list: List[int] = []
+ for i in range(max_num_completed_requests):
+ num_output = sample_random_positive_int(
+ mean_output_tokens, stddev_output_tokens
+ )
+ num_output_tokens_list.append(num_output)
+ prompts.append(
+ randomly_sample_sonnet_lines_prompt(
+ prompt_tokens_mean=mean_input_tokens,
+ prompt_tokens_stddev=stddev_input_tokens,
+ tokenizer=tokenizer,
+ )
+ )
+ start_time = time.monotonic()
+ completed_requests: List[Dict[str, Any]] = []
+ incremental_time_delay = 0.0
+ client = OpenAIChatCompletionsClient()
+ futures = []
+
+ # 2. Submitting tasks using a thread pool
+ with ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
+ for idx in range(max_num_completed_requests):
+ sampling = {"max_tokens": num_output_tokens_list[idx]}
+ sampling.update(additional_sampling_params)
+ cfg = RequestConfig(
+ model=model,
+ prompt=prompts[idx],
+ sampling_params=sampling,
+ llm_api=llm_api,
+ openai_api_base=openai_api_base,
+ )
+ futures.append(executor.submit(client.llm_request, cfg))
+ # 3. Waiting for completion or timeout
+ for future in as_completed(futures, timeout=test_timeout_s):
+ try:
+ metrics, gen_text, req_cfg = future.result()
+ except Exception as e:
+ logging.warning(f"[WARN] Future raised exception: {e}")
+ continue
+ num_output_tokens = get_token_length(gen_text)
+ if num_output_tokens:
+ metrics[common_metrics.INTER_TOKEN_LAT] /= (
+ (metrics[common_metrics.NUM_OUTPUT_TOKENS] - 1)
+ if (metrics[common_metrics.NUM_OUTPUT_TOKENS] - 1)
+ else 1
+ )
+ metrics[common_metrics.NUM_OUTPUT_TOKENS] = num_output_tokens
+ metrics[common_metrics.NUM_TOTAL_TOKENS] = (
+ metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens
+ )
+ try:
+ metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = (
+ num_output_tokens / metrics[common_metrics.E2E_LAT]
+ )
+ except ZeroDivisionError:
+ logging.error("Division by zero in throughput calculation.")
+
+ completed_requests.append(metrics)
+
+ incremental_time_delay += metrics.get(
+ common_metrics.INTER_TOKEN_LAT, 0.0
+ )
+
+ end_time = time.monotonic()
+
+ print(f"Results for token benchmark for {model} queried with the {llm_api} api.\n")
+ if mean_output_tokens == 2:
+ print(f"[INFO] First token sending pre-embedding completed\n")
+ return {}, [], 0.0, 0.0
+
+ ret = metrics_summary(completed_requests, start_time, end_time)
+
+ metadata = {
+ "model": model,
+ "mean_input_tokens": mean_input_tokens,
+ "stddev_input_tokens": stddev_input_tokens,
+ "mean_output_tokens": mean_output_tokens,
+ "stddev_output_tokens": stddev_output_tokens,
+ "concurrent_requests": concurrent_requests,
+ "additional_sampling_params": additional_sampling_params,
+ }
+
+ metadata["results"] = ret
+ elapsed_time = end_time - start_time
+ return metadata, completed_requests, elapsed_time, incremental_time_delay
+
+
+def compute_throughput(
+ summary: Dict[str, Any],
+ completed_requests: List[Dict[str, Any]],
+ elapsed_time: float,
+ incremental_time_delay: float,
+) -> Tuple[float, float]:
+ """
+ Compute total_throughput (token/s) based on the metrics in summary.
+
+ Formula: (mean_output_tokens * num_completed_requests) / total_e2e_latency_s
+
+ Args:
+ summary (Dict[str, Any]): A dictionary containing performance metrics.
+
+ Returns:
+ float: The computed total throughput in tokens per second. Returns 0.0 if latency is zero.
+ """
+ mean_output_tokens = summary.get("mean_output_tokens", 0)
+
+ total_throughput = (
+ (mean_output_tokens * len(completed_requests)) / elapsed_time
+ if elapsed_time > 0
+ else 0.0
+ )
+ incremental_throughput = (
+ (mean_output_tokens * len(completed_requests)) / incremental_time_delay
+ if incremental_time_delay > 0
+ else 0.0
+ )
+ return round(total_throughput, 4), round(incremental_throughput, 4)
+
+
+def metrics_summary(
+ metrics: List[Dict[str, Any]], start_time: int, end_time: int
+) -> Dict[str, Any]:
+ """Generate a summary over metrics generated from potentially multiple instances of this client.
+
+ Args:
+ metrics: The metrics to summarize.
+ start_time: The time the test started.
+ end_time: The time the test ended.
+
+ Returns:
+ A summary with the following information:
+ - Overall throughput (generated tokens / total test time)
+ - Number of completed requests
+ - Error rate
+ - Error code frequency
+ - Quantiles (p25-p99) for the following metrics:
+ - Inter token latency
+ - Time to first token
+ - User total request time
+ - Number of tokens processed per request
+ - Number of tokens generated per request
+ - User throughput (tokens / s)
+ """
+ ret = {}
+
+ def flatten(item):
+ for sub_item in item:
+ if isinstance(sub_item, Iterable) and not isinstance(sub_item, str):
+ yield from flatten(sub_item)
+ else:
+ yield sub_item
+
+ df = pd.DataFrame(metrics)
+ df_without_errored_req = df[df[common_metrics.ERROR_CODE].isna()]
+
+ for key in [
+ common_metrics.INTER_TOKEN_LAT,
+ common_metrics.TTFT,
+ common_metrics.E2E_LAT,
+ common_metrics.REQ_OUTPUT_THROUGHPUT,
+ common_metrics.NUM_INPUT_TOKENS,
+ common_metrics.NUM_OUTPUT_TOKENS,
+ ]:
+ print(key)
+ ret[key] = {}
+ series = pd.Series(list(flatten(df_without_errored_req[key]))).dropna()
+ series = series[series > 0] # Calculate non-zero values
+ quantiles = series.quantile([0.25, 0.5, 0.75, 0.9, 0.95, 0.99]).to_dict()
+ quantiles_reformatted_keys = {}
+ for quantile, value in quantiles.items():
+ reformatted_key = f"p{int(quantile * 100)}"
+ print(f" {reformatted_key} = {value}")
+ quantiles_reformatted_keys[reformatted_key] = value
+ ret[key]["quantiles"] = quantiles_reformatted_keys
+ mean = series.mean()
+ print(f" mean = {mean}")
+ ret[key]["mean"] = mean
+ print(f" min = {series.min()}")
+ ret[key]["min"] = series.min()
+ print(f" max = {series.max()}")
+ ret[key]["max"] = series.max()
+ print(f" stddev = {series.std()}")
+ ret[key]["stddev"] = series.std()
+
+ ret[common_metrics.NUM_REQ_STARTED] = len(metrics)
+
+ error_codes = df[common_metrics.ERROR_CODE].dropna()
+ num_errors = len(error_codes)
+ ret[common_metrics.ERROR_RATE] = num_errors / len(metrics) if len(metrics) else 0
+ ret[common_metrics.NUM_ERRORS] = num_errors
+ print(f"Number Of Errored Requests: {num_errors}")
+ error_code_frequency = dict(error_codes.value_counts())
+ if num_errors:
+ error_code_frequency = dict(error_codes.value_counts())
+ print("Error Code Frequency")
+ print(error_code_frequency)
+ ret[common_metrics.ERROR_CODE_FREQ] = str(error_code_frequency)
+
+ overall_output_throughput = df_without_errored_req[
+ common_metrics.NUM_OUTPUT_TOKENS
+ ].sum() / (end_time - start_time)
+
+ print(f"Overall Output Throughput: {overall_output_throughput}")
+ ret[common_metrics.OUTPUT_THROUGHPUT] = overall_output_throughput
+
+ num_completed_requests = len(df_without_errored_req)
+ num_completed_requests_per_min = (
+ num_completed_requests / (end_time - start_time) * 60
+ )
+ print(f"Number Of Completed Requests: {num_completed_requests}")
+ print(f"Completed Requests Per Minute: {num_completed_requests_per_min}")
+
+ ret[common_metrics.NUM_COMPLETED_REQUESTS] = num_completed_requests
+ ret[common_metrics.COMPLETED_REQUESTS_PER_MIN] = num_completed_requests_per_min
+
+ return ret
+
+
+def run_token_benchmark(
+ llm_api: str,
+ model: str,
+ test_timeout_s: int,
+ max_num_completed_requests: int,
+ concurrent_requests: int,
+ mean_input_tokens: int,
+ stddev_input_tokens: int,
+ mean_output_tokens: int,
+ stddev_output_tokens: int,
+ additional_sampling_params: str,
+ results_dir: str,
+ random_seed: int,
+ openai_api_base: str,
+ tokenizer_path: str,
+ user_metadata: Dict[str, Any],
+):
+ """
+ Args:
+ llm_api: The name of the llm api to use.
+ model: The name of the model to query.
+ max_num_completed_requests: The number of requests to complete before finishing the test.
+ test_timeout_s: The amount of time to run the test for before reporting results.
+ concurrent_requests: The number of concurrent requests to make. Increase
+ this to increase the amount of load and vice versa.
+ mean_input_tokens: The mean number of tokens to send in the prompt for the request.
+ stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request.
+ mean_output_tokens: The mean number of tokens to generate per request.
+ stddev_output_tokens: The standard deviation of the number of tokens to generate per request.
+ additional_sampling_params: Additional sampling parameters to send with the request.
+ For more information see the LLM APIs documentation for the completions.
+ results_dir: The directory to save the results to.
+ user_metadata: Additional metadata to include in the results.
+ """
+ if mean_input_tokens < 40:
+ print(
+ "the minimum number of input tokens that will be sent is 41"
+ " because of the prompting logic right now"
+ )
+
+ summary, completed_requests, elapsed_time, incremental_time_delay = (
+ get_token_throughput_latencies(
+ model=model,
+ llm_api=llm_api,
+ test_timeout_s=test_timeout_s,
+ max_num_completed_requests=max_num_completed_requests,
+ mean_input_tokens=mean_input_tokens,
+ stddev_input_tokens=stddev_input_tokens,
+ mean_output_tokens=mean_output_tokens,
+ stddev_output_tokens=stddev_output_tokens,
+ concurrent_requests=concurrent_requests,
+ additional_sampling_params=json.loads(additional_sampling_params),
+ random_seed=random_seed,
+ openai_api_base=openai_api_base,
+ tokenizer_path=tokenizer_path,
+ )
+ )
+ if mean_output_tokens == 2:
+ return summary, completed_requests, elapsed_time, incremental_time_delay
+
+ timestamp = int(time.time() * 1000)
+ if results_dir:
+ filename = f"{model}_{mean_input_tokens}_{mean_output_tokens}_{timestamp}"
+ filename = re.sub(r"[^\w\d-]+", "-", filename)
+ filename = re.sub(r"-{2,}", "-", filename)
+ summary_filename = f"{filename}_summary"
+
+ # Update to metadata.
+ summary.update(user_metadata)
+ total_tp, req_tp = compute_throughput(
+ summary, completed_requests, elapsed_time, incremental_time_delay
+ )
+ summary["num_completed_requests"] = len(completed_requests)
+ summary["elapsed_time"] = elapsed_time
+ summary["incremental_time_delay"] = incremental_time_delay
+ summary["total_throughput"] = total_tp
+ summary["incremental_throughput"] = req_tp
+
+ results = LLMPerfResults(name=summary_filename, metadata=summary)
+ results_dir = Path(results_dir)
+ if not results_dir.exists():
+ results_dir.mkdir(parents=True)
+ elif not results_dir.is_dir():
+ raise ValueError(f"{results_dir} is not a directory")
+
+ llmperf_dir = results_dir / "llmperf"
+ if not llmperf_dir.exists():
+ llmperf_dir.mkdir(parents=True)
+ elif not llmperf_dir.is_dir():
+ raise ValueError(f"{llmperf_dir} is not a directory")
+
+ try:
+ with open(llmperf_dir / f"{summary_filename}.json", "w") as f:
+ json.dump(results.to_dict(), f, indent=4, default=str)
+ except Exception as e:
+ print(results.to_dict())
+ raise e
+ return summary
diff --git a/test/common/llmperf/utils/utils.py b/test/common/llmperf/utils/utils.py
new file mode 100644
index 000000000..e2c270871
--- /dev/null
+++ b/test/common/llmperf/utils/utils.py
@@ -0,0 +1,171 @@
+import hashlib
+import json
+import math
+import os
+import pathlib
+import random
+import subprocess
+import time
+from typing import Any, Dict, Tuple
+
+from transformers import LlamaTokenizerFast
+
+RESULTS_VERSION = "2025-10-30"
+
+
+class LLMPerfResults:
+ def __init__(
+ self,
+ name: str,
+ metadata: Dict[str, Any] = None,
+ ):
+ self.name = name
+ self.metadata = metadata or {}
+ self.timestamp = int(time.time())
+ self.metadata["timestamp"] = self.timestamp
+ self.version = RESULTS_VERSION
+
+ def to_dict(self):
+ data = {
+ "version": self.version,
+ "name": self.name,
+ }
+ data.update(self.metadata)
+ data = flatten_dict(data)
+ return data
+
+ def json(self):
+ data = self.to_dict()
+ return json.dumps(data)
+
+
+def upload_to_s3(results_path: str, s3_path: str) -> None:
+ """Upload the results to s3.
+
+ Args:
+ results_path: The path to the results file.
+ s3_path: The s3 path to upload the results to.
+
+ """
+
+ command = ["aws", "s3", "sync", results_path, f"{s3_path}/"]
+ result = subprocess.run(command)
+ if result.returncode == 0:
+ print("Files uploaded successfully!")
+ else:
+ print("An error occurred:")
+ print(result.stderr)
+
+
+def randomly_sample_sonnet_lines_prompt(
+ prompt_tokens_mean: int = 550,
+ prompt_tokens_stddev: int = 250,
+ tokenizer: LlamaTokenizerFast = None,
+) -> Tuple[str, int]:
+ """Generate a prompt that randomly samples lines from a the shakespeare sonnet at sonnet.txt.
+
+ Args:
+ prompt_length_mean: The mean length of the prompt to generate.
+ prompt_len_stddev: The standard deviation of the length of the prompt to generate.
+ expect_output_tokens: The number of tokens to expect in the output. This is used to
+ determine the length of the prompt. The prompt will be generated such that the output
+ will be approximately this many tokens.
+
+ Note:
+ tokens will be counted from the sonnet using the Llama tokenizer. Using one tokenizer
+ ensures a fairer comparison across different LLMs. For example, if gpt 3.5 tokenizes
+ a prompt in less tokens than Llama2, then this will be reflected in the results since
+ they will be fed identical prompts.
+
+ Returns:
+ A tuple of the prompt and the length of the prompt.
+ """
+ get_token_length = lambda text: len(tokenizer.encode(text))
+
+ prompt = (
+ "Randomly stream lines from the following text "
+ "Don't generate eos tokens:\n\n"
+ )
+ # get a prompt length that is at least as long as the base
+ num_prompt_tokens = sample_random_positive_int(
+ prompt_tokens_mean, prompt_tokens_stddev
+ )
+ while num_prompt_tokens < get_token_length(prompt):
+ num_prompt_tokens = sample_random_positive_int(
+ prompt_tokens_mean, prompt_tokens_stddev
+ )
+ remaining_prompt_tokens = num_prompt_tokens - get_token_length(prompt)
+ sonnet_path = pathlib.Path(__file__).parent.resolve() / "sonnet.txt"
+ with open(sonnet_path, "r") as f:
+ sonnet_lines = f.readlines()
+ random.shuffle(sonnet_lines)
+ sampling_lines = True
+ while sampling_lines:
+ for line in sonnet_lines:
+ line_to_add = line
+ if remaining_prompt_tokens - get_token_length(line_to_add) < 0:
+ # This will cut off a line in the middle of a word, but that's ok since an
+ # llm should be able to handle that.
+ line_to_add = line_to_add[: int(math.ceil(remaining_prompt_tokens))]
+ sampling_lines = False
+ prompt += line_to_add
+ break
+ prompt += line_to_add
+ remaining_prompt_tokens -= get_token_length(line_to_add)
+ print(hashlib.sha256(prompt.encode("utf-8")).hexdigest())
+ return (prompt, num_prompt_tokens)
+
+
+def sample_random_positive_int(mean: int, stddev: int) -> int:
+ """Sample random numbers from a gaussian distribution until a positive number is sampled.
+
+ Args:
+ mean: The mean of the gaussian distribution to sample from.
+ stddev: The standard deviation of the gaussian distribution to sample from.
+
+ Returns:
+ A random positive integer sampled from the gaussian distribution.
+ """
+ ret = -1
+ while ret <= 0:
+ ret = int(random.gauss(mean, stddev))
+ return ret
+
+
+def flatten_dict(d, parent_key="", sep="_"):
+ items = []
+ for k, v in d.items():
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
+ if isinstance(v, dict):
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+def reset_prefill_cache(env, server_url):
+ """
+ prefix cache / HBM
+ Param:
+ env
+ server_url
+ """
+ reset_url = f"{server_url}/reset_prefix_cache"
+ print(f"[INFO] Resetting prefix cache: {reset_url}")
+ try:
+ result = subprocess.run(
+ ["curl", "-X", "POST", reset_url, "-s", "-f"],
+ env=env,
+ check=False,
+ capture_output=True,
+ text=True,
+ timeout=10,
+ )
+ if result.returncode == 0:
+ print("[INFO] Prefix cache successfully reset")
+ else:
+ print(
+ f"[ERROR] Unsuccessfully reset prefix cache,error code: {result.returncode}"
+ )
+ except Exception as e:
+ print(f"[ERROR] Exception in resetting prefix cache: {e}")
diff --git a/test/config.yaml b/test/config.yaml
new file mode 100644
index 000000000..7ac32f484
--- /dev/null
+++ b/test/config.yaml
@@ -0,0 +1,27 @@
+reports:
+ base_dir: "results/reports"
+ use_timestamp: true
+ directory_prefix: "pytest"
+ html: # pytest-html
+ enabled: true
+ filename: "report.html"
+ title: "UCM Pytest Test Report"
+
+database:
+ backup: "results/"
+ enabled: true
+ host: "127.0.0.1"
+ port: 3306
+ name: "ucm_pytest"
+ user: "root"
+ password: "123456"
+ charset: "utf8mb4"
+
+# LLM Connection Configuration
+llm_connection:
+ model: "qwen3"
+ server_url: "http://141.111.32.70:9382"
+ tokenizer_path: "/home/models/QwQ-32B"
+ stream: true # stream output
+ ignore_eos: true # Ignore the returned terminator
+ timeout: 180 # request time out
\ No newline at end of file
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 000000000..150257952
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,159 @@
+from __future__ import annotations
+
+import datetime as dt
+import platform as pf
+import sys
+from functools import wraps
+from pathlib import Path
+
+import pytest
+from common.config_utils import config_utils as config_instance
+from common.db_utils import database_connection, write_to_db
+
+# ---------------- Constants ----------------
+PRJ_ROOT = Path(__file__).resolve().parent
+sys.path.insert(0, str(PRJ_ROOT))
+
+
+# ---------------- CLI Options ----------------
+def pytest_addoption(parser):
+ parser.addoption(
+ "--stage", action="store", default="", help="Filter by stage marker (1,2,3,+)"
+ )
+ parser.addoption(
+ "--feature", action="store", default="", help="Filter by feature marker"
+ )
+ parser.addoption(
+ "--platform", action="store", default="", help="Filter by platform marker"
+ )
+
+
+# ---------------- Test Filtering ----------------
+def pytest_collection_modifyitems(config, items):
+ kept = items[:]
+
+ markers = [m.split(":", 1)[0].strip() for m in config.getini("markers")]
+ for name in markers:
+ opt = config.getoption(f"--{name}", "").strip()
+ if not opt:
+ continue
+
+ if name == "stage" and opt.endswith("+"):
+ min_stage = int(opt[:-1])
+ kept = [
+ it
+ for it in kept
+ if any(int(v) >= min_stage for v in _get_marker_args(it, "stage"))
+ ]
+ else:
+ wanted = {x.strip() for x in opt.split(",") if x.strip()}
+ kept = [
+ it
+ for it in kept
+ if any(v in wanted for v in _get_marker_args(it, name))
+ ]
+
+ config.hook.pytest_deselected(items=[i for i in items if i not in kept])
+ items[:] = kept
+
+
+def _get_marker_args(item, marker_name):
+ """Extract only args (not kwargs) from markers, as strings."""
+ return [
+ str(arg) for mark in item.iter_markers(name=marker_name) for arg in mark.args
+ ]
+
+
+# ---------------- Report Setup ----------------
+def _prepare_report_dir(config: pytest.Config) -> Path:
+ cfg = config_instance.get_config("reports", {})
+ base_dir = Path(cfg.get("base_dir", "reports"))
+ prefix = cfg.get("directory_prefix", "pytest")
+ if cfg.get("use_timestamp", False):
+ ts = dt.datetime.now().strftime("%Y%m%d_%H%M%S")
+ report_dir = base_dir / f"{prefix}_{ts}"
+ else:
+ report_dir = base_dir
+ report_dir.mkdir(parents=True, exist_ok=True)
+ return report_dir
+
+
+def _setup_html_report(config: pytest.Config, report_dir: Path) -> None:
+ reports_config = config_instance.get_config("reports", {})
+ html_cfg = reports_config.get("html", {})
+ if not html_cfg.get("enabled", True):
+ if hasattr(config.option, "htmlpath"):
+ config.option.htmlpath = None
+ print("HTML report disabled according to config.yaml")
+ return
+
+ html_filename = html_cfg.get("filename", "report.html")
+ config.option.htmlpath = str(report_dir / html_filename)
+ config.option.self_contained_html = True
+ print("HTML report enabled")
+
+
+# ---------------- Build ID & Session Init ----------------
+def _generate_build_id(config: pytest.Config) -> str:
+ ts = dt.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
+ cli_parts = []
+ markers = [m.split(":", 1)[0].strip() for m in config.getini("markers")]
+ for opt in markers:
+ val = config.getoption(opt, "")
+ if val:
+ cli_parts.append(f"{opt}={val}")
+ args_part = "_".join(cli_parts) if cli_parts else "all_cases"
+ return f"pytest_{ts}_{args_part}"
+
+
+# ---------------- Pytest Hooks ----------------
+def pytest_configure(config: pytest.Config) -> None:
+ """The global configuration will be executed directly upon entering pytest."""
+ print(f"Starting Test Session: {dt.datetime.now():%Y-%m-%d %H:%M:%S}")
+
+ # Set up report directory
+ report_dir = _prepare_report_dir(config)
+ config._report_dir = report_dir # Attach to config for later use
+ _setup_html_report(config, report_dir)
+
+ # Generate and register build ID into DB
+ build_id = _generate_build_id(config)
+ config._build_id = build_id
+ database_connection(build_id)
+
+
+def pytest_sessionstart(session):
+ print("")
+ print("-" * 60)
+ print(f"{'Python':<10} │ {pf.python_version()}")
+ print(f"{'Platform':<10} │ {pf.system()} {pf.release()}")
+ print("-" * 60)
+
+
+def pytest_sessionfinish(session, exitstatus):
+ report_dir = getattr(session.config, "_report_dir", "reports")
+ print("")
+ print("-" * 60)
+ print(f"{'Reports at':<10} │ {report_dir}")
+ print("Test session ended")
+ print("-" * 60)
+
+
+# ---------------- Fixtures ----------------
+
+
+def pytest_runtest_logreport(report):
+ """
+ Called after each test phase. We only care about 'call' (the actual test).
+ """
+ if report.when != "call":
+ return
+
+ status = report.outcome.upper() # 'passed', 'failed', 'skipped' → 'PASSED', etc.
+ test_result = {
+ "test_case": report.nodeid,
+ "status": status,
+ # "duration": report.duration,
+ "error": str(report.longrepr) if report.failed else None,
+ }
+ write_to_db("test_case_info", test_result)
diff --git a/test/pytest.ini b/test/pytest.ini
new file mode 100644
index 000000000..4be3cf477
--- /dev/null
+++ b/test/pytest.ini
@@ -0,0 +1,25 @@
+[pytest]
+testpaths = suites
+python_files = test_*.py
+python_classes = Test*
+python_functions = test_*
+
+addopts =
+ -ra
+ --strict-markers
+ --capture=no
+filterwarnings =
+ ignore::pytest.PytestReturnNotNoneWarning
+
+log_cli = 1
+log_cli_level = INFO
+log_cli_format = [%(levelname)s] %(name)s: %(message)s
+norecursedirs = .git venv env __pycache__ *.egg
+
+markers =
+ # -------- Levels (Required) --------
+ stage(n): Unit/Smoke/Regression/Release (0=Unit 1=Smoke 2=Regression 3=Release)
+ # -------- Features (Recommended) --------
+ feature: Feature tag
+ platform(name): Platform tag(gpu/npu)
+# end of markers
\ No newline at end of file
diff --git a/test/requirements.txt b/test/requirements.txt
new file mode 100644
index 000000000..07635b247
--- /dev/null
+++ b/test/requirements.txt
@@ -0,0 +1,6 @@
+pytest>=7.0.0
+pytest-html>=3.1.1
+PyYAML>=6.0
+# MySQL
+peewee>=3.14.5
+pymysql>=1.0.2
\ No newline at end of file
diff --git a/test/suites/E2E/test_uc_performance.py b/test/suites/E2E/test_uc_performance.py
new file mode 100644
index 000000000..dbec0318b
--- /dev/null
+++ b/test/suites/E2E/test_uc_performance.py
@@ -0,0 +1,158 @@
+import pytest
+from common.capture_utils import export_vars
+from common.llmperf.run_inference import inference_results
+
+
+@pytest.mark.parametrize("mean_input_tokens", [[2000, 3000]])
+@pytest.mark.parametrize("mean_output_tokens", [[200, 500]])
+@pytest.mark.parametrize("max_num_completed_requests", [[8, 4]])
+@pytest.mark.parametrize("concurrent_requests", [[8, 4]])
+@pytest.mark.parametrize("additional_sampling_params", [["{}", "{}"]])
+@pytest.mark.parametrize("hit_rate", [[0, 50]])
+@pytest.mark.feature("uc_performance_test")
+@export_vars
+def test_performance(
+ mean_input_tokens,
+ mean_output_tokens,
+ max_num_completed_requests,
+ concurrent_requests,
+ additional_sampling_params,
+ hit_rate,
+):
+ all_summaries = inference_results(
+ mean_input_tokens,
+ mean_output_tokens,
+ max_num_completed_requests,
+ concurrent_requests,
+ additional_sampling_params,
+ hit_rate,
+ )
+ failed_cases = []
+
+ value_lists = {
+ "mean_input_tokens": [],
+ "mean_output_tokens": [],
+ "results_inter_token_latency_s_quantiles_p50": [],
+ "results_inter_token_latency_s_quantiles_p90": [],
+ "results_inter_token_latency_s_quantiles_p99": [],
+ "results_inter_token_latency_s_mean": [],
+ "results_ttft_s_quantiles_p50": [],
+ "results_ttft_s_quantiles_p90": [],
+ "results_ttft_s_quantiles_p99": [],
+ "results_ttft_s_mean": [],
+ "results_end_to_end_latency_s_quantiles_p50": [],
+ "results_end_to_end_latency_s_quantiles_p90": [],
+ "results_end_to_end_latency_s_quantiles_p99": [],
+ "results_end_to_end_latency_s_mean": [],
+ "num_completed_requests": [],
+ "elapsed_time": [],
+ "incremental_time_delay": [],
+ "total_throughput": [],
+ "incremental_throughput": [],
+ }
+
+ for i, summary in enumerate(all_summaries):
+ mean_input_tokens = summary["mean_input_tokens"]
+ mean_output_tokens = summary["mean_output_tokens"]
+
+ results_inter_token_latency_s_quantiles_p50 = summary["results"][
+ "inter_token_latency_s"
+ ]["quantiles"]["p50"]
+ results_inter_token_latency_s_quantiles_p90 = summary["results"][
+ "inter_token_latency_s"
+ ]["quantiles"]["p90"]
+ results_inter_token_latency_s_quantiles_p99 = summary["results"][
+ "inter_token_latency_s"
+ ]["quantiles"]["p99"]
+ results_inter_token_latency_s_mean = summary["results"][
+ "inter_token_latency_s"
+ ]["mean"]
+
+ results_ttft_s_quantiles_p50 = summary["results"]["ttft_s"]["quantiles"]["p50"]
+ results_ttft_s_quantiles_p90 = summary["results"]["ttft_s"]["quantiles"]["p90"]
+ results_ttft_s_quantiles_p99 = summary["results"]["ttft_s"]["quantiles"]["p99"]
+ results_ttft_s_mean = summary["results"]["ttft_s"]["mean"]
+
+ results_end_to_end_latency_s_quantiles_p50 = summary["results"][
+ "end_to_end_latency_s"
+ ]["quantiles"]["p50"]
+ results_end_to_end_latency_s_quantiles_p90 = summary["results"][
+ "end_to_end_latency_s"
+ ]["quantiles"]["p90"]
+ results_end_to_end_latency_s_quantiles_p99 = summary["results"][
+ "end_to_end_latency_s"
+ ]["quantiles"]["p99"]
+ results_end_to_end_latency_s_mean = summary["results"]["end_to_end_latency_s"][
+ "mean"
+ ]
+
+ num_completed_requests = summary["num_completed_requests"]
+ elapsed_time = summary["elapsed_time"]
+ incremental_time_delay = summary["incremental_time_delay"]
+ total_throughput = summary["total_throughput"]
+ incremental_throughput = summary["incremental_throughput"]
+
+ values = [
+ mean_input_tokens,
+ mean_output_tokens,
+ results_inter_token_latency_s_quantiles_p50,
+ results_inter_token_latency_s_quantiles_p90,
+ results_inter_token_latency_s_quantiles_p99,
+ results_inter_token_latency_s_mean,
+ results_ttft_s_quantiles_p50,
+ results_ttft_s_quantiles_p90,
+ results_ttft_s_quantiles_p99,
+ results_ttft_s_mean,
+ results_end_to_end_latency_s_quantiles_p50,
+ results_end_to_end_latency_s_quantiles_p90,
+ results_end_to_end_latency_s_quantiles_p99,
+ results_end_to_end_latency_s_mean,
+ num_completed_requests,
+ elapsed_time,
+ incremental_time_delay,
+ total_throughput,
+ incremental_throughput,
+ ]
+
+ for var_name, val in zip(
+ [
+ "mean_input_tokens",
+ "mean_output_tokens",
+ "results_inter_token_latency_s_quantiles_p50",
+ "results_inter_token_latency_s_quantiles_p90",
+ "results_inter_token_latency_s_quantiles_p99",
+ "results_inter_token_latency_s_mean",
+ "results_ttft_s_quantiles_p50",
+ "results_ttft_s_quantiles_p90",
+ "results_ttft_s_quantiles_p99",
+ "results_ttft_s_mean",
+ "results_end_to_end_latency_s_quantiles_p50",
+ "results_end_to_end_latency_s_quantiles_p90",
+ "results_end_to_end_latency_s_quantiles_p99",
+ "results_end_to_end_latency_s_mean",
+ "num_completed_requests",
+ "elapsed_time",
+ "incremental_time_delay",
+ "total_throughput",
+ "incremental_throughput",
+ ],
+ values,
+ ):
+ value_lists[var_name].append(val)
+ if val is None:
+ failed_cases.append((i, var_name, "missing"))
+
+ try:
+ assert val > 0, f"value <= 0"
+ except AssertionError as e:
+ failed_cases.append((i, var_name, str(e)))
+
+ # Output final result
+ if failed_cases:
+ print(f"\n[WARNING] Assertion failed: {len(failed_cases)} abnormal cases found")
+ for i, key, reason in failed_cases:
+ print(f" Iteration={i + 1}, key='{key}' -> {reason}")
+ else:
+ print("\n[INFO] All values are greater than 0. Assertion passed!")
+
+ return {"_name": "llmperf", "_data": value_lists}
diff --git a/test/test_uc_connector.py b/test/test_uc_connector.py
index d4a0caeb7..0c2261d87 100644
--- a/test/test_uc_connector.py
+++ b/test/test_uc_connector.py
@@ -25,6 +25,7 @@
import random
import secrets
import unittest
+from collections import defaultdict
from typing import List, Union
from unittest.mock import MagicMock, Mock, patch
@@ -106,12 +107,14 @@ def init_uc(
ucconnector.dump_tasks: dict[str, dict[str, List[Task]]] = {}
ucconnector.total_tp_size = self.total_tp_size
ucconnector._connector_metadata = metadata
- ucconnector.layerwise_load_tasks: dict[
- str, dict[str, tuple[Task, Task]]
- ] = {}
+ ucconnector.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict(
+ dict
+ )
ucconnector._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
ucconnector._load_failed_reqs: set[str] = set()
ucconnector._load_req_to_blocks: dict[str, set[int]] = {}
+ ucconnector.num_layers = 48
+ ucconnector.is_mla = False
return ucconnector
def test_get_num_new_matched_tokens_hit_all_on_storage(self):
@@ -508,6 +511,7 @@ def test_wait_for_save_not_layerwise_invalid_para(self):
ucconnector.block_size = self.block_size
ucconnector.use_layerwise = False
ucconnector._connector_metadata = Mock()
+ ucconnector.is_mla = False
with self.assertRaises(AssertionError):
ucconnector.wait_for_save()
@@ -542,6 +546,7 @@ def mock_wait(task: Task) -> int:
)
forward_context = Mock()
ucconnector.start_load_kv(forward_context)
+ assert mock_connector.load.call_count == 1
def test_start_load_kv_invalid_para(self):
with patch.object(UnifiedCacheConnectorV1, "__init__", return_value=None):
@@ -559,6 +564,7 @@ def test_start_load_kv_layerwise_success(self):
req_meta1.load_blocks = [
(secrets.token_hex(8), i) for i in range(self.block_number)
]
+ req_meta1.load_async = False
metadata = UCConnectorV1Metadata()
metadata.requests = [req_meta1]
@@ -575,7 +581,7 @@ def mock_load(
ucconnector = self.init_uc(mock_connector, metadata=metadata)
forward_context = Mock()
ucconnector.start_load_kv(forward_context)
- assert mock_connector.load.call_count == 2 * self.num_layers
+ assert mock_connector.load.call_count == self.num_layers
if __name__ == "__main__":
diff --git a/test/test_ucm_connector_save_load.py b/test/test_ucm_connector_save_load.py
new file mode 100644
index 000000000..c0def663f
--- /dev/null
+++ b/test/test_ucm_connector_save_load.py
@@ -0,0 +1,599 @@
+# -*- coding: utf-8 -*-
+#
+# MIT License
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+"""
+Standalone bandwidth test for `UCMConnector`.
+
+This script instantiates the connector exactly like runtime code (no mocks) and
+benchmarks `wait_for_save` (dump) and `start_load_kv` (load).
+"""
+
+import csv
+import math
+import multiprocessing
+import os
+import secrets
+import time
+import traceback
+from dataclasses import dataclass
+from typing import Dict, List, Tuple, Union
+from unittest.mock import patch
+
+import torch
+from vllm.config import (
+ CacheConfig,
+ DeviceConfig,
+ KVTransferConfig,
+ ModelConfig,
+ ParallelConfig,
+ VllmConfig,
+)
+from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
+
+from ucm.integration.vllm.ucm_connector import (
+ RequestDispatchMeta,
+ UCMConnector,
+ UCMConnectorMetadata,
+)
+from ucm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def make_aligned_tensor(shape, dtype, device, alignment=4096):
+ numel = math.prod(shape)
+ dtype_size = torch.tensor(1, dtype=dtype).element_size()
+ total_bytes = numel * dtype_size
+
+ padded_bytes = total_bytes + alignment
+ storage = torch.ByteTensor(padded_bytes).to(device)
+
+ ptr = storage.data_ptr()
+ offset = ptr % alignment
+ aligned_ptr = ptr + (alignment - offset) if offset != 0 else ptr
+
+ aligned_storage = storage[(aligned_ptr - ptr) :].view(dtype)
+ tensor = aligned_storage[:numel].view(shape)
+ tensor.storage_ref = storage
+ return tensor
+
+
+def make_buffers(
+ block_number: int,
+ device_id: int,
+ batch_size: int,
+ head_dim: int,
+ block_len: int,
+ block_layer: int,
+ num_head: int,
+ kv: int,
+ is_mla: bool,
+) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+ logger.info(f"Allocating buffers: blocks={block_number}, batch_size={batch_size}")
+ hashes = [secrets.token_hex(16) for _ in range(block_number)]
+ device = f"cuda:{device_id}"
+ kv_caches: Dict[str, torch.Tensor] = {}
+
+ for layer in range(block_layer):
+ layer_name = f"layer.{layer}"
+ if is_mla:
+ kv_caches[layer_name] = make_aligned_tensor(
+ [block_number, block_len, head_dim],
+ dtype=torch.float16,
+ device=device,
+ )
+ else:
+ kv_caches[layer_name] = make_aligned_tensor(
+ [kv, block_number, block_len, num_head, head_dim],
+ dtype=torch.float16,
+ device=device,
+ )
+ return hashes, kv_caches
+
+
+def build_vllm_config(
+ *,
+ model_path: str,
+ block_size: int,
+ num_layers: int,
+ num_head: int,
+ head_size: int,
+ is_mla: bool,
+ tp_size: int,
+ connector_name: str,
+ storage_backends: str,
+ transfer_stream_number: int,
+ use_direct: bool,
+) -> VllmConfig:
+ cache_config = CacheConfig(
+ block_size=block_size,
+ gpu_memory_utilization=0.9,
+ swap_space=4,
+ cache_dtype="auto",
+ )
+
+ # This ensures connector uses test parameters for head_size, num_head, num_layers
+ hf_overrides = {
+ "head_dim": head_size, # Override head_size for get_head_size()
+ "num_key_value_heads": num_head, # Override num_head for get_num_kv_heads()
+ "num_hidden_layers": num_layers, # Override num_layers for get_num_layers()
+ }
+ if is_mla:
+ # head_dim = kv_lora_rank + qk_rope_head_dim (typically 512 + 64 = 576)
+ # For testing purposes, we set kv_lora_rank = head_size - 64
+ kv_lora_rank = head_size - 64 # qk_rope_head_dim = 64
+ hf_overrides.update(
+ {
+ "model_type": "deepseek_v3",
+ "kv_lora_rank": kv_lora_rank,
+ "qk_rope_head_dim": 64,
+ }
+ )
+
+ model_config = ModelConfig(
+ model=model_path,
+ tokenizer=None,
+ tokenizer_mode="auto",
+ trust_remote_code=False,
+ dtype="float16",
+ seed=0,
+ max_model_len=8192,
+ max_context_len_to_capture=8192,
+ max_logprobs=20,
+ disable_sliding_window=False,
+ skip_tokenizer_init=True,
+ limit_mm_per_prompt={},
+ use_async_output_proc=True,
+ override_neuron_config={},
+ config_format="auto",
+ is_deepseek_mla=is_mla,
+ hf_overrides=hf_overrides,
+ )
+
+ parallel_config = ParallelConfig(
+ pipeline_parallel_size=1,
+ tensor_parallel_size=tp_size,
+ worker_use_ray=False,
+ )
+
+ device = "cuda" if torch.cuda.is_available() else "npu"
+ device_config = DeviceConfig(device=device)
+
+ kv_transfer_config = KVTransferConfig(
+ kv_connector="UCMConnector",
+ kv_role="kv_both",
+ kv_connector_extra_config={
+ "ucm_connectors": [
+ {
+ "ucm_connector_name": connector_name,
+ "ucm_connector_config": {
+ "storage_backends": storage_backends,
+ "use_direct": use_direct,
+ "stream_number": transfer_stream_number,
+ "local_rank_size": 1,
+ },
+ }
+ ]
+ },
+ )
+
+ return VllmConfig(
+ model_config=model_config,
+ cache_config=cache_config,
+ parallel_config=parallel_config,
+ device_config=device_config,
+ kv_transfer_config=kv_transfer_config,
+ )
+
+
+@dataclass
+class DummyLayer:
+ kv_cache: Union[Dict[int, torch.Tensor], List[torch.Tensor]]
+
+
+@dataclass
+class DummyForwardContext:
+ no_compile_layers: Dict[str, DummyLayer]
+ virtual_engine: int = 0
+
+
+def build_forward_context(
+ kv_caches: Dict[str, torch.Tensor], is_mla: bool
+) -> DummyForwardContext:
+ layers = {}
+ for layer_name, tensor in kv_caches.items():
+ layers[layer_name] = DummyLayer(kv_cache={0: tensor})
+ return DummyForwardContext(no_compile_layers=layers, virtual_engine=0)
+
+
+def compute_total_bytes(
+ kv_caches: Dict[str, torch.Tensor], batch_size: int, is_mla: bool
+) -> int:
+ total = 0
+ for tensor in kv_caches.values():
+ if is_mla:
+ total += tensor[:batch_size].numel() * tensor.element_size()
+ else:
+ total += tensor[:, :batch_size].numel() * tensor.element_size()
+ return total
+
+
+def run_once(
+ connector: UCMConnector,
+ kv_caches: Dict[str, torch.Tensor],
+ hashes: List[str],
+ batch_size: int,
+ is_mla: bool,
+) -> Tuple[Tuple[float, float, float], Tuple[float, float, float]]:
+ dump_hashes = hashes[:batch_size]
+
+ metadata = UCMConnectorMetadata()
+ dump_vllm_block_ids = list(range(batch_size))
+ metadata.request_meta["uc_test_write"] = RequestDispatchMeta(
+ load_block_ids=([], []),
+ dump_block_ids=(dump_hashes, dump_vllm_block_ids),
+ )
+ connector.connector.kv_caches = kv_caches
+ connector.bind_connector_metadata(metadata)
+
+ total_bytes = compute_total_bytes(kv_caches, batch_size, is_mla)
+
+ start = time.perf_counter()
+ connector.wait_for_save()
+ write_time = time.perf_counter() - start
+
+ time.sleep(1)
+
+ write_bw = (total_bytes / (1024**3)) / write_time if write_time > 0 else 0.0
+
+ lookup = connector.connector.store.lookup(dump_hashes)
+ if not all(lookup):
+ raise RuntimeError("Found missing cache blocks before load test.")
+
+ load_metadata = UCMConnectorMetadata()
+ load_vllm_block_ids = list(range(batch_size))
+ load_metadata.request_meta["uc_test_read"] = RequestDispatchMeta(
+ load_block_ids=(dump_hashes, load_vllm_block_ids),
+ dump_block_ids=([], []),
+ )
+ connector.connector.kv_caches = kv_caches
+ connector.bind_connector_metadata(load_metadata)
+
+ forward_context = build_forward_context(kv_caches, is_mla)
+
+ start = time.perf_counter()
+ connector.start_load_kv(forward_context)
+ read_time = time.perf_counter() - start
+
+ read_bw = (total_bytes / (1024**3)) / read_time if read_time > 0 else 0.0
+
+ logger.info(
+ f"Size: {total_bytes / (1024**3):.4f} GB, Time: {write_time:.4f}s, WRITE SPEED: {write_bw:.4f} GB/s "
+ )
+ logger.info(
+ f"Size: {total_bytes / (1024**3):.4f} GB, Time: {read_time:.4f}s, READ SPEED: {read_bw:.4f} GB/s"
+ )
+
+ return (
+ (total_bytes / (1024**3), write_time, write_bw),
+ (total_bytes / (1024**3), read_time, read_bw),
+ )
+
+
+def run_test(
+ storage_backends: str,
+ device_id: int,
+ repeat: int,
+ num_head: int,
+ block_len: int,
+ num_tokens: int,
+ block_layer: int,
+ head_size: int,
+ block_elem_size: int,
+ kv: int,
+ mla: bool,
+ ucm_connector_name: str,
+ total_tp_size: int,
+ model_path: str,
+ transfer_stream_number: int,
+ use_direct: bool,
+) -> Tuple[float, float, float, float, float, float]:
+ block_dim = head_size * num_head
+ io_size = block_dim * block_len * block_elem_size
+ block_size = io_size * block_layer
+ batch_size = int(num_tokens / block_len)
+ real_blocks = batch_size * repeat + 10
+
+ vllm_config = build_vllm_config(
+ model_path=model_path,
+ block_size=block_len,
+ num_layers=block_layer,
+ num_head=num_head,
+ head_size=head_size,
+ is_mla=mla,
+ tp_size=total_tp_size,
+ connector_name=ucm_connector_name,
+ storage_backends=storage_backends,
+ transfer_stream_number=transfer_stream_number,
+ use_direct=use_direct,
+ )
+
+ dummy_world_group = type("DummyWorldGroup", (), {"local_rank": 0})()
+
+ class DummyTPGroup:
+ def broadcast(self, tensor, src):
+ pass
+
+ dummy_tp_group = DummyTPGroup()
+
+ patches = [
+ patch(
+ "ucm.integration.vllm.ucm_connector.get_world_group",
+ return_value=dummy_world_group,
+ ),
+ patch(
+ "ucm.integration.vllm.ucm_connector.get_tp_group",
+ return_value=dummy_tp_group,
+ ),
+ ]
+
+ with patches[0], patches[1]:
+ connector = UCMConnector(vllm_config, KVConnectorRole.WORKER)
+ connector.connector.rank = device_id if device_id >= 0 else 0
+ connector.connector.kv_caches = {}
+
+ hashes, kv_caches = make_buffers(
+ real_blocks,
+ device_id,
+ batch_size,
+ head_size,
+ block_len,
+ block_layer,
+ num_head,
+ kv,
+ mla,
+ )
+
+ w_sizes, w_times, w_bws = [], [], []
+ r_sizes, r_times, r_bws = [], [], []
+
+ for round_idx in range(repeat):
+ logger.info(f"Round {round_idx + 1}: start write test")
+ start_hash_idx = round_idx * batch_size
+ end_hash_idx = start_hash_idx + batch_size
+ round_hashes = hashes[start_hash_idx:end_hash_idx]
+
+ if len(round_hashes) < batch_size:
+ round_hashes = [secrets.token_hex(16) for _ in range(batch_size)]
+
+ (w_size, w_time, w_bw), (r_size, r_time, r_bw) = run_once(
+ connector, kv_caches, round_hashes, batch_size, mla
+ )
+
+ if round_idx != 0:
+ w_sizes.append(w_size)
+ w_times.append(w_time)
+ w_bws.append(w_bw)
+ r_sizes.append(r_size)
+ r_times.append(r_time)
+ r_bws.append(r_bw)
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ elif hasattr(torch, "npu") and torch.npu.is_available():
+ torch.npu.empty_cache()
+
+ def avg(values: List[float]) -> float:
+ return sum(values) / len(values) if values else 0.0
+
+ avg_w_size = avg(w_sizes)
+ avg_w_time = avg(w_times)
+ avg_w_bw = avg(w_bws)
+ avg_r_size = avg(r_sizes)
+ avg_r_time = avg(r_times)
+ avg_r_bw = avg(r_bws)
+
+ logger.info(
+ "\n=== Summary ===\n"
+ f"Write : size={avg_w_size:.4f} GB | time={avg_w_time:.4f} s | bw={avg_w_bw:.4f} GB/s\n"
+ f"Read : size={avg_r_size:.4f} GB | time={avg_r_time:.4f} s | bw={avg_r_bw:.4f} GB/s\n"
+ )
+
+ return avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size
+
+
+def run_wrapper(result_queue, *args):
+ try:
+ result = run_test(*args)
+ result_queue.put(("success", result))
+ except Exception as e:
+ result_queue.put(("error", traceback.format_exc()))
+
+
+def get_user_input(prompt, default=None):
+ if default is not None:
+ user_input = input(f"{prompt} (default: {default}): ").strip()
+ return user_input if user_input else default
+ else:
+ return input(f"{prompt}: ").strip()
+
+
+def main():
+
+ try:
+ multiprocessing.set_start_method("spawn", force=True)
+ except RuntimeError:
+ pass
+
+ storage_backends = "."
+ device_id = 0
+ repeat = 3
+ num_tokens_list = [2048, 4096, 8192, 16384, 32768]
+ ucm_connector_name = "UcmNfsStore"
+ model_path = "/home/models/QwQ-32B"
+ transfer_stream_numbers = [32, 64, 128]
+ os.environ["UC_LOGGER_LEVEL"] = "debug"
+
+ print("1. Model Selection:")
+ print(" 1 - QwQ-32B")
+ print(" 2 - deepseek-v3")
+ model_choice = get_user_input("Please select model", "1")
+ mla = True if model_choice == "2" else False
+ print("\n2. IoDirect Transfer:")
+ print(" 1 - Disable IoDirect (default)")
+ print(" 2 - Enable IoDirect")
+ use_direct = get_user_input("Please select Direct IO mode", "1")
+ use_direct = False if use_direct == "1" else True
+
+ if mla:
+ block_lens = [64]
+ block_layer = 61
+ head_size = 576
+ block_elem_size = 2
+ kv = 1
+ model_name = "deepseek-v3"
+ num_head_list = [1]
+ total_tp_size = 1
+ else:
+ block_lens = [128, 256]
+ block_layer = 64
+ head_size = 128
+ block_elem_size = 2
+ kv = 2
+ model_name = "QwQ-32B"
+ num_head_list = [1, 2, 4, 8]
+ total_tp_size = 1
+
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+ csv_file = os.path.join(SCRIPT_DIR, "save_load_result.csv")
+ need_header = not os.path.exists(csv_file)
+
+ with open(csv_file, "a", newline="", encoding="utf-8") as csv_fp:
+ writer = csv.writer(csv_fp)
+
+ if need_header:
+ writer.writerow(
+ [
+ "Model",
+ "Sequence Length",
+ "Batch Size",
+ "Layers",
+ "Element Size",
+ "KV",
+ "Num Head",
+ "Block Size",
+ "Stream Number",
+ "IO Count",
+ "IO Size(B)",
+ "Total Size(GB)",
+ "Write Avg Time(s)",
+ "Write Avg Bandwidth(GB/s)",
+ "Read Avg Time(s)",
+ "Read Avg Bandwidth(GB/s)",
+ ]
+ )
+
+ for num_head in num_head_list:
+ for block_len in block_lens:
+ for transfer_stream_number in transfer_stream_numbers:
+ block_dim = head_size * num_head
+ io_size = block_dim * block_len * block_elem_size
+
+ for num_tokens in num_tokens_list:
+ sep = "=" * 60
+ print(
+ f"\n{sep}\n= num_head={num_head} | num_tokens={num_tokens:>6} | Repeat {repeat} times =\n{sep}\n"
+ )
+
+ batch_size = int(num_tokens / block_len)
+ io_count = batch_size * block_layer
+
+ result_queue = multiprocessing.Queue()
+
+ process = multiprocessing.Process(
+ target=run_wrapper,
+ args=(
+ result_queue,
+ storage_backends,
+ device_id,
+ repeat,
+ num_head,
+ block_len,
+ num_tokens,
+ block_layer,
+ head_size,
+ block_elem_size,
+ kv,
+ mla,
+ ucm_connector_name,
+ total_tp_size,
+ model_path,
+ transfer_stream_number,
+ use_direct,
+ ),
+ )
+
+ process.start()
+ process.join()
+
+ status, result = result_queue.get()
+ if status == "error":
+ raise Exception(f"Error in subprocess: {result}")
+
+ (
+ avg_w_size,
+ avg_w_time,
+ avg_w_bw,
+ avg_r_time,
+ avg_r_bw,
+ avg_r_size,
+ ) = result
+
+ writer.writerow(
+ [
+ model_name,
+ num_tokens,
+ batch_size,
+ block_layer,
+ block_elem_size,
+ kv,
+ num_head,
+ block_len,
+ transfer_stream_number,
+ io_count,
+ io_size,
+ f"{avg_w_size:.4f}",
+ f"{avg_w_time:.4f}",
+ f"{avg_w_bw:.4f}",
+ f"{avg_r_time:.4f}",
+ f"{avg_r_bw:.4f}",
+ ]
+ )
+
+ csv_fp.flush()
+
+ print("\n" + "=" * 60 + "\n= All combinations tested =\n" + "=" * 60 + "\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/test_ucm_dram.py b/test/test_ucm_dram.py
deleted file mode 100644
index 020405d13..000000000
--- a/test/test_ucm_dram.py
+++ /dev/null
@@ -1,250 +0,0 @@
-#
-# MIT License
-#
-# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-#
-
-import random
-import unittest
-import unittest.mock as mock
-from contextlib import contextmanager
-from typing import List
-from unittest.mock import MagicMock
-
-import torch
-from vllm.multimodal.inputs import MultiModalKwargs
-from vllm.sampling_params import SamplingParams
-from vllm.utils import sha256
-from vllm.v1.core.kv_cache_utils import hash_request_tokens
-from vllm.v1.request import Request
-
-
-@contextmanager
-def mock_stream_context(stream=None):
- yield
-
-
-class MockStream:
- def __init__(self, device=None):
- self.device = device or torch.device("cpu")
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- pass
-
- def synchronize(self):
- pass
-
- def record_event(self, event=None):
- return event or MockEvent()
-
- def wait_stream(self, stream):
- pass
-
-
-class MockEvent:
- def __init__(self, enable_timing=False):
- self.enable_timing = enable_timing
-
- def record(self, stream=None):
- pass
-
- def wait(self, stream=None):
- pass
-
- def synchronize(self):
- pass
-
-
-def patch_cuda_for_cpu():
- mock.patch("torch.cuda.Stream", MockStream).start()
- mock.patch("torch.cuda.Event", MockEvent).start()
- mock.patch("torch.cuda.current_stream", return_value=MockStream()).start()
- mock.patch("torch.cuda.synchronize", side_effect=lambda *a, **k: None).start()
- mock.patch("torch.cuda.is_available", return_value=True).start()
- mock.patch("torch.cuda.stream", mock_stream_context).start()
-
-
-patch_cuda_for_cpu()
-from ucm.store.dramstore.dramstore_connector import ( # isort: skip
- DramTask,
- UcmDramStore,
-)
-
-
-def make_request(
- request_id, prompt_token_ids, mm_positions=None, mm_hashes=None, cache_salt=None
-):
- if mm_positions is None:
- multi_modal_inputs = None
- else:
- multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
-
- return Request(
- request_id=request_id,
- prompt_token_ids=prompt_token_ids,
- multi_modal_inputs=multi_modal_inputs,
- multi_modal_hashes=mm_hashes,
- multi_modal_placeholders=mm_positions,
- sampling_params=SamplingParams(max_tokens=17),
- pooling_params=None,
- eos_token_id=100,
- arrival_time=0,
- lora_request=None,
- cache_salt=cache_salt,
- )
-
-
-class TestUcmDram(unittest.TestCase):
-
- @classmethod
- def setUpClass(cls):
- print("===> Before all tests (setUpClass)")
-
- @classmethod
- def tearDownClass(cls):
- print("===> After all tests (setUpClass)")
-
- def setUp(self):
- self.config = {"block_size": 4}
- self.scheduler_config = {
- "role": "scheduler",
- "max_cache_size": 1073741824,
- "kv_block_size": 262144,
- }
- self.worker_config = {
- "role": "worker",
- "max_cache_size": 1073741824,
- "kv_block_size": 262144,
- }
-
- self.block_number = 4
- self.block_size = int(self.config["block_size"])
- self.scheduler_dram = UcmDramStore(self.scheduler_config)
- self.worker_dram = UcmDramStore(self.worker_config)
- random.seed(20250728)
- self.request = make_request(
- request_id=1,
- prompt_token_ids=random.sample(
- range(0, 10000), self.block_number * self.block_size
- ),
- mm_positions=None,
- mm_hashes=None,
- )
- block_hash_types = hash_request_tokens(sha256, self.block_size, self.request)
- self.block_hashes: List[str] = [str(x.hash_value) for x in block_hash_types]
-
- def test_look_up_all_hit(self):
- """
- Test for all blocks hitten in cache
- """
- expected = [True] * len(self.block_hashes)
- self.scheduler_dram.cached_blocks.update(self.block_hashes)
- actual = self.scheduler_dram.lookup(self.block_hashes)
-
- self.assertEqual(actual, expected)
-
- def test_lookup_partial_hit(self):
- """
- Test for part of the blocks hitten in cache
- """
- partial_index = random.randint(0, 4)
- partial_hashes = self.block_hashes[:partial_index]
- self.scheduler_dram.cached_blocks.update(partial_hashes)
- actual = self.scheduler_dram.lookup(self.block_hashes)
- expected = [True] * partial_index + [False] * (self.block_size - partial_index)
- self.assertEqual(actual, expected)
-
- def test_lookup_none_hit(self):
- """
- Test for none of the blocks hitten in cache
- """
- actual = self.scheduler_dram.lookup(self.block_hashes)
- expected = [False] * len(self.block_hashes)
- self.assertEqual(actual, expected)
-
- def test_load_success(self):
- """
- Test for load from cache successfully
- """
- src_tensors = [
- torch.randint(0, 100, (self.block_size,), dtype=torch.int8)
- for _ in range(len(self.block_hashes))
- ]
- offsets = [i for i in range(len(self.block_hashes))]
- dump_task = self.worker_dram.dump(self.block_hashes, offsets, src_tensors)
- self.worker_dram.wait(dump_task)
- dst_tensors = [
- torch.zeros(self.block_size, dtype=torch.int8)
- for _ in range(len(self.block_hashes))
- ]
- load_task = self.worker_dram.load(self.block_hashes, offsets, dst_tensors)
-
- self.assertIsInstance(load_task, DramTask)
- self.assertIsNotNone(load_task.event)
- for i, (src_tensor, dst_tensor) in enumerate(zip(src_tensors, dst_tensors)):
- self.assertEqual(dst_tensor.shape[0], self.block_size)
- self.assertTrue(
- torch.equal(src_tensor, dst_tensor),
- f"Block {i} loaded data is different",
- )
-
- def test_dump_success(self):
- """
- Test data dump successfully
- """
- src_tensors = [
- torch.randint(0, 100, (self.block_size,), dtype=torch.int8)
- for _ in range(len(self.block_hashes))
- ]
- offsets = [i for i in range(len(self.block_hashes))]
- original_data = [tensor.clone() for tensor in src_tensors]
- dump_task = self.worker_dram.dump(self.block_hashes, offsets, src_tensors)
- self.assertIsInstance(dump_task, DramTask)
- self.assertIsNotNone(dump_task.event)
- self.worker_dram.wait(dump_task)
- for i, block_id in enumerate(self.block_hashes):
- key = block_id + "_" + str(offsets[i])
- cached_data = self.worker_dram.dram_cache[key]
- self.assertEqual(cached_data.shape[0], self.block_size)
- self.assertTrue(torch.equal(cached_data, original_data[i]))
-
- def test_wait_success(self):
- """
- Test wait for task successfully
- """
- task = DramTask()
- task.event = MagicMock()
- result = self.worker_dram.wait(task)
- self.assertEqual(result, 0)
- task.event.synchronize.assert_called_once()
-
- def test_wait_failure(self):
- task = DramTask()
- task.event = None
- result = self.worker_dram.wait(task)
- self.assertEqual(result, -1)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/ucm/CMakeLists.txt b/ucm/CMakeLists.txt
index af0712f16..0d4579d5f 100644
--- a/ucm/CMakeLists.txt
+++ b/ucm/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(shared)
if(BUILD_UCM_STORE)
add_subdirectory(store)
endif()
diff --git a/ucm/__init__.py b/ucm/__init__.py
index e69de29bb..8052a3998 100644
--- a/ucm/__init__.py
+++ b/ucm/__init__.py
@@ -0,0 +1,17 @@
+from ucm.integration.vllm.uc_connector import UnifiedCacheConnectorV1
+from ucm.integration.vllm.ucm_connector import UCMConnector
+
+try:
+ from ucm.integration.vllm.patch.apply_patch import ensure_patches_applied
+
+ ensure_patches_applied()
+except Exception as e:
+ # Don't fail if patches can't be applied - might be running in environment without vLLM
+ import warnings
+
+ warnings.warn(
+ f"Failed to apply vLLM patches: {e}. "
+ f"If you're using vLLM, ensure it's installed and patches are compatible."
+ )
+
+__all__ = ["UCMConnector"]
diff --git a/ucm/integration/vllm/patch/0.9.1/vllm-adapt.patch b/ucm/integration/vllm/patch/0.9.1/vllm-adapt.patch
index f644f7276..3bffd1922 100644
--- a/ucm/integration/vllm/patch/0.9.1/vllm-adapt.patch
+++ b/ucm/integration/vllm/patch/0.9.1/vllm-adapt.patch
@@ -1,22 +1,22 @@
-From b837d3d46e593c946f5de70bdff178fa2bff882b Mon Sep 17 00:00:00 2001
-From: root
-Date: Mon, 15 Sep 2025 22:07:21 -0700
-Subject: [PATCH] 0.9.1-patch
+From 76751cae43498d693a7a6dd2c8ec4b2d40672385 Mon Sep 17 00:00:00 2001
+From: zhou-haitao <1300182097@qq.com>
+Date: Tue, 21 Oct 2025 03:31:16 -0700
+Subject: [PATCH] Add commit
---
.../kv_transfer/kv_connector/utils.py | 113 +++++++++++++++
.../kv_transfer/kv_connector/v1/base.py | 8 ++
.../v1/shared_storage_connector.py | 7 +-
vllm/v1/core/block_pool.py | 2 +-
- vllm/v1/core/sched/scheduler.py | 129 ++++++++++++++++++
+ vllm/v1/core/sched/scheduler.py | 132 ++++++++++++++++++
vllm/v1/core/single_type_kv_cache_manager.py | 2 +
vllm/v1/executor/multiproc_executor.py | 37 ++++-
vllm/v1/outputs.py | 5 +
vllm/v1/request.py | 1 +
vllm/v1/worker/gpu_input_batch.py | 9 ++
vllm/v1/worker/gpu_model_runner.py | 52 ++++++-
- vllm/v1/worker/gpu_worker.py | 23 +++-
- 12 files changed, 366 insertions(+), 22 deletions(-)
+ vllm/v1/worker/gpu_worker.py | 23 ++-
+ 12 files changed, 369 insertions(+), 22 deletions(-)
diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
index b9bed06d7..de062cfb3 100644
@@ -211,7 +211,7 @@ index d21f94727..1800665c7 100644
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(block_hashes) >= num_cached_blocks
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
-index 3d7bbe7e0..1ef81e960 100644
+index 3d7bbe7e0..b6d4a340a 100644
--- a/vllm/v1/core/sched/scheduler.py
+++ b/vllm/v1/core/sched/scheduler.py
@@ -707,16 +707,28 @@ class Scheduler(SchedulerInterface):
@@ -243,16 +243,19 @@ index 3d7bbe7e0..1ef81e960 100644
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0:
# The request was not scheduled in this step.
-@@ -761,6 +773,8 @@ class Scheduler(SchedulerInterface):
+@@ -761,6 +773,11 @@ class Scheduler(SchedulerInterface):
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
+ if model_runner_output.finished_dumping is not None:
+ request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, []))
++ is_prefill = request.num_output_tokens == 0
++ if is_prefill:
++ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
-@@ -824,6 +838,8 @@ class Scheduler(SchedulerInterface):
+@@ -824,6 +841,8 @@ class Scheduler(SchedulerInterface):
if not stopped:
new_running.append(request)
@@ -261,7 +264,7 @@ index 3d7bbe7e0..1ef81e960 100644
# KV Connector: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output)
-@@ -1042,3 +1058,116 @@ class Scheduler(SchedulerInterface):
+@@ -1042,3 +1061,116 @@ class Scheduler(SchedulerInterface):
for req_id in (model_runner_output.finished_sending or ()):
logger.debug("Finished sending KV transfer for request %s", req_id)
self._free_blocks(self.requests[req_id])
@@ -707,4 +710,5 @@ index b7d244f27..263a916d2 100644
def profile(self, is_start: bool = True):
if self.profiler is None:
--
-2.34.1
\ No newline at end of file
+2.34.1
+
diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-aggre.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-aggre.patch
new file mode 100644
index 000000000..5f8df381c
--- /dev/null
+++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-aggre.patch
@@ -0,0 +1,753 @@
+From 6e2c814bb3b3a74ca56149b44d6a0b2017b91136 Mon Sep 17 00:00:00 2001
+From: harrisonyhq
+Date: Tue, 4 Nov 2025 23:32:10 -0800
+Subject: [PATCH 2/3] [Patch1] Patch for load failure and aggregate
+
+---
+ .../kv_transfer/kv_connector/utils.py | 113 +++++++++++
+ .../kv_transfer/kv_connector/v1/base.py | 9 +
+ .../kv_connector/v1/multi_connector.py | 6 +
+ vllm/v1/core/block_pool.py | 2 +-
+ vllm/v1/core/sched/output.py | 2 +
+ vllm/v1/core/sched/scheduler.py | 184 ++++++++++++++++--
+ vllm/v1/core/single_type_kv_cache_manager.py | 3 +
+ vllm/v1/executor/multiproc_executor.py | 30 ++-
+ vllm/v1/outputs.py | 6 +-
+ vllm/v1/worker/gpu_input_batch.py | 14 ++
+ vllm/v1/worker/gpu_model_runner.py | 28 ++-
+ vllm/v1/worker/gpu_worker.py | 23 ++-
+ 12 files changed, 397 insertions(+), 23 deletions(-)
+
+diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
+index 5cbc8ca31..b63bf5965 100644
+--- a/vllm/distributed/kv_transfer/kv_connector/utils.py
++++ b/vllm/distributed/kv_transfer/kv_connector/utils.py
+@@ -5,10 +5,16 @@ KV cache helper for store.
+ """
+ import torch
+
++from collections import defaultdict
++from collections.abc import Sequence
++from concurrent.futures import CancelledError, Future
++from typing import Optional, cast
++
+ import vllm.envs as envs
+ from vllm import _custom_ops as ops
+ from vllm.config import VllmConfig, get_current_vllm_config
+ from vllm.logger import init_logger
++from vllm.v1.outputs import ModelRunnerOutput
+
+ logger = init_logger(__name__)
+
+@@ -107,3 +113,110 @@ def get_kv_connector_cache_layout():
+ "layout to HND for better xfer performance.")
+ return "HND"
+ return "NHD"
++
++
++class KVOutputAggregator:
++ """Utility class to aggregate the output of all workers into a single
++ output corresponding to Rank 0 for scheduler."""
++
++ def __init__(self, world_size: int):
++ # Complete transfer tracker. Used by to track finished requests
++ # [req_id -> n_finished_workers]
++ self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
++ self._send_remaining_count = defaultdict[str, int](lambda: world_size)
++ self._dump_remaining_count = defaultdict[str, int](lambda: world_size)
++
++ def aggregate(self,
++ outputs: list[ModelRunnerOutput],
++ output_rank: int = 0) -> ModelRunnerOutput:
++ # aggregate finished_sending, finished_recving from all workers
++
++ def update_finished_set(req_ids: Optional[set[str]],
++ remaining_count_dict: dict[str, int],
++ finished_set: set[str]) -> None:
++ for req_id in req_ids or ():
++ new_count = remaining_count_dict[req_id] - 1
++ if new_count == 0:
++ finished_set.add(req_id)
++ del remaining_count_dict[req_id]
++ else:
++ remaining_count_dict[req_id] = new_count
++
++ def update_finished_list(req_ids: Optional[dict[str, list[str]]],
++ remaining_count_dict: dict[str, int],
++ finished_list: dict[str, list[str]]) -> None:
++ for req_id, succeed_dump_blocks in (req_ids or {}).items():
++ if req_id not in finished_list:
++ finished_list[req_id] = []
++ for blk_id in succeed_dump_blocks:
++ new_count = remaining_count_dict[blk_id] - 1
++ if new_count == 0:
++ finished_list[req_id].append(blk_id)
++ del remaining_count_dict[blk_id]
++ else:
++ remaining_count_dict[blk_id] = new_count
++
++ finished_sending = set[str]()
++ finished_recving = set[str]()
++ invalid_block_ids = set[int]()
++ finished_dumping: dict[str, list[str]] = {}
++ for output in outputs:
++ update_finished_set(output.finished_sending,
++ self._send_remaining_count, finished_sending)
++ update_finished_set(output.finished_recving,
++ self._recv_remaining_count, finished_recving)
++ update_finished_list(output.finished_dumping,
++ self._dump_remaining_count, finished_dumping)
++ if output.invalid_block_ids:
++ invalid_block_ids |= output.invalid_block_ids
++
++ # select output of the worker specified by output_rank
++ output = outputs[output_rank]
++
++ # set the aggregated finished_sending / finished_recving
++ # if output.finished_sending/recving is not empty, but the other ranks
++ # still have unfinished send/recv, we want to set the aggregated
++ # finished_sending/recving to None until all ranks have finished
++ # send/recv
++ output.finished_sending = finished_sending if finished_sending else None
++ output.finished_recving = finished_recving if finished_recving else None
++ output.finished_dumping = finished_dumping if finished_dumping else None
++ output.invalid_block_ids = invalid_block_ids or None
++
++ return output
++
++ def async_aggregate(self,
++ output_futures: Sequence[Future[ModelRunnerOutput]],
++ output_rank: int = 0) -> Future[ModelRunnerOutput]:
++ """Takes a list of futures and returns a single future which resolves
++ to the respective list of outputs."""
++ result_future: Future[ModelRunnerOutput] = Future()
++
++ outputs: list[Optional[ModelRunnerOutput]] = [None
++ ] * len(output_futures)
++
++ def make_callback(idx):
++
++ def callback(fut):
++ if result_future.done():
++ return
++
++ try:
++ outputs[idx] = fut.result()
++ except CancelledError:
++ result_future.cancel()
++ except Exception as e:
++ result_future.set_exception(e)
++
++ # this check assumes io_thread_pool uses a single thread
++ if all(outputs):
++ result_future.set_result(
++ self.aggregate(cast(list[ModelRunnerOutput], outputs),
++ output_rank))
++
++ return callback
++
++ for i, output_future in enumerate(output_futures):
++ output_future.add_done_callback(make_callback(i))
++
++ return result_future
+diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
+index f80b5eba2..39d8fa389 100644
+--- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py
++++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
+@@ -201,6 +201,15 @@ class KVConnectorBase_V1(ABC):
+ """
+ return None, None
+
++ def get_block_ids_with_load_errors(self) -> set[int]:
++ """
++ Get the set of block IDs that failed to load.
++ Returns:
++ Optional[set[int]]: A set of block IDs that encountered load errors.
++ Returns None if no errors occurred during load.
++ """
++ return set()
++
+ # ==============================
+ # Scheduler-side methods
+ # ==============================
+diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
+index 5f92d69bd..4e1f45e7a 100644
+--- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
++++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
+@@ -134,6 +134,12 @@ class MultiConnector(KVConnectorBase_V1):
+
+ return finished_sending or None, finished_recving or None
+
++ def get_block_ids_with_load_errors(self) -> set[int]:
++ agg_block_ids: set[int] = set()
++ for c in self._connectors:
++ agg_block_ids |= c.get_block_ids_with_load_errors()
++ return agg_block_ids
++
+ # ==============================
+ # Scheduler-side methods
+ # ==============================
+diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py
+index d21f94727..1800665c7 100644
+--- a/vllm/v1/core/block_pool.py
++++ b/vllm/v1/core/block_pool.py
+@@ -124,7 +124,7 @@ class BlockPool:
+ kv_cache_group_id: The id of the KV cache group.
+ hash_fn: The hash function to use for block hashes.
+ """
+- if num_cached_blocks == num_full_blocks:
++ if num_cached_blocks >= num_full_blocks:
+ return
+ new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
+ assert len(block_hashes) >= num_cached_blocks
+diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py
+index d34f39327..c94e421c0 100644
+--- a/vllm/v1/core/sched/output.py
++++ b/vllm/v1/core/sched/output.py
+@@ -93,6 +93,7 @@ class CachedRequestData:
+ new_token_ids: list[list[int]]
+ new_block_ids: list[tuple[list[int], ...]]
+ num_computed_tokens: list[int]
++ num_output_tokens: list[int]
+
+ @property
+ def num_reqs(self) -> int:
+@@ -106,6 +107,7 @@ class CachedRequestData:
+ new_token_ids=[],
+ new_block_ids=[],
+ num_computed_tokens=[],
++ num_output_tokens=[],
+ )
+
+
+diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
+index cd80f92a1..2d4fd4d59 100644
+--- a/vllm/v1/core/sched/scheduler.py
++++ b/vllm/v1/core/sched/scheduler.py
+@@ -119,6 +119,7 @@ class Scheduler(SchedulerInterface):
+
+ # KV Connector: requests in process of async KV loading or recving
+ self.finished_recving_kv_req_ids: set[str] = set()
++ self.failed_recving_kv_req_ids: set[str] = set()
+
+ # Encoder-related.
+ # Calculate encoder cache size if applicable
+@@ -621,6 +622,7 @@ class Scheduler(SchedulerInterface):
+ new_token_ids: list[list[int]] = []
+ new_block_ids: list[tuple[list[int], ...]] = []
+ num_computed_tokens: list[int] = []
++ num_output_tokens: list[int] = []
+
+ for req in itertools.chain(running_reqs, resumed_reqs):
+ req_id = req.request_id
+@@ -638,6 +640,7 @@ class Scheduler(SchedulerInterface):
+ new_token_ids.append(token_ids)
+ new_block_ids.append(req_to_new_block_ids[req_id])
+ num_computed_tokens.append(req.num_computed_tokens)
++ num_output_tokens.append(len(req.output_token_ids))
+ # Because resumed_reqs is usually empty, it is more efficient to do
+ # in-place appending so that we don't need to allocate a new list.
+ resumed_from_preemption = [False] * len(running_reqs)
+@@ -649,6 +652,7 @@ class Scheduler(SchedulerInterface):
+ new_token_ids=new_token_ids,
+ new_block_ids=new_block_ids,
+ num_computed_tokens=num_computed_tokens,
++ num_output_tokens=num_output_tokens,
+ )
+
+ def _try_schedule_encoder_inputs(
+@@ -746,16 +750,29 @@ class Scheduler(SchedulerInterface):
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens
+ pooler_outputs = model_runner_output.pooler_output
+ num_nans_in_logits = model_runner_output.num_nans_in_logits
++ invalid_block_ids = model_runner_output.invalid_block_ids
+
+ new_running: list[Request] = []
+ outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
+ spec_decoding_stats: Optional[SpecDecodingStats] = None
+
++ failed_kv_load_req_ids = None
++ if invalid_block_ids:
++ # These blocks contain externally computed tokens that failed to
++ # load. Identify affected requests and adjust their computed token
++ # count to trigger recomputation of the invalid blocks.
++ failed_kv_load_req_ids = self._handle_invalid_blocks(invalid_block_ids)
++
+ # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
+ # loop can be a performance bottleneck. We should do our best to avoid
+ # expensive operations inside the loop.
+ for request in self.running:
+ req_id = request.request_id
++ # self.req_meta.stage == SequenceStage.PREFILL and self.req_meta.is_last_chunk
++ if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
++ # Skip requests that were recovered from KV load failure
++ new_running.append(request)
++ continue
+ num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
+ if num_tokens_scheduled == 0:
+ # The request was not scheduled in this step.
+@@ -1089,18 +1106,31 @@ class Scheduler(SchedulerInterface):
+ if request.request_id not in self.finished_recving_kv_req_ids:
+ return False
+
+- # Now that the blocks are ready, actually cache them.
+- (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
+- num_computed_tokens = len(block_ids) * self.block_size
+- # Handle the case where num request tokens less then one block.
+- num_computed_tokens = min(num_computed_tokens, request.num_tokens)
+- if num_computed_tokens == request.num_tokens:
+- num_computed_tokens -= 1
+- # This will cache the blocks iff caching is enabled.
+- self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
+-
+- # Update the request state for scheduling.
+- request.num_computed_tokens = num_computed_tokens
++ if request.request_id in self.failed_recving_kv_req_ids:
++ # Request had KV load failures; num_computed_tokens was already
++ # updated in _update_requests_with_invalid_blocks
++ if request.num_computed_tokens:
++ # Cache any valid computed tokens.
++ self.kv_cache_manager.cache_blocks(request,
++ request.num_computed_tokens)
++ else:
++ # No valid computed tokens, release allocated blocks.
++ # There may be a local cache hit on retry.
++ self.kv_cache_manager.free(request)
++ self.failed_recving_kv_req_ids.remove(request.request_id)
++ else:
++ # Now that the blocks are ready, actually cache them.
++ (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
++ num_computed_tokens = len(block_ids) * self.block_size
++ # Handle the case where num request tokens less then one block.
++ num_computed_tokens = min(num_computed_tokens, request.num_tokens)
++ if num_computed_tokens == request.num_tokens:
++ num_computed_tokens -= 1
++ # This will cache the blocks iff caching is enabled.
++ self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
++
++ # Update the request state for scheduling.
++ request.num_computed_tokens = num_computed_tokens
+
+ # Return that we are ready.
+ self.finished_recving_kv_req_ids.remove(request.request_id)
+@@ -1124,3 +1154,133 @@ class Scheduler(SchedulerInterface):
+ for req_id in (model_runner_output.finished_sending or ()):
+ logger.debug("Finished sending KV transfer for request %s", req_id)
+ self._free_blocks(self.requests[req_id])
++
++
++ def _update_requests_with_invalid_blocks(
++ self, requests: Iterable[Request],
++ invalid_block_ids: set[int]) -> tuple[set[str], int]:
++ """
++ Identify and update requests affected by invalid KV cache blocks.
++ This method scans the given requests, detects those with invalid blocks
++ and adjusts their `num_computed_tokens` to the longest valid prefix.
++ For observability, it also accumulates the total number of tokens that
++ will need to be recomputed across all affected requests.
++ Args:
++ requests: The set of requests to scan for invalid blocks.
++ invalid_block_ids: IDs of invalid blocks.
++ Returns:
++ tuple:
++ - affected_req_ids (set[str]): IDs of requests impacted by
++ invalid blocks.
++ - total_affected_tokens (int): Total number of tokens that must
++ be recomputed across all affected requests (for observability).
++ """
++ affected_req_ids: set[str] = set()
++ total_affected_tokens = 0
++ # If a block is invalid and shared by multiple requests in the batch,
++ # these requests must be rescheduled, but only the first will recompute
++ # it. This set tracks blocks already marked for recomputation.
++ marked_invalid_block_ids: set[int] = set()
++ for request in requests:
++ is_affected = False
++ marked_invalid_block = False
++ req_id = request.request_id
++ # TODO (davidb): add support for hybrid memory allocator
++ (req_block_ids, ) = self.kv_cache_manager.get_block_ids(req_id)
++ # We iterate only over blocks that may contain externally computed
++ # tokens
++ if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
++ # Async loading. If num_computed_tokens is set it implies we
++ # already processed some block failures for it in a prior step
++ req_num_computed_tokens = (
++ request.num_computed_tokens if req_id
++ in self.failed_recving_kv_req_ids else len(req_block_ids) *
++ self.block_size)
++ else:
++ # Sync loading. num_computed_tokens includes new tokens
++ req_num_computed_tokens = request.num_cached_tokens
++
++ req_num_computed_blocks = (req_num_computed_tokens +
++ self.block_size - 1) // self.block_size
++ for idx, block_id in zip(range(req_num_computed_blocks),
++ req_block_ids):
++
++ if block_id not in invalid_block_ids:
++ continue
++
++ is_affected = True
++
++ if block_id in marked_invalid_block_ids:
++ # This invalid block is shared with a previous request
++ # and was already marked for recomputation.
++ # This means this request can still consider this block
++ # as computed when rescheduled.
++ # Currently this only applies to sync loading; Async
++ # loading does not yet support block sharing
++ continue
++
++ marked_invalid_block_ids.add(block_id)
++
++ if marked_invalid_block:
++ # This request has already marked an invalid block for
++ # recomputation and updated its num_computed_tokens.
++ continue
++
++ marked_invalid_block = True
++ # Truncate the computed tokens at the first failed block
++ request.num_computed_tokens = idx * self.block_size
++ total_affected_tokens += (req_num_computed_tokens -
++ request.num_computed_tokens)
++
++ if is_affected:
++ if not marked_invalid_block:
++ # All invalid blocks of this request are shared with
++ # previous requests and will be recomputed by them.
++ # Revert to considering only cached tokens as computed.
++ # Currently this only applies to sync loading; Async
++ # loading does not yet support block sharing
++ total_affected_tokens += (request.num_computed_tokens -
++ request.num_cached_tokens)
++ request.num_computed_tokens = request.num_cached_tokens
++
++ affected_req_ids.add(request.request_id)
++
++ return (affected_req_ids, total_affected_tokens)
++
++
++ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
++ total_requests_to_reschedule = 0
++ total_tokens_to_reschedule = 0
++
++ # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) ---
++ async_load_reqs = (
++ req for req in self.waiting
++ if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
++ async_affected_req_ids, num_tokens_to_reschedule = (
++ self._update_requests_with_invalid_blocks(async_load_reqs,
++ invalid_block_ids))
++
++ total_requests_to_reschedule += len(async_affected_req_ids)
++ total_tokens_to_reschedule += num_tokens_to_reschedule
++
++ # Mark requests with async KV load failures; they will be rescheduled
++ # once loading completes
++ self.failed_recving_kv_req_ids |= async_affected_req_ids
++
++ # --- Handle sync KV loads (running requests) ---
++ sync_affected_req_ids, num_tokens_to_reschedule = (
++ self._update_requests_with_invalid_blocks(self.running,
++ invalid_block_ids))
++
++ total_requests_to_reschedule += len(sync_affected_req_ids)
++ total_tokens_to_reschedule += num_tokens_to_reschedule
++
++ if total_requests_to_reschedule:
++ logger.warning(
++ "Recovered from KV load failure: "
++ "%d request(s) rescheduled (%d tokens affected).",
++ total_requests_to_reschedule, total_tokens_to_reschedule)
++
++ # Return the IDs of affected running requests to skip in
++ # update_from_output.
++ return sync_affected_req_ids
+diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py
+index 5b4718038..28bd4618a 100644
+--- a/vllm/v1/core/single_type_kv_cache_manager.py
++++ b/vllm/v1/core/single_type_kv_cache_manager.py
+@@ -142,6 +142,9 @@ class SingleTypeKVCacheManager(ABC):
+ num_cached_blocks = self.num_cached_block[request.request_id]
+ num_full_blocks = num_tokens // self.block_size
+
++ if num_cached_blocks >= num_full_blocks:
++ return
++
+ self.block_pool.cache_full_blocks(
+ request=request,
+ blocks=self.req_to_blocks[request.request_id],
+diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py
+index b06b7cc80..61cd7110f 100644
+--- a/vllm/v1/executor/multiproc_executor.py
++++ b/vllm/v1/executor/multiproc_executor.py
+@@ -26,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment,
+ destroy_model_parallel)
+ from vllm.distributed.device_communicators.shm_broadcast import (Handle,
+ MessageQueue)
++from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
+ from vllm.executor.multiproc_worker_utils import (
+ _add_prefix, set_multiprocessing_worker_envs)
+ from vllm.logger import init_logger
+@@ -111,10 +112,14 @@ class MultiprocExecutor(Executor):
+ if self.max_concurrent_batches > 1:
+ # Note: must use only 1 IO thread to keep dequeue sequence
+ # from the response queue
++ # _async_aggregate_workers_output also assumes a single IO thread
+ self.io_thread_pool = ThreadPoolExecutor(
+ max_workers=1, thread_name_prefix="mp_exec_io")
+
+ self.output_rank = self._get_output_rank()
++ self.has_connector = self.vllm_config.kv_transfer_config is not None
++ self.kv_output_aggregator = KVOutputAggregator(
++ self.parallel_config.world_size)
+
+ def start_worker_monitor(self):
+ workers = self.workers
+@@ -155,13 +160,30 @@ class MultiprocExecutor(Executor):
+ self,
+ scheduler_output,
+ ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
+- (output, ) = self.collective_rpc(
++ non_block = self.max_concurrent_batches > 1
++
++ if not self.has_connector or self.vllm_config.model_config.use_mla:
++ # get output only from a single worker (output_rank)
++ (output, ) = self.collective_rpc(
++ "execute_model",
++ args=(scheduler_output, ),
++ unique_reply_rank=self.output_rank,
++ non_block=non_block,
++ timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
++ return output
++
++ # get output from all workers
++ outputs = self.collective_rpc(
+ "execute_model",
+ args=(scheduler_output, ),
+- unique_reply_rank=self.output_rank,
+- non_block=self.max_concurrent_batches > 1,
++ non_block=non_block,
+ timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
+- return output
++
++ # aggregate all workers output to a single output
++ if non_block:
++ return self.kv_output_aggregator.async_aggregate(
++ outputs, self.output_rank)
++ return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
+
+ def collective_rpc(self,
+ method: Union[str, Callable],
+diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py
+index c8388baed..16af8dbce 100644
+--- a/vllm/v1/outputs.py
++++ b/vllm/v1/outputs.py
+@@ -1,7 +1,7 @@
+ # SPDX-License-Identifier: Apache-2.0
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+-from dataclasses import dataclass
++from dataclasses import dataclass, field
+ from typing import NamedTuple, Optional
+
+ import torch
+@@ -109,6 +109,10 @@ class ModelRunnerOutput:
+ finished_recving: Optional[set[str]] = None
+ finished_dumping: Optional[dict[str, list[str]]] = None
+
++ # IDs of externally computed KV blocks that failed to load.
++ # Requests referencing these blocks should be rescheduled to recompute them.
++ invalid_block_ids: set[int] = field(default_factory=set)
++
+ # req_id -> num_nans_in_logits
+ num_nans_in_logits: Optional[dict[str, int]] = None
+
+diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
+index 1a79d72be..8819d7629 100644
+--- a/vllm/v1/worker/gpu_input_batch.py
++++ b/vllm/v1/worker/gpu_input_batch.py
+@@ -96,6 +96,9 @@ class InputBatch:
+ pin_memory=False,
+ )
+ self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
++ self.is_token_ids = torch.zeros(
++ (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
++ )
+ self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
+ self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
+ self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
+@@ -286,8 +289,14 @@ class InputBatch:
+ req_index, :num_prompt_tokens] = request.prompt_token_ids
+ start_idx = num_prompt_tokens
+ end_idx = start_idx + len(request.output_token_ids)
++ if request.prompt_token_ids is not None:
++ self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
++ self.is_token_ids[req_index, :num_prompt_tokens] = True
++ else:
++ self.is_token_ids[req_index, :num_prompt_tokens] = False
+ self.token_ids_cpu[req_index,
+ start_idx:end_idx] = request.output_token_ids
++ self.is_token_ids[req_index, start_idx:end_idx] = True
+ # Number of token ids in token_ids_cpu.
+ # NOTE(woosuk): This may include spec decode tokens.
+ self.num_tokens[req_index] = request.num_tokens
+@@ -473,6 +482,8 @@ class InputBatch:
+ self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
+ self.token_ids_cpu[i2, ...] = tmp
+
++ self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
++
+ swap_dict_values(self.generators, i1, i2)
+ swap_dict_values(self.bad_words_token_ids, i1, i2)
+
+@@ -542,6 +553,9 @@ class InputBatch:
+ num_tokens = self.num_tokens[last_req_index]
+ self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
+ last_req_index, :num_tokens]
++ self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
++ last_req_index, :num_tokens
++ ]
+ self.num_tokens[empty_index] = num_tokens
+ self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
+ last_req_index]
+diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
+index 53ee8cfcd..c3df1d5d2 100644
+--- a/vllm/v1/worker/gpu_model_runner.py
++++ b/vllm/v1/worker/gpu_model_runner.py
+@@ -473,6 +473,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ num_computed_tokens = req_data.num_computed_tokens[i]
+ new_block_ids = req_data.new_block_ids[i]
+ resumed_from_preemption = req_data.resumed_from_preemption[i]
++ num_output_tokens = req_data.num_output_tokens[i]
+
+ # Update the cached states.
+ req_state.num_computed_tokens = num_computed_tokens
+@@ -492,6 +493,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ elif num_new_tokens > 0:
+ req_state.output_token_ids.extend(
+ new_token_ids[-num_new_tokens:])
++ elif num_output_tokens < len(req_state.output_token_ids):
++ # Some output tokens were discarded due to a sync-KV-load
++ # failure. Align the cached state.
++ del req_state.output_token_ids[num_output_tokens:]
++
++ req_index = self.input_batch.req_id_to_index.get(req_id)
++ if req_index is not None:
++ old_end_idx = self.input_batch.num_tokens_no_spec[
++ req_index]
++ end_idx = self.input_batch.num_prompt_tokens[
++ req_index] + num_output_tokens
++ self.input_batch.num_tokens[req_index] = end_idx
++ self.input_batch.num_tokens_no_spec[req_index] = end_idx
++ self.input_batch.is_token_ids[req_index,
++ end_idx:old_end_idx] = False
+
+ # Update the block IDs.
+ if not resumed_from_preemption:
+@@ -1381,6 +1397,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ finished_dumping = self.maybe_wait_for_kv_save()
+ finished_sending, finished_recving = (
+ self.get_finished_kv_transfers(scheduler_output))
++ invalid_block_ids = self.get_block_ids_with_load_errors()
+
+ if self.use_aux_hidden_state_outputs:
+ hidden_states, aux_hidden_states = model_output
+@@ -1564,6 +1581,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ finished_recving=finished_recving,
+ finished_dumping=finished_dumping,
+ num_nans_in_logits=num_nans_in_logits,
++ invalid_block_ids = invalid_block_ids
+ )
+
+ def propose_draft_token_ids(
+@@ -1694,13 +1712,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ self.maybe_setup_kv_connector(scheduler_output)
+ finished_sending, finished_recving = (
+ self.get_finished_kv_transfers(scheduler_output))
++ invalid_block_ids = self.get_block_ids_with_load_errors()
++ get_kv_transfer_group().clear_connector_metadata()
+
+- if not finished_sending and not finished_recving:
++ if not finished_sending and not finished_recving and not invalid_block_ids:
+ return EMPTY_MODEL_RUNNER_OUTPUT
+
+ output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
+ output.finished_sending = finished_sending
+ output.finished_recving = finished_recving
++ output.invalid_block_ids = invalid_block_ids
+ return output
+
+ @staticmethod
+@@ -1733,6 +1754,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ scheduler_output.finished_req_ids)
+ return None, None
+
++ def get_block_ids_with_load_errors(self) -> Optional[set[int]]:
++ if has_kv_transfer_group():
++ return get_kv_transfer_group().get_block_ids_with_load_errors()
++ return None
++
+ def propose_ngram_draft_token_ids(
+ self,
+ sampled_token_ids: list[list[int]],
+diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
+index 9e7e44d06..1b816b25b 100644
+--- a/vllm/v1/worker/gpu_worker.py
++++ b/vllm/v1/worker/gpu_worker.py
+@@ -1,6 +1,7 @@
+ # SPDX-License-Identifier: Apache-2.0
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+ """A GPU worker class."""
++import copy
+ import gc
+ import os
+ from typing import TYPE_CHECKING, Optional
+@@ -15,7 +16,8 @@ from vllm.device_allocator.cumem import CuMemAllocator
+ from vllm.distributed import (ensure_model_parallel_initialized,
+ init_distributed_environment,
+ set_custom_all_reduce)
+-from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
++from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
++ has_kv_transfer_group)
+ from vllm.distributed.parallel_state import get_pp_group, get_tp_group
+ from vllm.logger import init_logger
+ from vllm.lora.request import LoRARequest
+@@ -24,7 +26,7 @@ from vllm.platforms import current_platform
+ from vllm.sequence import IntermediateTensors
+ from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
+ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
+-from vllm.v1.outputs import ModelRunnerOutput
++from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
+ from vllm.v1.utils import report_usage_stats
+ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
+ from vllm.v1.worker.worker_base import WorkerBase
+@@ -313,9 +315,22 @@ class Worker(WorkerBase):
+ assert isinstance(output, IntermediateTensors)
+ get_pp_group().send_tensor_dict(output.tensors,
+ all_gather_group=get_tp_group())
+- return None
++ if not has_kv_transfer_group():
++ return None
++
++ # In case of PP with kv transfer, we need to pass through the
++ # finished_sending and finished_recving buffers.
++ new_output = EMPTY_MODEL_RUNNER_OUTPUT
++ if output.finished_sending or output.finished_recving or output.finished_dumping or output.invalid_block_ids:
++ new_output = copy.copy(new_output)
++ new_output.finished_sending = output.finished_sending
++ new_output.finished_recving = output.finished_recving
++ new_output.finished_dumping = output.finished_dumping
++ new_output.invalid_block_ids = output.invalid_block_ids
++ output = new_output
++
+ assert isinstance(output, ModelRunnerOutput)
+- return output if self.is_driver_worker else None
++ return output
+
+ def profile(self, is_start: bool = True):
+ if self.profiler is None:
+--
+2.34.1
+
diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-pc.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-pc.patch
new file mode 100644
index 000000000..bf0b7e19a
--- /dev/null
+++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-pc.patch
@@ -0,0 +1,122 @@
+From 26fdd2026cc3d1ed7da894907ae244a155a16566 Mon Sep 17 00:00:00 2001
+From: harrisonyhq
+Date: Tue, 4 Nov 2025 19:36:36 -0800
+Subject: [PATCH 1/3] [Patch0] UCM PC adapt patch
+
+---
+ .../kv_transfer/kv_connector/v1/multi_connector.py | 7 ++++++-
+ vllm/v1/core/sched/scheduler.py | 11 +++++++++++
+ vllm/v1/outputs.py | 1 +
+ vllm/v1/request.py | 2 ++
+ vllm/v1/worker/gpu_model_runner.py | 7 ++++---
+ 5 files changed, 24 insertions(+), 4 deletions(-)
+
+diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
+index be3c23399..5f92d69bd 100644
+--- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
++++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
+@@ -99,8 +99,13 @@ class MultiConnector(KVConnectorBase_V1):
+ c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
+
+ def wait_for_save(self):
++ success_dumped_blocks = None
+ for c in self._connectors:
+- c.wait_for_save()
++ uc_dump_blocks = c.wait_for_save()
++ if uc_dump_blocks:
++ success_dumped_blocks = uc_dump_blocks
++
++ return success_dumped_blocks if success_dumped_blocks else None
+
+ def get_finished(
+ self, finished_req_ids: set[str]
+diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
+index fe552db74..cd80f92a1 100644
+--- a/vllm/v1/core/sched/scheduler.py
++++ b/vllm/v1/core/sched/scheduler.py
+@@ -34,6 +34,7 @@ from vllm.v1.outputs import ModelRunnerOutput
+ from vllm.v1.request import Request, RequestStatus
+ from vllm.v1.spec_decode.metrics import SpecDecodingStats
+ from vllm.v1.structured_output import StructuredOutputManager
++from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import MultiConnector
+
+ logger = init_logger(__name__)
+
+@@ -791,6 +792,16 @@ class Scheduler(SchedulerInterface):
+ new_logprobs = None
+ new_token_ids = generated_token_ids
+ kv_transfer_params = None
++ if model_runner_output.finished_dumping is not None:
++ request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, []))
++ is_prefill = request.num_output_tokens == 0
++ if is_prefill:
++ if isinstance(self.connector, MultiConnector):
++ for c in self.connector._connectors:
++ if hasattr(c, 'connector') and hasattr(c.connector, 'commit'):
++ c.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
++ else:
++ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
+
+ # Append generated tokens and check for stop. Note that if
+ # a request is still being prefilled, we expect the model runner
+diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py
+index f78623f57..c8388baed 100644
+--- a/vllm/v1/outputs.py
++++ b/vllm/v1/outputs.py
+@@ -107,6 +107,7 @@ class ModelRunnerOutput:
+ # [req_ids]
+ finished_sending: Optional[set[str]] = None
+ finished_recving: Optional[set[str]] = None
++ finished_dumping: Optional[dict[str, list[str]]] = None
+
+ # req_id -> num_nans_in_logits
+ num_nans_in_logits: Optional[dict[str, int]] = None
+diff --git a/vllm/v1/request.py b/vllm/v1/request.py
+index 9b96f4599..e70d1695b 100644
+--- a/vllm/v1/request.py
++++ b/vllm/v1/request.py
+@@ -103,6 +103,8 @@ class Request:
+ # The number of tokens with prefix cache hits.
+ self.num_cached_tokens = -1
+
++ self.succeed_dumped_blocks: list[str] = []
++
+ # The number of NaNs in logits. A value greater than 0
+ # indicates that the output is corrupted
+ self.num_nans_in_logits = 0
+diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
+index 5a26e88db..53ee8cfcd 100644
+--- a/vllm/v1/worker/gpu_model_runner.py
++++ b/vllm/v1/worker/gpu_model_runner.py
+@@ -1378,7 +1378,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ inputs_embeds=inputs_embeds,
+ )
+
+- self.maybe_wait_for_kv_save()
++ finished_dumping = self.maybe_wait_for_kv_save()
+ finished_sending, finished_recving = (
+ self.get_finished_kv_transfers(scheduler_output))
+
+@@ -1562,6 +1562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ pooler_output=[],
+ finished_sending=finished_sending,
+ finished_recving=finished_recving,
++ finished_dumping=finished_dumping,
+ num_nans_in_logits=num_nans_in_logits,
+ )
+
+@@ -1719,9 +1720,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ kv_connector.start_load_kv(get_forward_context())
+
+ @staticmethod
+- def maybe_wait_for_kv_save() -> None:
++ def maybe_wait_for_kv_save():
+ if has_kv_transfer_group():
+- get_kv_transfer_group().wait_for_save()
++ return get_kv_transfer_group().wait_for_save()
+
+ @staticmethod
+ def get_finished_kv_transfers(
+--
+2.34.1
+
diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch
new file mode 100644
index 000000000..eb9848756
--- /dev/null
+++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch
@@ -0,0 +1,628 @@
+From 0431022b90649f7115b89b61aaf2a0f83e896d5a Mon Sep 17 00:00:00 2001
+From: wenxinwang
+Date: Mon, 10 Nov 2025 20:35:47 +0800
+Subject: [PATCH] adapt to deepseek patch
+
+---
+ vllm/attention/layer.py | 49 ++++++++++++-
+ .../kv_transfer/kv_connector/utils.py | 5 ++
+ .../v1/shared_storage_connector.py | 7 +-
+ vllm/v1/attention/backends/mla/common.py | 10 ++-
+ vllm/v1/core/kv_cache_manager.py | 7 +-
+ vllm/v1/core/sched/output.py | 3 +
+ vllm/v1/core/sched/scheduler.py | 37 +++++++---
+ vllm/v1/worker/block_table.py | 13 ++++
+ vllm/v1/worker/gpu_model_runner.py | 71 +++++++++++++++----
+ vllm/v1/worker/gpu_worker.py | 2 +
+ 10 files changed, 171 insertions(+), 33 deletions(-)
+
+diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
+index f0ad68b16..728ab99fd 100644
+--- a/vllm/attention/layer.py
++++ b/vllm/attention/layer.py
+@@ -2,7 +2,6 @@
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+ """Attention layer."""
+ from typing import Any, Dict, List, Optional
+-
+ import torch
+ import torch.nn as nn
+ import torch.nn.functional as F
+@@ -22,6 +21,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
+ from vllm.platforms import _Backend, current_platform
+ from vllm.utils import direct_register_custom_op
+ from vllm.v1.attention.backends.utils import validate_kv_sharing_target
++from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
+
+
+ class Attention(nn.Module):
+@@ -409,9 +409,10 @@ def unified_attention(
+ attn_metadata = attn_metadata[layer_name]
+ self = forward_context.no_compile_layers[layer_name]
+ kv_cache = self.kv_cache[forward_context.virtual_engine]
++ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
+ output = self.impl.forward(self, query, key, value, kv_cache,
+ attn_metadata)
+-
++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
+ maybe_save_kv_layer_to_connector(layer_name, kv_cache)
+ return output
+
+@@ -449,6 +450,8 @@ def unified_attention_with_output(
+ attn_metadata = attn_metadata[layer_name]
+ self = forward_context.no_compile_layers[layer_name]
+ kv_cache = self.kv_cache[forward_context.virtual_engine]
++ if not self.use_mla:
++ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
+ self.impl.forward(self,
+ query,
+ key,
+@@ -457,7 +460,8 @@ def unified_attention_with_output(
+ attn_metadata,
+ output=output,
+ output_scale=output_scale)
+-
++ if not self.use_mla:
++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
+ maybe_save_kv_layer_to_connector(layer_name, kv_cache)
+
+
+@@ -479,3 +483,42 @@ direct_register_custom_op(
+ fake_impl=unified_attention_with_output_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
++
++def maybe_execute_sparse_attention_begin(
++ query: torch.Tensor,
++ key: torch.Tensor,
++ value: torch.Tensor,
++ layer_name: str,
++ forward_context: ForwardContext,
++ phase: Optional[str] = None,
++):
++ if not has_ucm_sparse():
++ return
++
++ ucm_sparse = get_ucm_sparse()
++
++ attn_metadata = forward_context.attn_metadata
++ if attn_metadata is None:
++ return
++
++ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context, phase)
++
++def maybe_execute_sparse_attention_finished(
++ query: torch.Tensor,
++ key: torch.Tensor,
++ value: torch.Tensor,
++ attn_output: torch.Tensor,
++ layer_name: str,
++ forward_context: ForwardContext,
++ phase: Optional[str] = None,
++):
++ if not has_ucm_sparse():
++ return
++
++ ucm_sparse = get_ucm_sparse()
++
++ attn_metadata = forward_context.attn_metadata
++ if attn_metadata is None:
++ return
++
++ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context, phase)
+diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
+index b63bf5965..155597c51 100644
+--- a/vllm/distributed/kv_transfer/kv_connector/utils.py
++++ b/vllm/distributed/kv_transfer/kv_connector/utils.py
+@@ -3,6 +3,11 @@
+ """
+ KV cache helper for store.
+ """
++from collections import defaultdict
++from collections.abc import Sequence
++from concurrent.futures import CancelledError, Future
++from typing import Optional, cast
++
+ import torch
+
+ from collections import defaultdict
+diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
+index 3c574d065..223106def 100644
+--- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
++++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
+@@ -2,7 +2,7 @@
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+ import hashlib
+ import os
+-from dataclasses import dataclass
++from dataclasses import dataclass, field
+ from typing import TYPE_CHECKING
+
+ import safetensors
+@@ -53,10 +53,7 @@ class ReqMeta:
+
+ @dataclass
+ class SharedStorageConnectorMetadata(KVConnectorMetadata):
+- requests: list[ReqMeta]
+-
+- def __init__(self):
+- self.requests = []
++ requests: list[ReqMeta] = field(default_factory=list)
+
+ def add_request(
+ self,
+diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py
+index f2aaf59a4..b56f62b39 100644
+--- a/vllm/v1/attention/backends/mla/common.py
++++ b/vllm/v1/attention/backends/mla/common.py
+@@ -200,6 +200,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
+ MLAAttentionImpl)
+ from vllm.attention.backends.utils import get_mla_dims
+ from vllm.attention.ops.merge_attn_states import merge_attn_states
++from vllm.forward_context import ForwardContext, get_forward_context
+ from vllm.attention.utils.fa_utils import get_flash_attn_version
+ from vllm.logger import init_logger
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+@@ -211,6 +212,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
+ CommonAttentionMetadata)
+ from vllm.v1.kv_cache_interface import AttentionSpec
+ from vllm.v1.worker.block_table import BlockTable
++from vllm.attention.layer import (maybe_execute_sparse_attention_begin, maybe_execute_sparse_attention_finished)
+
+ try:
+ from vllm.vllm_flash_attn import flash_attn_varlen_func
+@@ -908,7 +910,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
+ output: Optional[torch.Tensor] = None,
+ output_scale: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+-
++ forward_context: ForwardContext = get_forward_context()
+ assert output is not None, "Output tensor must be provided."
+
+ if output_scale is not None:
+@@ -957,10 +959,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
+ )
+
+ if has_prefill:
++ maybe_execute_sparse_attention_begin(prefill_q, prefill_k_c_normed, prefill_k_pe, layer.layer_name, forward_context, "prefill")
+ output[num_decode_tokens:] = self._forward_prefill(
+ prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
+ attn_metadata)
+-
++ maybe_execute_sparse_attention_finished(prefill_q, prefill_k_c_normed, prefill_k_pe, output[num_decode_tokens:], layer.layer_name, forward_context, "prefill")
+ if has_decode:
+ assert attn_metadata.decode is not None
+ decode_q_nope, decode_q_pe = decode_q.split(
+@@ -971,8 +974,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
+ decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
+ # Convert from (N, B, L) to (B, N, L)
+ decode_ql_nope = decode_ql_nope.transpose(0, 1)
+-
++ maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, layer.layer_name, forward_context, "decode")
+ output[:num_decode_tokens] = self._forward_decode(
+ decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
++ maybe_execute_sparse_attention_finished(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, output[:num_decode_tokens], layer.layer_name, forward_context, "decode")
+
+ return output_padded
+diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py
+index 6937455e7..bf9aec864 100644
+--- a/vllm/v1/core/kv_cache_manager.py
++++ b/vllm/v1/core/kv_cache_manager.py
+@@ -3,7 +3,7 @@
+
+ from collections import defaultdict
+ from dataclasses import dataclass
+-from typing import Optional
++from typing import Optional, Union
+
+ from vllm.distributed.kv_events import KVCacheEvent
+ from vllm.logger import init_logger
+@@ -14,6 +14,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
+ from vllm.v1.kv_cache_interface import KVCacheConfig
+ from vllm.v1.metrics.stats import PrefixCacheStats
+ from vllm.v1.request import Request, RequestStatus
++from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
++from ucm.sparse.base import INVALID_SLOT
+
+ logger = init_logger(__name__)
+
+@@ -193,6 +195,7 @@ class KVCacheManager:
+ num_draft_tokens: int = 0,
+ num_lookahead_tokens: int = 0,
+ delay_cache_blocks: bool = False,
++ num_slots_sparsed: Union[None, int] = None
+ ) -> Optional[KVCacheBlocks]:
+ """Add slots for a request with new tokens to append.
+
+@@ -231,6 +234,8 @@ class KVCacheManager:
+ """
+ if num_new_tokens == 0:
+ raise ValueError("num_new_tokens must be greater than 0")
++ if num_slots_sparsed != INVALID_SLOT:
++ return get_ucm_sparse().allocate_slots(self, request, num_slots_sparsed)
+
+ if new_computed_blocks is not None:
+ new_computed_block_list = new_computed_blocks.blocks
+diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py
+index c94e421c0..fff0eeb42 100644
+--- a/vllm/v1/core/sched/output.py
++++ b/vllm/v1/core/sched/output.py
+@@ -157,3 +157,6 @@ class SchedulerOutput:
+
+ # KV Cache Connector metadata.
+ kv_connector_metadata: Optional[KVConnectorMetadata] = None
++
++ # modified slots by sparse algorithm
++ req_sparsed_slots: dict[str, int] = None
+diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
+index 2d4fd4d59..e99a51788 100644
+--- a/vllm/v1/core/sched/scheduler.py
++++ b/vllm/v1/core/sched/scheduler.py
+@@ -35,6 +35,8 @@ from vllm.v1.request import Request, RequestStatus
+ from vllm.v1.spec_decode.metrics import SpecDecodingStats
+ from vllm.v1.structured_output import StructuredOutputManager
+ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import MultiConnector
++from ucm.sparse.state import ensure_ucm_sparse_initialized, get_ucm_sparse, has_ucm_sparse
++from ucm.sparse.base import UcmSparseBase, UcmSparseRole, INVALID_SLOT
+
+ logger = init_logger(__name__)
+
+@@ -80,12 +82,18 @@ class Scheduler(SchedulerInterface):
+ # will have a corresponding KVConnector with Role=WORKER.
+ # KV Connector pushes/pull of remote KVs for P/D and offloading.
+ self.connector = None
++ self.ucm_sparse = None
+ if self.vllm_config.kv_transfer_config is not None:
+ assert len(self.kv_cache_config.kv_cache_groups) == 1, (
+ "Multiple KV cache groups are not currently supported "
+ "with KV connectors")
+ self.connector = KVConnectorFactory.create_connector_v1(
+ config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
++ # Initialize UCM Sparse if available
++ if "ucm_sparse_config" in vllm_config.kv_transfer_config.kv_connector_extra_config:
++ ensure_ucm_sparse_initialized(vllm_config, role=UcmSparseRole.SCHEDULER)
++ self.ucm_sparse = get_ucm_sparse()
++ logger.info("UCM Sparse initialized successfully: {}".format(self.ucm_sparse))
+
+ self.kv_event_publisher = EventPublisherFactory.create(
+ self.kv_events_config,
+@@ -203,8 +211,13 @@ class Scheduler(SchedulerInterface):
+
+ # First, schedule the RUNNING requests.
+ req_index = 0
++ req_sparsed_slots: dict[str, int] = {}
+ while req_index < len(self.running) and token_budget > 0:
+ request = self.running[req_index]
++ num_slots_sparsed = INVALID_SLOT
++ if self.ucm_sparse:
++ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request)
++ req_sparsed_slots.update({request.request_id: num_slots_sparsed})
+
+ num_new_tokens = (request.num_tokens_with_spec -
+ request.num_computed_tokens)
+@@ -252,7 +265,8 @@ class Scheduler(SchedulerInterface):
+ request,
+ num_new_tokens,
+ num_draft_tokens=num_draft_tokens,
+- num_lookahead_tokens=self.num_lookahead_tokens)
++ num_lookahead_tokens=self.num_lookahead_tokens,
++ num_slots_sparsed=num_slots_sparsed)
+ if new_blocks is None:
+ # The request cannot be scheduled.
+ # Preempt the lowest-priority request.
+@@ -339,6 +353,10 @@ class Scheduler(SchedulerInterface):
+ break
+
+ request = self.waiting.peek_request()
++ num_slots_sparsed = INVALID_SLOT
++ if self.ucm_sparse:
++ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request)
++ req_sparsed_slots.update({request.request_id: num_slots_sparsed})
+
+ # KVTransfer: skip request if still waiting for remote kvs.
+ if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
+@@ -448,6 +466,7 @@ class Scheduler(SchedulerInterface):
+ new_computed_blocks,
+ num_lookahead_tokens=self.num_lookahead_tokens,
+ delay_cache_blocks=load_kv_async,
++ num_slots_sparsed=num_slots_sparsed
+ )
+ if new_blocks is None:
+ # The request cannot be scheduled.
+@@ -561,6 +580,7 @@ class Scheduler(SchedulerInterface):
+ scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
+ scheduled_encoder_inputs=scheduled_encoder_inputs,
+ num_common_prefix_blocks=num_common_prefix_blocks,
++ req_sparsed_slots=req_sparsed_slots,
+ # finished_req_ids is an existing state in the scheduler,
+ # instead of being newly scheduled in this step.
+ # It contains the request IDs that are finished in between
+@@ -809,16 +829,12 @@ class Scheduler(SchedulerInterface):
+ new_logprobs = None
+ new_token_ids = generated_token_ids
+ kv_transfer_params = None
++
+ if model_runner_output.finished_dumping is not None:
+ request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, []))
+ is_prefill = request.num_output_tokens == 0
+ if is_prefill:
+- if isinstance(self.connector, MultiConnector):
+- for c in self.connector._connectors:
+- if hasattr(c, 'connector') and hasattr(c.connector, 'commit'):
+- c.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
+- else:
+- self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
++ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
+
+ # Append generated tokens and check for stop. Note that if
+ # a request is still being prefilled, we expect the model runner
+@@ -870,7 +886,6 @@ class Scheduler(SchedulerInterface):
+ spec_token_ids[req_index])
+ else:
+ request.spec_token_ids = spec_token_ids[req_index]
+-
+ # Get prompt logprobs for this request.
+ prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
+ if new_token_ids or pooler_output is not None \
+@@ -897,6 +912,7 @@ class Scheduler(SchedulerInterface):
+
+ if not stopped:
+ new_running.append(request)
++
+ self.running = new_running
+
+ # KV Connector: update state for finished KV Transfers.
+@@ -955,6 +971,8 @@ class Scheduler(SchedulerInterface):
+ def add_request(self, request: Request) -> None:
+ self.waiting.add_request(request)
+ self.requests[request.request_id] = request
++ if self.ucm_sparse:
++ self.ucm_sparse.request_begin(request.request_id, request.prompt_token_ids)
+ if self.log_stats:
+ request.record_event(EngineCoreEventType.QUEUED)
+
+@@ -1004,6 +1022,8 @@ class Scheduler(SchedulerInterface):
+
+ def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
+ assert request.is_finished()
++ if self.ucm_sparse:
++ self.ucm_sparse.request_finished_in_scheduler(request.request_id)
+
+ delay_free_blocks, kv_xfer_params = self._connector_finished(request)
+ self.encoder_cache_manager.free(request)
+@@ -1155,7 +1175,6 @@ class Scheduler(SchedulerInterface):
+ logger.debug("Finished sending KV transfer for request %s", req_id)
+ self._free_blocks(self.requests[req_id])
+
+-
+ def _update_requests_with_invalid_blocks(
+ self, requests: Iterable[Request],
+ invalid_block_ids: set[int]) -> tuple[set[str], int]:
+diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py
+index 8f4e8d64c..f45e39f5c 100644
+--- a/vllm/v1/worker/block_table.py
++++ b/vllm/v1/worker/block_table.py
+@@ -61,6 +61,15 @@ class BlockTable:
+ self.num_blocks_per_row[row_idx] += num_blocks
+ self.block_table_np[row_idx, start:start + num_blocks] = block_ids
+
++ def reset_row(
++ self,
++ row_idx: int,
++ ) -> None:
++ self.num_blocks_per_row[row_idx] = 0
++ self.block_table[row_idx].fill_(0)
++ self.block_table_cpu[row_idx].fill_(0)
++ self.block_table_np[row_idx].fill(0)
++
+ def add_row(self, block_ids: list[int], row_idx: int) -> None:
+ self.num_blocks_per_row[row_idx] = 0
+ self.append_row(block_ids, row_idx)
+@@ -117,6 +126,10 @@ class MultiGroupBlockTable:
+ for i, block_table in enumerate(self.block_tables):
+ block_table.append_row(block_ids[i], row_idx)
+
++ def reset_row(self, row_idx: int) -> None:
++ for i, block_table in enumerate(self.block_tables):
++ block_table.reset_row(row_idx)
++
+ def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
+ for i, block_table in enumerate(self.block_tables):
+ block_table.add_row(block_ids[i], row_idx)
+diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
+index c3df1d5d2..dbf1ea7d7 100644
+--- a/vllm/v1/worker/gpu_model_runner.py
++++ b/vllm/v1/worker/gpu_model_runner.py
+@@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager
+ from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
+ sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
+
++from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
++from ucm.sparse.base import UcmSparseMetadata, INVALID_SLOT
++
+ if TYPE_CHECKING:
+ import xgrammar as xgr
+ import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501
+@@ -365,6 +368,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ """
+ # Remove finished requests from the cached states.
+ for req_id in scheduler_output.finished_req_ids:
++ self.ucm_sparse_request_finished_in_worker(req_id)
+ self.requests.pop(req_id, None)
+ self.encoder_cache.pop(req_id, None)
+ # Remove the finished requests from the persistent batch.
+@@ -468,12 +472,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ # Update the states of the running/resumed requests.
+ is_last_rank = get_pp_group().is_last_rank
+ req_data = scheduler_output.scheduled_cached_reqs
++ req_sparsed_slots = scheduler_output.req_sparsed_slots
+ for i, req_id in enumerate(req_data.req_ids):
+ req_state = self.requests[req_id]
+ num_computed_tokens = req_data.num_computed_tokens[i]
+ new_block_ids = req_data.new_block_ids[i]
+ resumed_from_preemption = req_data.resumed_from_preemption[i]
+ num_output_tokens = req_data.num_output_tokens[i]
++ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT
+
+ # Update the cached states.
+ req_state.num_computed_tokens = num_computed_tokens
+@@ -510,15 +516,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ end_idx:old_end_idx] = False
+
+ # Update the block IDs.
+- if not resumed_from_preemption:
+- # Append the new blocks to the existing block IDs.
+- for block_ids, new_ids in zip(req_state.block_ids,
+- new_block_ids):
+- block_ids.extend(new_ids)
+- else:
++ if resumed_from_preemption or is_sparsed_request:
+ # The request is resumed from preemption.
+ # Replace the existing block IDs with the new ones.
+ req_state.block_ids = new_block_ids
++ else:
++ # Append the new blocks to the existing block IDs.
++ for block_ids, new_ids in zip(req_state.block_ids,
++ new_block_ids):
++ block_ids.extend(new_ids)
+
+ req_index = self.input_batch.req_id_to_index.get(req_id)
+ if req_index is None:
+@@ -531,6 +537,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ # Update the persistent batch.
+ self.input_batch.num_computed_tokens_cpu[req_index] = (
+ num_computed_tokens)
++ if is_sparsed_request:
++ self.input_batch.block_table.reset_row(req_index)
+ self.input_batch.block_table.append_row(new_block_ids, req_index)
+
+ # For the last rank, we don't need to update the token_ids_cpu
+@@ -639,6 +647,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ if self.uses_mrope:
+ self._calc_mrope_positions(scheduler_output)
+
++ self.seq_lens_np[:num_reqs] = (
++ self.input_batch.num_computed_tokens_cpu[:num_reqs] +
++ num_scheduled_tokens)
++
++ # TODO: improve performance, no `positions_np.copy()`
++ sparsed_positions = positions_np.copy()
++ req_sparsed_slots = scheduler_output.req_sparsed_slots
++ for req_id in self.input_batch.req_id_to_index:
++ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT
++ req_index = self.input_batch.req_id_to_index[req_id]
++ offset = 0 if req_index == 0 else cu_num_tokens[req_index - 1] # TODO: support MTP
++ if is_sparsed_request:
++ sparsed_positions[offset] = req_sparsed_slots[req_id] - 1
+ # Get token indices.
+ # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
+ # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
+@@ -668,11 +689,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ # block_size.
+ block_table_indices = (
+ req_indices * block_table.max_num_blocks_per_req +
+- positions_np // block_size)
++ sparsed_positions // block_size)
+ block_table_cpu = block_table.get_cpu_tensor()
+ block_numbers = block_table_cpu.flatten(
+ )[block_table_indices].numpy()
+- block_offsets = positions_np % block_size
++ block_offsets = sparsed_positions % block_size
+ np.add(
+ block_numbers * block_size,
+ block_offsets,
+@@ -682,9 +703,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ self.query_start_loc_np[0] = 0
+ self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
+
+- self.seq_lens_np[:num_reqs] = (
+- self.input_batch.num_computed_tokens_cpu[:num_reqs] +
+- num_scheduled_tokens)
++ for req_id in self.input_batch.req_id_to_index:
++ req_index = self.input_batch.req_id_to_index[req_id]
++ is_sparsed_request = scheduler_output.req_sparsed_slots[req_id] != INVALID_SLOT
++ if is_sparsed_request:
++ self.seq_lens_np[req_index] = scheduler_output.req_sparsed_slots[req_id]
+
+ # Copy the tensors to the GPU.
+ self.input_ids[:total_num_scheduled_tokens].copy_(
+@@ -696,6 +719,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ non_blocking=True)
+ else:
+ # Common case (1D positions)
++ self.positions_cpu[:total_num_scheduled_tokens] = torch.from_numpy(
++ positions_np[:total_num_scheduled_tokens])
+ self.positions[:total_num_scheduled_tokens].copy_(
+ self.positions_cpu[:total_num_scheduled_tokens],
+ non_blocking=True)
+@@ -1386,6 +1411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ skip_cuda_graphs=skip_cuda_graphs,
+ ):
+ self.maybe_setup_kv_connector(scheduler_output)
++ self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata)
+
+ model_output = self.model(
+ input_ids=input_ids,
+@@ -1395,6 +1421,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ )
+
+ finished_dumping = self.maybe_wait_for_kv_save()
++ self.maybe_execute_ucm_sparse_finished()
++
+ finished_sending, finished_recving = (
+ self.get_finished_kv_transfers(scheduler_output))
+ invalid_block_ids = self.get_block_ids_with_load_errors()
+@@ -1741,10 +1769,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
+ kv_connector.start_load_kv(get_forward_context())
+
+ @staticmethod
+- def maybe_wait_for_kv_save():
++ def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]:
+ if has_kv_transfer_group():
+ return get_kv_transfer_group().wait_for_save()
+
++ def maybe_execute_ucm_sparse_begin(self, scheduler_output: "SchedulerOutput", attn_metadata: CommonAttentionMetadata):
++ if not has_ucm_sparse():
++ return
++ ucm_sparse = get_ucm_sparse()
++ ucm_sparse.build_sparse_meta(scheduler_output, self.requests, self.input_batch, attn_metadata)
++ ucm_sparse.execute_begin(scheduler_output)
++
++ def maybe_execute_ucm_sparse_finished(self):
++ if not has_ucm_sparse():
++ return
++ ucm_sparse = get_ucm_sparse()
++ ucm_sparse.execute_finished()
++
++ def ucm_sparse_request_finished_in_worker(self, request_id: str | int):
++ if not has_ucm_sparse():
++ return
++ ucm_sparse = get_ucm_sparse()
++ ucm_sparse.request_finished_in_worker(request_id)
++
+ @staticmethod
+ def get_finished_kv_transfers(
+ scheduler_output: "SchedulerOutput",
+diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
+index 1b816b25b..d9666d102 100644
+--- a/vllm/v1/worker/gpu_worker.py
++++ b/vllm/v1/worker/gpu_worker.py
+@@ -30,6 +30,7 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
+ from vllm.v1.utils import report_usage_stats
+ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
+ from vllm.v1.worker.worker_base import WorkerBase
++from ucm.sparse.state import ensure_ucm_sparse_initialized
+
+ logger = init_logger(__name__)
+
+@@ -401,6 +402,7 @@ def init_worker_distributed_environment(
+ parallel_config.pipeline_parallel_size)
+
+ ensure_kv_transfer_initialized(vllm_config)
++ ensure_ucm_sparse_initialized(vllm_config)
+
+
+ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
+--
+2.34.1
+
diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch
deleted file mode 100644
index da340e7ef..000000000
--- a/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch
+++ /dev/null
@@ -1,1159 +0,0 @@
-From 555ba9e4920445381aecda262b9146342e92eeee Mon Sep 17 00:00:00 2001
-From: hek14 <1023129548@qq.com>
-Date: Fri, 26 Sep 2025 09:51:36 +0800
-Subject: [PATCH] UCM adaptor
-
----
- vllm/attention/layer.py | 45 ++++-
- .../kv_transfer/kv_connector/utils.py | 113 ++++++++++++
- .../kv_transfer/kv_connector/v1/base.py | 9 +
- .../v1/shared_storage_connector.py | 7 +-
- vllm/v1/core/block_pool.py | 2 +-
- vllm/v1/core/kv_cache_manager.py | 11 +-
- vllm/v1/core/sched/output.py | 3 +
- vllm/v1/core/sched/scheduler.py | 164 +++++++++++++++++-
- vllm/v1/core/single_type_kv_cache_manager.py | 3 +
- vllm/v1/executor/multiproc_executor.py | 30 +++-
- vllm/v1/outputs.py | 5 +
- vllm/v1/request.py | 2 +-
- vllm/v1/worker/block_table.py | 13 ++
- vllm/v1/worker/gpu_input_batch.py | 9 +
- vllm/v1/worker/gpu_model_runner.py | 120 +++++++++++--
- vllm/v1/worker/gpu_worker.py | 25 ++-
- 16 files changed, 524 insertions(+), 37 deletions(-)
-
-diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
-index f0ad68b16..2acde35d8 100644
---- a/vllm/attention/layer.py
-+++ b/vllm/attention/layer.py
-@@ -2,7 +2,6 @@
- # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
- """Attention layer."""
- from typing import Any, Dict, List, Optional
--
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-@@ -22,6 +21,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
- from vllm.platforms import _Backend, current_platform
- from vllm.utils import direct_register_custom_op
- from vllm.v1.attention.backends.utils import validate_kv_sharing_target
-+from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
-
-
- class Attention(nn.Module):
-@@ -409,9 +409,10 @@ def unified_attention(
- attn_metadata = attn_metadata[layer_name]
- self = forward_context.no_compile_layers[layer_name]
- kv_cache = self.kv_cache[forward_context.virtual_engine]
-+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
- output = self.impl.forward(self, query, key, value, kv_cache,
- attn_metadata)
--
-+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
- maybe_save_kv_layer_to_connector(layer_name, kv_cache)
- return output
-
-@@ -449,6 +450,7 @@ def unified_attention_with_output(
- attn_metadata = attn_metadata[layer_name]
- self = forward_context.no_compile_layers[layer_name]
- kv_cache = self.kv_cache[forward_context.virtual_engine]
-+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
- self.impl.forward(self,
- query,
- key,
-@@ -457,7 +459,7 @@ def unified_attention_with_output(
- attn_metadata,
- output=output,
- output_scale=output_scale)
--
-+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
- maybe_save_kv_layer_to_connector(layer_name, kv_cache)
-
-
-@@ -479,3 +481,40 @@ direct_register_custom_op(
- fake_impl=unified_attention_with_output_fake,
- dispatch_key=current_platform.dispatch_key,
- )
-+
-+def maybe_execute_sparse_attention_begin(
-+ query: torch.Tensor,
-+ key: torch.Tensor,
-+ value: torch.Tensor,
-+ layer_name: str,
-+ forward_context: ForwardContext,
-+):
-+ if not has_ucm_sparse():
-+ return
-+
-+ ucm_sparse = get_ucm_sparse()
-+
-+ attn_metadata = forward_context.attn_metadata
-+ if attn_metadata is None:
-+ return
-+
-+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context)
-+
-+def maybe_execute_sparse_attention_finished(
-+ query: torch.Tensor,
-+ key: torch.Tensor,
-+ value: torch.Tensor,
-+ attn_output: torch.Tensor,
-+ layer_name: str,
-+ forward_context: ForwardContext,
-+):
-+ if not has_ucm_sparse():
-+ return
-+
-+ ucm_sparse = get_ucm_sparse()
-+
-+ attn_metadata = forward_context.attn_metadata
-+ if attn_metadata is None:
-+ return
-+
-+ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context)
-diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
-index 5cbc8ca31..8556a979e 100644
---- a/vllm/distributed/kv_transfer/kv_connector/utils.py
-+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py
-@@ -3,12 +3,18 @@
- """
- KV cache helper for store.
- """
-+from collections import defaultdict
-+from collections.abc import Sequence
-+from concurrent.futures import CancelledError, Future
-+from typing import Optional, cast
-+
- import torch
-
- import vllm.envs as envs
- from vllm import _custom_ops as ops
- from vllm.config import VllmConfig, get_current_vllm_config
- from vllm.logger import init_logger
-+from vllm.v1.outputs import ModelRunnerOutput
-
- logger = init_logger(__name__)
-
-@@ -107,3 +113,110 @@ def get_kv_connector_cache_layout():
- "layout to HND for better xfer performance.")
- return "HND"
- return "NHD"
-+
-+
-+class KVOutputAggregator:
-+ """Utility class to aggregate the output of all workers into a single
-+ output corresponding to Rank 0 for scheduler."""
-+
-+ def __init__(self, world_size: int):
-+ # Complete transfer tracker. Used by to track finished requests
-+ # [req_id -> n_finished_workers]
-+ self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
-+ self._send_remaining_count = defaultdict[str, int](lambda: world_size)
-+ self._dump_remaining_count = defaultdict[str, int](lambda: world_size)
-+
-+ def aggregate(self,
-+ outputs: list[ModelRunnerOutput],
-+ output_rank: int = 0) -> ModelRunnerOutput:
-+ # aggregate finished_sending, finished_recving from all workers
-+
-+ def update_finished_set(req_ids: Optional[set[str]],
-+ remaining_count_dict: dict[str, int],
-+ finished_set: set[str]) -> None:
-+ for req_id in req_ids or ():
-+ new_count = remaining_count_dict[req_id] - 1
-+ if new_count == 0:
-+ finished_set.add(req_id)
-+ del remaining_count_dict[req_id]
-+ else:
-+ remaining_count_dict[req_id] = new_count
-+
-+ def update_finished_list(req_ids: Optional[dict[str, list[str]]],
-+ remaining_count_dict: dict[str, int],
-+ finished_list: dict[str, list[str]]) -> None:
-+ for req_id, succeed_dump_blocks in (req_ids or {}).items():
-+ if req_id not in finished_list:
-+ finished_list[req_id] = []
-+ for blk_id in succeed_dump_blocks:
-+ new_count = remaining_count_dict[blk_id] - 1
-+ if new_count == 0:
-+ finished_list[req_id].append(blk_id)
-+ del remaining_count_dict[blk_id]
-+ else:
-+ remaining_count_dict[blk_id] = new_count
-+
-+ finished_sending = set[str]()
-+ finished_recving = set[str]()
-+ invalid_block_ids = set[int]()
-+ finished_dumping: dict[str, list[str]] = {}
-+ for output in outputs:
-+ update_finished_set(output.finished_sending,
-+ self._send_remaining_count, finished_sending)
-+ update_finished_set(output.finished_recving,
-+ self._recv_remaining_count, finished_recving)
-+ update_finished_list(output.finished_dumping,
-+ self._dump_remaining_count, finished_dumping)
-+ if output.invalid_block_ids:
-+ invalid_block_ids |= output.invalid_block_ids
-+
-+ # select output of the worker specified by output_rank
-+ output = outputs[output_rank]
-+
-+ # set the aggregated finished_sending / finished_recving
-+ # if output.finished_sending/recving is not empty, but the other ranks
-+ # still have unfinished send/recv, we want to set the aggregated
-+ # finished_sending/recving to None until all ranks have finished
-+ # send/recv
-+ output.finished_sending = finished_sending if finished_sending else None
-+ output.finished_recving = finished_recving if finished_recving else None
-+ output.finished_dumping = finished_dumping if finished_dumping else None
-+ output.invalid_block_ids = invalid_block_ids or None
-+
-+ return output
-+
-+ def async_aggregate(self,
-+ output_futures: Sequence[Future[ModelRunnerOutput]],
-+ output_rank: int = 0) -> Future[ModelRunnerOutput]:
-+ """Takes a list of futures and returns a single future which resolves
-+ to the respective list of outputs."""
-+ result_future: Future[ModelRunnerOutput] = Future()
-+
-+ outputs: list[Optional[ModelRunnerOutput]] = [None
-+ ] * len(output_futures)
-+
-+ def make_callback(idx):
-+
-+ def callback(fut):
-+ if result_future.done():
-+ return
-+
-+ try:
-+ outputs[idx] = fut.result()
-+ except CancelledError:
-+ result_future.cancel()
-+ except Exception as e:
-+ result_future.set_exception(e)
-+
-+ # this check assumes io_thread_pool uses a single thread
-+ if all(outputs):
-+ result_future.set_result(
-+ self.aggregate(cast(list[ModelRunnerOutput], outputs),
-+ output_rank))
-+
-+ return callback
-+
-+ for i, output_future in enumerate(output_futures):
-+ output_future.add_done_callback(make_callback(i))
-+
-+ return result_future
-diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
-index f80b5eba2..8891246e6 100644
---- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py
-+++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
-@@ -201,6 +201,15 @@ class KVConnectorBase_V1(ABC):
- """
- return None, None
-
-+ def get_block_ids_with_load_errors(self) -> Optional[set[int]]:
-+ """
-+ Get the set of block IDs that failed to load.
-+ Returns:
-+ Optional[set[int]]: A set of block IDs that encountered load errors.
-+ Returns None if no errors occurred during load.
-+ """
-+ return None
-+
- # ==============================
- # Scheduler-side methods
- # ==============================
-diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
-index 3c574d065..223106def 100644
---- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
-+++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
-@@ -2,7 +2,7 @@
- # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
- import hashlib
- import os
--from dataclasses import dataclass
-+from dataclasses import dataclass, field
- from typing import TYPE_CHECKING
-
- import safetensors
-@@ -53,10 +53,7 @@ class ReqMeta:
-
- @dataclass
- class SharedStorageConnectorMetadata(KVConnectorMetadata):
-- requests: list[ReqMeta]
--
-- def __init__(self):
-- self.requests = []
-+ requests: list[ReqMeta] = field(default_factory=list)
-
- def add_request(
- self,
-diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py
-index d21f94727..1800665c7 100644
---- a/vllm/v1/core/block_pool.py
-+++ b/vllm/v1/core/block_pool.py
-@@ -124,7 +124,7 @@ class BlockPool:
- kv_cache_group_id: The id of the KV cache group.
- hash_fn: The hash function to use for block hashes.
- """
-- if num_cached_blocks == num_full_blocks:
-+ if num_cached_blocks >= num_full_blocks:
- return
- new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
- assert len(block_hashes) >= num_cached_blocks
-diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py
-index 6937455e7..c36a25bc5 100644
---- a/vllm/v1/core/kv_cache_manager.py
-+++ b/vllm/v1/core/kv_cache_manager.py
-@@ -3,7 +3,7 @@
-
- from collections import defaultdict
- from dataclasses import dataclass
--from typing import Optional
-+from typing import Optional, Union
-
- from vllm.distributed.kv_events import KVCacheEvent
- from vllm.logger import init_logger
-@@ -14,6 +14,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
- from vllm.v1.kv_cache_interface import KVCacheConfig
- from vllm.v1.metrics.stats import PrefixCacheStats
- from vllm.v1.request import Request, RequestStatus
-+from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
-+from ucm.sparse.base import INVALID_SLOT
-
- logger = init_logger(__name__)
-
-@@ -193,6 +195,7 @@ class KVCacheManager:
- num_draft_tokens: int = 0,
- num_lookahead_tokens: int = 0,
- delay_cache_blocks: bool = False,
-+ num_slots_sparsed: Union[None, int] = None
- ) -> Optional[KVCacheBlocks]:
- """Add slots for a request with new tokens to append.
-
-@@ -231,6 +234,12 @@ class KVCacheManager:
- """
- if num_new_tokens == 0:
- raise ValueError("num_new_tokens must be greater than 0")
-+ if num_slots_sparsed != INVALID_SLOT:
-+ return get_ucm_sparse().allocate_slots(request,
-+ num_slots_sparsed,
-+ self.coordinator,
-+ self.block_pool,
-+ self.kv_cache_config.kv_cache_groups)
-
- if new_computed_blocks is not None:
- new_computed_block_list = new_computed_blocks.blocks
-diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py
-index d34f39327..141d750b3 100644
---- a/vllm/v1/core/sched/output.py
-+++ b/vllm/v1/core/sched/output.py
-@@ -155,3 +155,6 @@ class SchedulerOutput:
-
- # KV Cache Connector metadata.
- kv_connector_metadata: Optional[KVConnectorMetadata] = None
-+
-+ # modified slots by sparse algorithm
-+ req_sparsed_slots: dict[str, int] = None
-diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
-index fe552db74..6a9d4b4b9 100644
---- a/vllm/v1/core/sched/scheduler.py
-+++ b/vllm/v1/core/sched/scheduler.py
-@@ -34,6 +34,8 @@ from vllm.v1.outputs import ModelRunnerOutput
- from vllm.v1.request import Request, RequestStatus
- from vllm.v1.spec_decode.metrics import SpecDecodingStats
- from vllm.v1.structured_output import StructuredOutputManager
-+from ucm.sparse.state import ensure_ucm_sparse_initialized, get_ucm_sparse, has_ucm_sparse
-+from ucm.sparse.base import UcmSparseBase, UcmSparseRole, INVALID_SLOT
-
- logger = init_logger(__name__)
-
-@@ -79,12 +81,18 @@ class Scheduler(SchedulerInterface):
- # will have a corresponding KVConnector with Role=WORKER.
- # KV Connector pushes/pull of remote KVs for P/D and offloading.
- self.connector = None
-+ self.ucm_sparse = None
- if self.vllm_config.kv_transfer_config is not None:
- assert len(self.kv_cache_config.kv_cache_groups) == 1, (
- "Multiple KV cache groups are not currently supported "
- "with KV connectors")
- self.connector = KVConnectorFactory.create_connector_v1(
- config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
-+ # Initialize UCM Sparse if available
-+ if "ucm_sparse_config" in vllm_config.kv_transfer_config.kv_connector_extra_config:
-+ ensure_ucm_sparse_initialized(vllm_config, role=UcmSparseRole.SCHEDULER)
-+ self.ucm_sparse = get_ucm_sparse()
-+ logger.info("UCM Sparse initialized successfully: {}".format(self.ucm_sparse))
-
- self.kv_event_publisher = EventPublisherFactory.create(
- self.kv_events_config,
-@@ -201,8 +209,13 @@ class Scheduler(SchedulerInterface):
-
- # First, schedule the RUNNING requests.
- req_index = 0
-+ req_sparsed_slots: dict[str, int] = {}
- while req_index < len(self.running) and token_budget > 0:
- request = self.running[req_index]
-+ num_slots_sparsed = INVALID_SLOT
-+ if self.ucm_sparse:
-+ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request)
-+ req_sparsed_slots.update({request.request_id: num_slots_sparsed})
-
- num_new_tokens = (request.num_tokens_with_spec -
- request.num_computed_tokens)
-@@ -250,7 +263,8 @@ class Scheduler(SchedulerInterface):
- request,
- num_new_tokens,
- num_draft_tokens=num_draft_tokens,
-- num_lookahead_tokens=self.num_lookahead_tokens)
-+ num_lookahead_tokens=self.num_lookahead_tokens,
-+ num_slots_sparsed=num_slots_sparsed)
- if new_blocks is None:
- # The request cannot be scheduled.
- # Preempt the lowest-priority request.
-@@ -337,6 +351,10 @@ class Scheduler(SchedulerInterface):
- break
-
- request = self.waiting.peek_request()
-+ num_slots_sparsed = INVALID_SLOT
-+ if self.ucm_sparse:
-+ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request)
-+ req_sparsed_slots.update({request.request_id: num_slots_sparsed})
-
- # KVTransfer: skip request if still waiting for remote kvs.
- if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
-@@ -446,6 +464,7 @@ class Scheduler(SchedulerInterface):
- new_computed_blocks,
- num_lookahead_tokens=self.num_lookahead_tokens,
- delay_cache_blocks=load_kv_async,
-+ num_slots_sparsed=num_slots_sparsed
- )
- if new_blocks is None:
- # The request cannot be scheduled.
-@@ -559,6 +578,7 @@ class Scheduler(SchedulerInterface):
- scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
- scheduled_encoder_inputs=scheduled_encoder_inputs,
- num_common_prefix_blocks=num_common_prefix_blocks,
-+ req_sparsed_slots=req_sparsed_slots,
- # finished_req_ids is an existing state in the scheduler,
- # instead of being newly scheduled in this step.
- # It contains the request IDs that are finished in between
-@@ -745,16 +765,31 @@ class Scheduler(SchedulerInterface):
- num_scheduled_tokens = scheduler_output.num_scheduled_tokens
- pooler_outputs = model_runner_output.pooler_output
- num_nans_in_logits = model_runner_output.num_nans_in_logits
-+ invalid_block_ids = model_runner_output.invalid_block_ids
-
- new_running: list[Request] = []
- outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
- spec_decoding_stats: Optional[SpecDecodingStats] = None
-
-+ recovered_req_ids = None
-+ if invalid_block_ids:
-+ # These blocks contain externally computed tokens that failed to
-+ # load. Identify affected requests and adjust their computed token
-+ # count to trigger recomputation of the invalid blocks.
-+ recovered_req_ids = self._handle_invalid_blocks(invalid_block_ids)
-+
- # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
- # loop can be a performance bottleneck. We should do our best to avoid
- # expensive operations inside the loop.
- for request in self.running:
- req_id = request.request_id
-+ # self.req_meta.stage == SequenceStage.PREFILL and self.req_meta.is_last_chunk
-+
-+
-+ if recovered_req_ids and req_id in recovered_req_ids:
-+ # Skip requests that were recovered from KV load failure
-+ new_running.append(request)
-+ continue
- num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
- if num_tokens_scheduled == 0:
- # The request was not scheduled in this step.
-@@ -792,6 +827,13 @@ class Scheduler(SchedulerInterface):
- new_token_ids = generated_token_ids
- kv_transfer_params = None
-
-+ if model_runner_output.finished_dumping is not None:
-+ request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, []))
-+ is_prefill = request.num_output_tokens == 0
-+ is_last_chunk = (num_tokens_scheduled + request.num_computed_tokens >= request.num_prompt_tokens)
-+ if is_prefill and is_last_chunk:
-+ self.connector.connector.commit(request.succeed_dumped_blocks, True)
-+
- # Append generated tokens and check for stop. Note that if
- # a request is still being prefilled, we expect the model runner
- # to return empty token ids for the request.
-@@ -842,7 +884,6 @@ class Scheduler(SchedulerInterface):
- spec_token_ids[req_index])
- else:
- request.spec_token_ids = spec_token_ids[req_index]
--
- # Get prompt logprobs for this request.
- prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
- if new_token_ids or pooler_output is not None \
-@@ -869,6 +910,7 @@ class Scheduler(SchedulerInterface):
-
- if not stopped:
- new_running.append(request)
-+
- self.running = new_running
-
- # KV Connector: update state for finished KV Transfers.
-@@ -927,6 +969,8 @@ class Scheduler(SchedulerInterface):
- def add_request(self, request: Request) -> None:
- self.waiting.add_request(request)
- self.requests[request.request_id] = request
-+ if self.ucm_sparse:
-+ self.ucm_sparse.request_begin(request.request_id, request.prompt_token_ids)
- if self.log_stats:
- request.record_event(EngineCoreEventType.QUEUED)
-
-@@ -976,6 +1020,8 @@ class Scheduler(SchedulerInterface):
-
- def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
- assert request.is_finished()
-+ if self.ucm_sparse:
-+ self.ucm_sparse.request_finished_in_scheduler(request.request_id)
-
- delay_free_blocks, kv_xfer_params = self._connector_finished(request)
- self.encoder_cache_manager.free(request)
-@@ -1113,3 +1159,117 @@ class Scheduler(SchedulerInterface):
- for req_id in (model_runner_output.finished_sending or ()):
- logger.debug("Finished sending KV transfer for request %s", req_id)
- self._free_blocks(self.requests[req_id])
-+
-+ def _update_requests_with_invalid_blocks(
-+ self, requests: Iterable[Request],
-+ invalid_block_ids: set[int]) -> tuple[set[Request], int, set[int]]:
-+ affected_requests: set[Request] = set()
-+ num_tokens_to_reschedule = 0
-+ # If a block is invalid and shared by multiple requests in the batch,
-+ # all requests must be rescheduled, but only the first will recompute
-+ # it. This set tracks blocks already marked for recomputation.
-+ marked_invalid_block_ids: set[int] = set()
-+ for request in requests:
-+ is_affected = False
-+ marked_invalid_block = False
-+ req_id = request.request_id
-+ req_block_ids = self.kv_cache_manager.get_block_ids(req_id)[0]
-+ # We iterate only over blocks that may contain externally computed
-+ # tokens
-+ if request.num_cached_tokens > 0:
-+ req_num_computed_blocks = (request.num_cached_tokens +
-+ self.block_size -
-+ 1) // self.block_size
-+ else:
-+ req_num_computed_blocks = len(req_block_ids)
-+
-+ for idx, block_id in zip(range(req_num_computed_blocks),
-+ req_block_ids):
-+
-+ if block_id not in invalid_block_ids:
-+ continue
-+
-+ is_affected = True
-+
-+ if block_id in marked_invalid_block_ids:
-+ # This invalid block is shared with a previous request
-+ # and was already marked for recomputation.
-+ # This means this request can still consider this block
-+ # as computed when rescheduled.
-+ continue
-+
-+ marked_invalid_block_ids.add(block_id)
-+
-+ if marked_invalid_block:
-+ # This request has already marked an invalid block for
-+ # recomputation and updated its num_computed_tokens.
-+ continue
-+
-+ marked_invalid_block = True
-+ num_tokens_to_reschedule += request.num_computed_tokens
-+ request.num_computed_tokens = idx * self.block_size
-+ num_tokens_to_reschedule -= request.num_computed_tokens
-+
-+ if is_affected:
-+ if not marked_invalid_block:
-+ # All invalid blocks of this request are shared with
-+ # previous requests and will be recomputed by them.
-+ # Revert to considering only cached tokens as computed.
-+ num_tokens_to_reschedule += (request.num_computed_tokens -
-+ request.num_cached_tokens)
-+ request.num_computed_tokens = request.num_cached_tokens
-+
-+ affected_requests.add(request)
-+
-+ return (affected_requests, num_tokens_to_reschedule,
-+ marked_invalid_block_ids)
-+
-+ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
-+ total_requests_to_reschedule = 0
-+ total_tokens_to_reschedule = 0
-+
-+ # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) ---
-+ async_load_reqs = (
-+ req for req in self.waiting
-+ if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
-+ (affected_requests, num_tokens_to_reschedule,
-+ marked_invalid_block_ids) = (
-+ self._update_requests_with_invalid_blocks(async_load_reqs,
-+ invalid_block_ids))
-+
-+ total_requests_to_reschedule += len(affected_requests)
-+ total_tokens_to_reschedule += num_tokens_to_reschedule
-+
-+ for request in affected_requests:
-+ if request.num_computed_tokens:
-+ # Cache any valid computed tokens.
-+ self.kv_cache_manager.cache_blocks(request,
-+ request.num_computed_tokens)
-+ else:
-+ # No valid computed tokens, release allocated blocks.
-+ # There may be a local cache hit on retry.
-+ self.kv_cache_manager.free(request)
-+
-+ request.status = RequestStatus.WAITING
-+
-+ # Remove async loaded invalid blocks already handled,
-+ # as they cannot be shared with running requests.
-+ invalid_block_ids.difference_update(marked_invalid_block_ids)
-+
-+ # --- Handle sync KV loads (running requests) ---
-+ affected_requests, num_tokens_to_reschedule, _ = (
-+ self._update_requests_with_invalid_blocks(self.running,
-+ invalid_block_ids))
-+
-+ total_requests_to_reschedule += len(affected_requests)
-+ total_tokens_to_reschedule += num_tokens_to_reschedule
-+
-+ if total_requests_to_reschedule:
-+ logger.info(
-+ "Recovered from KV load failure: "
-+ "%d request(s) rescheduled (%d tokens affected).",
-+ total_requests_to_reschedule, total_tokens_to_reschedule)
-+
-+ # Return the IDs of affected running requests to skip in
-+ # update_from_output.
-+ return {r.request_id for r in affected_requests}
-diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py
-index 5b4718038..28bd4618a 100644
---- a/vllm/v1/core/single_type_kv_cache_manager.py
-+++ b/vllm/v1/core/single_type_kv_cache_manager.py
-@@ -142,6 +142,9 @@ class SingleTypeKVCacheManager(ABC):
- num_cached_blocks = self.num_cached_block[request.request_id]
- num_full_blocks = num_tokens // self.block_size
-
-+ if num_cached_blocks >= num_full_blocks:
-+ return
-+
- self.block_pool.cache_full_blocks(
- request=request,
- blocks=self.req_to_blocks[request.request_id],
-diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py
-index b06b7cc80..61cd7110f 100644
---- a/vllm/v1/executor/multiproc_executor.py
-+++ b/vllm/v1/executor/multiproc_executor.py
-@@ -26,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment,
- destroy_model_parallel)
- from vllm.distributed.device_communicators.shm_broadcast import (Handle,
- MessageQueue)
-+from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
- from vllm.executor.multiproc_worker_utils import (
- _add_prefix, set_multiprocessing_worker_envs)
- from vllm.logger import init_logger
-@@ -111,10 +112,14 @@ class MultiprocExecutor(Executor):
- if self.max_concurrent_batches > 1:
- # Note: must use only 1 IO thread to keep dequeue sequence
- # from the response queue
-+ # _async_aggregate_workers_output also assumes a single IO thread
- self.io_thread_pool = ThreadPoolExecutor(
- max_workers=1, thread_name_prefix="mp_exec_io")
-
- self.output_rank = self._get_output_rank()
-+ self.has_connector = self.vllm_config.kv_transfer_config is not None
-+ self.kv_output_aggregator = KVOutputAggregator(
-+ self.parallel_config.world_size)
-
- def start_worker_monitor(self):
- workers = self.workers
-@@ -155,13 +160,30 @@ class MultiprocExecutor(Executor):
- self,
- scheduler_output,
- ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
-- (output, ) = self.collective_rpc(
-+ non_block = self.max_concurrent_batches > 1
-+
-+ if not self.has_connector or self.vllm_config.model_config.use_mla:
-+ # get output only from a single worker (output_rank)
-+ (output, ) = self.collective_rpc(
-+ "execute_model",
-+ args=(scheduler_output, ),
-+ unique_reply_rank=self.output_rank,
-+ non_block=non_block,
-+ timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
-+ return output
-+
-+ # get output from all workers
-+ outputs = self.collective_rpc(
- "execute_model",
- args=(scheduler_output, ),
-- unique_reply_rank=self.output_rank,
-- non_block=self.max_concurrent_batches > 1,
-+ non_block=non_block,
- timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
-- return output
-+
-+ # aggregate all workers output to a single output
-+ if non_block:
-+ return self.kv_output_aggregator.async_aggregate(
-+ outputs, self.output_rank)
-+ return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
-
- def collective_rpc(self,
- method: Union[str, Callable],
-diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py
-index f78623f57..c7b4100e3 100644
---- a/vllm/v1/outputs.py
-+++ b/vllm/v1/outputs.py
-@@ -107,6 +107,11 @@ class ModelRunnerOutput:
- # [req_ids]
- finished_sending: Optional[set[str]] = None
- finished_recving: Optional[set[str]] = None
-+ finished_dumping: Optional[dict[str, list[str]]] = None
-+
-+ # IDs of externally computed KV blocks that failed to load.
-+ # Requests referencing these blocks should be rescheduled to recompute them.
-+ invalid_block_ids: Optional[set[int]] = None
-
- # req_id -> num_nans_in_logits
- num_nans_in_logits: Optional[dict[str, int]] = None
-diff --git a/vllm/v1/request.py b/vllm/v1/request.py
-index 9b96f4599..825b77bba 100644
---- a/vllm/v1/request.py
-+++ b/vllm/v1/request.py
-@@ -102,7 +102,7 @@ class Request:
- # State
- # The number of tokens with prefix cache hits.
- self.num_cached_tokens = -1
--
-+ self.succeed_dumped_blocks: list[str] = []
- # The number of NaNs in logits. A value greater than 0
- # indicates that the output is corrupted
- self.num_nans_in_logits = 0
-diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py
-index 8f4e8d64c..f45e39f5c 100644
---- a/vllm/v1/worker/block_table.py
-+++ b/vllm/v1/worker/block_table.py
-@@ -61,6 +61,15 @@ class BlockTable:
- self.num_blocks_per_row[row_idx] += num_blocks
- self.block_table_np[row_idx, start:start + num_blocks] = block_ids
-
-+ def reset_row(
-+ self,
-+ row_idx: int,
-+ ) -> None:
-+ self.num_blocks_per_row[row_idx] = 0
-+ self.block_table[row_idx].fill_(0)
-+ self.block_table_cpu[row_idx].fill_(0)
-+ self.block_table_np[row_idx].fill(0)
-+
- def add_row(self, block_ids: list[int], row_idx: int) -> None:
- self.num_blocks_per_row[row_idx] = 0
- self.append_row(block_ids, row_idx)
-@@ -117,6 +126,10 @@ class MultiGroupBlockTable:
- for i, block_table in enumerate(self.block_tables):
- block_table.append_row(block_ids[i], row_idx)
-
-+ def reset_row(self, row_idx: int) -> None:
-+ for i, block_table in enumerate(self.block_tables):
-+ block_table.reset_row(row_idx)
-+
- def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
- for i, block_table in enumerate(self.block_tables):
- block_table.add_row(block_ids[i], row_idx)
-diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
-index 1a79d72be..0e65c98f5 100644
---- a/vllm/v1/worker/gpu_input_batch.py
-+++ b/vllm/v1/worker/gpu_input_batch.py
-@@ -46,6 +46,11 @@ class CachedRequestState:
-
- def __post_init__(self):
- self.num_prompt_tokens = len(self.prompt_token_ids)
-+ # 'last_generator_offset' and 'last_gelen_last_output_token_ids' are
-+ # used to allow safe rollback in case a sampled token turns out to be
-+ # invalid (e.g., due to KV load errors).
-+ self.last_generator_offset = 0 if self.generator else None
-+ self.len_last_output_token_ids = len(self.output_token_ids)
-
- @property
- def num_tokens(self) -> int:
-@@ -201,6 +206,7 @@ class InputBatch:
- # NOTE(woosuk): The indices of the requests that do not have their own
- # generator should not be included in the dictionary.
- self.generators: dict[int, torch.Generator] = {}
-+ self.generators_last_offset: dict[int, int] = {}
-
- self.num_logprobs: dict[str, int] = {}
- # NOTE(rob): num_prompt_logprobs only includes reqs
-@@ -335,6 +341,9 @@ class InputBatch:
- # do not have their own generator.
- if request.generator is not None:
- self.generators[req_index] = request.generator
-+ assert (request.last_generator_offset is not None)
-+ self.generators_last_offset[
-+ req_index] = request.last_generator_offset
-
- if sampling_params.logprobs is not None:
- self.num_logprobs[req_id] = sampling_params.logprobs
-diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
-index 5a26e88db..17b3d1c79 100644
---- a/vllm/v1/worker/gpu_model_runner.py
-+++ b/vllm/v1/worker/gpu_model_runner.py
-@@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager
- from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
- sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
-
-+from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
-+from ucm.sparse.base import UcmSparseMetadata, INVALID_SLOT
-+
- if TYPE_CHECKING:
- import xgrammar as xgr
- import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501
-@@ -365,6 +368,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- """
- # Remove finished requests from the cached states.
- for req_id in scheduler_output.finished_req_ids:
-+ self.ucm_sparse_request_finished_in_worker(req_id)
- self.requests.pop(req_id, None)
- self.encoder_cache.pop(req_id, None)
- # Remove the finished requests from the persistent batch.
-@@ -468,13 +472,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- # Update the states of the running/resumed requests.
- is_last_rank = get_pp_group().is_last_rank
- req_data = scheduler_output.scheduled_cached_reqs
-+ req_sparsed_slots = scheduler_output.req_sparsed_slots
- for i, req_id in enumerate(req_data.req_ids):
- req_state = self.requests[req_id]
- num_computed_tokens = req_data.num_computed_tokens[i]
- new_block_ids = req_data.new_block_ids[i]
- resumed_from_preemption = req_data.resumed_from_preemption[i]
-+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT
-
- # Update the cached states.
-+ if (num_computed_tokens <= req_state.num_computed_tokens):
-+ # The request was rescheduled after a KV load failure. Clear
-+ # the last sampled tokens and rewind the generator state
-+ len_output_token_ids = len(req_state.output_token_ids)
-+ del req_state.output_token_ids[req_state.
-+ len_last_output_token_ids:]
-+ if req_state.generator:
-+ req_state.generator.set_offset(
-+ req_state.last_generator_offset)
-+ req_index = self.input_batch.req_id_to_index.get(req_id)
-+ if req_index is not None:
-+ len_last_sampled = (len_output_token_ids -
-+ req_state.len_last_output_token_ids)
-+ end_idx = self.input_batch.num_tokens_no_spec[
-+ req_index] - len_last_sampled
-+ self.input_batch.num_tokens[req_index] = end_idx
-+ self.input_batch.num_tokens_no_spec[req_index] = end_idx
-+
- req_state.num_computed_tokens = num_computed_tokens
-
- if not is_last_rank:
-@@ -493,16 +517,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- req_state.output_token_ids.extend(
- new_token_ids[-num_new_tokens:])
-
-+ req_state.len_last_output_token_ids = len(
-+ req_state.output_token_ids)
-+ if req_state.generator:
-+ req_state.last_generator_offset = (
-+ req_state.generator.get_offset())
-+
- # Update the block IDs.
-- if not resumed_from_preemption:
-- # Append the new blocks to the existing block IDs.
-- for block_ids, new_ids in zip(req_state.block_ids,
-- new_block_ids):
-- block_ids.extend(new_ids)
-- else:
-+ if resumed_from_preemption or is_sparsed_request:
- # The request is resumed from preemption.
- # Replace the existing block IDs with the new ones.
- req_state.block_ids = new_block_ids
-+ else:
-+ # Append the new blocks to the existing block IDs.
-+ for block_ids, new_ids in zip(req_state.block_ids,
-+ new_block_ids):
-+ block_ids.extend(new_ids)
-
- req_index = self.input_batch.req_id_to_index.get(req_id)
- if req_index is None:
-@@ -512,9 +542,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- req_ids_to_add.append(req_id)
- continue
-
-+ if req_state.generator:
-+ assert (req_state.last_generator_offset is not None)
-+ self.input_batch.generators_last_offset[
-+ req_index] = req_state.last_generator_offset
-+
- # Update the persistent batch.
- self.input_batch.num_computed_tokens_cpu[req_index] = (
- num_computed_tokens)
-+ if is_sparsed_request:
-+ self.input_batch.block_table.reset_row(req_index)
- self.input_batch.block_table.append_row(new_block_ids, req_index)
-
- # For the last rank, we don't need to update the token_ids_cpu
-@@ -623,6 +660,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- if self.uses_mrope:
- self._calc_mrope_positions(scheduler_output)
-
-+ self.seq_lens_np[:num_reqs] = (
-+ self.input_batch.num_computed_tokens_cpu[:num_reqs] +
-+ num_scheduled_tokens)
-+
-+ # TODO: improve performance, no `positions_np.copy()`
-+ sparsed_positions = positions_np.copy()
-+ req_sparsed_slots = scheduler_output.req_sparsed_slots
-+ for req_id in self.input_batch.req_id_to_index:
-+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT
-+ req_index = self.input_batch.req_id_to_index[req_id]
-+ offset = 0 if req_index == 0 else cu_num_tokens[req_index - 1] # TODO: support MTP
-+ if is_sparsed_request:
-+ sparsed_positions[offset] = req_sparsed_slots[req_id] - 1
- # Get token indices.
- # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
- # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
-@@ -652,11 +702,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- # block_size.
- block_table_indices = (
- req_indices * block_table.max_num_blocks_per_req +
-- positions_np // block_size)
-+ sparsed_positions // block_size)
- block_table_cpu = block_table.get_cpu_tensor()
- block_numbers = block_table_cpu.flatten(
- )[block_table_indices].numpy()
-- block_offsets = positions_np % block_size
-+ block_offsets = sparsed_positions % block_size
- np.add(
- block_numbers * block_size,
- block_offsets,
-@@ -666,9 +716,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- self.query_start_loc_np[0] = 0
- self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
-
-- self.seq_lens_np[:num_reqs] = (
-- self.input_batch.num_computed_tokens_cpu[:num_reqs] +
-- num_scheduled_tokens)
-+ for req_id in self.input_batch.req_id_to_index:
-+ req_index = self.input_batch.req_id_to_index[req_id]
-+ is_sparsed_request = scheduler_output.req_sparsed_slots[req_id] != INVALID_SLOT
-+ if is_sparsed_request:
-+ self.seq_lens_np[req_index] = scheduler_output.req_sparsed_slots[req_id]
-
- # Copy the tensors to the GPU.
- self.input_ids[:total_num_scheduled_tokens].copy_(
-@@ -680,6 +732,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- non_blocking=True)
- else:
- # Common case (1D positions)
-+ self.positions_cpu[:total_num_scheduled_tokens] = torch.from_numpy(
-+ positions_np[:total_num_scheduled_tokens])
- self.positions[:total_num_scheduled_tokens].copy_(
- self.positions_cpu[:total_num_scheduled_tokens],
- non_blocking=True)
-@@ -1370,6 +1424,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- skip_cuda_graphs=skip_cuda_graphs,
- ):
- self.maybe_setup_kv_connector(scheduler_output)
-+ self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata)
-
- model_output = self.model(
- input_ids=input_ids,
-@@ -1378,9 +1433,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- inputs_embeds=inputs_embeds,
- )
-
-- self.maybe_wait_for_kv_save()
-+ finished_dumping = self.maybe_wait_for_kv_save()
-+ self.maybe_execute_ucm_sparse_finished()
-+
- finished_sending, finished_recving = (
- self.get_finished_kv_transfers(scheduler_output))
-+ invalid_block_ids = self.get_block_ids_with_load_errors()
-
- if self.use_aux_hidden_state_outputs:
- hidden_states, aux_hidden_states = model_output
-@@ -1474,7 +1532,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- # This relies on cuda-specific torch-internal impl details
- generator = self.input_batch.generators.get(i)
- if generator is not None:
-- generator.set_offset(generator.get_offset() - 4)
-+ generator.set_offset(
-+ self.input_batch.generators_last_offset.get(i))
- # Record the index of the request that should not be sampled,
- # so that we could clear the sampled tokens before returning.
- discard_sampled_tokens_req_indices.append(i)
-@@ -1563,6 +1622,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- finished_sending=finished_sending,
- finished_recving=finished_recving,
- num_nans_in_logits=num_nans_in_logits,
-+ finished_dumping=finished_dumping,
-+ invalid_block_ids = invalid_block_ids
- )
-
- def propose_draft_token_ids(
-@@ -1693,13 +1754,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- self.maybe_setup_kv_connector(scheduler_output)
- finished_sending, finished_recving = (
- self.get_finished_kv_transfers(scheduler_output))
-+ invalid_block_ids = self.get_block_ids_with_load_errors()
-+ get_kv_transfer_group().clear_connector_metadata()
-
-- if not finished_sending and not finished_recving:
-+ if not finished_sending and not finished_recving and not invalid_block_ids:
- return EMPTY_MODEL_RUNNER_OUTPUT
-
- output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
- output.finished_sending = finished_sending
- output.finished_recving = finished_recving
-+ output.invalid_block_ids = invalid_block_ids
- return output
-
- @staticmethod
-@@ -1719,9 +1783,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- kv_connector.start_load_kv(get_forward_context())
-
- @staticmethod
-- def maybe_wait_for_kv_save() -> None:
-+ def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]:
- if has_kv_transfer_group():
-- get_kv_transfer_group().wait_for_save()
-+ return get_kv_transfer_group().wait_for_save()
-+
-+ def maybe_execute_ucm_sparse_begin(self, scheduler_output: "SchedulerOutput", attn_metadata: CommonAttentionMetadata):
-+ if not has_ucm_sparse():
-+ return
-+ ucm_sparse = get_ucm_sparse()
-+ ucm_sparse.build_sparse_meta(scheduler_output, self.requests, self.input_batch, attn_metadata)
-+ ucm_sparse.execute_begin(scheduler_output)
-+
-+ def maybe_execute_ucm_sparse_finished(self):
-+ if not has_ucm_sparse():
-+ return
-+ ucm_sparse = get_ucm_sparse()
-+ ucm_sparse.execute_finished()
-+
-+ def ucm_sparse_request_finished_in_worker(self, request_id: str | int):
-+ if not has_ucm_sparse():
-+ return
-+ ucm_sparse = get_ucm_sparse()
-+ ucm_sparse.request_finished_in_worker(request_id)
-
- @staticmethod
- def get_finished_kv_transfers(
-@@ -1732,6 +1815,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
- scheduler_output.finished_req_ids)
- return None, None
-
-+ def get_block_ids_with_load_errors(self) -> Optional[set[int]]:
-+ if has_kv_transfer_group():
-+ return get_kv_transfer_group().get_block_ids_with_load_errors()
-+ return None
-+
- def propose_ngram_draft_token_ids(
- self,
- sampled_token_ids: list[list[int]],
-diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
-index 9e7e44d06..d52a49a2e 100644
---- a/vllm/v1/worker/gpu_worker.py
-+++ b/vllm/v1/worker/gpu_worker.py
-@@ -1,6 +1,7 @@
- # SPDX-License-Identifier: Apache-2.0
- # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
- """A GPU worker class."""
-+import copy
- import gc
- import os
- from typing import TYPE_CHECKING, Optional
-@@ -15,7 +16,8 @@ from vllm.device_allocator.cumem import CuMemAllocator
- from vllm.distributed import (ensure_model_parallel_initialized,
- init_distributed_environment,
- set_custom_all_reduce)
--from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
-+from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
-+ has_kv_transfer_group)
- from vllm.distributed.parallel_state import get_pp_group, get_tp_group
- from vllm.logger import init_logger
- from vllm.lora.request import LoRARequest
-@@ -24,10 +26,11 @@ from vllm.platforms import current_platform
- from vllm.sequence import IntermediateTensors
- from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
- from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
--from vllm.v1.outputs import ModelRunnerOutput
-+from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
- from vllm.v1.utils import report_usage_stats
- from vllm.v1.worker.gpu_model_runner import GPUModelRunner
- from vllm.v1.worker.worker_base import WorkerBase
-+from ucm.sparse.state import ensure_ucm_sparse_initialized
-
- logger = init_logger(__name__)
-
-@@ -313,9 +316,22 @@ class Worker(WorkerBase):
- assert isinstance(output, IntermediateTensors)
- get_pp_group().send_tensor_dict(output.tensors,
- all_gather_group=get_tp_group())
-- return None
-+ if not has_kv_transfer_group():
-+ return None
-+
-+ # In case of PP with kv transfer, we need to pass through the
-+ # finished_sending and finished_recving buffers.
-+ new_output = EMPTY_MODEL_RUNNER_OUTPUT
-+ if output.finished_sending or output.finished_recving or output.finished_dumping or output.invalid_block_ids:
-+ new_output = copy.copy(new_output)
-+ new_output.finished_sending = output.finished_sending
-+ new_output.finished_recving = output.finished_recving
-+ new_output.finished_dumping = output.finished_dumping
-+ new_output.invalid_block_ids = output.invalid_block_ids
-+ output = new_output
-+
- assert isinstance(output, ModelRunnerOutput)
-- return output if self.is_driver_worker else None
-+ return output
-
- def profile(self, is_start: bool = True):
- if self.profiler is None:
-@@ -386,6 +402,7 @@ def init_worker_distributed_environment(
- parallel_config.pipeline_parallel_size)
-
- ensure_kv_transfer_initialized(vllm_config)
-+ ensure_ucm_sparse_initialized(vllm_config)
-
-
- def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
---
-2.50.1.windows.1
-diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
-index 6a9d4b4b9..ae06cf9eb 100644
---- a/vllm/v1/core/sched/scheduler.py
-+++ b/vllm/v1/core/sched/scheduler.py
-@@ -830,9 +830,8 @@ class Scheduler(SchedulerInterface):
- if model_runner_output.finished_dumping is not None:
- request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, []))
- is_prefill = request.num_output_tokens == 0
-- is_last_chunk = (num_tokens_scheduled + request.num_computed_tokens >= request.num_prompt_tokens)
-- if is_prefill and is_last_chunk:
-- self.connector.connector.commit(request.succeed_dumped_blocks, True)
-+ if is_prefill:
-+ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
-
- # Append generated tokens and check for stop. Note that if
- # a request is still being prefilled, we expect the model runner
diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch b/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch
index e15f7ab52..8c459aa7f 100644
--- a/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch
+++ b/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch
@@ -1,16 +1,17 @@
-From 67b10fc431e5aac0155ca5b77cd9a99e35656521 Mon Sep 17 00:00:00 2001
+From 73de421dd3a9d3877b8903b8ee419e692da62b29 Mon Sep 17 00:00:00 2001
From: wenxinwang
-Date: Thu, 25 Sep 2025 05:31:48 -0700
-Subject: [PATCH] UCM adaptor
+Date: Mon, 10 Nov 2025 20:44:02 +0800
+Subject: [PATCH] adapt to deepseek
---
- vllm_ascend/attention/attention_v1.py | 75 ++++++++++++++++++++
+ vllm_ascend/attention/attention_v1.py | 76 ++++++++++++++++++++
+ vllm_ascend/attention/mla_v1.py | 14 +++-
vllm_ascend/worker/model_runner_v1.py | 99 +++++++++++++++++++++++----
vllm_ascend/worker/worker_v1.py | 25 +++++--
- 3 files changed, 183 insertions(+), 16 deletions(-)
+ 4 files changed, 196 insertions(+), 18 deletions(-)
diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py
-index 7d7f488..09c4345 100644
+index 7d7f488f..18039f42 100644
--- a/vllm_ascend/attention/attention_v1.py
+++ b/vllm_ascend/attention/attention_v1.py
@@ -24,6 +24,9 @@ import torch_npu
@@ -26,10 +27,10 @@ index 7d7f488..09c4345 100644
@@ -33,6 +36,8 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec)
-
+
+from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
+
-
+
class AscendAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@@ -444,10 +449,14 @@ def unified_ascend_attention_with_output(
@@ -42,19 +43,20 @@ index 7d7f488..09c4345 100644
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
-+
-+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
++ if not self.use_mla:
++ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
self.impl.forward(self,
query,
key,
-@@ -456,8 +465,74 @@ def unified_ascend_attention_with_output(
+@@ -456,8 +465,75 @@ def unified_ascend_attention_with_output(
attn_metadata,
output,
trace_flag=False)
-+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
++ if not self.use_mla:
++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
+ maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return
-
+
+def wait_for_kv_layer_from_connector(layer_name: str):
+ if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
+ return
@@ -119,11 +121,67 @@ index 7d7f488..09c4345 100644
+ return
+
+ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context)
-
+
def unified_attention_with_output_fake(
query: torch.Tensor,
+diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py
+index f50fe56e..4a27c22f 100644
+--- a/vllm_ascend/attention/mla_v1.py
++++ b/vllm_ascend/attention/mla_v1.py
+@@ -13,10 +13,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size
+ from vllm.model_executor.layers.linear import (LinearBase,
+ UnquantizedLinearMethod)
+ from vllm.utils import cdiv, round_down
++from vllm.forward_context import ForwardContext, get_forward_context
++from vllm.attention.layer import (maybe_execute_sparse_attention_begin, maybe_execute_sparse_attention_finished)
+
+ from vllm_ascend.ascend_config import get_ascend_config
+ from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
+-from vllm_ascend.attention.attention_v1 import AscendAttentionState
++from vllm_ascend.attention.attention_v1 import AscendAttentionState, wait_for_kv_layer_from_connector, maybe_save_kv_layer_to_connector
+ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
+ from vllm_ascend.multistream.context import get_multistream_comm_context
+ from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
+@@ -1042,6 +1044,7 @@ class AscendMLAImpl(MLAAttentionImpl):
+ enable_multistream_mla: bool = False,
+ ckq: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
++ forward_context: ForwardContext = get_forward_context()
+ assert output is not None, "Output tensor must be provided."
+ if attn_metadata is None:
+ # Profiling run.
+@@ -1192,6 +1195,8 @@ class AscendMLAImpl(MLAAttentionImpl):
+ # FIX: aicore move should be also placed on the comm stream in dbo,
+ # otherwise it may affect the accuracy
+ # TODO: use an elegant way to overlap
++ wait_for_kv_layer_from_connector(layer.layer_name)
++ maybe_execute_sparse_attention_begin(prefill_q, prefill_k_c_normed, prefill_k_pe, layer.layer_name, forward_context, "prefill")
+ output_prefill = self._forward_prefill(prefill_q,
+ prefill_k_c_normed,
+ prefill_k_pe, kv_cache,
+@@ -1203,8 +1208,11 @@ class AscendMLAImpl(MLAAttentionImpl):
+ current_ms_metadata.after_comm_event.record()
+ else:
+ output[num_decode_tokens:] = output_prefill
+-
++ maybe_execute_sparse_attention_finished(prefill_q, prefill_k_c_normed, prefill_k_pe, output[num_decode_tokens:], layer.layer_name, forward_context, "prefill")
++ maybe_save_kv_layer_to_connector(layer.layer_name, kv_cache)
+ if has_decode:
++ wait_for_kv_layer_from_connector(layer.layer_name)
++ maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, layer.layer_name, forward_context, "decode")
+ if self.running_in_graph:
+ return self._forward_decode(decode_ql_nope, decode_q_pe,
+ decode_k_nope, decode_k_pe,
+@@ -1223,5 +1231,7 @@ class AscendMLAImpl(MLAAttentionImpl):
+ current_ms_metadata.after_comm_event.record()
+ else:
+ output[:num_decode_tokens] = output_decode
++ maybe_execute_sparse_attention_finished(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, output[:num_decode_tokens], layer.layer_name, forward_context, "decode")
++ maybe_save_kv_layer_to_connector(layer.layer_name, kv_cache)
+
+ return output_padded
diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py
-index eabcdbc..e51f46e 100644
+index eabcdbcc..179dffde 100644
--- a/vllm_ascend/worker/model_runner_v1.py
+++ b/vllm_ascend/worker/model_runner_v1.py
@@ -39,7 +39,10 @@ from vllm.config import CompilationLevel, VllmConfig
@@ -141,7 +199,7 @@ index eabcdbc..e51f46e 100644
@@ -88,6 +91,9 @@ from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
-
+
+from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
+from ucm.sparse.base import UcmSparseMetadata, INVALID_SLOT
+
@@ -157,7 +215,7 @@ index eabcdbc..e51f46e 100644
self.encoder_cache.pop(req_id, None)
# Remove the finished requests from the persistent batch.
@@ -453,12 +460,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
-
+
# Update the states of the running/resumed requests.
req_data = scheduler_output.scheduled_cached_reqs
+ req_sparsed_slots = scheduler_output.req_sparsed_slots
@@ -168,7 +226,7 @@ index eabcdbc..e51f46e 100644
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT
-
+
req_state.num_computed_tokens = num_computed_tokens
if not is_last_rank:
@@ -474,15 +483,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
@@ -189,18 +247,18 @@ index eabcdbc..e51f46e 100644
- # The request is resumed from preemption.
- # Replace the existing block IDs with the new ones.
- req_state.block_ids = new_block_ids
-
+
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
@@ -496,6 +505,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
-
+
+ if is_sparsed_request:
+ self.input_batch.block_table.reset_row(req_index)
+
self.input_batch.block_table.append_row(new_block_ids, req_index)
-
+
if not is_last_rank:
@@ -876,7 +888,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors: Optional[IntermediateTensors] = None,
@@ -215,7 +273,7 @@ index eabcdbc..e51f46e 100644
@@ -955,12 +968,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens)
seq_lens = self.seq_lens_cpu[:num_reqs]
-
+
+ # TODO: improve performance, no `positions_np.copy()`
+ sparsed_positions = positions_np.copy()
+ req_sparsed_slots = scheduler_output.req_sparsed_slots
@@ -229,7 +287,7 @@ index eabcdbc..e51f46e 100644
block_table_indices = (req_indices * self.max_num_blocks_per_req +
- positions_np // self.block_size)
+ sparsed_positions // self.block_size)
-
+
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
- block_offsets = positions_np % self.block_size
@@ -240,7 +298,7 @@ index eabcdbc..e51f46e 100644
@@ -985,10 +1008,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
attn_state = AscendAttentionState.PrefillCacheHit
-
+
+ for req_id in self.input_batch.req_id_to_index:
+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT
+ req_index = self.input_batch.req_id_to_index[req_id]
@@ -254,10 +312,10 @@ index eabcdbc..e51f46e 100644
+ position=torch.tensor(sparsed_positions).npu(),
attn_state=attn_state)
self.attn_state = attn_state # type: ignore
-
+
@@ -1100,6 +1129,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions = self.positions[:padded_batch_size]
-
+
# Run forward pass
+ finished_dumping = None
with set_forward_context(attn_metadata,
@@ -269,7 +327,7 @@ index eabcdbc..e51f46e 100644
ACL_FORMAT_FRACTAL_ND)
+ self.maybe_setup_kv_connector(scheduler_output)
+ self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata)
-
+
hidden_states = self.model(
input_ids=input_ids,
@@ -1133,6 +1165,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
@@ -278,16 +336,16 @@ index eabcdbc..e51f46e 100644
)
+ finished_dumping = self.maybe_wait_for_kv_save()
+ self.maybe_execute_ucm_sparse_finished()
-
+
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
@@ -1163,7 +1197,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
-
+
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
- num_scheduled_tokens)
+ num_scheduled_tokens, finished_dumping)
-
+
def _get_cumsum_and_arange(
self,
@@ -1400,7 +1434,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
@@ -297,7 +355,7 @@ index eabcdbc..e51f46e 100644
- num_scheduled_tokens_np) = (self._process_reqs(
+ num_scheduled_tokens_np, finished_dumping) = (self._process_reqs(
scheduler_output, intermediate_tensors))
-
+
with ProfileExecuteDuration().capture_async("post process"):
@@ -1561,6 +1595,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logprobs=logprobs_lists,
@@ -305,7 +363,7 @@ index eabcdbc..e51f46e 100644
pooler_output=[],
+ finished_dumping=finished_dumping
)
-
+
durations = ProfileExecuteDuration().pop_captured_sync()
@@ -2369,3 +2404,43 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if batch_size <= padded_batch_size < selected_batch_size:
@@ -353,16 +411,16 @@ index eabcdbc..e51f46e 100644
+ ucm_sparse.request_finished_in_worker(request_id)
\ No newline at end of file
diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py
-index df03d50..a854923 100644
+index df03d508..5d5d9b5a 100644
--- a/vllm_ascend/worker/worker_v1.py
+++ b/vllm_ascend/worker/worker_v1.py
@@ -17,6 +17,7 @@
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
#
-
+
+import copy
from typing import Optional
-
+
import torch
@@ -27,7 +28,8 @@ from vllm import envs
from vllm.config import VllmConfig
@@ -381,15 +439,15 @@ index df03d50..a854923 100644
-from vllm.v1.outputs import ModelRunnerOutput
+from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase
-
+
import vllm_ascend.envs as envs_ascend
@@ -49,6 +51,7 @@ from vllm_ascend.utils import (check_kv_cache_bytes_cache_exist,
read_kv_cache_bytes_from_file,
sleep_mode_enabled, try_register_lib)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
+from ucm.sparse.state import ensure_ucm_sparse_initialized
-
-
+
+
class NPUWorker(WorkerBase):
@@ -222,9 +225,22 @@ class NPUWorker(WorkerBase):
assert isinstance(output, IntermediateTensors)
@@ -413,7 +471,7 @@ index df03d50..a854923 100644
assert isinstance(output, ModelRunnerOutput)
- return output if self.is_driver_worker else None
+ return output
-
+
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
@@ -321,6 +337,7 @@ class NPUWorker(WorkerBase):
@@ -421,8 +479,9 @@ index df03d50..a854923 100644
)
ensure_kv_transfer_initialized(self.vllm_config)
+ ensure_ucm_sparse_initialized(self.vllm_config)
-
+
def _init_profiler(self):
# Torch profiler. Enabled and configured through env vars:
---
+--
2.34.1
+
diff --git a/ucm/integration/vllm/patch/__init__.py b/ucm/integration/vllm/patch/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/ucm/integration/vllm/patch/apply_patch.py b/ucm/integration/vllm/patch/apply_patch.py
new file mode 100644
index 000000000..39f5ccbb0
--- /dev/null
+++ b/ucm/integration/vllm/patch/apply_patch.py
@@ -0,0 +1,175 @@
+#
+# MIT License
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+"""
+Monkey patching module for vLLM to apply UCM patches automatically.
+This replaces the need for manual `git apply` commands.
+"""
+
+import sys
+from typing import Optional
+
+from ucm.logger import init_logger
+
+logger = init_logger(__name__)
+
+import os
+
+PLATFORM = os.getenv("PLATFORM")
+
+
+def _patch_ascend() -> bool:
+ return PLATFORM == "ascend"
+
+
+# Track if patches have been applied
+_patches_applied = False
+_import_hook_installed = False
+_vllm_version: Optional[str] = None
+_vllm_import_hook = None
+
+
+def get_vllm_version() -> Optional[str]:
+ """Detect vLLM version."""
+ global _vllm_version
+ if _vllm_version is not None:
+ return _vllm_version
+
+ try:
+ # Try to get version from vllm module
+ import vllm as vllm_pkg
+
+ vllm_version = vllm_pkg.__version__
+ return vllm_version
+ except ImportError:
+ logger.warning("vLLM is not installed")
+ return None
+ except Exception as e:
+ logger.warning(f"Failed to detect vLLM version: {e}")
+ return None
+
+
+def get_supported_versions() -> list[str]:
+ """Get list of supported vLLM versions."""
+ return ["0.9.2"]
+
+
+def apply_all_patches() -> None:
+ """Apply all vLLM patches based on detected version."""
+ global _patches_applied
+ if _patches_applied:
+ return
+
+ try:
+ version = get_vllm_version()
+ if version is None:
+ raise ValueError("Could not detect vLLM version")
+
+ supported_versions = get_supported_versions()
+ if version not in supported_versions:
+ logger.warning(
+ f"vLLM version {version} is not explicitly supported to apply UCM patches. "
+ f"Supported versions: {', '.join(supported_versions)}. "
+ )
+
+ # Apply version-specific patches
+ match version:
+ case "0.9.2":
+ _apply_patches_v092()
+ case _:
+ logger.warning(
+ f"Unsupported vLLM version: {version} to apply UCM patches. "
+ f"Supported versions: {', '.join(supported_versions)}."
+ )
+
+ _patches_applied = True
+ logger.info(f"All vLLM patches applied successfully for version {version}")
+ except Exception as e:
+ logger.error(f"Failed to apply vLLM patches: {e}", exc_info=True)
+ raise
+
+
+def _apply_patches_v092() -> None:
+ """Apply patches for vLLM 0.9.2."""
+ from .patch_funcs.v092.vllm_patch import _apply_sparse_adapt
+
+ _apply_sparse_adapt() # apply vllm-sparse-adapt.patch
+ if _patch_ascend():
+ from .patch_funcs.v092.vllm_ascend_patch import _apply_ascend_patch
+
+ _apply_ascend_patch() # apply vllm-ascend-adapt.patch
+
+
+def install_import_hook() -> None:
+ """Install an import hook to automatically apply patches when vLLM is imported."""
+ global _import_hook_installed, _vllm_import_hook
+
+ if _import_hook_installed:
+ return
+
+ try:
+ # Check if vLLM is already imported
+ if "vllm" in sys.modules:
+ # vLLM already imported, apply patches immediately
+ apply_all_patches()
+ _import_hook_installed = True
+ else:
+ # Install import hook by wrapping the builtin __import__ function
+ # This intercepts all imports and applies patches when vLLM is imported
+ import builtins
+
+ original_import = builtins.__import__
+
+ def import_hook(name, globals=None, locals=None, fromlist=(), level=0):
+ # Call original import
+ module = original_import(name, globals, locals, fromlist, level)
+
+ # If the main vLLM module is being imported, apply patches
+ # We only check for 'vllm' (not submodules) to avoid multiple patch attempts
+ if name == "vllm" and not _patches_applied:
+ try:
+ apply_all_patches()
+ except Exception as e:
+ logger.warning(f"Failed to apply patches during import: {e}")
+
+ return module
+
+ # Replace builtin __import__
+ builtins.__import__ = import_hook
+ _vllm_import_hook = import_hook
+ _import_hook_installed = True
+ logger.debug("Import hook installed to intercept vLLM imports")
+
+ except Exception as e:
+ logger.warning(f"Failed to install import hook: {e}")
+
+
+def ensure_patches_applied() -> None:
+ """Ensure patches are applied, installing import hook if needed."""
+ if not _patches_applied:
+ # Try to apply patches immediately
+ try:
+ apply_all_patches()
+ except Exception:
+ # If it fails (vLLM not imported yet), install hook
+ install_import_hook()
diff --git a/ucm/integration/vllm/patch/patch_funcs/__init__.py b/ucm/integration/vllm/patch/patch_funcs/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/__init__.py b/ucm/integration/vllm/patch/patch_funcs/v092/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py
new file mode 100644
index 000000000..f3927ece9
--- /dev/null
+++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py
@@ -0,0 +1,1446 @@
+#
+# MIT License
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+
+from __future__ import annotations
+
+import os
+
+from ucm.logger import init_logger
+
+logger = init_logger(__name__)
+
+ENABLE_SPARSE = os.getenv("ENABLE_SPARSE")
+
+
+def _enable_sparse() -> bool:
+ return ENABLE_SPARSE is not None and ENABLE_SPARSE.lower() == "true"
+
+
+def _apply_ascend_patch() -> None:
+ """Apply patch for vLLM-Ascend."""
+ try:
+ from vllm_ascend.patch import platform, worker
+
+ if _enable_sparse():
+ _patch_attention_v1()
+ _patch_mla_v1()
+ _patch_model_runner_v1()
+ _patch_worker_v1()
+ logger.info("UCM sparse adapt patches applied successfully")
+
+ except Exception as e:
+ logger.error(f"Could not apply sparse adapt patches: {e}")
+ raise e
+
+
+# ========================= vllm_ascend/attention/attention_v1.py =========================
+def _patch_attention_v1() -> None:
+ """Patch attention_v1.py for vLLM-Ascend."""
+ try:
+ from typing import List
+
+ import torch
+ from vllm.forward_context import ForwardContext, get_forward_context
+ from vllm_ascend.attention import attention_v1
+
+ from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
+
+ def maybe_execute_sparse_attention_begin(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ layer_name: str,
+ forward_context: ForwardContext,
+ ):
+ if not has_ucm_sparse():
+ return
+
+ ucm_sparse = get_ucm_sparse()
+ attn_metadata = forward_context.attn_metadata
+ if attn_metadata is None:
+ return
+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context)
+
+ attention_v1.maybe_execute_sparse_attention_begin = (
+ maybe_execute_sparse_attention_begin
+ )
+
+ def maybe_execute_sparse_attention_finished(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_output: torch.Tensor,
+ layer_name: str,
+ forward_context: ForwardContext,
+ ):
+ if not has_ucm_sparse():
+ return
+ ucm_sparse = get_ucm_sparse()
+ attn_metadata = forward_context.attn_metadata
+ if attn_metadata is None:
+ return
+ ucm_sparse.attention_finished(
+ query, key, value, attn_output, layer_name, forward_context
+ )
+
+ attention_v1.maybe_execute_sparse_attention_finished = (
+ maybe_execute_sparse_attention_finished
+ )
+
+ vllm_ops = torch.ops.vllm
+ orig_unified_ascend_attention_with_output = (
+ vllm_ops.unified_ascend_attention_with_output
+ )
+
+ def _wrap_op_overload(orig, impl):
+ class _Wrapper:
+ def __init__(self, orig):
+ self._orig = orig
+
+ def __call__(self, *args, **kwargs):
+ return impl(*args, **kwargs)
+
+ def __getattr__(self, name):
+ return getattr(self._orig, name)
+
+ return _Wrapper(orig)
+
+ def unified_ascend_attention_with_output_impl(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ output: torch.Tensor,
+ layer_name: str,
+ ) -> None:
+
+ forward_context: ForwardContext = get_forward_context()
+ attn_metadata = forward_context.attn_metadata
+ self = forward_context.no_compile_layers[layer_name]
+ kv_cache = self.kv_cache[forward_context.virtual_engine]
+ if not self.use_mla:
+ maybe_execute_sparse_attention_begin(
+ query, key, value, layer_name, forward_context
+ )
+ self.impl.forward(
+ self,
+ query,
+ key,
+ value,
+ kv_cache,
+ attn_metadata,
+ output,
+ trace_flag=False,
+ )
+ if not self.use_mla:
+ maybe_execute_sparse_attention_finished(
+ query, key, value, output, layer_name, forward_context
+ )
+ return
+
+ vllm_ops.unified_ascend_attention_with_output = _wrap_op_overload(
+ orig_unified_ascend_attention_with_output,
+ unified_ascend_attention_with_output_impl,
+ )
+
+ attention_v1.unified_ascend_attention_with_output = (
+ unified_ascend_attention_with_output_impl
+ )
+ except ImportError as e:
+ logger.error(f"Failed to patch attention_v1.py: {e}", exc_info=True)
+ raise
+
+
+# ========================= vllm_ascend/attention/mla_v1.py =========================
+def _patch_mla_v1() -> None:
+ """Patch mla_v1.py for vLLM-Ascend."""
+ try:
+ from typing import Optional
+
+ import torch
+ import torch_npu
+ from vllm.attention.backends.abstract import AttentionLayer
+ from vllm.attention.layer import (
+ maybe_execute_sparse_attention_begin,
+ maybe_execute_sparse_attention_finished,
+ )
+ from vllm.forward_context import ForwardContext, get_forward_context
+ from vllm_ascend.attention.attention_v1 import (
+ AscendAttentionState,
+ )
+ from vllm_ascend.attention.mla_v1 import AscendMLAImpl
+ from vllm_ascend.multistream.context import get_multistream_comm_context
+ from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
+
+ def forward(
+ self,
+ layer: AttentionLayer,
+ hidden_states_or_q_c: torch.Tensor, # query in unified attn
+ hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
+ k_pe: torch.Tensor, # value in unified attn
+ kv_cache: torch.Tensor,
+ attn_metadata: M,
+ output: Optional[torch.Tensor] = None,
+ enable_multistream_mla: bool = False,
+ ckq: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ forward_context: ForwardContext = get_forward_context()
+ assert output is not None, "Output tensor must be provided."
+ if attn_metadata is None:
+ # Profiling run.
+ return output
+ self.running_in_graph = (
+ self.torchair_graph_enabled
+ and attn_metadata.attn_state
+ in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding]
+ )
+ num_actual_toks = attn_metadata.num_actual_tokens
+ if k_pe is None and not self.running_in_graph:
+ if not self.torchair_graph_enabled:
+ kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states_or_kv_c_normed)[
+ 0
+ ].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
+ kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
+ else:
+ kv_c_normed = hidden_states_or_kv_c_normed
+ assert (
+ attn_metadata.num_decodes is not None
+ and attn_metadata.num_prefills is not None
+ and attn_metadata.num_decode_tokens is not None
+ )
+ has_decode = attn_metadata.num_decodes > 0
+ has_prefill = attn_metadata.num_prefills > 0
+ num_decode_tokens = attn_metadata.num_decode_tokens
+ if not self.running_in_graph:
+ # Inputs and outputs may be padded for CUDA graphs
+ output_padded = output
+ output = output[:num_actual_toks, ...]
+ if not self.torchair_graph_enabled:
+ kv_c_normed = kv_c_normed[:num_actual_toks, ...]
+ prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
+ if not self.running_in_graph:
+ hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
+ prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
+ if not self.torchair_graph_enabled:
+ decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
+ k_pe = k_pe[:num_actual_toks, ...]
+ k_pe = k_pe.unsqueeze(1)
+ decode_k_pe = k_pe[:num_decode_tokens]
+ prefill_k_pe = k_pe[num_decode_tokens:]
+ else:
+ decode_hs_or_q_c = hidden_states_or_q_c
+ if has_decode:
+ decode_k_nope = None
+ assert attn_metadata.decode is not None
+ if self.running_in_graph:
+ seq_len = (
+ self.rotary_emb.max_position_embeddings
+ * self.rotary_emb.scaling_factor
+ )
+ cos = self.rotary_emb.cos_cached[:seq_len].to(
+ dtype=decode_hs_or_q_c.dtype
+ )
+ sin = self.rotary_emb.sin_cached[:seq_len].to(
+ dtype=decode_hs_or_q_c.dtype
+ )
+ cos = cos[attn_metadata.decode.input_positions]
+ sin = sin[attn_metadata.decode.input_positions]
+ cos = cos[:, None, None, :]
+ sin = sin[:, None, None, :]
+ with npu_stream_switch(
+ "mla_secondary", 0, enabled=enable_multistream_mla
+ ):
+ npu_wait_tensor(
+ hidden_states_or_kv_c_normed,
+ ckq,
+ enabled=enable_multistream_mla,
+ )
+ decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
+ hidden_states_or_kv_c_normed,
+ cos,
+ sin,
+ kv_cache,
+ attn_metadata.slot_mapping,
+ )
+ # Without explicitly controlling the order, IndexByTensor operations
+ # would be placed after `matmul W_KV_T` hindering the overlapping of
+ # KvRmsNormRopeCache and SingleRope.
+ npu_wait_tensor(
+ decode_hs_or_q_c, cos, enabled=enable_multistream_mla
+ )
+ npu_wait_tensor(
+ decode_hs_or_q_c, sin, enabled=enable_multistream_mla
+ )
+ npu_wait_tensor(
+ decode_hs_or_q_c, decode_kv, enabled=enable_multistream_mla
+ )
+
+ decode_ql_nope, decode_q_pe = self._q_proj_and_k_up_proj(
+ decode_hs_or_q_c
+ )
+ if self.running_in_graph:
+ with npu_stream_switch(
+ "mla_secondary", 0, enabled=enable_multistream_mla
+ ):
+ npu_wait_tensor(
+ decode_q_pe, decode_k_pe, enabled=enable_multistream_mla
+ )
+ decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
+ else:
+ decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
+ attn_metadata.decode.input_positions,
+ decode_q_pe.contiguous(),
+ decode_k_pe,
+ max_seq_len=attn_metadata.decode.max_seq_lens,
+ )
+ if has_prefill:
+ assert attn_metadata.prefill is not None
+ prefill_q = self.q_proj(prefill_hs_or_q_c)[0].view(
+ -1, self.num_heads, self.qk_head_dim
+ )
+ prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :]
+ prefill_q_nope = prefill_q[..., : self.qk_nope_head_dim]
+ if self.torchair_graph_enabled:
+ num_tokens = prefill_hs_or_q_c.shape[0]
+ seq_len = (
+ self.rotary_emb.max_position_embeddings
+ * self.rotary_emb.scaling_factor
+ )
+ cos = self.rotary_emb.cos_cached[:seq_len].to(
+ dtype=prefill_q_pe.dtype
+ )
+ sin = self.rotary_emb.sin_cached[:seq_len].to(
+ dtype=prefill_q_pe.dtype
+ )
+ cos = cos[attn_metadata.prefill.input_positions]
+ sin = sin[attn_metadata.prefill.input_positions]
+ cos = cos[:, None, None, :]
+ sin = sin[:, None, None, :]
+
+ prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
+ prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
+ hidden_states_or_kv_c_normed,
+ cos,
+ sin,
+ kv_cache,
+ attn_metadata.slot_mapping,
+ )
+
+ kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
+ prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
+ prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, -1)
+ prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
+ else:
+ prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
+ attn_metadata.prefill.input_positions,
+ prefill_q_pe.contiguous(),
+ prefill_k_pe,
+ max_seq_len=attn_metadata.prefill.max_seq_lens,
+ )
+ if self.torchair_graph_enabled:
+ if (
+ len(kv_cache) > 0
+ and kv_cache[0].numel() > 0
+ and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
+ ):
+ slots = attn_metadata.slot_mapping
+ # NOTE: Separate the kv cache in advance to avoid OOM or other issues
+ torch_npu._npu_reshape_and_cache(
+ key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1),
+ value=prefill_k_pe,
+ key_cache=kv_cache[0],
+ value_cache=kv_cache[1],
+ slot_indices=slots,
+ )
+ elif kv_cache.numel() > 0:
+ key = torch.cat(
+ [kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe],
+ dim=2,
+ )
+ torch_npu._npu_reshape_and_cache_siso(
+ key=key,
+ key_cache=kv_cache,
+ slot_indices=attn_metadata.slot_mapping.flatten(),
+ )
+ if has_prefill:
+ # FIX: aicore move should be also placed on the comm stream in dbo,
+ # otherwise it may affect the accuracy
+ # TODO: use an elegant way to overlap
+ maybe_execute_sparse_attention_begin(
+ prefill_q,
+ prefill_k_c_normed,
+ prefill_k_pe,
+ layer.layer_name,
+ forward_context,
+ "prefill",
+ )
+ output_prefill = self._forward_prefill(
+ prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata
+ )
+ current_ms_metadata = get_multistream_comm_context()
+ if current_ms_metadata is not None:
+ with torch.npu.stream(current_ms_metadata.comm_stream):
+ output[num_decode_tokens:] = output_prefill
+ current_ms_metadata.after_comm_event.record()
+ else:
+ output[num_decode_tokens:] = output_prefill
+ maybe_execute_sparse_attention_finished(
+ prefill_q,
+ prefill_k_c_normed,
+ prefill_k_pe,
+ output[num_decode_tokens:],
+ layer.layer_name,
+ forward_context,
+ "prefill",
+ )
+ if has_decode:
+ maybe_execute_sparse_attention_begin(
+ torch.cat([decode_ql_nope, decode_q_pe], dim=-1),
+ decode_ql_nope,
+ decode_q_pe,
+ layer.layer_name,
+ forward_context,
+ "decode",
+ )
+ if self.running_in_graph:
+ return self._forward_decode(
+ decode_ql_nope,
+ decode_q_pe,
+ decode_k_nope,
+ decode_k_pe,
+ kv_cache,
+ attn_metadata,
+ enable_multistream_mla,
+ )
+ else:
+ output_decode = self._forward_decode(
+ decode_ql_nope,
+ decode_q_pe,
+ decode_k_nope,
+ decode_k_pe,
+ kv_cache,
+ attn_metadata,
+ )
+ current_ms_metadata = get_multistream_comm_context()
+ if current_ms_metadata is not None:
+ with torch.npu.stream(current_ms_metadata.comm_stream):
+ output[:num_decode_tokens] = output_decode
+ current_ms_metadata.after_comm_event.record()
+ else:
+ output[:num_decode_tokens] = output_decode
+ maybe_execute_sparse_attention_finished(
+ torch.cat([decode_ql_nope, decode_q_pe], dim=-1),
+ decode_ql_nope,
+ decode_q_pe,
+ output[:num_decode_tokens],
+ layer.layer_name,
+ forward_context,
+ "decode",
+ )
+
+ return output_padded
+
+ AscendMLAImpl.forward = forward
+ except ImportError as e:
+ logger.error(f"Failed to patch mla_v1.py: {e}", exc_info=True)
+ raise
+
+
+# ========================= vllm_ascend/worker/model_runner_v1.py =========================
+def _patch_model_runner_v1() -> None:
+ """Patch model_runner_v1.py for vLLM-Ascend."""
+ try:
+ from typing import TYPE_CHECKING, List, Optional, Union
+
+ import numpy as np
+ import torch
+ from vllm.distributed.parallel_state import get_pp_group, get_tp_group
+ from vllm.forward_context import set_forward_context
+ from vllm.logger import logger
+ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
+ from vllm.sampling_params import SamplingType
+ from vllm.sequence import IntermediateTensors
+ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
+ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
+ from vllm_ascend.ascend_config import get_ascend_config
+ from vllm_ascend.attention.attention_v1 import (
+ AscendAttentionState,
+ AscendMetadata,
+ )
+ from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
+ from vllm_ascend.attention.mla_v1 import (
+ AscendMLAMetadata,
+ CommonAttentionMetadata,
+ )
+ from vllm_ascend.utils import (
+ ACL_FORMAT_FRACTAL_ND,
+ ACL_FORMAT_FRACTAL_NZ,
+ ProfileExecuteDuration,
+ maybe_converting_weight_acl_format,
+ )
+ from vllm_ascend.worker.npu_input_batch import CachedRequestState
+
+ from ucm.sparse.base import INVALID_SLOT
+
+ if TYPE_CHECKING:
+ from vllm.v1.core.sched.output import SchedulerOutput
+ from vllm.distributed.kv_transfer import (
+ get_kv_transfer_group,
+ has_kv_transfer_group,
+ )
+ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
+ from vllm.forward_context import get_forward_context, set_forward_context
+ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
+
+ from ucm.sparse.base import INVALID_SLOT
+ from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
+
+ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
+ """Update the cached states and the persistent batch with the scheduler
+ output.
+
+ The SamplingMetadata is updated and copied to the NPU if there is a
+ new/resumed/paused/finished request in the batch.
+ """
+ # Remove finished requests from the cached states.
+ for req_id in scheduler_output.finished_req_ids:
+ self.ucm_sparse_request_finished_in_worker(req_id)
+ self.requests.pop(req_id, None)
+ self.encoder_cache.pop(req_id, None)
+ # Remove the finished requests from the persistent batch.
+ # NOTE(woosuk): There could be an edge case where finished_req_ids and
+ # scheduled_req_ids overlap. This happens when a request is aborted and
+ # then resubmitted with the same ID. In this case, we treat them as two
+ # distinct requests - clearing the cached states for the first request
+ # and handling the second as a new request.
+ removed_req_indices: List[int] = []
+ for req_id in scheduler_output.finished_req_ids:
+ req_index = self.input_batch.remove_request(req_id)
+ if req_index is not None:
+ removed_req_indices.append(req_index)
+
+ # Free the cached encoder outputs.
+ for req_id, input_id in scheduler_output.free_encoder_input_ids:
+ encoder_outputs = self.encoder_cache.get(req_id)
+ if encoder_outputs is not None:
+ encoder_outputs.pop(input_id, None)
+ if not encoder_outputs:
+ self.encoder_cache.pop(req_id, None)
+
+ # Remove the unscheduled requests from the persistent batch.
+ # NOTE(woosuk): The unscheduled requests are either preempted requests
+ # or running requests that are not scheduled in this step. We remove
+ # them from the persistent batch but keep their cached states since
+ # they will be scheduled again sometime in the future.
+ scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
+ cached_req_ids = self.input_batch.req_id_to_index.keys()
+ unscheduled_req_ids = cached_req_ids - scheduled_req_ids
+ # NOTE(woosuk): The persistent batch optimization assumes that
+ # consecutive batches contain mostly the same requests. If batches
+ # have low request overlap (e.g., alternating between two distinct
+ # sets of requests), this optimization becomes very inefficient.
+ for req_id in unscheduled_req_ids:
+ req_index = self.input_batch.remove_request(req_id)
+ assert req_index is not None
+ removed_req_indices.append(req_index)
+
+ req_ids_to_add: List[str] = []
+ # Add new requests to the cached states.
+ for new_req_data in scheduler_output.scheduled_new_reqs:
+ req_id = new_req_data.req_id
+ sampling_params = new_req_data.sampling_params
+ if (
+ sampling_params
+ and sampling_params.sampling_type == SamplingType.RANDOM_SEED
+ ):
+ generator = torch.Generator(device=self.device)
+ generator.manual_seed(sampling_params.seed)
+ else:
+ generator = None
+
+ self.requests[req_id] = CachedRequestState(
+ req_id=req_id,
+ prompt_token_ids=new_req_data.prompt_token_ids,
+ mm_inputs=new_req_data.mm_inputs,
+ mm_positions=new_req_data.mm_positions,
+ sampling_params=sampling_params,
+ pooling_params=new_req_data.pooling_params,
+ generator=generator,
+ block_ids=new_req_data.block_ids,
+ num_computed_tokens=new_req_data.num_computed_tokens,
+ output_token_ids=[],
+ lora_request=new_req_data.lora_request,
+ )
+
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ if self.uses_mrope:
+ image_grid_thw = []
+ video_grid_thw = []
+ second_per_grid_ts = []
+ audio_feature_lengths = []
+ use_audio_in_video = False
+ for mm_input in self.requests[req_id].mm_inputs:
+ if mm_input.get("image_grid_thw") is not None:
+ image_grid_thw.extend(mm_input["image_grid_thw"].tolist())
+ if mm_input.get("video_grid_thw") is not None:
+ video_grid_thw.extend(mm_input["video_grid_thw"].tolist())
+ if mm_input.get("second_per_grid_ts") is not None:
+ second_per_grid_ts.extend(mm_input["second_per_grid_ts"])
+ if mm_input.get("audio_feature_lengths") is not None:
+ audio_feature_lengths.extend(
+ mm_input["audio_feature_lengths"]
+ )
+ if mm_input.get("use_audio_in_video") is True:
+ use_audio_in_video = True
+
+ hf_config = self.model_config.hf_config
+
+ (
+ self.requests[req_id].mrope_positions,
+ self.requests[req_id].mrope_position_delta,
+ ) = MRotaryEmbedding.get_input_positions_tensor(
+ self.requests[req_id].prompt_token_ids,
+ hf_config=hf_config,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ audio_feature_lengths=audio_feature_lengths,
+ use_audio_in_video=use_audio_in_video,
+ )
+
+ req_ids_to_add.append(req_id)
+
+ # Update the states of the running/resumed requests.
+ req_data = scheduler_output.scheduled_cached_reqs
+ req_sparsed_slots = scheduler_output.req_sparsed_slots
+ is_last_rank = get_pp_group().is_last_rank
+ for i, req_id in enumerate(req_data.req_ids):
+ req_state = self.requests[req_id]
+ num_computed_tokens = req_data.num_computed_tokens[i]
+ new_block_ids = req_data.new_block_ids[i]
+ resumed_from_preemption = req_data.resumed_from_preemption[i]
+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT
+
+ req_state.num_computed_tokens = num_computed_tokens
+ if not is_last_rank:
+ new_token_ids = req_data.new_token_ids[i]
+ # Add the sampled token(s) from the previous step (if any).
+ # This doesn't include "unverified" tokens like spec decode tokens.
+ num_new_tokens = (
+ num_computed_tokens + len(new_token_ids) - req_state.num_tokens
+ )
+ if num_new_tokens == 1:
+ # Avoid slicing list in most common case.
+ req_state.output_token_ids.append(new_token_ids[-1])
+ elif num_new_tokens > 0:
+ req_state.output_token_ids.extend(
+ new_token_ids[-num_new_tokens:]
+ )
+ # Update the block IDs.
+ if resumed_from_preemption or is_sparsed_request:
+ # The request is resumed from preemption.
+ # Replace the existing block IDs with the new ones.
+ req_state.block_ids = new_block_ids
+ else:
+ # Append the new blocks to the existing block IDs.
+ for block_ids, new_ids in zip( # type: ignore[call-overload]
+ req_state.block_ids, new_block_ids
+ ):
+ block_ids.extend(new_ids)
+
+ req_index = self.input_batch.req_id_to_index.get(req_id)
+ if req_index is None:
+ # The request is not in the persistent batch.
+ # The request was either preempted and resumed later, or was not
+ # scheduled in the previous step and needs to be added again.
+ req_ids_to_add.append(req_id)
+ continue
+
+ # Update the persistent batch.
+ self.input_batch.num_computed_tokens_cpu[req_index] = (
+ num_computed_tokens
+ )
+
+ if is_sparsed_request:
+ self.input_batch.block_table.reset_row(req_index)
+
+ self.input_batch.block_table.append_row(new_block_ids, req_index)
+
+ if not is_last_rank:
+ # Add new_token_ids to token_ids_cpu.
+ start_token_index = num_computed_tokens
+ end_token_index = num_computed_tokens + len(new_token_ids)
+ self.input_batch.token_ids_cpu[
+ req_index, start_token_index:end_token_index
+ ] = new_token_ids
+ self.input_batch.num_tokens_no_spec[req_index] = end_token_index
+ # Add spec_token_ids to token_ids_cpu.
+ spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
+ req_id, ()
+ )
+ if spec_token_ids:
+ start_index = end_token_index
+ end_token_index += len(spec_token_ids)
+ self.input_batch.token_ids_cpu[
+ req_index, start_index:end_token_index
+ ] = spec_token_ids
+ # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
+ self.input_batch.num_tokens[req_index] = end_token_index
+
+ # Check if the batch has changed. If not, we can skip copying the
+ # sampling metadata from CPU to GPU.
+ batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
+
+ # Add the new or resumed requests to the persistent batch.
+ # The smaller empty indices are filled first.
+ removed_req_indices.sort(reverse=True)
+ for req_id in req_ids_to_add:
+ req_state = self.requests[req_id]
+ if removed_req_indices:
+ # Fill the empty index.
+ req_index = removed_req_indices.pop()
+ else:
+ # Append to the end.
+ req_index = None
+ self.input_batch.add_request(req_state, req_index)
+
+ # Condense the batched states if there are empty indices.
+ if removed_req_indices:
+ self.input_batch.condense(removed_req_indices)
+
+ if batch_changed:
+ self.input_batch.refresh_sampling_metadata()
+
+ NPUModelRunner._update_states = _update_states
+
+ def _process_reqs(
+ self,
+ scheduler_output: "SchedulerOutput",
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ ) -> tuple[
+ Union[AscendMetadata, AscendMLAMetadata, AscendTorchairMetadata],
+ torch.Tensor,
+ SpecDecodeMetadata,
+ torch.Tensor,
+ int,
+ torch.Tensor,
+ torch.Tensor,
+ np.ndarray,
+ Optional[dict[str, list[str]]],
+ ]:
+ # Check input valid
+ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
+ assert total_num_scheduled_tokens > 0
+ num_reqs = self.input_batch.num_reqs
+ assert num_reqs > 0
+ if (
+ self.use_aclgraph
+ and total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]
+ ):
+ # Add padding to the batch size.
+ num_input_tokens = self.vllm_config.pad_for_cudagraph(
+ total_num_scheduled_tokens
+ )
+ else:
+ # Eager mode.
+ num_input_tokens = total_num_scheduled_tokens
+
+ modified_batch = self.attn_metadata_builder.reorder_batch(
+ self.input_batch, scheduler_output
+ )
+ if modified_batch:
+ self.input_batch.refresh_sampling_metadata()
+
+ # OPTIMIZATION: Start copying the block table first.
+ # This way, we can overlap the copy with the following CPU operations.
+ self.input_batch.block_table.commit(num_reqs)
+
+ # Get the number of scheduled tokens for each request.
+ # TODO: The Python loop can be slow. Optimize.
+ num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
+ num_valid_tokens = np.empty(num_reqs, dtype=np.int32)
+ max_num_scheduled_tokens = 0
+ for i, req_id in enumerate(self.input_batch.req_ids):
+ num_tokens = scheduler_output.num_scheduled_tokens[req_id]
+ num_scheduled_tokens[i] = num_tokens
+ num_valid_tokens[i] = num_tokens - len(
+ scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
+ )
+ max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens)
+
+ # Hot-Swap lora model
+ if self.lora_config:
+ self.set_active_loras(self.input_batch, num_scheduled_tokens)
+
+ # Prepare positions
+ req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
+ cu_num_tokens = np.cumsum(num_scheduled_tokens)
+ cumsums_offsets = np.repeat(
+ cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens
+ )
+ logits_indices = cu_num_tokens - 1
+ logits_indices = torch.from_numpy(logits_indices).to(
+ self.device, non_blocking=True
+ )
+ arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
+
+ positions_np = self.positions_np[:total_num_scheduled_tokens]
+ np.add(
+ self.input_batch.num_computed_tokens_cpu[req_indices],
+ arange,
+ out=positions_np,
+ )
+
+ # Calculate M-RoPE positions.
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ if self.uses_mrope:
+ self._calc_mrope_positions(scheduler_output)
+
+ if self.uses_mrope:
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
+ self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
+ non_blocking=True,
+ )
+
+ self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
+ self.positions[:total_num_scheduled_tokens].copy_(
+ self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True
+ )
+ positions = self.positions[:num_input_tokens]
+ self.query_lens = torch.from_numpy(num_scheduled_tokens)
+
+ self.seq_lens_np[:num_reqs] = (
+ self.input_batch.num_computed_tokens_cpu[:num_reqs]
+ + num_scheduled_tokens
+ )
+ seq_lens = self.seq_lens_cpu[:num_reqs]
+
+ # TODO: improve performance, no `positions_np.copy()`
+ sparsed_positions = positions_np.copy()
+ req_sparsed_slots = scheduler_output.req_sparsed_slots
+ for req_id in self.input_batch.req_id_to_index:
+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT
+ req_index = self.input_batch.req_id_to_index[req_id]
+ offset = (
+ 0 if req_index == 0 else cu_num_tokens[req_index - 1]
+ ) # TODO: support MTP
+ if is_sparsed_request:
+ sparsed_positions[offset] = req_sparsed_slots[req_id] - 1
+
+ block_table_indices = (
+ req_indices * self.max_num_blocks_per_req
+ + sparsed_positions // self.block_size
+ )
+
+ block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
+ block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
+ block_offsets = sparsed_positions % self.block_size
+ np.add(
+ block_numbers * self.block_size,
+ block_offsets,
+ out=self.slot_mapping_np[:total_num_scheduled_tokens],
+ )
+
+ ascend_config = get_ascend_config()
+ use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
+ if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
+ attn_state = AscendAttentionState.PrefillNoCache
+ # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
+ elif np.all(num_scheduled_tokens == 1):
+ attn_state = AscendAttentionState.DecodeOnly
+ # Speculative decoding.
+ elif np.all(num_valid_tokens == 1):
+ if self.use_eagle:
+ attn_state = AscendAttentionState.ChunkedPrefill
+ else:
+ attn_state = AscendAttentionState.SpecDecoding
+ # splitfuse
+ elif (
+ not ascend_config.ascend_scheduler_config.enabled
+ or self.chunked_prefill_enabled
+ ):
+ attn_state = AscendAttentionState.ChunkedPrefill
+ else:
+ attn_state = AscendAttentionState.PrefillCacheHit
+
+ for req_id in self.input_batch.req_id_to_index:
+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT
+ req_index = self.input_batch.req_id_to_index[req_id]
+ if is_sparsed_request:
+ seq_lens[req_index] = req_sparsed_slots[req_id]
+
+ self.attn_mask = self._make_attention_mask(
+ seq_lens=seq_lens,
+ query_lens=num_scheduled_tokens,
+ position=torch.tensor(sparsed_positions).npu(),
+ attn_state=attn_state,
+ )
+ self.attn_state = attn_state # type: ignore
+
+ extra_builder_kwargs = {}
+
+ self.query_start_loc_np[0] = 0
+ self.query_start_loc_np[1 : num_reqs + 1] = cu_num_tokens
+ self.query_start_loc[: num_reqs + 1].copy_(
+ self.query_start_loc_cpu[: num_reqs + 1], non_blocking=True
+ )
+ self.seq_lens[:num_reqs].copy_(
+ self.seq_lens_cpu[:num_reqs], non_blocking=True
+ )
+
+ # Fill unused with -1. Needed for reshape_and_cache
+ self.seq_lens[num_reqs:].fill_(0)
+ self.query_start_loc[num_reqs + 1 :].fill_(-1)
+
+ with_prefill = attn_state not in [
+ AscendAttentionState.DecodeOnly,
+ AscendAttentionState.SpecDecoding,
+ ]
+
+ if self.dp_size > 1:
+ max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
+ total_num_scheduled_tokens, with_prefill
+ )
+ extra_builder_kwargs["max_num_tokens_across_dp"] = max_num_tokens
+ extra_builder_kwargs["with_prefill_across_dp"] = with_prefill
+
+ # Add graph_pad_size here
+ if self.torchair_graph_enabled and not with_prefill:
+ if self.dp_size > 1:
+ padded_batch_size = self.select_torchair_padded_batch_size(
+ max_num_tokens
+ )
+ else:
+ padded_batch_size = self.select_torchair_padded_batch_size(
+ total_num_scheduled_tokens
+ )
+ graph_pad_size = padded_batch_size - total_num_scheduled_tokens
+
+ extra_builder_kwargs["graph_pad_size"] = graph_pad_size
+
+ if self.vllm_config.model_config.use_mla:
+ query_start_loc = self.query_start_loc[: num_reqs + 1]
+ seq_lens = self.seq_lens[:num_reqs]
+ common_attn_metadata = CommonAttentionMetadata(
+ query_start_loc=query_start_loc, seq_lens=seq_lens
+ )
+ attn_metadata = self.attn_metadata_builder.build( # type: ignore
+ num_reqs=num_reqs,
+ num_actual_tokens=total_num_scheduled_tokens,
+ max_query_len=max_num_scheduled_tokens,
+ common_attn_metadata=common_attn_metadata,
+ common_prefix_len=None,
+ **extra_builder_kwargs,
+ )
+ else:
+ attn_metadata = self.attn_metadata_builder.build( # type: ignore
+ num_reqs=num_reqs,
+ num_actual_tokens=total_num_scheduled_tokens,
+ max_query_len=max_num_scheduled_tokens,
+ common_prefix_len=None,
+ **extra_builder_kwargs,
+ )
+ attn_metadata.num_input_tokens = num_input_tokens
+
+ # Prepare input_ids
+ token_indices = (
+ positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
+ )
+ torch.index_select(
+ self.input_batch.token_ids_cpu_tensor.flatten(),
+ 0,
+ torch.from_numpy(token_indices),
+ out=self.input_ids_cpu[:total_num_scheduled_tokens],
+ )
+ # Copy the tensors to the NPU.
+ self.input_ids[:total_num_scheduled_tokens].copy_(
+ self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True
+ )
+
+ # _prepare_inputs may reorder the batch, so we must gather multi
+ # modal outputs after that to ensure the correct order
+ if self.is_multimodal_model:
+ # Run the multimodal encoder if any.
+ self._execute_mm_encoder(scheduler_output)
+ mm_embeds = self._gather_mm_embeddings(scheduler_output)
+ else:
+ mm_embeds = []
+
+ if self.is_multimodal_model:
+ # NOTE(woosuk): To unify token ids and soft tokens (vision
+ # embeddings), we always use embeddings (rather than token ids)
+ # as input to the multimodal model, even when the input is text.
+ input_ids = self.input_ids[:total_num_scheduled_tokens]
+ if mm_embeds:
+ inputs_embeds = self.model.get_input_embeddings(
+ input_ids, mm_embeds
+ )
+ else:
+ inputs_embeds = self.model.get_input_embeddings(input_ids)
+ # TODO(woosuk): Avoid the copy. Optimize.
+ self.inputs_embeds[:total_num_scheduled_tokens].copy_(inputs_embeds)
+ inputs_embeds = self.inputs_embeds[:num_input_tokens]
+ input_ids = None
+ else:
+ # For text-only models, we use token ids as input.
+ # While it is possible to use embeddings as input just like the
+ # multimodal models, it is not desirable for performance since
+ # then the embedding layer is not included in the ACL graph.
+ input_ids = self.input_ids[:num_input_tokens]
+ inputs_embeds = None
+ if self.uses_mrope:
+ positions = self.mrope_positions[:, :num_input_tokens]
+
+ if self.torchair_graph_enabled and not with_prefill:
+ input_ids = self.input_ids[:padded_batch_size]
+ positions = self.positions[:padded_batch_size]
+
+ # Run forward pass
+ with set_forward_context(
+ attn_metadata, self.vllm_config, num_tokens=num_input_tokens
+ ):
+ with ProfileExecuteDuration().capture_async("forward"):
+ model_kwargs = {}
+ if self.torchair_graph_enabled:
+ model_kwargs["kv_caches"] = self.kv_caches
+ model_kwargs["attn_metadata"] = attn_metadata
+ if self.torchair_graph_enabled and not with_prefill:
+ maybe_converting_weight_acl_format(
+ self.model, ACL_FORMAT_FRACTAL_NZ
+ )
+
+ compiled_model = self._get_torchair_lazy_compiled_model(
+ padded_batch_size
+ )
+ hidden_states = compiled_model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ **model_kwargs,
+ )
+ else:
+ assert self.model is not None
+ maybe_converting_weight_acl_format(
+ self.model, ACL_FORMAT_FRACTAL_ND
+ )
+ self.maybe_setup_kv_connector(scheduler_output)
+ self.maybe_execute_ucm_sparse_begin(
+ scheduler_output, attn_metadata
+ )
+
+ hidden_states = self.model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ **model_kwargs,
+ )
+ self.maybe_wait_for_kv_save()
+ self.maybe_execute_ucm_sparse_finished()
+
+ use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
+ if not use_spec_decode:
+ # NOTE(woosuk): Due to chunked prefills, the batch may contain
+ # partial requests. While we should not sample any token
+ # from these partial requests, we do so for simplicity.
+ # We will ignore the sampled tokens from the partial requests.
+ # TODO: Support prompt logprobs.
+ spec_decode_metadata = None
+ else:
+ # Get the number of draft tokens for each request.
+ # Iterate over the dictionary rather than all requests since not all
+ # requests have draft tokens.
+ num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
+ for (
+ req_id,
+ draft_token_ids,
+ ) in scheduler_output.scheduled_spec_decode_tokens.items():
+ req_idx = self.input_batch.req_id_to_index[req_id]
+ num_draft_tokens[req_idx] = len(draft_token_ids)
+
+ spec_decode_metadata = self._calc_spec_decode_metadata(
+ num_draft_tokens, cu_num_tokens
+ )
+ logits_indices = spec_decode_metadata.logits_indices
+
+ aux_hidden_states = None
+ if self.use_aux_hidden_state_outputs:
+ hidden_states, aux_hidden_states = hidden_states
+
+ return (
+ attn_metadata,
+ hidden_states,
+ spec_decode_metadata,
+ positions,
+ total_num_scheduled_tokens,
+ logits_indices,
+ aux_hidden_states,
+ num_scheduled_tokens,
+ )
+
+ NPUModelRunner._process_reqs = _process_reqs
+
+ @torch.inference_mode()
+ def execute_model(
+ self,
+ scheduler_output: "SchedulerOutput",
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ ) -> Union[ModelRunnerOutput, torch.Tensor]:
+ with ProfileExecuteDuration().capture_async("prepare input and forward"):
+ self._update_states(scheduler_output)
+ if not scheduler_output.total_num_scheduled_tokens:
+ # Return empty ModelRunnerOuptut if there's no work to do.
+ return EMPTY_MODEL_RUNNER_OUTPUT
+ (
+ attn_metadata,
+ hidden_states,
+ spec_decode_metadata,
+ positions,
+ num_scheduled_tokens,
+ logits_indices,
+ aux_hidden_states,
+ num_scheduled_tokens_np,
+ ) = self._process_reqs(scheduler_output, intermediate_tensors)
+
+ with ProfileExecuteDuration().capture_async("post process"):
+ # Broadcast PP output for external_launcher (torchrun)
+ # to make sure we are synced across pp ranks
+ # TODO: Support overlapping mirco-batches
+ # https://github.com/vllm-project/vllm/issues/18019
+ broadcast_pp_output = (
+ self.parallel_config.distributed_executor_backend
+ == "external_launcher"
+ and len(get_pp_group().ranks) > 0
+ )
+ if not get_pp_group().is_last_rank:
+ # For mid-pipeline stages, return the hidden states.
+ if not broadcast_pp_output:
+ return hidden_states
+ assert isinstance(hidden_states, IntermediateTensors)
+ get_pp_group().send_tensor_dict(
+ hidden_states.tensors, all_gather_group=get_tp_group()
+ )
+ logits = None
+ else:
+ if self.input_batch.pooling_params:
+ return self._pool(
+ hidden_states, num_scheduled_tokens, num_scheduled_tokens_np
+ )
+ sample_hidden_states = hidden_states[logits_indices]
+ logits = self.model.compute_logits(sample_hidden_states, None)
+ if broadcast_pp_output:
+ model_output_broadcast_data = (
+ {
+ "logits": logits.contiguous(),
+ }
+ if logits is not None
+ else {}
+ )
+ model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
+ model_output_broadcast_data, src=len(get_pp_group().ranks) - 1
+ )
+ assert model_output_broadcast_data is not None
+ logits = model_output_broadcast_data["logits"]
+
+ # Apply structured output bitmasks if present
+ if scheduler_output.grammar_bitmask is not None:
+ logits = self.apply_grammar_bitmask(scheduler_output, logits)
+
+ # Sample the next token and get logprobs if needed.
+ sampling_metadata = self.input_batch.sampling_metadata
+ if spec_decode_metadata is None:
+ sampler_output = self.sampler(
+ logits=logits,
+ sampling_metadata=sampling_metadata,
+ )
+ else:
+ # When indexing with a tensor (bonus_logits_indices), PyTorch
+ # creates a new tensor with separate storage from the original
+ # logits tensor. This means any in-place operations on bonus_logits
+ # won't affect the original logits tensor.
+ assert logits is not None
+ bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
+ sampler_output = self.sampler(
+ logits=bonus_logits,
+ sampling_metadata=sampling_metadata,
+ )
+ bonus_token_ids = sampler_output.sampled_token_ids
+
+ # Just like `bonus_logits`, `target_logits` is a new tensor with
+ # separate storage from the original `logits` tensor. Therefore,
+ # it is safe to update `target_logits` in place.
+ target_logits = logits[spec_decode_metadata.target_logits_indices]
+ output_token_ids = self.rejection_sampler(
+ spec_decode_metadata,
+ None, # draft_probs
+ target_logits,
+ bonus_token_ids,
+ sampling_metadata,
+ )
+ sampler_output.sampled_token_ids = output_token_ids
+
+ discard_sampled_tokens_req_indices: list[int] = []
+ # TODO(woosuk): The following loop can be slow since it iterates over
+ # the requests one by one. Optimize.
+ discard_sampled_tokens_req_indices = []
+ for i, req_id in enumerate(self.input_batch.req_ids):
+ req_state = self.requests[req_id]
+ seq_len = (
+ req_state.num_computed_tokens
+ + scheduler_output.num_scheduled_tokens[req_id]
+ )
+ if seq_len < req_state.num_tokens:
+ # Ignore the sampled token.
+ # Rewind the generator state as if the token was not sampled.
+ generator = self.input_batch.generators.get(i)
+ if generator is not None:
+ generator.set_offset(generator.get_offset() - 4)
+ discard_sampled_tokens_req_indices.append(i)
+
+ # NOTE: NPU -> CPU Sync happens here.
+ # Move as many CPU operations as possible before this sync point.
+ logprobs_tensors = sampler_output.logprobs_tensors
+ logprobs_lists = (
+ logprobs_tensors.tolists() if logprobs_tensors is not None else None
+ )
+
+ # Compute prompt logprobs if needed.
+ prompt_logprobs_dict = self._get_prompt_logprobs_dict(
+ hidden_states[:num_scheduled_tokens],
+ scheduler_output,
+ )
+
+ # Get the valid generated tokens.
+ sampled_token_ids = sampler_output.sampled_token_ids
+ max_gen_len = sampled_token_ids.shape[-1]
+ if max_gen_len == 1:
+ # No spec decode tokens.
+ valid_sampled_token_ids = sampled_token_ids.tolist()
+ else:
+ # Includes spec decode tokens.
+ valid_sampled_token_ids = self.rejection_sampler.parse_output(
+ sampled_token_ids,
+ self.input_batch.vocab_size,
+ )
+
+ for i in discard_sampled_tokens_req_indices:
+ valid_sampled_token_ids[i].clear()
+ # Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
+ # doesn't need to send them back.
+ # NOTE(woosuk): As an exception, when using PP, the scheduler sends
+ # the sampled tokens back, because there's no direct communication
+ # between the first-stage worker and the last-stage worker.
+ for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
+ if not sampled_ids:
+ continue
+
+ start_idx = self.input_batch.num_tokens_no_spec[req_idx]
+ end_idx = start_idx + len(sampled_ids)
+ assert end_idx <= self.model_config.max_model_len, (
+ "Sampled token IDs exceed the max model length. "
+ f"Total number of tokens: {end_idx} > max_model_len: "
+ f"{self.model_config.max_model_len}"
+ )
+
+ self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = (
+ sampled_ids
+ )
+ self.input_batch.num_tokens_no_spec[req_idx] = end_idx
+ self.input_batch.num_tokens[req_idx] = end_idx
+ req_id = self.input_batch.req_ids[req_idx]
+ req_state = self.requests[req_id]
+ req_state.output_token_ids.extend(sampled_ids)
+
+ spec_token_ids = self._get_spec_token_ids(
+ valid_sampled_token_ids,
+ sampling_metadata,
+ scheduler_output,
+ spec_decode_metadata,
+ positions,
+ num_scheduled_tokens,
+ hidden_states,
+ attn_metadata,
+ aux_hidden_states,
+ )
+
+ model_runner_output = ModelRunnerOutput(
+ req_ids=self.input_batch.req_ids,
+ req_id_to_index=self.input_batch.req_id_to_index,
+ sampled_token_ids=valid_sampled_token_ids,
+ spec_token_ids=spec_token_ids,
+ logprobs=logprobs_lists,
+ prompt_logprobs_dict=prompt_logprobs_dict,
+ pooler_output=[],
+ )
+
+ durations = ProfileExecuteDuration().pop_captured_sync()
+ if durations:
+ dr_str = [
+ f"[{tag}]:{duration:.2f}ms" for tag, duration in durations.items()
+ ]
+ captured_name = (
+ "Decode"
+ if self.attn_state == AscendAttentionState.DecodeOnly
+ else "Prefill"
+ )
+ logger.info(
+ "Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)
+ )
+
+ return model_runner_output
+
+ NPUModelRunner.execute_model = execute_model
+
+ @staticmethod
+ def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
+ # Update KVConnector with the KVConnector metadata forward().
+ if has_kv_transfer_group():
+ kv_connector = get_kv_transfer_group()
+ assert isinstance(kv_connector, KVConnectorBase_V1)
+ assert scheduler_output.kv_connector_metadata is not None
+ kv_connector.bind_connector_metadata(
+ scheduler_output.kv_connector_metadata
+ )
+ # Background KV cache transfers happen here.
+ # These transfers are designed to be async and the requests
+ # involved may be disjoint from the running requests.
+ # Do this here to save a collective_rpc.
+ kv_connector.start_load_kv(get_forward_context())
+
+ NPUModelRunner.maybe_setup_kv_connector = maybe_setup_kv_connector
+
+ @staticmethod
+ def maybe_wait_for_kv_save():
+ if has_kv_transfer_group():
+ get_kv_transfer_group().wait_for_save()
+
+ NPUModelRunner.maybe_wait_for_kv_save = maybe_wait_for_kv_save
+
+ def maybe_execute_ucm_sparse_begin(
+ self,
+ scheduler_output: "SchedulerOutput",
+ attn_metadata: CommonAttentionMetadata,
+ ):
+ if not has_ucm_sparse():
+ return
+ ucm_sparse = get_ucm_sparse()
+ ucm_sparse.build_sparse_meta(
+ scheduler_output, self.requests, self.input_batch, attn_metadata
+ )
+ ucm_sparse.execute_begin(scheduler_output)
+
+ def maybe_execute_ucm_sparse_finished(self):
+ if not has_ucm_sparse():
+ return
+ ucm_sparse = get_ucm_sparse()
+ ucm_sparse.execute_finished()
+
+ def ucm_sparse_request_finished_in_worker(self, request_id: str | int):
+ if not has_ucm_sparse():
+ return
+ ucm_sparse = get_ucm_sparse()
+ ucm_sparse.request_finished_in_worker(request_id)
+
+ NPUModelRunner.maybe_execute_ucm_sparse_begin = maybe_execute_ucm_sparse_begin
+ NPUModelRunner.maybe_execute_ucm_sparse_finished = (
+ maybe_execute_ucm_sparse_finished
+ )
+ NPUModelRunner.ucm_sparse_request_finished_in_worker = (
+ ucm_sparse_request_finished_in_worker
+ )
+ except ImportError as e:
+ logger.error(f"Failed to patch model_runner_v1.py: {e}", exc_info=True)
+ raise
+
+
+# ========================= vllm_ascend/worker/worker_v1.py =========================
+def _patch_worker_v1() -> None:
+ """Patch worker_v1.py for vLLM-Ascend."""
+ try:
+ import copy
+ from typing import Optional
+
+ from vllm.distributed.parallel_state import get_pp_group, get_tp_group
+ from vllm.logger import logger
+ from vllm.sequence import IntermediateTensors
+ from vllm.v1.core.sched.output import SchedulerOutput
+ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
+ from vllm_ascend.worker.worker_v1 import NPUWorker
+
+ from ucm.sparse.state import ensure_ucm_sparse_initialized
+
+ def execute_model(
+ self,
+ scheduler_output: "SchedulerOutput",
+ ) -> Optional[ModelRunnerOutput]:
+ intermediate_tensors = None
+ if not get_pp_group().is_first_rank:
+ intermediate_tensors = IntermediateTensors(
+ get_pp_group().recv_tensor_dict(all_gather_group=get_tp_group())
+ )
+
+ output = self.model_runner.execute_model(
+ scheduler_output, intermediate_tensors
+ )
+ parallel_config = self.vllm_config.parallel_config
+ if (
+ parallel_config.distributed_executor_backend != "external_launcher"
+ and not get_pp_group().is_last_rank
+ ):
+ assert isinstance(output, IntermediateTensors)
+ get_pp_group().send_tensor_dict(
+ output.tensors, all_gather_group=get_tp_group()
+ )
+
+ kv_connector_output = output.kv_connector_output
+ finished_sending = kv_connector_output.finished_sending
+ finished_recving = kv_connector_output.finished_recving
+
+ if not finished_sending and not finished_recving:
+ return EMPTY_MODEL_RUNNER_OUTPUT
+
+ new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
+ new_output.kv_connector_output = kv_connector_output
+ return new_output
+
+ assert isinstance(output, ModelRunnerOutput)
+ return output
+
+ NPUWorker.execute_model = execute_model
+
+ original_init_worker_distributed_environment = (
+ NPUWorker._init_worker_distributed_environment
+ )
+
+ def patched_init_worker_distributed_environment(self) -> None:
+ original_init_worker_distributed_environment(self)
+ ensure_ucm_sparse_initialized(self.vllm_config)
+
+ NPUWorker._init_worker_distributed_environment = (
+ patched_init_worker_distributed_environment
+ )
+ except ImportError as e:
+ logger.error(f"Failed to patch worker_v1.py: {e}", exc_info=True)
+ raise
diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py
new file mode 100644
index 000000000..2a697efb0
--- /dev/null
+++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py
@@ -0,0 +1,1992 @@
+#
+# MIT License
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+from __future__ import annotations
+
+import os
+
+from ucm.logger import init_logger
+
+logger = init_logger(__name__)
+
+ENABLE_SPARSE = os.getenv("ENABLE_SPARSE")
+
+
+def _enable_sparse() -> bool:
+ return ENABLE_SPARSE is not None and ENABLE_SPARSE.lower() == "true"
+
+
+def _apply_sparse_adapt() -> None:
+ """Apply sparse adapt patches."""
+ try:
+ if _enable_sparse():
+ _patch_block_table()
+ _patch_kv_cache_manager()
+ _patch_shared_storage_connector()
+ _patch_attention_layer()
+ _patch_mla_common()
+ _patch_gpu_model_runner()
+ _patch_gpu_worker()
+ _patch_scheduler_output()
+ _patch_scheduler()
+ logger.info("UCM sparse adapt patches applied successfully")
+ except Exception as e:
+ logger.error(f"Could not apply sparse adapt patches: {e}")
+ raise e
+
+
+# ==================== vllm/v1/core/sched/output.py ====================
+def _patch_scheduler_output() -> None:
+ """Patch scheduler output to add UCM sparse support."""
+ try:
+ from dataclasses import dataclass
+ from typing import TYPE_CHECKING, Optional
+
+ if TYPE_CHECKING:
+ import numpy as np
+ import numpy.typing as npt
+ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
+ KVConnectorMetadata,
+ )
+ from vllm.v1.core.sched import output
+ from vllm.v1.core.sched.output import CachedRequestData, NewRequestData
+
+ @dataclass
+ class SchedulerOutput:
+
+ # list of the requests that are scheduled for the first time.
+ # We cache the request's data in each worker process, so that we don't
+ # need to re-send it every scheduling step.
+ scheduled_new_reqs: list[NewRequestData]
+ # list of the requests that have been scheduled before.
+ # Since the request's data is already cached in the worker processes,
+ # we only send the diff to minimize the communication cost.
+ scheduled_cached_reqs: CachedRequestData
+
+ # req_id -> num_scheduled_tokens
+ # Number of tokens scheduled for each request.
+ num_scheduled_tokens: dict[str, int]
+ # Total number of tokens scheduled for all requests.
+ # Equal to sum(num_scheduled_tokens.values())
+ total_num_scheduled_tokens: int
+ # req_id -> spec_token_ids
+ # If a request does not have any spec decode tokens, it will not be
+ # included in the dictionary.
+ scheduled_spec_decode_tokens: dict[str, list[int]]
+ # req_id -> encoder input indices that need processing.
+ # E.g., if a request has [0, 1], it could mean the vision encoder needs
+ # to process that the request's 0-th and 1-th images in the current step.
+ scheduled_encoder_inputs: dict[str, list[int]]
+ # Number of common prefix blocks for all requests in each KV cache group.
+ # This can be used for cascade attention.
+ num_common_prefix_blocks: list[int]
+
+ # Request IDs that are finished in between the previous and the current
+ # steps. This is used to notify the workers about the finished requests
+ # so that they can free the cached states for those requests.
+ finished_req_ids: set[str]
+ # list of (req_id, encoder_input_index) tuples.
+ # Used to free the encoder cache.
+ free_encoder_input_ids: list[tuple[str, int]]
+
+ # Dict of request ids to their index within the batch
+ # for filling the next token bitmask
+ structured_output_request_ids: dict[str, int]
+ # the bitmask for the whole batch
+ grammar_bitmask: Optional[npt.NDArray[np.int32]]
+
+ # KV Cache Connector metadata.
+ kv_connector_metadata: Optional[KVConnectorMetadata] = None
+
+ # modified slots by sparse algorithm
+ req_sparsed_slots: dict[str, int] = None
+
+ # Set module and qualname to make the class pickleable
+ # This ensures pickle can find the class when serializing
+ SchedulerOutput.__module__ = output.__name__
+ SchedulerOutput.__qualname__ = "SchedulerOutput"
+
+ output.SchedulerOutput = SchedulerOutput
+
+ except ImportError:
+ logger.warning("Could not patch scheduler output - module not found")
+
+
+# ==================== vllm/attention/layer.py ====================
+def _patch_attention_layer() -> None:
+ """Patch attention layer & unified_attention_with_output C++ op."""
+ try:
+ from typing import Optional
+
+ import torch
+ from vllm.attention.layer import (
+ maybe_save_kv_layer_to_connector,
+ wait_for_kv_layer_from_connector,
+ )
+ from vllm.forward_context import ForwardContext, get_forward_context
+
+ from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
+
+ def maybe_execute_sparse_attention_begin(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ layer_name: str,
+ forward_context: ForwardContext,
+ phase: Optional[str] = None,
+ ):
+ if not has_ucm_sparse():
+ return
+
+ ucm_sparse = get_ucm_sparse()
+
+ attn_metadata = forward_context.attn_metadata
+ if attn_metadata is None:
+ return
+
+ ucm_sparse.attention_begin(
+ query, key, value, layer_name, forward_context, phase
+ )
+
+ def maybe_execute_sparse_attention_finished(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_output: torch.Tensor,
+ layer_name: str,
+ forward_context: ForwardContext,
+ phase: Optional[str] = None,
+ ):
+ if not has_ucm_sparse():
+ return
+
+ ucm_sparse = get_ucm_sparse()
+
+ attn_metadata = forward_context.attn_metadata
+ if attn_metadata is None:
+ return
+
+ ucm_sparse.attention_finished(
+ query, key, value, attn_output, layer_name, forward_context, phase
+ )
+
+ vllm_ops = torch.ops.vllm
+ orig_unified_attention_with_output = vllm_ops.unified_attention_with_output
+ orig_unified_attention = vllm_ops.unified_attention
+
+ def _wrap_op_overload(orig, impl):
+ class _Wrapper:
+ def __init__(self, orig):
+ self._orig = orig
+
+ def __call__(self, *args, **kwargs):
+ return impl(*args, **kwargs)
+
+ def __getattr__(self, name):
+ return getattr(self._orig, name)
+
+ return _Wrapper(orig)
+
+ def unified_attention_impl(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ layer_name: str,
+ ) -> torch.Tensor:
+ wait_for_kv_layer_from_connector(layer_name)
+
+ forward_context: ForwardContext = get_forward_context()
+ attn_metadata = forward_context.attn_metadata
+ if isinstance(attn_metadata, dict):
+ attn_metadata = attn_metadata[layer_name]
+ self = forward_context.no_compile_layers[layer_name]
+ kv_cache = self.kv_cache[forward_context.virtual_engine]
+ maybe_execute_sparse_attention_begin(
+ query, key, value, layer_name, forward_context
+ )
+ output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
+ maybe_execute_sparse_attention_finished(
+ query, key, value, output, layer_name, forward_context
+ )
+ maybe_save_kv_layer_to_connector(layer_name, kv_cache)
+ return output
+
+ def unified_attention_with_output_impl(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ output: torch.Tensor,
+ layer_name: str,
+ output_scale: Optional[torch.Tensor] = None,
+ ) -> None:
+ wait_for_kv_layer_from_connector(layer_name)
+ forward_context: ForwardContext = get_forward_context()
+ attn_metadata = forward_context.attn_metadata
+ if isinstance(attn_metadata, dict):
+ attn_metadata = attn_metadata[layer_name]
+ self = forward_context.no_compile_layers[layer_name]
+ kv_cache = self.kv_cache[forward_context.virtual_engine]
+ if not self.use_mla:
+ maybe_execute_sparse_attention_begin(
+ query, key, value, layer_name, forward_context
+ )
+ self.impl.forward(
+ self,
+ query,
+ key,
+ value,
+ kv_cache,
+ attn_metadata,
+ output=output,
+ output_scale=output_scale,
+ )
+ if not self.use_mla:
+ maybe_execute_sparse_attention_finished(
+ query, key, value, output, layer_name, forward_context
+ )
+
+ maybe_save_kv_layer_to_connector(layer_name, kv_cache)
+
+ vllm_ops.unified_attention_with_output = _wrap_op_overload(
+ orig_unified_attention_with_output, unified_attention_with_output_impl
+ )
+ vllm_ops.unified_attention = _wrap_op_overload(
+ orig_unified_attention, unified_attention_impl
+ )
+ from vllm.attention import layer
+
+ layer.maybe_execute_sparse_attention_begin = (
+ maybe_execute_sparse_attention_begin
+ )
+ layer.maybe_execute_sparse_attention_finished = (
+ maybe_execute_sparse_attention_finished
+ )
+ layer.unified_attention = unified_attention_impl
+ layer.unified_attention_with_output = unified_attention_with_output_impl
+
+ except ImportError:
+ logger.warning(
+ "Could not patch unified attention with output - module not found"
+ )
+
+
+# ==================== v1/shared_storage_connector.py ====================
+def _patch_shared_storage_connector() -> None:
+ """Patch kv connector utils to add UCM sparse support."""
+ try:
+ from dataclasses import dataclass, field
+
+ from vllm.distributed.kv_transfer.kv_connector.v1 import (
+ shared_storage_connector,
+ )
+ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
+ KVConnectorMetadata,
+ )
+ from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
+ ReqMeta,
+ )
+
+ @dataclass
+ class SharedStorageConnectorMetadata(KVConnectorMetadata):
+ requests: list[ReqMeta] = field(default_factory=list)
+
+ def add_request(
+ self,
+ token_ids: list[int],
+ block_ids: list[int],
+ block_size: int,
+ is_store: bool,
+ ) -> None:
+ self.requests.append(
+ ReqMeta.make_meta(token_ids, block_ids, block_size, is_store)
+ )
+
+ shared_storage_connector.SharedStorageConnectorMetadata = (
+ SharedStorageConnectorMetadata
+ )
+ except ImportError:
+ logger.warning("Could not patch shared storage connector - module not found")
+
+
+# ==================== vllm/v1/attention/backends/mla/common.py ====================
+def _patch_mla_common() -> None:
+ """Patch mla common to add UCM sparse support."""
+ try:
+ from typing import Optional, TypeVar
+
+ import torch
+ from vllm import _custom_ops as ops
+ from vllm.attention.backends.abstract import AttentionLayer
+ from vllm.attention.layer import (
+ maybe_execute_sparse_attention_begin,
+ maybe_execute_sparse_attention_finished,
+ )
+ from vllm.forward_context import ForwardContext, get_forward_context
+ from vllm.v1.attention.backends.mla.common import (
+ MLACommonImpl,
+ MLACommonMetadata,
+ )
+
+ M = TypeVar("M", bound=MLACommonMetadata)
+
+ def forward(
+ self,
+ layer: AttentionLayer,
+ q: torch.Tensor,
+ k_c_normed: torch.Tensor, # key in unified attn
+ k_pe: torch.Tensor, # value in unified attn
+ kv_cache: torch.Tensor,
+ attn_metadata: M,
+ output: Optional[torch.Tensor] = None,
+ output_scale: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ forward_context: ForwardContext = get_forward_context()
+ assert output is not None, "Output tensor must be provided."
+
+ if output_scale is not None:
+ raise NotImplementedError(
+ "fused output quantization is not yet supported"
+ " for MLACommonImpl"
+ )
+
+ if attn_metadata is None:
+ # The zero fill is required when used with DP + EP
+ # to ensure all ranks within a DP group compute the
+ # same expert outputs.
+ return output.fill_(0)
+
+ num_actual_toks = attn_metadata.num_actual_tokens
+
+ # Inputs and outputs may be padded for CUDA graphs
+ output_padded = output
+ output = output[:num_actual_toks, ...]
+ q = q[:num_actual_toks, ...]
+ k_c_normed = k_c_normed[:num_actual_toks, ...]
+ k_pe = k_pe[:num_actual_toks, ...]
+
+ assert (
+ attn_metadata.num_decodes is not None
+ and attn_metadata.num_prefills is not None
+ and attn_metadata.num_decode_tokens is not None
+ )
+
+ has_decode = attn_metadata.num_decodes > 0
+ has_prefill = attn_metadata.num_prefills > 0
+ num_decode_tokens = attn_metadata.num_decode_tokens
+
+ decode_q = q[:num_decode_tokens]
+
+ prefill_q = q[num_decode_tokens:]
+ prefill_k_pe = k_pe[num_decode_tokens:]
+ prefill_k_c_normed = k_c_normed[num_decode_tokens:]
+
+ # write the latent and rope to kv cache
+ if kv_cache.numel() > 0:
+ ops.concat_and_cache_mla(
+ k_c_normed,
+ k_pe.squeeze(1),
+ kv_cache,
+ attn_metadata.slot_mapping.flatten(),
+ kv_cache_dtype=self.kv_cache_dtype,
+ scale=layer._k_scale,
+ )
+
+ if has_prefill:
+ maybe_execute_sparse_attention_begin(
+ prefill_q,
+ prefill_k_c_normed,
+ prefill_k_pe,
+ layer.layer_name,
+ forward_context,
+ "prefill",
+ )
+ output[num_decode_tokens:] = self._forward_prefill(
+ prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata
+ )
+ maybe_execute_sparse_attention_finished(
+ prefill_q,
+ prefill_k_c_normed,
+ prefill_k_pe,
+ output[num_decode_tokens:],
+ layer.layer_name,
+ forward_context,
+ "prefill",
+ )
+ if has_decode:
+ assert attn_metadata.decode is not None
+ decode_q_nope, decode_q_pe = decode_q.split(
+ [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+ )
+ # Convert from (B, N, P) to (N, B, P)
+ decode_q_nope = decode_q_nope.transpose(0, 1)
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
+ decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
+ # Convert from (N, B, L) to (B, N, L)
+ decode_ql_nope = decode_ql_nope.transpose(0, 1)
+ maybe_execute_sparse_attention_begin(
+ torch.cat([decode_ql_nope, decode_q_pe], dim=-1),
+ decode_ql_nope,
+ decode_q_pe,
+ layer.layer_name,
+ forward_context,
+ "decode",
+ )
+ output[:num_decode_tokens] = self._forward_decode(
+ decode_ql_nope, decode_q_pe, kv_cache, attn_metadata
+ )
+ maybe_execute_sparse_attention_finished(
+ torch.cat([decode_ql_nope, decode_q_pe], dim=-1),
+ decode_ql_nope,
+ decode_q_pe,
+ output[:num_decode_tokens],
+ layer.layer_name,
+ forward_context,
+ "decode",
+ )
+ return output_padded
+
+ MLACommonImpl.forward = forward
+ except ImportError:
+ logger.warning("Could not patch mla common - module not found")
+
+
+# ==================== v1/core/kv_cache_manager.py ====================
+def _patch_kv_cache_manager() -> None:
+ """Patch kv cache manager to add UCM sparse support."""
+ try:
+ from typing import Optional, Union
+
+ from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
+ from vllm.v1.request import Request
+
+ from ucm.sparse.base import INVALID_SLOT
+ from ucm.sparse.state import get_ucm_sparse
+
+ original_allocate_slots = KVCacheManager.allocate_slots
+
+ def patched_allocate_slots(
+ self,
+ request: Request,
+ num_new_tokens: int,
+ num_new_computed_tokens: int = 0,
+ new_computed_blocks: Optional[KVCacheBlocks] = None,
+ num_draft_tokens: int = 0,
+ num_lookahead_tokens: int = 0,
+ delay_cache_blocks: bool = False,
+ num_slots_sparsed: Union[None, int] = None,
+ ) -> Optional[KVCacheBlocks]:
+ if num_new_tokens == 0:
+ raise ValueError("num_new_tokens must be greater than 0")
+ # Only route to UCM sparse path when caller explicitly provided
+ # a valid sparsified slot count.
+ if (num_slots_sparsed is not None) and (num_slots_sparsed != INVALID_SLOT):
+ return get_ucm_sparse().allocate_slots(self, request, num_slots_sparsed)
+ return original_allocate_slots(
+ self,
+ request,
+ num_new_tokens,
+ num_new_computed_tokens,
+ new_computed_blocks,
+ num_draft_tokens,
+ num_lookahead_tokens,
+ delay_cache_blocks,
+ )
+
+ KVCacheManager.allocate_slots = patched_allocate_slots
+ except ImportError:
+ logger.warning("Could not patch kv cache manager - module not found")
+
+
+# ==================== vllm/v1/core/sched/scheduler.py ====================
+def _patch_scheduler() -> None:
+ """Patch Scheduler to add num_output_tokens field."""
+ try:
+ import itertools
+ import time
+ from collections import defaultdict
+ from collections.abc import Iterable
+ from typing import Optional
+
+ from vllm.distributed.kv_events import KVEventBatch
+ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
+ MultiConnector,
+ )
+ from vllm.v1.core.sched.output import (
+ CachedRequestData,
+ NewRequestData,
+ SchedulerOutput,
+ )
+ from vllm.v1.core.sched.request_queue import (
+ SchedulingPolicy,
+ create_request_queue,
+ )
+ from vllm.v1.core.sched.scheduler import Scheduler
+ from vllm.v1.core.sched.utils import check_stop
+ from vllm.v1.engine import (
+ EngineCoreEventType,
+ EngineCoreOutput,
+ EngineCoreOutputs,
+ )
+ from vllm.v1.outputs import ModelRunnerOutput
+ from vllm.v1.request import Request, RequestStatus
+ from vllm.v1.spec_decode.metrics import SpecDecodingStats
+
+ from ucm.sparse.base import INVALID_SLOT, UcmSparseRole
+ from ucm.sparse.state import ensure_ucm_sparse_initialized, get_ucm_sparse
+
+ def init_ucm_sparse(self):
+ self.ucm_sparse = None
+ if self.vllm_config.kv_transfer_config is not None:
+ if (
+ "ucm_sparse_config"
+ in self.vllm_config.kv_transfer_config.kv_connector_extra_config
+ ):
+ ensure_ucm_sparse_initialized(
+ self.vllm_config, role=UcmSparseRole.SCHEDULER
+ )
+ self.ucm_sparse = get_ucm_sparse()
+ logger.info(
+ "UCM Sparse initialized successfully: {}".format(
+ self.ucm_sparse
+ )
+ )
+
+ def patched_schedule(self) -> SchedulerOutput:
+ # NOTE(woosuk) on the scheduling algorithm:
+ # There's no "decoding phase" nor "prefill phase" in the scheduler.
+ # Each request just has the num_computed_tokens and
+ # num_tokens_with_spec. num_tokens_with_spec =
+ # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
+ # At each step, the scheduler tries to assign tokens to the requests
+ # so that each request's num_computed_tokens can catch up its
+ # num_tokens_with_spec. This is general enough to cover
+ # chunked prefills, prefix caching, speculative decoding,
+ # and the "jump decoding" optimization in the future.
+
+ scheduled_new_reqs: list[Request] = []
+ scheduled_resumed_reqs: list[Request] = []
+ scheduled_running_reqs: list[Request] = []
+ preempted_reqs: list[Request] = []
+
+ # NOTE: structured_output_request_ids maps
+ # a request's (request that uses structured output)
+ # request_id to the running request index.
+ # This will helps us determine to slice the grammar bitmask
+ # and only applies valid mask for requests that
+ # uses structured decoding.
+ structured_output_request_ids: dict[str, int] = {}
+
+ req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {}
+ num_scheduled_tokens: dict[str, int] = {}
+ token_budget = self.max_num_scheduled_tokens
+ # Encoder-related.
+ scheduled_encoder_inputs: dict[str, list[int]] = {}
+ encoder_budget = self.max_num_encoder_input_tokens
+ # Spec decode-related.
+ scheduled_spec_decode_tokens: dict[str, list[int]] = {}
+
+ # For logging.
+ scheduled_timestamp = time.monotonic()
+
+ # First, schedule the RUNNING requests.
+ req_index = 0
+ req_sparsed_slots: dict[str, int] = {}
+ if not hasattr(self, "ucm_sparse"):
+ init_ucm_sparse(self)
+ while req_index < len(self.running) and token_budget > 0:
+ request = self.running[req_index]
+ num_slots_sparsed = INVALID_SLOT
+ if self.ucm_sparse:
+ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(
+ request
+ )
+ req_sparsed_slots.update({request.request_id: num_slots_sparsed})
+
+ num_new_tokens = (
+ request.num_tokens_with_spec - request.num_computed_tokens
+ )
+ if (
+ 0
+ < self.scheduler_config.long_prefill_token_threshold
+ < num_new_tokens
+ ):
+ num_new_tokens = self.scheduler_config.long_prefill_token_threshold
+ num_new_tokens = min(num_new_tokens, token_budget)
+
+ # Make sure the input position does not exceed the max model len.
+ # This is necessary when using spec decoding.
+ num_new_tokens = min(
+ num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
+ )
+
+ # Schedule encoder inputs.
+ encoder_inputs_to_schedule = None
+ new_encoder_budget = encoder_budget
+ if request.has_encoder_inputs:
+ (encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget) = (
+ self._try_schedule_encoder_inputs(
+ request,
+ request.num_computed_tokens,
+ num_new_tokens,
+ encoder_budget,
+ )
+ )
+
+ if num_new_tokens == 0:
+ # The request cannot be scheduled because one of the following
+ # reasons:
+ # 1. No new tokens to schedule. This may happen when PP>1 and
+ # we have already scheduled all prompt tokens but they are
+ # not finished yet.
+ # 2. The encoder budget is exhausted.
+ # 3. The encoder cache is exhausted.
+ # NOTE(woosuk): Here, by doing `continue` instead of `break`,
+ # we do not strictly follow the FCFS scheduling policy and
+ # allow the lower-priority requests to be scheduled.
+ req_index += 1
+ continue
+
+ num_draft_tokens = max(
+ num_new_tokens + request.num_computed_tokens - request.num_tokens, 0
+ )
+
+ while True:
+ new_blocks = self.kv_cache_manager.allocate_slots(
+ request,
+ num_new_tokens,
+ num_draft_tokens=num_draft_tokens,
+ num_lookahead_tokens=self.num_lookahead_tokens,
+ num_slots_sparsed=num_slots_sparsed,
+ )
+ if new_blocks is None:
+ # The request cannot be scheduled.
+ # Preempt the lowest-priority request.
+ if self.policy == SchedulingPolicy.PRIORITY:
+ preempted_req = max(
+ self.running,
+ key=lambda r: (r.priority, r.arrival_time),
+ )
+ self.running.remove(preempted_req)
+ else:
+ preempted_req = self.running.pop()
+
+ self.kv_cache_manager.free(preempted_req)
+ preempted_req.status = RequestStatus.PREEMPTED
+ preempted_req.num_computed_tokens = 0
+ if self.log_stats:
+ preempted_req.record_event(
+ EngineCoreEventType.PREEMPTED, scheduled_timestamp
+ )
+
+ self.waiting.prepend_request(preempted_req)
+ preempted_reqs.append(preempted_req)
+ if preempted_req == request:
+ # No more request to preempt.
+ can_schedule = False
+ break
+ else:
+ # The request can be scheduled.
+ can_schedule = True
+ break
+ if not can_schedule:
+ break
+ assert new_blocks is not None
+
+ # Schedule the request.
+ scheduled_running_reqs.append(request)
+ if request.use_structured_output:
+ # PERF: in case of chunked prefill,
+ # request might not include any new tokens.
+ # Therefore, we might introduce some additional
+ # cycle to fill in the bitmask, which could be a big no-op.
+ structured_output_request_ids[request.request_id] = req_index
+ req_to_new_block_ids[request.request_id] = new_blocks.get_block_ids()
+ num_scheduled_tokens[request.request_id] = num_new_tokens
+ token_budget -= num_new_tokens
+ req_index += 1
+
+ # Speculative decode related.
+ if request.spec_token_ids:
+ num_scheduled_spec_tokens = (
+ num_new_tokens
+ + request.num_computed_tokens
+ - request.num_tokens
+ )
+ if num_scheduled_spec_tokens > 0:
+ # Trim spec_token_ids list to num_scheduled_spec_tokens.
+ del request.spec_token_ids[num_scheduled_spec_tokens:]
+ scheduled_spec_decode_tokens[request.request_id] = (
+ request.spec_token_ids
+ )
+
+ # Encoder-related.
+ if encoder_inputs_to_schedule:
+ scheduled_encoder_inputs[request.request_id] = (
+ encoder_inputs_to_schedule
+ )
+ # Allocate the encoder cache.
+ for i in encoder_inputs_to_schedule:
+ self.encoder_cache_manager.allocate(request, i)
+ encoder_budget = new_encoder_budget
+
+ # Record the LoRAs in scheduled_running_reqs
+ scheduled_loras: set[int] = set()
+ if self.lora_config:
+ scheduled_loras = set(
+ req.lora_request.lora_int_id
+ for req in scheduled_running_reqs
+ if req.lora_request and req.lora_request.lora_int_id > 0
+ )
+ assert len(scheduled_loras) <= self.lora_config.max_loras
+
+ # Use a temporary RequestQueue to collect requests that need to be
+ # skipped and put back at the head of the waiting queue later
+ skipped_waiting_requests = create_request_queue(self.policy)
+
+ # Next, schedule the WAITING requests.
+ if not preempted_reqs:
+ while self.waiting and token_budget > 0:
+ if len(self.running) == self.max_num_running_reqs:
+ break
+
+ request = self.waiting.peek_request()
+ num_slots_sparsed = INVALID_SLOT
+ if self.ucm_sparse:
+ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(
+ request
+ )
+ req_sparsed_slots.update({request.request_id: num_slots_sparsed})
+
+ # KVTransfer: skip request if still waiting for remote kvs.
+ if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
+ is_ready = self._update_waiting_for_remote_kv(request)
+ if is_ready:
+ request.status = RequestStatus.WAITING
+ else:
+ logger.debug(
+ "%s is still in WAITING_FOR_REMOTE_KVS state.",
+ request.request_id,
+ )
+ self.waiting.pop_request()
+ skipped_waiting_requests.prepend_request(request)
+ continue
+
+ # Skip request if the structured output request is still waiting
+ # for FSM compilation.
+ if request.status == RequestStatus.WAITING_FOR_FSM:
+ structured_output_req = request.structured_output_request
+ if structured_output_req and structured_output_req.grammar:
+ request.status = RequestStatus.WAITING
+ else:
+ self.waiting.pop_request()
+ skipped_waiting_requests.prepend_request(request)
+ continue
+
+ # Check that adding the request still respects the max_loras
+ # constraint.
+ if (
+ self.lora_config
+ and request.lora_request
+ and (
+ len(scheduled_loras) == self.lora_config.max_loras
+ and request.lora_request.lora_int_id not in scheduled_loras
+ )
+ ):
+ # Scheduling would exceed max_loras, skip.
+ self.waiting.pop_request()
+ skipped_waiting_requests.prepend_request(request)
+ continue
+
+ num_external_computed_tokens = 0
+ load_kv_async = False
+
+ # Get already-cached tokens.
+ if request.num_computed_tokens == 0:
+ # Get locally-cached tokens.
+ new_computed_blocks, num_new_local_computed_tokens = (
+ self.kv_cache_manager.get_computed_blocks(request)
+ )
+
+ # Get externally-cached tokens if using a KVConnector.
+ if self.connector is not None:
+ num_external_computed_tokens, load_kv_async = (
+ self.connector.get_num_new_matched_tokens(
+ request, num_new_local_computed_tokens
+ )
+ )
+
+ # Total computed tokens (local + external).
+ num_computed_tokens = (
+ num_new_local_computed_tokens + num_external_computed_tokens
+ )
+ # KVTransfer: WAITING reqs have num_computed_tokens > 0
+ # after async KV recvs are completed.
+ else:
+ new_computed_blocks = (
+ self.kv_cache_manager.create_empty_block_list()
+ )
+ num_new_local_computed_tokens = 0
+ num_computed_tokens = request.num_computed_tokens
+
+ encoder_inputs_to_schedule = None
+ new_encoder_budget = encoder_budget
+
+ # KVTransfer: loading remote KV, do not allocate for new work.
+ if load_kv_async:
+ assert num_external_computed_tokens > 0
+ num_new_tokens = 0
+ # Number of tokens to be scheduled.
+ else:
+ # We use `request.num_tokens` instead of
+ # `request.num_prompt_tokens` to consider the resumed
+ # requests, which have output tokens.
+ num_new_tokens = request.num_tokens - num_computed_tokens
+ if (
+ 0
+ < self.scheduler_config.long_prefill_token_threshold
+ < num_new_tokens
+ ):
+ num_new_tokens = (
+ self.scheduler_config.long_prefill_token_threshold
+ )
+
+ # chunked prefill has to be enabled explicitly to allow
+ # pooling requests to be chunked
+ if (
+ not self.scheduler_config.chunked_prefill_enabled
+ and num_new_tokens > token_budget
+ ):
+ self.waiting.pop_request()
+ skipped_waiting_requests.prepend_request(request)
+ continue
+
+ num_new_tokens = min(num_new_tokens, token_budget)
+ assert num_new_tokens > 0
+
+ # Schedule encoder inputs.
+ if request.has_encoder_inputs:
+ (
+ encoder_inputs_to_schedule,
+ num_new_tokens,
+ new_encoder_budget,
+ ) = self._try_schedule_encoder_inputs(
+ request,
+ num_computed_tokens,
+ num_new_tokens,
+ encoder_budget,
+ )
+ if num_new_tokens == 0:
+ # The request cannot be scheduled.
+ break
+
+ new_blocks = self.kv_cache_manager.allocate_slots(
+ request,
+ num_new_tokens + num_external_computed_tokens,
+ num_new_local_computed_tokens,
+ new_computed_blocks,
+ num_lookahead_tokens=self.num_lookahead_tokens,
+ delay_cache_blocks=load_kv_async,
+ num_slots_sparsed=num_slots_sparsed,
+ )
+ if new_blocks is None:
+ # The request cannot be scheduled.
+ break
+
+ # KVTransfer: the connector uses this info to determine
+ # if a load is needed. Note that
+ # This information is used to determine if a load is
+ # needed for this request.
+ if self.connector is not None:
+ self.connector.update_state_after_alloc(
+ request,
+ new_computed_blocks + new_blocks,
+ num_external_computed_tokens,
+ )
+
+ # Request was already popped from self.waiting
+ # unless it was re-added above due to new_blocks being None.
+ request = self.waiting.pop_request()
+ if load_kv_async:
+ # If loading async, allocate memory and put request
+ # into the WAITING_FOR_REMOTE_KV state.
+ skipped_waiting_requests.prepend_request(request)
+ request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
+ continue
+
+ if request.use_structured_output:
+ structured_output_request_ids[request.request_id] = req_index
+ req_index += 1
+ self.running.append(request)
+ if self.log_stats:
+ request.record_event(
+ EngineCoreEventType.SCHEDULED, scheduled_timestamp
+ )
+ if request.status == RequestStatus.WAITING:
+ scheduled_new_reqs.append(request)
+ elif request.status == RequestStatus.PREEMPTED:
+ scheduled_resumed_reqs.append(request)
+ else:
+ raise RuntimeError(f"Invalid request status: {request.status}")
+
+ if self.lora_config and request.lora_request:
+ scheduled_loras.add(request.lora_request.lora_int_id)
+ req_to_new_block_ids[request.request_id] = (
+ self.kv_cache_manager.get_block_ids(request.request_id)
+ )
+ num_scheduled_tokens[request.request_id] = num_new_tokens
+ token_budget -= num_new_tokens
+ request.status = RequestStatus.RUNNING
+ request.num_computed_tokens = num_computed_tokens
+ # Count the number of prefix cached tokens.
+ if request.num_cached_tokens < 0:
+ request.num_cached_tokens = num_computed_tokens
+ # Encoder-related.
+ if encoder_inputs_to_schedule:
+ scheduled_encoder_inputs[request.request_id] = (
+ encoder_inputs_to_schedule
+ )
+ # Allocate the encoder cache.
+ for i in encoder_inputs_to_schedule:
+ self.encoder_cache_manager.allocate(request, i)
+ encoder_budget = new_encoder_budget
+
+ # Put back any skipped requests at the head of the waiting queue
+ if skipped_waiting_requests:
+ self.waiting.prepend_requests(skipped_waiting_requests)
+
+ # Check if the scheduling constraints are satisfied.
+ total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
+ assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
+ assert token_budget >= 0
+ assert len(self.running) <= self.max_num_running_reqs
+ # Since some requests in the RUNNING queue may not be scheduled in
+ # this step, the total number of scheduled requests can be smaller than
+ # len(self.running).
+ assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
+ scheduled_running_reqs
+ ) <= len(self.running)
+
+ # Get the longest common prefix among all requests in the running queue.
+ # This can be potentially used for cascade attention.
+ num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
+ if self.running:
+ any_request = self.running[0]
+ num_common_prefix_blocks = (
+ self.kv_cache_manager.get_num_common_prefix_blocks(
+ any_request, len(self.running)
+ )
+ )
+
+ grammar_bitmask = self.structured_output_manager.grammar_bitmask(
+ self.requests,
+ structured_output_request_ids,
+ scheduled_spec_decode_tokens,
+ )
+ # Construct the scheduler output.
+ new_reqs_data = [
+ NewRequestData.from_request(req, req_to_new_block_ids[req.request_id])
+ for req in scheduled_new_reqs
+ ]
+ cached_reqs_data = self._make_cached_request_data(
+ scheduled_running_reqs,
+ scheduled_resumed_reqs,
+ num_scheduled_tokens,
+ scheduled_spec_decode_tokens,
+ req_to_new_block_ids,
+ )
+ scheduler_output = SchedulerOutput(
+ scheduled_new_reqs=new_reqs_data,
+ scheduled_cached_reqs=cached_reqs_data,
+ num_scheduled_tokens=num_scheduled_tokens,
+ total_num_scheduled_tokens=total_num_scheduled_tokens,
+ scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
+ scheduled_encoder_inputs=scheduled_encoder_inputs,
+ num_common_prefix_blocks=num_common_prefix_blocks,
+ req_sparsed_slots=req_sparsed_slots,
+ # finished_req_ids is an existing state in the scheduler,
+ # instead of being newly scheduled in this step.
+ # It contains the request IDs that are finished in between
+ # the previous and the current steps.
+ finished_req_ids=self.finished_req_ids,
+ free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
+ structured_output_request_ids=structured_output_request_ids,
+ grammar_bitmask=grammar_bitmask,
+ )
+
+ # NOTE(Kuntai): this function is designed for multiple purposes:
+ # 1. Plan the KV cache store
+ # 2. Wrap up all the KV cache load / save ops into an opaque object
+ # 3. Clear the internal states of the connector
+ if self.connector is not None:
+ meta = self.connector.build_connector_meta(scheduler_output)
+ scheduler_output.kv_connector_metadata = meta
+
+ events = self.kv_cache_manager.take_events()
+ if events:
+ batch = KVEventBatch(ts=time.time(), events=events)
+ self.kv_event_publisher.publish(batch)
+
+ self._update_after_schedule(scheduler_output)
+ return scheduler_output
+
+ Scheduler.schedule = patched_schedule
+
+ def patched_add_request(self, request: Request) -> None:
+ if not hasattr(self, "ucm_sparse"):
+ init_ucm_sparse(self)
+ self.waiting.add_request(request)
+ self.requests[request.request_id] = request
+ if self.ucm_sparse:
+ self.ucm_sparse.request_begin(
+ request.request_id, request.prompt_token_ids
+ )
+ if self.log_stats:
+ request.record_event(EngineCoreEventType.QUEUED)
+
+ Scheduler.add_request = patched_add_request
+
+ original_free_request = Scheduler._free_request
+
+ def patched_free_request(self, request: Request):
+ assert request.is_finished()
+ if not hasattr(self, "ucm_sparse"):
+ init_ucm_sparse(self)
+ if self.ucm_sparse:
+ self.ucm_sparse.request_finished_in_scheduler(request.request_id)
+ original_free_request(self, request)
+
+ Scheduler._free_request = patched_free_request
+ except ImportError:
+ logger.warning("Could not patch Scheduler - module not found")
+
+
+# ==================== vllm/v1/worker/block_table.py ====================
+def _patch_block_table() -> None:
+ """Patch block table to add UCM sparse support."""
+ try:
+ from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
+
+ def reset_row(
+ self,
+ row_idx: int,
+ ) -> None:
+ self.num_blocks_per_row[row_idx] = 0
+ self.block_table[row_idx].fill_(0)
+ self.block_table_cpu[row_idx].fill_(0)
+ self.block_table_np[row_idx].fill(0)
+
+ BlockTable.reset_row = reset_row
+
+ def reset_row(self, row_idx: int) -> None:
+ for i, block_table in enumerate(self.block_tables):
+ block_table.reset_row(row_idx)
+
+ MultiGroupBlockTable.reset_row = reset_row
+ except ImportError:
+ logger.warning("Could not patch multigroup block table - module not found")
+
+
+# ==================== vllm/v1/worker/gpu_model_runner.py ====================
+def _patch_gpu_model_runner() -> None:
+ """Patch gpu model runner to add UCM sparse support."""
+ try:
+ import copy
+ from typing import TYPE_CHECKING, Any, Optional
+
+ import numpy as np
+ import torch
+ import vllm.envs as envs
+ from vllm.distributed.kv_transfer import (
+ get_kv_transfer_group,
+ has_kv_transfer_group,
+ )
+ from vllm.distributed.parallel_state import get_pp_group
+ from vllm.forward_context import set_forward_context
+ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
+ from vllm.sampling_params import SamplingType
+ from vllm.sequence import IntermediateTensors
+ from vllm.utils import round_up
+ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
+ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
+ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
+ from vllm.v1.worker.block_table import BlockTable
+ from vllm.v1.worker.gpu_input_batch import CachedRequestState
+
+ from ucm.sparse.base import INVALID_SLOT
+ from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
+
+ if TYPE_CHECKING:
+ from vllm.v1.core.sched.output import SchedulerOutput
+
+ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
+
+ @staticmethod
+ def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]:
+ if has_kv_transfer_group():
+ return get_kv_transfer_group().wait_for_save()
+ return None
+
+ GPUModelRunner.maybe_wait_for_kv_save = maybe_wait_for_kv_save
+
+ def maybe_execute_ucm_sparse_begin(
+ self,
+ scheduler_output: "SchedulerOutput",
+ attn_metadata: CommonAttentionMetadata,
+ ):
+ if not has_ucm_sparse():
+ return
+ ucm_sparse = get_ucm_sparse()
+ ucm_sparse.build_sparse_meta(
+ scheduler_output, self.requests, self.input_batch, attn_metadata
+ )
+ ucm_sparse.execute_begin(scheduler_output)
+
+ def maybe_execute_ucm_sparse_finished(self):
+ if not has_ucm_sparse():
+ return
+ ucm_sparse = get_ucm_sparse()
+ ucm_sparse.execute_finished()
+
+ def ucm_sparse_request_finished_in_worker(self, request_id: str | int):
+ if not has_ucm_sparse():
+ return
+ ucm_sparse = get_ucm_sparse()
+ ucm_sparse.request_finished_in_worker(request_id)
+
+ GPUModelRunner.maybe_execute_ucm_sparse_begin = maybe_execute_ucm_sparse_begin
+ GPUModelRunner.maybe_execute_ucm_sparse_finished = (
+ maybe_execute_ucm_sparse_finished
+ )
+ GPUModelRunner.ucm_sparse_request_finished_in_worker = (
+ ucm_sparse_request_finished_in_worker
+ )
+
+ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
+ """Update the cached states and the persistent batch with the scheduler
+ output.
+
+ The updated states are used by the `_prepare_inputs` function to create
+ the input GPU tensors for the model.
+
+ The SamplingMetadata is updated and copied to the GPU if there is a
+ new/resumed/paused/finished request in the batch.
+ """
+ # Remove finished requests from the cached states.
+ for req_id in scheduler_output.finished_req_ids:
+ self.ucm_sparse_request_finished_in_worker(req_id)
+ self.requests.pop(req_id, None)
+ self.encoder_cache.pop(req_id, None)
+ # Remove the finished requests from the persistent batch.
+ # NOTE(woosuk): There could be an edge case where finished_req_ids and
+ # scheduled_req_ids overlap. This happens when a request is aborted and
+ # then resubmitted with the same ID. In this case, we treat them as two
+ # distinct requests - clearing the cached states for the first request
+ # and handling the second as a new request.
+ for req_id in scheduler_output.finished_req_ids:
+ self.input_batch.remove_request(req_id)
+
+ # Free the cached encoder outputs.
+ for req_id, input_id in scheduler_output.free_encoder_input_ids:
+ encoder_outputs = self.encoder_cache.get(req_id)
+ if encoder_outputs is not None:
+ encoder_outputs.pop(input_id, None)
+ if not encoder_outputs:
+ self.encoder_cache.pop(req_id, None)
+
+ # Remove the unscheduled requests from the persistent batch.
+ # NOTE(woosuk): The unscheduled requests are either preempted requests
+ # or running requests that are not scheduled in this step. We remove
+ # them from the persistent batch but keep their cached states since
+ # they will be scheduled again sometime in the future.
+ scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
+ cached_req_ids = self.input_batch.req_id_to_index.keys()
+ unscheduled_req_ids = cached_req_ids - scheduled_req_ids
+ # NOTE(woosuk): The persistent batch optimization assumes that
+ # consecutive batches contain mostly the same requests. If batches
+ # have low request overlap (e.g., alternating between two distinct
+ # sets of requests), this optimization becomes very inefficient.
+ for req_id in unscheduled_req_ids:
+ self.input_batch.remove_request(req_id)
+
+ req_ids_to_add: list[str] = []
+ # Add new requests to the cached states.
+ for new_req_data in scheduler_output.scheduled_new_reqs:
+ req_id = new_req_data.req_id
+ sampling_params = new_req_data.sampling_params
+ pooling_params = new_req_data.pooling_params
+ if (
+ sampling_params
+ and sampling_params.sampling_type == SamplingType.RANDOM_SEED
+ ):
+ generator = torch.Generator(device=self.device)
+ generator.manual_seed(sampling_params.seed)
+ else:
+ generator = None
+
+ self.requests[req_id] = CachedRequestState(
+ req_id=req_id,
+ prompt_token_ids=new_req_data.prompt_token_ids,
+ mm_inputs=new_req_data.mm_inputs,
+ mm_positions=new_req_data.mm_positions,
+ sampling_params=sampling_params,
+ pooling_params=pooling_params,
+ generator=generator,
+ block_ids=new_req_data.block_ids,
+ num_computed_tokens=new_req_data.num_computed_tokens,
+ output_token_ids=[],
+ lora_request=new_req_data.lora_request,
+ )
+
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ if self.uses_mrope:
+ image_grid_thw = []
+ video_grid_thw = []
+ second_per_grid_ts = []
+ audio_feature_lengths = []
+ use_audio_in_video = False
+ for mm_input in self.requests[req_id].mm_inputs:
+ if mm_input.get("image_grid_thw") is not None:
+ image_grid_thw.extend(mm_input["image_grid_thw"].tolist())
+ if mm_input.get("video_grid_thw") is not None:
+ video_grid_thw.extend(mm_input["video_grid_thw"].tolist())
+ if mm_input.get("second_per_grid_ts") is not None:
+ second_per_grid_ts.extend(mm_input["second_per_grid_ts"])
+ if mm_input.get("audio_feature_lengths") is not None:
+ audio_feature_lengths.extend(
+ mm_input["audio_feature_lengths"]
+ )
+ if mm_input.get("use_audio_in_video") is True:
+ use_audio_in_video = True
+
+ hf_config = self.model_config.hf_config
+
+ (
+ self.requests[req_id].mrope_positions,
+ self.requests[req_id].mrope_position_delta,
+ ) = MRotaryEmbedding.get_input_positions_tensor(
+ self.requests[req_id].prompt_token_ids,
+ hf_config=hf_config,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ audio_feature_lengths=audio_feature_lengths,
+ use_audio_in_video=use_audio_in_video,
+ )
+
+ req_ids_to_add.append(req_id)
+
+ # Update the states of the running/resumed requests.
+ is_last_rank = get_pp_group().is_last_rank
+ req_data = scheduler_output.scheduled_cached_reqs
+ req_sparsed_slots = scheduler_output.req_sparsed_slots
+ for i, req_id in enumerate(req_data.req_ids):
+ req_state = self.requests[req_id]
+ num_computed_tokens = req_data.num_computed_tokens[i]
+ new_block_ids = req_data.new_block_ids[i]
+ resumed_from_preemption = req_data.resumed_from_preemption[i]
+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT
+
+ # Update the cached states.
+ req_state.num_computed_tokens = num_computed_tokens
+
+ if not is_last_rank:
+ # When using PP, the scheduler sends the sampled tokens back,
+ # because there's no direct communication between the first-
+ # stage worker and the last-stage worker.
+ new_token_ids = req_data.new_token_ids[i]
+ # Add the sampled token(s) from the previous step (if any).
+ # This doesn't include "unverified" tokens like spec tokens.
+ num_new_tokens = (
+ num_computed_tokens + len(new_token_ids) - req_state.num_tokens
+ )
+ if num_new_tokens == 1:
+ # Avoid slicing list in most common case.
+ req_state.output_token_ids.append(new_token_ids[-1])
+ elif num_new_tokens > 0:
+ req_state.output_token_ids.extend(
+ new_token_ids[-num_new_tokens:]
+ )
+
+ # Update the block IDs.
+ if resumed_from_preemption or is_sparsed_request:
+ # The request is resumed from preemption.
+ # Replace the existing block IDs with the new ones.
+ req_state.block_ids = new_block_ids
+ else:
+ # Append the new blocks to the existing block IDs.
+ for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
+ block_ids.extend(new_ids)
+
+ req_index = self.input_batch.req_id_to_index.get(req_id)
+ if req_index is None:
+ # The request is not in the persistent batch.
+ # The request was either preempted and resumed later, or was not
+ # scheduled in the previous step and needs to be added again.
+ req_ids_to_add.append(req_id)
+ continue
+
+ # Update the persistent batch.
+ self.input_batch.num_computed_tokens_cpu[req_index] = (
+ num_computed_tokens
+ )
+ if is_sparsed_request:
+ self.input_batch.block_table.reset_row(req_index)
+ self.input_batch.block_table.append_row(new_block_ids, req_index)
+
+ # For the last rank, we don't need to update the token_ids_cpu
+ # because the sampled tokens are already cached.
+ if not is_last_rank:
+ # Add new_token_ids to token_ids_cpu.
+ start_token_index = num_computed_tokens
+ end_token_index = num_computed_tokens + len(new_token_ids)
+ self.input_batch.token_ids_cpu[
+ req_index, start_token_index:end_token_index
+ ] = new_token_ids
+ self.input_batch.num_tokens_no_spec[req_index] = end_token_index
+ # Add spec_token_ids to token_ids_cpu.
+ spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
+ req_id, ()
+ )
+ if spec_token_ids:
+ start_index = end_token_index
+ end_token_index += len(spec_token_ids)
+ self.input_batch.token_ids_cpu[
+ req_index, start_index:end_token_index
+ ] = spec_token_ids
+ # NOTE(woosuk): `num_tokens` here may include spec tokens.
+ self.input_batch.num_tokens[req_index] = end_token_index
+
+ # Add the new or resumed requests to the persistent batch.
+ # The smaller empty indices are filled first.
+ for req_id in req_ids_to_add:
+ req_state = self.requests[req_id]
+ self.input_batch.add_request(req_state)
+
+ # Condense the batched states if there are gaps left by removed requests
+ self.input_batch.condense()
+ # Allow attention backend to reorder the batch, potentially
+ self._may_reorder_batch(scheduler_output)
+ # Refresh batch metadata with any pending updates.
+ self.input_batch.refresh_metadata()
+
+ GPUModelRunner._update_states = _update_states
+
+ def _prepare_inputs(
+ self,
+ scheduler_output: "SchedulerOutput",
+ ) -> tuple[
+ dict[str, Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], np.ndarray
+ ]:
+ """
+ :return: tuple[
+ attn_metadata: layer-to-attention_metadata mapping,
+ attention_cuda_graphs: whether attention can run in cudagraph
+ logits_indices, spec_decode_metadata
+ ]
+ """
+ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
+ assert total_num_scheduled_tokens > 0
+ num_reqs = self.input_batch.num_reqs
+ assert num_reqs > 0
+
+ # OPTIMIZATION: Start copying the block table first.
+ # This way, we can overlap the copy with the following CPU operations.
+ self.input_batch.block_table.commit(num_reqs)
+
+ # Get the number of scheduled tokens for each request.
+ req_ids = self.input_batch.req_ids
+ tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
+ num_scheduled_tokens = np.array(tokens, dtype=np.int32)
+ max_num_scheduled_tokens = max(tokens)
+
+ # Get request indices.
+ # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
+ req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
+
+ # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
+ # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
+ cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
+
+ # Get positions.
+ positions_np = self.positions_np[:total_num_scheduled_tokens]
+ np.add(
+ self.input_batch.num_computed_tokens_cpu[req_indices],
+ arange,
+ out=positions_np,
+ )
+
+ # Calculate M-RoPE positions.
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ if self.uses_mrope:
+ self._calc_mrope_positions(scheduler_output)
+
+ self.seq_lens_np[:num_reqs] = (
+ self.input_batch.num_computed_tokens_cpu[:num_reqs]
+ + num_scheduled_tokens
+ )
+
+ # TODO: improve performance, no `positions_np.copy()`
+ sparsed_positions = positions_np.copy()
+ req_sparsed_slots = scheduler_output.req_sparsed_slots
+ for req_id in self.input_batch.req_id_to_index:
+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT
+ req_index = self.input_batch.req_id_to_index[req_id]
+ offset = (
+ 0 if req_index == 0 else cu_num_tokens[req_index - 1]
+ ) # TODO: support MTP
+ if is_sparsed_request:
+ sparsed_positions[offset] = req_sparsed_slots[req_id] - 1
+
+ # Get token indices.
+ # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
+ # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
+ # where M is the max_model_len.
+ token_indices = (
+ positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
+ )
+
+ # NOTE(woosuk): We use torch.index_select instead of np.take here
+ # because torch.index_select is much faster than np.take for large
+ # tensors.
+ torch.index_select(
+ self.input_batch.token_ids_cpu_tensor.flatten(),
+ 0,
+ torch.from_numpy(token_indices),
+ out=self.input_ids_cpu[:total_num_scheduled_tokens],
+ )
+
+ # Calculate the slot mapping for each KV cache group.
+ for kv_cache_group_id, kv_cache_group_spec in enumerate(
+ self.kv_cache_config.kv_cache_groups
+ ):
+ block_size = kv_cache_group_spec.kv_cache_spec.block_size
+ block_table: BlockTable = self.input_batch.block_table[
+ kv_cache_group_id
+ ]
+ # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
+ # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
+ # where K is the max_num_blocks_per_req and the block size is 2.
+ # NOTE(woosuk): We can't simply use `token_indices // block_size`
+ # here because M (max_model_len) is not necessarily divisible by
+ # block_size.
+ block_table_indices = (
+ req_indices * block_table.max_num_blocks_per_req
+ + sparsed_positions // block_size
+ )
+ block_table_cpu = block_table.get_cpu_tensor()
+ block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
+ block_offsets = sparsed_positions % block_size
+ np.add(
+ block_numbers * block_size,
+ block_offsets,
+ out=block_table.slot_mapping_np[:total_num_scheduled_tokens],
+ )
+
+ # Prepare the attention metadata.
+ self.query_start_loc_np[0] = 0
+ self.query_start_loc_np[1 : num_reqs + 1] = cu_num_tokens
+
+ for req_id in self.input_batch.req_id_to_index:
+ req_index = self.input_batch.req_id_to_index[req_id]
+ is_sparsed_request = (
+ scheduler_output.req_sparsed_slots[req_id] != INVALID_SLOT
+ )
+ if is_sparsed_request:
+ self.seq_lens_np[req_index] = scheduler_output.req_sparsed_slots[
+ req_id
+ ]
+
+ # Copy the tensors to the GPU.
+ self.input_ids[:total_num_scheduled_tokens].copy_(
+ self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True
+ )
+ if self.uses_mrope:
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
+ self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
+ non_blocking=True,
+ )
+ else:
+ # Common case (1D positions)
+ self.positions_cpu[:total_num_scheduled_tokens] = torch.from_numpy(
+ positions_np[:total_num_scheduled_tokens]
+ )
+ self.positions[:total_num_scheduled_tokens].copy_(
+ self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True
+ )
+
+ self.query_start_loc[: num_reqs + 1].copy_(
+ self.query_start_loc_cpu[: num_reqs + 1], non_blocking=True
+ )
+ self.seq_lens[:num_reqs].copy_(
+ self.seq_lens_cpu[:num_reqs], non_blocking=True
+ )
+
+ # Fill unused with -1. Needed for reshape_and_cache
+ self.seq_lens[num_reqs:].fill_(0)
+ # Note: pad query_start_loc to be non-decreasing, as kernels
+ # like FlashAttention requires that
+ self.query_start_loc[num_reqs + 1 :].fill_(
+ self.query_start_loc_cpu[num_reqs].item()
+ )
+
+ query_start_loc = self.query_start_loc[: num_reqs + 1]
+ seq_lens = self.seq_lens[:num_reqs]
+
+ common_attn_metadata = CommonAttentionMetadata(
+ query_start_loc=query_start_loc,
+ seq_lens=seq_lens,
+ num_reqs=num_reqs,
+ num_actual_tokens=total_num_scheduled_tokens,
+ max_query_len=max_num_scheduled_tokens,
+ )
+
+ attn_metadata: dict[str, Any] = {}
+ # Prepare the attention metadata for each KV cache group and make layers
+ # in the same group share the same metadata.
+ for kv_cache_group_id, kv_cache_group_spec in enumerate(
+ self.kv_cache_config.kv_cache_groups
+ ):
+
+ # Prepare for cascade attention if enabled & beneficial.
+ common_prefix_len = 0
+ builder = self.attn_metadata_builders[kv_cache_group_id]
+ if self.cascade_attn_enabled:
+ common_prefix_len = self._compute_cascade_attn_prefix_len(
+ num_scheduled_tokens,
+ scheduler_output.num_common_prefix_blocks[kv_cache_group_id],
+ kv_cache_group_spec.kv_cache_spec,
+ builder,
+ )
+
+ attn_metadata_i = builder.build(
+ common_prefix_len=common_prefix_len,
+ common_attn_metadata=common_attn_metadata,
+ )
+
+ for layer_name in kv_cache_group_spec.layer_names:
+ attn_metadata[layer_name] = attn_metadata_i
+
+ attention_cuda_graphs = all(
+ b.can_run_in_cudagraph(common_attn_metadata)
+ for b in self.attn_metadata_builders
+ )
+
+ use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
+ if not use_spec_decode:
+ # NOTE(woosuk): Due to chunked prefills, the batch may contain
+ # partial requests. While we should not sample any token
+ # from these partial requests, we do so for simplicity.
+ # We will ignore the sampled tokens from the partial requests.
+ # TODO: Support prompt logprobs.
+ logits_indices = query_start_loc[1:] - 1
+ spec_decode_metadata = None
+ else:
+ # Get the number of draft tokens for each request.
+ # Iterate over the dictionary rather than all requests since not all
+ # requests have draft tokens.
+ num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
+ for (
+ req_id,
+ draft_token_ids,
+ ) in scheduler_output.scheduled_spec_decode_tokens.items():
+ req_idx = self.input_batch.req_id_to_index[req_id]
+ num_draft_tokens[req_idx] = len(draft_token_ids)
+
+ spec_decode_metadata = self._calc_spec_decode_metadata(
+ num_draft_tokens, cu_num_tokens
+ )
+ logits_indices = spec_decode_metadata.logits_indices
+
+ # Hot-Swap lora model
+ if self.lora_config:
+ self.set_active_loras(self.input_batch, num_scheduled_tokens)
+
+ return (
+ attn_metadata,
+ attention_cuda_graphs,
+ logits_indices,
+ spec_decode_metadata,
+ num_scheduled_tokens,
+ )
+
+ GPUModelRunner._prepare_inputs = _prepare_inputs
+
+ @torch.inference_mode()
+ def execute_model(
+ self,
+ scheduler_output: "SchedulerOutput",
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ ) -> Union[ModelRunnerOutput, IntermediateTensors]:
+ self._update_states(scheduler_output)
+ if not scheduler_output.total_num_scheduled_tokens:
+ if not has_kv_transfer_group():
+ # Return empty ModelRunnerOutput if there's no work to do.
+ return EMPTY_MODEL_RUNNER_OUTPUT
+
+ return self.kv_connector_no_forward(scheduler_output)
+
+ # Prepare the decoder inputs.
+ (
+ attn_metadata,
+ attention_cuda_graphs,
+ logits_indices,
+ spec_decode_metadata,
+ num_scheduled_tokens_np,
+ ) = self._prepare_inputs(scheduler_output)
+ num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
+ if (
+ self.use_cuda_graph
+ and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]
+ ):
+ # Use piecewise CUDA graphs.
+ # Add padding to the batch size.
+ num_input_tokens = self.vllm_config.pad_for_cudagraph(
+ num_scheduled_tokens
+ )
+ else:
+ # Eager mode.
+ # Pad tokens to multiple of tensor_parallel_size when
+ # enabled collective fusion for SP
+ tp_size = self.vllm_config.parallel_config.tensor_parallel_size
+ if (
+ self.compilation_config.pass_config.enable_sequence_parallelism
+ and tp_size > 1
+ ):
+ num_input_tokens = round_up(num_scheduled_tokens, tp_size)
+ else:
+ num_input_tokens = num_scheduled_tokens
+
+ # Padding for DP
+ num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
+ num_input_tokens += num_pad
+
+ # _prepare_inputs may reorder the batch, so we must gather multi
+ # modal outputs after that to ensure the correct order
+ if self.is_multimodal_model:
+ # Run the multimodal encoder if any.
+ self._execute_mm_encoder(scheduler_output)
+ mm_embeds = self._gather_mm_embeddings(scheduler_output)
+ else:
+ mm_embeds = []
+
+ if self.is_multimodal_model and get_pp_group().is_first_rank:
+ # NOTE(woosuk): To unify token ids and soft tokens (vision
+ # embeddings), we always use embeddings (rather than token ids)
+ # as input to the multimodal model, even when the input is text.
+ input_ids = self.input_ids[:num_scheduled_tokens]
+ if mm_embeds:
+ inputs_embeds = self.model.get_input_embeddings(
+ input_ids, mm_embeds
+ )
+ else:
+ inputs_embeds = self.model.get_input_embeddings(input_ids)
+ # TODO(woosuk): Avoid the copy. Optimize.
+ self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
+ inputs_embeds = self.inputs_embeds[:num_input_tokens]
+ input_ids = None
+ else:
+ # For text-only models, we use token ids as input.
+ # While it is possible to use embeddings as input just like the
+ # multimodal models, it is not desirable for performance since
+ # then the embedding layer is not included in the CUDA graph.
+ input_ids = self.input_ids[:num_input_tokens]
+ inputs_embeds = None
+ if self.uses_mrope:
+ positions = self.mrope_positions[:, :num_input_tokens]
+ else:
+ positions = self.positions[:num_input_tokens]
+
+ if get_pp_group().is_first_rank:
+ intermediate_tensors = None
+ else:
+ intermediate_tensors = self.sync_and_slice_intermediate_tensors(
+ num_input_tokens, intermediate_tensors, True
+ )
+
+ # Some attention backends only support CUDA Graphs in pure decode.
+ # If attention doesn't support CUDA Graphs for this batch, but we
+ # compiled with full CUDA graphs, we have to skip them entirely.
+ skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
+
+ # Run the model.
+ # Use persistent buffers for CUDA graphs.
+ with set_forward_context(
+ attn_metadata,
+ self.vllm_config,
+ num_tokens=num_input_tokens,
+ num_tokens_across_dp=num_tokens_across_dp,
+ skip_cuda_graphs=skip_cuda_graphs,
+ ):
+ self.maybe_setup_kv_connector(scheduler_output)
+ self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata)
+
+ model_output = self.model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+
+ self.maybe_wait_for_kv_save()
+ self.maybe_execute_ucm_sparse_finished()
+
+ finished_sending, finished_recving = self.get_finished_kv_transfers(
+ scheduler_output
+ )
+
+ if self.use_aux_hidden_state_outputs:
+ hidden_states, aux_hidden_states = model_output
+ else:
+ hidden_states = model_output
+ aux_hidden_states = None
+
+ # Broadcast PP output for external_launcher (torchrun)
+ # to make sure we are synced across pp ranks
+ # TODO: Support overlapping mirco-batches
+ # https://github.com/vllm-project/vllm/issues/18019
+ broadcast_pp_output = (
+ self.parallel_config.distributed_executor_backend == "external_launcher"
+ and len(get_pp_group().ranks) > 0
+ )
+ if not get_pp_group().is_last_rank:
+ # For mid-pipeline stages, return the hidden states.
+ if not broadcast_pp_output:
+ return hidden_states
+ assert isinstance(hidden_states, IntermediateTensors)
+ get_pp_group().send_tensor_dict(
+ hidden_states.tensors, all_gather_group=get_tp_group()
+ )
+ logits = None
+ else:
+ if self.input_batch.pooling_params:
+ return self._pool(
+ hidden_states,
+ num_scheduled_tokens,
+ num_scheduled_tokens_np,
+ finished_sending,
+ finished_recving,
+ )
+
+ sample_hidden_states = hidden_states[logits_indices]
+ logits = self.model.compute_logits(sample_hidden_states, None)
+ if broadcast_pp_output:
+ model_output_broadcast_data = (
+ {
+ "logits": logits.contiguous(),
+ }
+ if logits is not None
+ else {}
+ )
+ model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
+ model_output_broadcast_data, src=len(get_pp_group().ranks) - 1
+ )
+ assert model_output_broadcast_data is not None
+ logits = model_output_broadcast_data["logits"]
+
+ # Apply structured output bitmasks if present
+ if scheduler_output.grammar_bitmask is not None:
+ self.apply_grammar_bitmask(scheduler_output, logits)
+
+ # Sample the next token and get logprobs if needed.
+ sampling_metadata = self.input_batch.sampling_metadata
+ if spec_decode_metadata is None:
+ sampler_output = self.sampler(
+ logits=logits,
+ sampling_metadata=sampling_metadata,
+ )
+ else:
+ # When indexing with a tensor (bonus_logits_indices), PyTorch
+ # creates a new tensor with separate storage from the original
+ # logits tensor. This means any in-place operations on bonus_logits
+ # won't affect the original logits tensor.
+ assert logits is not None
+ bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
+ sampler_output = self.sampler(
+ logits=bonus_logits,
+ sampling_metadata=sampling_metadata,
+ )
+ bonus_token_ids = sampler_output.sampled_token_ids
+
+ # Just like `bonus_logits`, `target_logits` is a new tensor with
+ # separate storage from the original `logits` tensor. Therefore,
+ # it is safe to update `target_logits` in place.
+ target_logits = logits[spec_decode_metadata.target_logits_indices]
+ output_token_ids = self.rejection_sampler(
+ spec_decode_metadata,
+ None, # draft_probs
+ target_logits,
+ bonus_token_ids,
+ sampling_metadata,
+ )
+ sampler_output.sampled_token_ids = output_token_ids
+
+ num_nans_in_logits = {}
+ if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
+ num_nans_in_logits = self._get_nans_in_logits(logits)
+
+ # TODO(woosuk): The following loop can be slow since it iterates over
+ # the requests one by one. Optimize.
+ discard_sampled_tokens_req_indices = []
+ for i, req_id in enumerate(self.input_batch.req_ids):
+ req_state = self.requests[req_id]
+ seq_len = (
+ req_state.num_computed_tokens
+ + scheduler_output.num_scheduled_tokens[req_id]
+ )
+ if seq_len < req_state.num_tokens:
+ # Ignore the sampled token for partial prefills.
+ # Rewind the generator state as if the token was not sampled.
+ # This relies on cuda-specific torch-internal impl details
+ generator = self.input_batch.generators.get(i)
+ if generator is not None:
+ generator.set_offset(generator.get_offset() - 4)
+ # Record the index of the request that should not be sampled,
+ # so that we could clear the sampled tokens before returning.
+ discard_sampled_tokens_req_indices.append(i)
+
+ # NOTE: GPU -> CPU Sync happens here.
+ # Move as many CPU operations as possible before this sync point.
+ logprobs_tensors = sampler_output.logprobs_tensors
+ logprobs_lists = (
+ logprobs_tensors.tolists() if logprobs_tensors is not None else None
+ )
+
+ # Compute prompt logprobs if needed.
+ prompt_logprobs_dict = self._get_prompt_logprobs_dict(
+ hidden_states[:num_scheduled_tokens],
+ scheduler_output,
+ )
+
+ # Get the valid generated tokens.
+ sampled_token_ids = sampler_output.sampled_token_ids
+ max_gen_len = sampled_token_ids.shape[-1]
+ if max_gen_len == 1:
+ # No spec decode tokens.
+ valid_sampled_token_ids = sampled_token_ids.tolist()
+ else:
+ # Includes spec decode tokens.
+ valid_sampled_token_ids = self.rejection_sampler.parse_output(
+ sampled_token_ids,
+ self.input_batch.vocab_size,
+ )
+ # Mask out the sampled tokens that should not be sampled.
+ for i in discard_sampled_tokens_req_indices:
+ valid_sampled_token_ids[i].clear()
+
+ # Cache the sampled tokens in the model runner, so that the scheduler
+ # doesn't need to send them back.
+ # NOTE(woosuk): As an exception, when using PP, the scheduler sends
+ # the sampled tokens back, because there's no direct communication
+ # between the first-stage worker and the last-stage worker.
+ for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
+ if not sampled_ids:
+ continue
+
+ start_idx = self.input_batch.num_tokens_no_spec[req_idx]
+ end_idx = start_idx + len(sampled_ids)
+ assert end_idx <= self.max_model_len, (
+ "Sampled token IDs exceed the max model length. "
+ f"Total number of tokens: {end_idx} > max_model_len: "
+ f"{self.max_model_len}"
+ )
+
+ self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
+ self.input_batch.num_tokens_no_spec[req_idx] = end_idx
+ self.input_batch.num_tokens[req_idx] = end_idx
+ req_id = self.input_batch.req_ids[req_idx]
+ req_state = self.requests[req_id]
+ req_state.output_token_ids.extend(sampled_ids)
+
+ if not self.speculative_config:
+ # Speculative decoding is not enabled.
+ spec_token_ids = None
+ else:
+ spec_token_ids = self.propose_draft_token_ids(
+ scheduler_output,
+ valid_sampled_token_ids,
+ sampling_metadata,
+ hidden_states,
+ sample_hidden_states,
+ aux_hidden_states,
+ spec_decode_metadata,
+ attn_metadata,
+ )
+
+ # Clear KVConnector state after all KVs are generated.
+ if has_kv_transfer_group():
+ get_kv_transfer_group().clear_connector_metadata()
+
+ self.eplb_step()
+
+ return ModelRunnerOutput(
+ req_ids=self.input_batch.req_ids,
+ req_id_to_index=self.input_batch.req_id_to_index,
+ sampled_token_ids=valid_sampled_token_ids,
+ spec_token_ids=spec_token_ids,
+ logprobs=logprobs_lists,
+ prompt_logprobs_dict=prompt_logprobs_dict,
+ pooler_output=[],
+ finished_sending=finished_sending,
+ finished_recving=finished_recving,
+ num_nans_in_logits=num_nans_in_logits,
+ )
+
+ GPUModelRunner.execute_model = execute_model
+
+ except ImportError:
+ logger.warning("Could not patch prepare inputs - module not found")
+
+
+# ==================== vllm/v1/worker/gpu_worker.py ====================
+def _patch_gpu_worker() -> None:
+ """Patch gpu worker to add UCM sparse support."""
+ try:
+ from typing import Optional
+
+ from vllm.config import VllmConfig
+ from vllm.v1.worker import gpu_worker
+
+ from ucm.sparse.state import ensure_ucm_sparse_initialized
+
+ original_init_worker_distributed_environment = (
+ gpu_worker.init_worker_distributed_environment
+ )
+
+ def patched_init_worker_distributed_environment(
+ vllm_config: VllmConfig,
+ rank: int,
+ distributed_init_method: Optional[str] = None,
+ local_rank: int = -1,
+ backend: str = "nccl",
+ ) -> None:
+ original_init_worker_distributed_environment(
+ vllm_config, rank, distributed_init_method, local_rank, backend
+ )
+ ensure_ucm_sparse_initialized(vllm_config)
+
+ gpu_worker.init_worker_distributed_environment = (
+ patched_init_worker_distributed_environment
+ )
+ except ImportError:
+ logger.warning("Could not patch gpu worker - module not found")
diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py
index 421011ca2..c8317007b 100644
--- a/ucm/integration/vllm/uc_connector.py
+++ b/ucm/integration/vllm/uc_connector.py
@@ -25,9 +25,10 @@
#
import hashlib
import pickle
+from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
-from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union
+from typing import TYPE_CHECKING, Any, List, Optional, Union
import torch
from vllm.config import VllmConfig
@@ -43,6 +44,7 @@
from ucm.logger import init_logger
from ucm.store.factory import UcmConnectorFactory
from ucm.store.ucmstore import Task
+from ucm.utils import Config
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -89,7 +91,7 @@ class UnifiedCacheConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self.block_size = vllm_config.cache_config.block_size
- self.use_layerwise = True
+ self.use_layerwise = False
self.kv_caches: dict[str, torch.Tensor] = {}
self.total_tp_size = vllm_config.parallel_config.tensor_parallel_size
self.rank = (
@@ -98,36 +100,25 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
self.request_block_infos: dict[str, RequestBlockInfo] = {}
# dump tasks record request -> block -> list[task]
self.dump_tasks: dict[str, dict[str, List[Task]]] = {}
- self.layerwise_load_tasks: dict[str, dict[str, tuple[Task, Task]]] = {}
+ self.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict(dict)
self.is_mla = self._vllm_config.model_config.is_deepseek_mla
self.num_layers = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config
)
self.element_size = vllm_config.model_config.dtype.itemsize
self.kv_role = vllm_config.kv_transfer_config.kv_role
- self._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
+ self._need_load_reqs: dict[str, Union[list[int], Task]] = {}
self._load_failed_reqs: set[str] = set()
self._load_req_to_blocks: dict[str, set[int]] = {}
self.num_head = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
)
self.head_size = vllm_config.model_config.get_head_size()
- if (
- self._vllm_config.kv_transfer_config is not None
- and "ucm_connector_name"
- in self._vllm_config.kv_transfer_config.kv_connector_extra_config
- ):
- name = self._vllm_config.kv_transfer_config.kv_connector_extra_config[
- "ucm_connector_name"
- ]
- config = {}
- if (
- "ucm_connector_config"
- in self._vllm_config.kv_transfer_config.kv_connector_extra_config
- ):
- config = self._vllm_config.kv_transfer_config.kv_connector_extra_config[
- "ucm_connector_config"
- ]
+ ucm_config = Config(vllm_config.kv_transfer_config)
+ launch_config = ucm_config.get_config()
+ if "ucm_connector_name" in launch_config:
+ name = launch_config.get("ucm_connector_name")
+ config = launch_config.get("ucm_connector_config") or {}
config["device"] = self.rank
config["role"] = (
"scheduler" if role == KVConnectorRole.SCHEDULER else "worker"
@@ -187,9 +178,9 @@ def DataOffset(self, kv_layer, rank, layer_id, is_v):
kv_layer[1][0].numel() if not self.is_mla else 0
) * elem_size
# When tp > 1 layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size
- layer_size = (
- k_min_data_block_size + v_min_data_block_size
- ) * self.total_tp_size
+ layer_size = (k_min_data_block_size + v_min_data_block_size) * (
+ self.total_tp_size if not self.is_mla else 1
+ )
if is_v:
# Offset of v = Offset of k + k_min_data_block_size
return int(
@@ -258,65 +249,49 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
if len(self.kv_caches) == 0:
self._init_kv_caches_from_forward_context(forward_context)
+ if len(list(self.kv_caches.values())[0]) == 2:
+ self.is_mla = False
self.layerwise_load_tasks.clear()
self.current_layer = 0
+ need_load_tasks: dict[str, Task] = {}
for request in metadata.requests:
if not request.load_blocks:
continue
storage_block_ids = [block[0] for block in request.load_blocks]
vllm_block_ids = [block[1] for block in request.load_blocks]
- blocks_len = len(storage_block_ids)
self._load_req_to_blocks.setdefault(request.request_id, set()).update(
vllm_block_ids
)
+ is_load_async = request.load_async
+ total_offsets = []
+ total_tensors = []
+ storage_block_ids = storage_block_ids * (1 if self.is_mla else 2)
for layer_name, kv_layer in self.kv_caches.items():
tensors, offsets = self.get_tensor_and_offset_layerwise(
vllm_block_ids, kv_layer, layer_name
)
- k_task_id = self.connector.load(
- storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
- )
- v_task_id = None
- if not self.is_mla:
- v_task_id = self.connector.load(
- storage_block_ids,
- offsets[blocks_len:],
- tensors[blocks_len:],
- )
- if request.request_id not in self.layerwise_load_tasks:
- self.layerwise_load_tasks[request.request_id] = {}
- self.layerwise_load_tasks[request.request_id][layer_name] = (
- k_task_id,
- v_task_id,
+ if self.use_layerwise and not is_load_async:
+ task_id = self.connector.load(storage_block_ids, offsets, tensors)
+ self.layerwise_load_tasks[request.request_id][layer_name] = task_id
+ continue
+ else:
+ total_offsets.extend(offsets)
+ total_tensors.extend(tensors)
+ if total_offsets and total_tensors:
+ storage_block_ids = storage_block_ids * self.num_layers
+ task_id = self.connector.load(
+ storage_block_ids, total_offsets, total_tensors
)
-
- if request.load_async and request.request_id in self.layerwise_load_tasks:
- for _, (k_task, v_task) in self.layerwise_load_tasks[
- request.request_id
- ].items():
- if request.request_id not in self._need_load_reqs:
- self._need_load_reqs[request.request_id] = []
- self._need_load_reqs[request.request_id].append(k_task)
- if not self.is_mla:
- self._need_load_reqs[request.request_id].append(v_task)
- self.layerwise_load_tasks.pop(request.request_id)
- continue
-
- if (
- not self.use_layerwise
- and request.request_id in self.layerwise_load_tasks
- ):
- for _, (k_task, v_task) in self.layerwise_load_tasks[
- request.request_id
- ].items():
- if self.connector.wait(k_task) != 0:
- self._load_failed_reqs.add(request.request_id)
- break
- if v_task and self.connector.wait(v_task) != 0:
- self._load_failed_reqs.add(request.request_id)
- break
+ if is_load_async:
+ self._need_load_reqs[request.request_id] = task_id
+ else:
+ need_load_tasks[request.request_id] = task_id
+ for req_id, task_id in need_load_tasks.items():
+ if self.connector.wait(task_id) != 0:
+ self._load_failed_reqs.add(req_id)
+ logger.error(f"Failed to load blocks for req {req_id}")
def wait_for_layer_load(self, layer_name: str) -> None:
"""
@@ -334,26 +309,19 @@ def wait_for_layer_load(self, layer_name: str) -> None:
if self.layerwise_load_tasks:
logger.debug(f"Waiting for layer {self.current_layer} to be loaded")
- assert (
- self.current_layer < self.num_layers
- ), "The current layer should be less than total layers!"
+ if self.current_layer >= self.num_layers:
+ return
+
for request_id, layer_to_task in self.layerwise_load_tasks.items():
if request_id in self._load_failed_reqs:
continue
- k_task, v_task = layer_to_task[layer_name]
- if self.connector.wait(k_task) != 0:
+ task = layer_to_task[layer_name]
+ if self.connector.wait(task) != 0:
self._load_failed_reqs.add(request_id)
logger.error(
f"Failed to load block for request {request_id} on layer {layer_name}"
)
continue
- if not self.is_mla:
- if self.connector.wait(v_task) != 0:
- self._load_failed_reqs.add(request_id)
- logger.error(
- f"Failed to load block for request {request_id} on layer {layer_name}"
- )
- continue
logger.debug(f"Load tasks for {request_id} on layer {layer_name} finished.")
def save_kv_layer(
@@ -384,6 +352,9 @@ def save_kv_layer(
if not self.use_layerwise:
return
+ if self.current_layer > self.num_layers:
+ return
+
metadata = self._get_connector_metadata()
assert isinstance(metadata, UCConnectorV1Metadata)
@@ -407,6 +378,8 @@ def save_kv_layer(
torch.npu.current_stream().synchronize()
elif kv_layer[0].device.type == "cuda":
torch.cuda.current_stream().synchronize()
+ elif kv_layer[0].device.type == "musa":
+ torch.musa.current_stream().synchronize()
for block_id, offset, tensor in zip(
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
@@ -434,6 +407,8 @@ def wait_for_save(self) -> Optional[dict[str, list[str]]]:
"""
if hasattr(self, "kv_role") and self.kv_role == "kv_consumer":
return
+ if self.is_mla and self.rank != 0:
+ return
# request id -> succeed dumped blocks
success_dumped_blocks: dict[str, list[str]] = {}
@@ -452,57 +427,57 @@ def wait_for_tasks():
self.dump_tasks.clear()
return success_dumped_blocks if success_dumped_blocks else None
+ req_to_dump_blocks: dict[str, list[str]] = {}
+ need_dump_tasks: dict[str, Task] = {}
for request in metadata.requests:
if not request.dump_blocks:
continue
storage_block_ids = [block[0] for block in request.dump_blocks]
vllm_block_ids = [block[1] for block in request.dump_blocks]
- blocks_len = len(storage_block_ids)
+ req_to_dump_blocks[request.request_id] = storage_block_ids
+ total_offsets = []
+ total_tensors = []
+ total_block_ids = (
+ storage_block_ids * (1 if self.is_mla else 2) * self.num_layers
+ )
for layer_name, kv_layer in self.kv_caches.items():
tensors, offsets = self.get_tensor_and_offset_layerwise(
vllm_block_ids, kv_layer, layer_name
)
- for block_id, offset, tensor in zip(
- storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
- ):
- task = self.connector.dump([block_id], [offset], [tensor])
- self.dump_tasks.setdefault(request.request_id, {}).setdefault(
- block_id, []
- ).append(task)
- if not self.is_mla:
- for block_id, offset, tensor in zip(
- storage_block_ids,
- offsets[blocks_len:],
- tensors[blocks_len:],
- ):
- task = self.connector.dump([block_id], [offset], [tensor])
- self.dump_tasks.setdefault(request.request_id, {}).setdefault(
- block_id, []
- ).append(task)
- wait_for_tasks()
- self.dump_tasks.clear()
+ total_offsets.extend(offsets)
+ total_tensors.extend(tensors)
+ task_id = self.connector.dump(total_block_ids, total_offsets, total_tensors)
+ need_dump_tasks[request.request_id] = task_id
+
+ for req_id, task_id in need_dump_tasks.items():
+ if self.connector.wait(task_id) != 0:
+ logger.error(f"Failed to dump blocks for req {request.request_id}")
+ else:
+ success_dumped_blocks[req_id] = req_to_dump_blocks[req_id]
return success_dumped_blocks if success_dumped_blocks else None
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
done_recving: set[str] = set()
- for req_id, tasks in self._need_load_reqs.items():
+ for req_id, task in self._need_load_reqs.items():
if req_id in self._load_failed_reqs:
+ done_recving.add(req_id)
continue
- unfinished_tasks = []
- for task in tasks:
- ret = self.connector.check(task)
- if ret == -1:
- unfinished_tasks.append(task)
- continue
- elif ret == 0 and self.connector.wait(task) == 0:
- continue
+ ret, finish = self.connector.check(task)
+ if ret != 0:
+ logger.error(
+ f"Task {task} failed, check return {ret} for request {req_id}"
+ )
self._load_failed_reqs.add(req_id)
- break
- if not unfinished_tasks:
- done_recving.add(req_id)
- self._need_load_reqs[req_id] = unfinished_tasks
+ elif not finish:
+ continue
+ elif (wret := self.connector.wait(task)) != 0:
+ logger.error(
+ f"Task {task} failed, wait return {wret} for request {req_id}"
+ )
+ self._load_failed_reqs.add(req_id)
+ done_recving.add(req_id)
# remove the finished requests
for req_id in list(done_recving):
@@ -596,9 +571,9 @@ def hash_request_tokens(
# TODO we will fix hole match later
break
logger.info(
- f"\nnum_total_blocks: {len(block_hashes)}\n"
- f"\nnum_lookup_hits on hbm: {start_position}\n"
- f"\nnum_lookup_hits on storage except hbm: {num_lookup_hits}\n"
+ f"num_total_blocks: {len(block_hashes)}, "
+ f"num_lookup_hits on hbm: {start_position}, "
+ f"num_lookup_hits on storage except hbm: {num_lookup_hits}"
)
# Load async when Decode instance need to load
@@ -776,7 +751,8 @@ def request_finished(
if cancel_blocks:
logger.debug(f"commit {cancel_blocks} to False.")
self.connector.commit(cancel_blocks, False)
- request.succeed_dumped_blocks.clear()
+ if hasattr(request, "succeed_dumped_blocks"):
+ request.succeed_dumped_blocks.clear()
return False, None
def _extract_blocks(
diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py
new file mode 100644
index 000000000..e843a3e7e
--- /dev/null
+++ b/ucm/integration/vllm/ucm_connector.py
@@ -0,0 +1,868 @@
+import hashlib
+import itertools
+import os
+import pickle
+import time
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Callable, List, Optional
+
+import torch
+from vllm.config import VllmConfig
+from vllm.distributed.kv_transfer.kv_connector.v1.base import (
+ KVConnectorBase_V1,
+ KVConnectorMetadata,
+ KVConnectorRole,
+)
+from vllm.distributed.parallel_state import get_tp_group, get_world_group
+from vllm.platforms import current_platform
+from vllm.v1.core.sched.output import SchedulerOutput
+from vllm.v1.request import Request
+
+from ucm.logger import init_logger
+from ucm.shared.metrics import ucmmonitor
+from ucm.shared.metrics.observability import UCMStatsLogger
+from ucm.store.factory import UcmConnectorFactory
+from ucm.store.ucmstore import Task, UcmKVStoreBase
+from ucm.utils import Config
+
+if TYPE_CHECKING:
+ from vllm.attention.backends.abstract import AttentionMetadata
+ from vllm.forward_context import ForwardContext
+ from vllm.v1.core.kv_cache_manager import KVCacheBlocks
+
+logger = init_logger(__name__)
+
+
+@dataclass
+class RequestMeta:
+ ucm_block_ids: list[str] = field(default_factory=list)
+ hbm_hit_block_num: int = 0
+ # local_computed_block + external_computed_block
+ total_hit_block_num: int = 0
+ num_token_ids: int = 0
+ vllm_block_ids: list[int] = field(default_factory=list)
+ token_processed: int = 0
+
+
+@dataclass
+class RequestDispatchMeta:
+ load_block_ids: tuple[
+ list[str], list[int]
+ ] # [0] mean ucm_block_ids, [1] means vllm_block_ids
+ dump_block_ids: tuple[list[str], list[int]]
+
+
+@dataclass
+class UCMConnectorMetadata(KVConnectorMetadata):
+ request_meta: dict[str, RequestDispatchMeta] = field(default_factory=dict)
+
+
+class RequestHasher:
+ """hash(md5) request to generate ucm block id"""
+
+ _SEED_HASH = None
+
+ def __init__(self, vllm_config, rank_id):
+ meta = f"{vllm_config.model_config.model}:{vllm_config.parallel_config.world_size}:{vllm_config.model_config.dtype}:{rank_id}"
+ self.meta_bytes = meta.encode("utf-8")
+
+ if RequestHasher._SEED_HASH is None:
+ RequestHasher._SEED_HASH = self("UCM_HASH_SEED")
+
+ def __call__(self, input_data) -> int:
+ if isinstance(input_data, str):
+ input_bytes = input_data.encode("utf-8")
+ else:
+ input_bytes = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL)
+
+ h = hashlib.md5(self.meta_bytes + input_bytes)
+ return int.from_bytes(h.digest(), byteorder="big")
+
+
+class UCMDirectConnector(KVConnectorBase_V1):
+ """
+ This connector means synchronize:
+ load -> forward -> save
+ """
+
+ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
+ super().__init__(vllm_config=vllm_config, role=role)
+ self.kv_caches: dict[str, torch.Tensor] = {}
+ self.local_rank = (
+ -1 if role == KVConnectorRole.SCHEDULER else get_world_group().local_rank
+ )
+ self.global_rank = self._vllm_config.parallel_config.rank
+ self.block_size = self._vllm_config.cache_config.block_size
+ self.is_mla = self._vllm_config.model_config.is_deepseek_mla
+ self.is_dsa = False
+ self.kv_cache_dtype: torch.dtype = None
+
+ if current_platform.is_cuda_alike():
+ logger.info("CUDA device is available.")
+ torch_dev = torch
+ dev_name = "cuda"
+ elif current_platform.device_type == "npu":
+ logger.info("NPU device is available.")
+ torch_dev = torch.npu
+ dev_name = "npu"
+ else:
+ raise RuntimeError("Unsupported device platform for UCMDirectConnector.")
+
+ if self.local_rank >= 0:
+ self.device = torch_dev.device(f"{dev_name}:{self.local_rank}")
+ self._layer_offset_cache = {}
+
+ self.store: UcmKVStoreBase
+
+ if role == KVConnectorRole.SCHEDULER:
+ self.request_hasher = RequestHasher(vllm_config, 0)
+ else:
+ self.request_hasher = RequestHasher(vllm_config, self.global_rank)
+
+ # save block info, avoid hash request twice, and track them until request finished
+ self.requests_meta: dict[str, RequestMeta] = {}
+
+ ucm_config = Config(vllm_config.kv_transfer_config)
+ self.launch_config = ucm_config.get_config()
+
+ self.load_only_first_rank: bool = (
+ self.launch_config.get("load_only_first_rank", self.is_mla) and self.is_mla
+ )
+ if self.load_only_first_rank:
+ if role == KVConnectorRole.WORKER:
+ self.group_coordinator = get_tp_group()
+ self.broadcast_fn = self.group_coordinator.broadcast
+ self.broadcast_stream = torch.cuda.Stream()
+
+ logger.info(f"self.launch_config: {self.launch_config}")
+ connector_configs = self.launch_config.get("ucm_connectors", [])
+ assert len(connector_configs) > 0, "no storage connector name in config."
+
+ name = connector_configs[0].get("ucm_connector_name")
+ config = connector_configs[0].get("ucm_connector_config") or {}
+ config["device"] = self.local_rank
+ config["role"] = "scheduler" if role == KVConnectorRole.SCHEDULER else "worker"
+ element_size = vllm_config.model_config.dtype.itemsize
+ single_head_dim = vllm_config.model_config.get_head_size()
+ num_head_per_tp = vllm_config.model_config.get_num_kv_heads(
+ vllm_config.parallel_config
+ )
+ total_tp_size = vllm_config.parallel_config.tensor_parallel_size
+ num_layers = vllm_config.model_config.get_num_layers(
+ vllm_config.parallel_config
+ )
+ block_size_per_layer = self.block_size * element_size * single_head_dim
+ config["kv_block_size"] = (
+ block_size_per_layer
+ * num_layers
+ * (1 if self.is_mla else num_head_per_tp * 2)
+ )
+ config["io_size"] = block_size_per_layer * (
+ 1 if self.is_mla else num_head_per_tp
+ )
+ self.store = UcmConnectorFactory.create_connector(name, config)
+ self.block_data_size = config["kv_block_size"]
+
+ logger.info("init UCConnectorImpl, connector: %s", name)
+ logger.info(
+ "single file size = %d MB, io_size = %d KB,",
+ config["kv_block_size"] / 1024 / 1024,
+ config["io_size"] / 1024,
+ )
+
+ self.metrics_config = self.launch_config.get("metrics_config_path", "")
+ if self.metrics_config:
+ self.stats_logger = UCMStatsLogger(
+ vllm_config.model_config.served_model_name,
+ self.global_rank,
+ self.metrics_config,
+ )
+ self.monitor = ucmmonitor.StatsMonitor.get_instance()
+ self.synchronize = (
+ torch.cuda.synchronize
+ if current_platform.is_cuda_alike()
+ else torch.npu.synchronize
+ )
+
+ def generate_hash(self, block_size: int, request: "Request") -> list[str]:
+ token_ids = request.all_token_ids
+
+ ret = []
+ parent_block_hash_value = RequestHasher._SEED_HASH
+ for start in range(0, len(token_ids), block_size):
+ end = start + block_size
+ block_token_ids = token_ids[start:end]
+ # Do not hash the block if it is not full.
+ if len(block_token_ids) < block_size:
+ break
+
+ block_token_ids_tuple = tuple(block_token_ids)
+ hash_value = self.request_hasher(
+ (parent_block_hash_value, block_token_ids_tuple)
+ )
+ parent_block_hash_value = hash_value
+ ret.append(str(hash_value))
+
+ return ret
+
+ def get_num_new_matched_tokens(
+ self,
+ request: "Request",
+ num_computed_tokens: int,
+ ) -> tuple[int, bool]:
+ assert num_computed_tokens % self.block_size == 0
+ hbm_hit_block_num = num_computed_tokens // self.block_size
+
+ ucm_block_ids = self.generate_hash(self.block_size, request)
+
+ external_block_ids = ucm_block_ids[hbm_hit_block_num:]
+ if not external_block_ids:
+ return 0, False
+
+ lookup_results = self.store.lookup(external_block_ids)
+ external_hit_blocks = 0
+ for i, hit in enumerate(lookup_results):
+ if not hit:
+ break
+ external_hit_blocks += 1
+ logger.info(
+ f"request_id: {request.request_id}, "
+ f"total_blocks_num: {len(ucm_block_ids)}, "
+ f"hit hbm: {hbm_hit_block_num}, "
+ f"hit external: {external_hit_blocks}"
+ )
+ if self.metrics_config:
+ self.monitor.update_stats(
+ "ConnStats",
+ {"interval_lookup_hit_rates": external_hit_blocks / len(ucm_block_ids)},
+ )
+
+ total_hit_block_num = hbm_hit_block_num + external_hit_blocks
+
+ external_hit_tokens = external_hit_blocks * self.block_size
+
+ # When all the tokens are cached in ssd or hbm,
+ # we need to recompute the last token. This if condition will be removed
+ # once vLLM scheduler provides a better solution in the future.
+ num_total_hit_tokens = total_hit_block_num * self.block_size
+ if num_total_hit_tokens == request.num_tokens:
+ external_hit_tokens -= 1
+
+ self.requests_meta[request.request_id] = RequestMeta(
+ ucm_block_ids=ucm_block_ids,
+ hbm_hit_block_num=hbm_hit_block_num,
+ total_hit_block_num=total_hit_block_num,
+ num_token_ids=len(request.all_token_ids),
+ token_processed=num_total_hit_tokens,
+ )
+
+ return external_hit_tokens, False
+
+ def update_state_after_alloc(
+ self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
+ ):
+ pass
+
+ def _generate_dispatch_meta(
+ self,
+ req_meta: RequestMeta,
+ new_tokens: int,
+ vllm_block_ids: list[int],
+ need_load: bool = True,
+ ) -> RequestDispatchMeta:
+ """
+ Request Blocks layout:
+ ----------------------------------------------------------------------------------------------------
+ | local_computed_block(HBM hit) | external_computed_block(external hit) | new_block(need to dump) |
+ ----------------------------------------------------------------------------------------------------
+ | hbm_hit_block_num | LOAD | new_blocks_num |
+ ----------------------------------------------------------------------------------------------------
+ | total_hit_block_num |
+ ----------------------------------------------------------------------------------------------------
+ | scheduled_block_num |
+ """
+
+ hbm_hit_block_num = req_meta.hbm_hit_block_num
+ total_hit_block_num = req_meta.total_hit_block_num
+ ucm_block_ids = req_meta.ucm_block_ids
+ req_meta.vllm_block_ids.extend(vllm_block_ids)
+
+ load_ucm_block_ids, load_vllm_block_ids = [], []
+ dump_ucm_block_ids, dump_vllm_block_ids = [], []
+ if need_load:
+ load_ucm_block_ids = ucm_block_ids[hbm_hit_block_num:total_hit_block_num]
+ load_vllm_block_ids = vllm_block_ids[hbm_hit_block_num:total_hit_block_num]
+
+ if req_meta.token_processed < req_meta.num_token_ids:
+ start_idx = req_meta.token_processed // self.block_size
+ end_idx = (req_meta.token_processed + new_tokens) // self.block_size
+ dump_ucm_block_ids = ucm_block_ids[start_idx:end_idx]
+ dump_vllm_block_ids = req_meta.vllm_block_ids[start_idx:end_idx]
+ req_meta.token_processed += new_tokens
+
+ return RequestDispatchMeta(
+ (load_ucm_block_ids, load_vllm_block_ids),
+ (dump_ucm_block_ids, dump_vllm_block_ids),
+ )
+
+ def build_connector_meta(
+ self, scheduler_output: SchedulerOutput
+ ) -> KVConnectorMetadata:
+ requests_dispatch_meta = {}
+ # for new request, we need to load and dump
+ for request in scheduler_output.scheduled_new_reqs:
+ request_id, vllm_block_ids = request.req_id, request.block_ids[0]
+ req_meta = self.requests_meta.get(request_id)
+ if req_meta:
+ requests_dispatch_meta[request_id] = self._generate_dispatch_meta(
+ req_meta,
+ scheduler_output.num_scheduled_tokens[request_id],
+ vllm_block_ids,
+ )
+
+ # for cached request, there are 3 situation:
+ # 1. chunked prefill: we only need dump
+ # 2. resumed: we need to handle like new request
+ # 3. TODO decode stage: nothing happened
+ scheduled_cached_reqs = scheduler_output.scheduled_cached_reqs
+ if not isinstance(scheduled_cached_reqs, list):
+ # >= 0.9.2
+ for i, request_id in enumerate(scheduled_cached_reqs.req_ids):
+ req_meta = self.requests_meta.get(request_id)
+ if req_meta:
+ new_block_ids = []
+ if scheduled_cached_reqs.new_block_ids[i] != None:
+ new_block_ids = scheduled_cached_reqs.new_block_ids[i][0]
+ requests_dispatch_meta[request_id] = self._generate_dispatch_meta(
+ req_meta,
+ scheduler_output.num_scheduled_tokens[request_id],
+ new_block_ids,
+ scheduled_cached_reqs.resumed_from_preemption[i],
+ )
+ else:
+ for request in scheduled_cached_reqs:
+ request_id = request.req_id
+ req_meta = self.requests_meta.get(request_id)
+ if req_meta:
+ requests_dispatch_meta[request_id] = self._generate_dispatch_meta(
+ req_meta,
+ scheduler_output.num_scheduled_tokens[request_id],
+ request.new_block_ids[0],
+ request.resumed_from_preemption,
+ )
+
+ # clear finished request
+ for request_id in scheduler_output.finished_req_ids:
+ self.requests_meta.pop(request_id, None)
+
+ return UCMConnectorMetadata(requests_dispatch_meta)
+
+ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"):
+ if len(self.kv_caches) > 0:
+ return
+ for layer_name in forward_context.no_compile_layers:
+ attn_layer = forward_context.no_compile_layers[layer_name]
+ if not hasattr(attn_layer, "kv_cache"):
+ continue
+
+ if layer_name not in self.kv_caches:
+ self.kv_caches[layer_name] = attn_layer.kv_cache[
+ forward_context.virtual_engine
+ ]
+ # Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to
+ # (2, num_blocks, block_size, num_kv_heads, nope_dim/rope_dim).
+ # Currently, we treat it as GQA, and use is_dsa to mark it,
+ # which works but leads to space inefficiency.
+ # TODO: Optimize this to avoid unnecessary space usage.
+ sample_kv_layer = next(iter(self.kv_caches.values()))
+ if self.is_mla and len(sample_kv_layer) == 2:
+ self.is_mla = False
+ self.is_dsa = True
+ if self.kv_cache_dtype is None:
+ self.kv_cache_dtype = sample_kv_layer[0].dtype
+
+ @staticmethod
+ def _extract_layer_index(layer_name: str) -> Optional[int]:
+ """
+ Extract the layer index from the layer name.
+ """
+ for chunk in layer_name.split("."):
+ if chunk.isdigit():
+ return int(chunk)
+ return None
+
+ def _precompute_layer_offsets(self):
+ if not self.kv_caches:
+ return
+
+ sample_kv_layer = next(iter(self.kv_caches.values()))
+ elem_size = sample_kv_layer[0].element_size()
+ block_data_size = (
+ sample_kv_layer[0].numel() if self.is_mla else sample_kv_layer[0][0].numel()
+ ) * elem_size
+ layer_data_size = block_data_size if self.is_mla else block_data_size * 2
+
+ # precompute all layers offset
+ for layer_name, _ in self.kv_caches.items():
+ layer_id = self._extract_layer_index(layer_name)
+ assert layer_id is not None
+ k_offset = layer_data_size * layer_id
+ v_offset = k_offset + block_data_size if not self.is_mla else 0
+ self._layer_offset_cache[layer_name] = (k_offset, v_offset)
+
+ def _get_tensor_and_offset(
+ self, vllm_block_ids: list[int], kv_layer: torch.Tensor, layer_name: str
+ ) -> tuple[list[torch.Tensor], list[int]]:
+ """
+ GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
+ MLA: one layer shape is (num_blocks, block_size, head_size)
+ """
+ k_tensors, k_offsets = [], []
+ v_tensors, v_offsets = [], []
+ k_offset, v_offset = self._layer_offset_cache[layer_name]
+
+ for vllm_block_id in vllm_block_ids:
+ k_tensors.append(
+ kv_layer[vllm_block_id] if self.is_mla else kv_layer[0][vllm_block_id]
+ )
+ k_offsets.append(k_offset)
+ if not self.is_mla:
+ v_tensors.append(kv_layer[1][vllm_block_id])
+ v_offsets.append(v_offset)
+ return k_tensors + v_tensors, k_offsets + v_offsets
+
+ def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]):
+ if not self._layer_offset_cache:
+ self._precompute_layer_offsets()
+
+ num_layers = len(self.kv_caches)
+ num_blocks_per_layer = len(vllm_block_ids)
+ num_tensors_per_layer = num_blocks_per_layer * (1 if self.is_mla else 2)
+ dst_tensor_addr = [None] * (num_layers * num_tensors_per_layer)
+ ucm_offsets = [0] * (num_layers * num_tensors_per_layer)
+
+ idx = 0
+ for layer_name, one_layer_kv_cache in self.kv_caches.items():
+ tensors, offsets = self._get_tensor_and_offset(
+ vllm_block_ids, one_layer_kv_cache, layer_name
+ )
+ dst_tensor_addr[idx : idx + len(tensors)] = tensors
+ ucm_offsets[idx : idx + len(offsets)] = offsets
+ idx += len(tensors)
+
+ repeat_times = len(self.kv_caches) * (1 if self.is_mla else 2)
+ ucm_total_block_ids = ucm_block_ids * repeat_times
+
+ assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr)
+ return ucm_total_block_ids, ucm_offsets, dst_tensor_addr
+
+ def _broadcast(self, dst_tensor_addr: list[torch.Tensor]):
+ rec_tensor: torch.Tensor = None
+ with torch.cuda.stream(self.broadcast_stream):
+ # TODO support broadcast when PP
+ if self.global_rank == 0:
+ tensor_to_broadcast = torch.stack(dst_tensor_addr, dim=0)
+ self.broadcast_fn(tensor_to_broadcast, 0)
+ else:
+ shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape
+ # TODO create earlier
+ rec_tensor = torch.empty(
+ shape, dtype=self.kv_cache_dtype, device=self.device
+ )
+ self.broadcast_fn(rec_tensor, 0)
+ self.broadcast_stream.synchronize()
+ if self.global_rank != 0 and rec_tensor is not None:
+ for i, tensor in enumerate(dst_tensor_addr):
+ tensor.copy_(rec_tensor[i])
+
+ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
+ metadata = self._get_connector_metadata()
+ assert isinstance(metadata, UCMConnectorMetadata)
+
+ self._init_kv_caches_from_forward_context(forward_context)
+
+ request_to_task: dict[str, Optional[Task]] = {}
+ req_broadcast_addr = {}
+ is_load = False
+ num_loaded_block = 0
+ num_loaded_request = 0
+ load_start_time = time.perf_counter() * 1000
+ for request_id, request in metadata.request_meta.items():
+ if len(request.load_block_ids[0]) == 0:
+ continue
+ is_load = True
+ num_loaded_block += len(request.load_block_ids[0])
+ num_loaded_request += 1
+
+ ucm_block_ids, vllm_block_ids = request.load_block_ids
+ if self.global_rank != 0 and not self.is_mla and not self.is_dsa:
+ for i, ucm_block_id in enumerate(ucm_block_ids):
+ ucm_block_ids[i] = str(self.request_hasher(ucm_block_id))
+ ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
+ vllm_block_ids, ucm_block_ids
+ )
+ if self.global_rank == 0 or not self.load_only_first_rank:
+ request_to_task[request_id] = self.store.load(
+ ucm_total_block_ids, ucm_offsets, dst_tensor_addr
+ )
+ else:
+ request_to_task[request_id] = None
+ req_broadcast_addr[request_id] = dst_tensor_addr
+
+ for request_id, task in request_to_task.items():
+ # TODO error handling
+ if self.global_rank == 0 or not self.load_only_first_rank:
+ if self.store.wait(task) != 0:
+ logger.error(f"request {request_id} load kv cache failed.")
+ if self.load_only_first_rank:
+ self._broadcast(req_broadcast_addr[request_id])
+ load_end_time = time.perf_counter() * 1000
+ load_speed = (
+ num_loaded_block
+ * self.block_data_size
+ / (load_end_time - load_start_time)
+ / 1024
+ / 1024
+ ) # GB/s
+ if self.metrics_config and is_load:
+ self.monitor.update_stats(
+ "ConnStats",
+ {
+ "load_requests_num": num_loaded_request,
+ "load_blocks_num": num_loaded_block,
+ "load_duration": load_end_time - load_start_time,
+ "load_speed": load_speed,
+ },
+ )
+
+ def wait_for_layer_load(self, layer_name: str) -> None:
+ pass
+
+ def save_kv_layer(
+ self,
+ layer_name: str,
+ kv_layer: torch.Tensor,
+ attn_metadata: "AttentionMetadata",
+ **kwargs,
+ ) -> None:
+ pass
+
+ def wait_for_save(self) -> None:
+
+ # TODO support PP
+ if (self.is_mla or self.is_dsa) and self.global_rank != 0:
+ return
+ if self.metrics_config:
+ self.synchronize()
+
+ metadata = self._get_connector_metadata()
+ assert isinstance(metadata, UCMConnectorMetadata)
+
+ request_to_task: dict[str, Task] = {}
+ request_to_blocks: dict[str, list[str]] = {}
+ is_save = False
+ num_saved_block = 0
+ num_saved_request = 0
+ save_start_time = time.perf_counter() * 1000
+ for request_id, request in metadata.request_meta.items():
+ if len(request.dump_block_ids[0]) == 0:
+ continue
+ is_save = True
+ num_saved_block += len(request.dump_block_ids[0])
+ num_saved_request += 1
+
+ ucm_block_ids, vllm_block_ids = request.dump_block_ids
+ if self.global_rank != 0:
+ for i, ucm_block_id in enumerate(ucm_block_ids):
+ ucm_block_ids[i] = str(self.request_hasher(ucm_block_id))
+ rets = self.store.create(ucm_block_ids)
+ end = 0
+ for i, ret in enumerate(rets):
+ if ret != 0:
+ logger.error(
+ f"create blocks for {request_id} failed, block index: {i}, ret code: {ret}"
+ )
+ break
+ end += 1
+
+ if end == 0:
+ continue
+ ucm_block_ids = ucm_block_ids[:end]
+ vllm_block_ids = vllm_block_ids[:end]
+ ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
+ vllm_block_ids, ucm_block_ids
+ )
+ request_to_task[request_id] = self.store.dump(
+ ucm_total_block_ids, ucm_offsets, dst_tensor_addr
+ )
+ request_to_blocks[request_id] = ucm_block_ids
+
+ for request_id, task in request_to_task.items():
+ ucm_block_ids = request_to_blocks[request_id]
+ if self.store.wait(task) == 0:
+ self.store.commit(ucm_block_ids, True)
+ else:
+ logger.error(f"request {request_id} dump kv cache failed.")
+ self.store.commit(ucm_block_ids, False)
+ save_end_time = time.perf_counter() * 1000
+ save_speed = (
+ num_saved_block
+ * self.block_data_size
+ / (save_end_time - save_start_time)
+ / 1024
+ / 1024
+ ) # GB/s
+ if self.metrics_config and is_save:
+ self.monitor.update_stats(
+ "ConnStats",
+ {
+ "save_requests_num": num_saved_request,
+ "save_blocks_num": num_saved_block,
+ "save_duration": save_end_time - save_start_time,
+ "save_speed": save_speed,
+ },
+ )
+
+ def clear_connector_metadata(self) -> None:
+ super().clear_connector_metadata()
+
+
+class UCMLayerWiseConnector(UCMDirectConnector):
+ """
+ This Connector means overlap:
+ load l0 -> forward l0 -> save l0
+ load l1 -> forward l1 -> save l1
+ load l2 -> forward l2 -> save l2
+ """
+
+ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
+ super().__init__(vllm_config, role)
+
+ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
+ raise NotImplementedError
+
+ def wait_for_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
+ raise NotImplementedError
+
+ def save_kv_layer(
+ self,
+ layer_name: str,
+ kv_layer: torch.Tensor,
+ attn_metadata: "AttentionMetadata",
+ **kwargs,
+ ) -> None:
+ raise NotImplementedError
+
+ def wait_for_save(self) -> None:
+ raise NotImplementedError
+
+
+class UCMPDConnector(UCMDirectConnector):
+ """
+ This Connector means overlap (especially for Decode Instance):
+ step (req0,1,2) forward -> step (req0,1,2,3) forward
+ load req3 -> load req4
+ """
+
+ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
+ super().__init__(vllm_config, role)
+
+ def get_num_new_matched_tokens(
+ self,
+ request: "Request",
+ num_computed_tokens: int,
+ ) -> tuple[int, bool]:
+ raise NotImplementedError
+
+ def get_finished(
+ self, finished_req_ids: set[str]
+ ) -> tuple[Optional[set[str]], Optional[set[str]]]:
+ """
+ Notifies worker-side connector ids of requests that have
+ finished generating tokens.
+
+ Returns:
+ ids of requests that have finished asynchronous transfer
+ (requests that previously returned True from request_finished()),
+ tuple of (sending/saving ids, recving/loading ids).
+ The finished saves/sends req ids must belong to a set provided in a
+ call to this method (this call or a prior one).
+ """
+ raise NotImplementedError
+
+
+class UCMMockConnector(UCMDirectConnector):
+ """
+ This Connector can control hit ratio, for example: if your hit ratio is 100%,
+ you can set "hit_ratio" by config or env_vars, then get_num_new_matched_tokens()
+ will reduce hit_tokens under the hit_ratio you set.
+ """
+
+ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
+ super().__init__(vllm_config, role)
+ self._hit_ratio = float(self.launch_config["hit_ratio"])
+ logger.info(f"hit_ratio: {self._hit_ratio}")
+
+ def get_num_new_matched_tokens(
+ self,
+ request: "Request",
+ num_computed_tokens: int,
+ ) -> tuple[int, bool]:
+ hit_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens)
+ expect_hit_tokens = int(self._hit_ratio * request.num_prompt_tokens)
+ if hit_tokens <= expect_hit_tokens:
+ return hit_tokens, False
+ expect_hit_block_num = expect_hit_tokens // self.block_size
+ request_meta = self.requests_meta[request.request_id]
+ request_meta.total_hit_block_num = expect_hit_block_num
+ request_meta.hbm_hit_block_num = min(
+ expect_hit_block_num, request_meta.hbm_hit_block_num
+ )
+
+ logger.info(
+ "Hijacked By MockConnector,"
+ f"request_id: {request.request_id}, "
+ f"total_blocks_num: {len(request_meta.ucm_block_ids)}, "
+ f"hit hbm: {request_meta.hbm_hit_block_num}, "
+ f"hit external: {request_meta.total_hit_block_num - request_meta.hbm_hit_block_num}"
+ )
+
+ return expect_hit_block_num * self.block_size, False
+
+
+class UCMConnector(KVConnectorBase_V1):
+ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
+ super().__init__(vllm_config=vllm_config, role=role)
+ self.connector: KVConnectorBase_V1
+ # TODO new conn by config
+ if (
+ self._vllm_config.kv_transfer_config is not None
+ and "hit_ratio"
+ in self._vllm_config.kv_transfer_config.kv_connector_extra_config
+ ):
+ self.connector = UCMMockConnector(vllm_config, role)
+ else:
+ self.connector = UCMDirectConnector(vllm_config, role)
+
+ def get_num_new_matched_tokens(
+ self,
+ request: "Request",
+ num_computed_tokens: int,
+ ) -> tuple[int, bool]:
+ """
+ Get number of new tokens that can be loaded from the
+ external KV cache beyond the num_computed_tokens.
+
+ Args:
+ request (Request): the request object.
+ num_computed_tokens (int): the number of locally
+ computed tokens for this request
+
+ Returns:
+ the number of tokens that can be loaded from the
+ external KV cache beyond what is already computed.
+ """
+ return self.connector.get_num_new_matched_tokens(request, num_computed_tokens)
+
+ def update_state_after_alloc(
+ self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
+ ):
+ """
+ Update KVConnector state after block allocation.
+ """
+ self.connector.update_state_after_alloc(request, blocks, num_external_tokens)
+
+ def build_connector_meta(
+ self, scheduler_output: SchedulerOutput
+ ) -> KVConnectorMetadata:
+ """
+ Build the connector metadata for this step.
+
+ This function should NOT modify fields in the scheduler_output.
+ Also, calling this function will reset the state of the connector.
+
+ Args:
+ scheduler_output (SchedulerOutput): the scheduler output object.
+ """
+ return self.connector.build_connector_meta(scheduler_output)
+
+ def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
+ """Set the connector metadata from the scheduler.
+
+ This function should be called by the model runner every time
+ before the model execution. The metadata will be used for runtime
+ KV cache loading and saving.
+
+ Args:
+ connector_metadata (dict): the connector metadata.
+ """
+ self.connector.bind_connector_metadata(connector_metadata)
+
+ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
+ """
+ Start loading the KV cache from the connector to vLLM's paged
+ KV buffer. This is called from the forward context before the
+ forward pass to enable async loading during model execution.
+
+ Args:
+ forward_context (ForwardContext): the forward context.
+ **kwargs: additional arguments for the load operation
+
+ Note:
+ The number of elements in kv_caches and layer_names should be
+ the same.
+
+ """
+ self.connector.start_load_kv(forward_context, **kwargs)
+
+ def wait_for_layer_load(self, layer_name: str) -> None:
+ """
+ Block until the KV for a specific layer is loaded into vLLM's
+ paged buffer. This is called from within attention layer to ensure
+ async copying from start_load_kv is complete.
+
+ This interface will be useful for layer-by-layer pipelining.
+
+ Args:
+ layer_name: the name of that layer
+ """
+ self.connector.wait_for_layer_load(layer_name)
+
+ def save_kv_layer(
+ self,
+ layer_name: str,
+ kv_layer: torch.Tensor,
+ attn_metadata: "AttentionMetadata",
+ **kwargs,
+ ) -> None:
+ """
+ Start saving the a layer of KV cache from vLLM's paged buffer
+ to the connector. This is called from within attention layer to
+ enable async copying during execution.
+
+ Args:
+ layer_name (str): the name of the layer.
+ kv_layer (torch.Tensor): the paged KV buffer of the current
+ layer in vLLM.
+ attn_metadata (AttentionMetadata): the attention metadata.
+ **kwargs: additional arguments for the save operation.
+ """
+ self.connector.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
+
+ def wait_for_save(self) -> None:
+ """
+ Block until all the save operations is done. This is called
+ as the forward context exits to ensure that the async saving
+ from save_kv_layer is complete before finishing the forward.
+
+ This prevents overwrites of paged KV buffer before saving done.
+ """
+ self.connector.wait_for_save()
+
+ def clear_connector_metadata(self) -> None:
+ """Clear the connector metadata.
+
+ This function should be called by the model runner every time
+ after the model execution.
+ """
+ self.connector.clear_connector_metadata()
diff --git a/ucm/pd/toy_proxy_server.py b/ucm/pd/toy_proxy_server.py
index 9b7f15799..99dc75825 100644
--- a/ucm/pd/toy_proxy_server.py
+++ b/ucm/pd/toy_proxy_server.py
@@ -17,53 +17,78 @@
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
- Lifespan context manager to handle startup and shutdown events.
+ Lifespan context manager to initialize clients based on mode.
"""
- # Startup: Initialize client pools for prefiller and decoder services
app.state.prefill_clients = []
app.state.decode_clients = []
-
- # Create prefill clients
- for i, (host, port) in enumerate(global_args.prefiller_instances):
- prefiller_base_url = f"http://{host}:{port}/v1"
- app.state.prefill_clients.append(
- {
- "client": httpx.AsyncClient(timeout=None, base_url=prefiller_base_url),
- "host": host,
- "port": port,
- "id": i,
- }
+ app.state.worker_clients = [] # For PD-mixed workers
+
+ if global_args.pd_disaggregation:
+ # === PD disaggregation ===
+ for i, (host, port) in enumerate(global_args.prefiller_instances):
+ base_url = f"http://{host}:{port}/v1"
+ app.state.prefill_clients.append(
+ {
+ "client": httpx.AsyncClient(timeout=None, base_url=base_url),
+ "host": host,
+ "port": port,
+ "id": i,
+ }
+ )
+
+ for i, (host, port) in enumerate(global_args.decoder_instances):
+ base_url = f"http://{host}:{port}/v1"
+ app.state.decode_clients.append(
+ {
+ "client": httpx.AsyncClient(timeout=None, base_url=base_url),
+ "host": host,
+ "port": port,
+ "id": i,
+ }
+ )
+
+ app.state.prefill_iterator = itertools.cycle(
+ range(len(app.state.prefill_clients))
)
-
- # Create decode clients
- for i, (host, port) in enumerate(global_args.decoder_instances):
- decoder_base_url = f"http://{host}:{port}/v1"
- app.state.decode_clients.append(
- {
- "client": httpx.AsyncClient(timeout=None, base_url=decoder_base_url),
- "host": host,
- "port": port,
- "id": i,
- }
+ app.state.decode_iterator = itertools.cycle(
+ range(len(app.state.decode_clients))
)
- # Initialize round-robin iterators
- app.state.prefill_iterator = itertools.cycle(range(len(app.state.prefill_clients)))
- app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients)))
+ print(
+ f"[PD Mode] Initialized {len(app.state.prefill_clients)} prefillers "
+ f"and {len(app.state.decode_clients)} decoders."
+ )
- print(
- f"Initialized {len(app.state.prefill_clients)} prefill clients "
- f"and {len(app.state.decode_clients)} decode clients."
- )
+ else:
+ # === PD mix ===
+ for i, (host, port) in enumerate(global_args.worker_instances):
+ base_url = f"http://{host}:{port}/v1"
+ app.state.worker_clients.append(
+ {
+ "client": httpx.AsyncClient(timeout=None, base_url=base_url),
+ "host": host,
+ "port": port,
+ "id": i,
+ }
+ )
+
+ app.state.worker_iterator = itertools.cycle(
+ range(len(app.state.worker_clients))
+ )
+ print(
+ f"[Mixed Mode] Initialized {len(app.state.worker_clients)} PD-mixed workers."
+ )
yield
- # Shutdown: Close all clients
- for client_info in app.state.prefill_clients:
- await client_info["client"].aclose()
-
- for client_info in app.state.decode_clients:
- await client_info["client"].aclose()
+ # Close all clients
+ for client_list in [
+ app.state.prefill_clients,
+ app.state.decode_clients,
+ app.state.worker_clients,
+ ]:
+ for client_info in client_list:
+ await client_info["client"].aclose()
# Update FastAPI app initialization to use lifespan
@@ -75,6 +100,26 @@ def parse_args():
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument(
+ "--pd-disaggregation",
+ action="store_true",
+ help="Enable PD disaggregation mode (prefill and decode separation)",
+ )
+ # For PD mix instances
+ parser.add_argument(
+ "--worker-hosts",
+ "--work-host",
+ type=str,
+ nargs="+",
+ default=["localhost"],
+ )
+ parser.add_argument(
+ "--worker-ports",
+ "--work-port",
+ type=int,
+ nargs="+",
+ default=[8100],
+ )
# For prefiller instances
parser.add_argument(
@@ -107,9 +152,15 @@ def parse_args():
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError("Number of decoder hosts must match number of decoder ports")
- # Create tuples of (host, port) for each service type
+ if len(args.worker_hosts) != len(args.worker_ports):
+ raise ValueError("Number of worker hosts must match number of worker ports")
+
+ # Create instance tuples
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
+ args.worker_instances = list(
+ zip(args.worker_hosts, args.worker_ports)
+ ) # Mixed workers
return args
@@ -120,12 +171,15 @@ def get_next_client(app, service_type: str):
Args:
app: The FastAPI app instance
- service_type: Either 'prefill' or 'decode'
+ service_type: 'worker' 、'prefill' 、'decode'
Returns:
The next client to use
"""
- if service_type == "prefill":
+ if service_type == "worker":
+ worker_idx = next(app.state.worker_iterator)
+ return app.state.worker_clients[worker_idx]
+ elif service_type == "prefill":
client_idx = next(app.state.prefill_iterator)
return app.state.prefill_clients[client_idx]
elif service_type == "decode":
@@ -183,37 +237,72 @@ async def _handle_completions(api: str, request: Request):
req_data = await request.json()
request_id = str(uuid.uuid4())
- # Get the next prefill client in round-robin fashion
- prefill_client_info = get_next_client(request.app, "prefill")
-
- # Send request to prefill service
- response = await send_request_to_service(
- prefill_client_info, api, req_data, request_id
- )
-
- # Extract the needed fields
- response_json = response.json()
-
- # Get the next decode client in round-robin fashion
- decode_client_info = get_next_client(request.app, "decode")
-
- logger.debug("Using %s %s", prefill_client_info, decode_client_info)
-
- # Stream response from decode service
- async def generate_stream():
- async for chunk in stream_service_response(
- decode_client_info, api, req_data, request_id=request_id
- ):
- yield chunk
-
- return StreamingResponse(generate_stream(), media_type="application/json")
+ headers = {
+ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
+ "X-Request-Id": request_id,
+ }
+
+ if global_args.pd_disaggregation:
+ # === PD disaggregation logic ===
+
+ # Step 1: Send request to prefiller (to trigger computation and cache KV)
+ prefill_client_info = get_next_client(request.app, "prefill")
+ prefill_req_data = req_data.copy()
+ prefill_req_data["stream"] = False
+ prefill_req_data["max_tokens"] = 1
+ if "stream_options" in prefill_req_data:
+ del prefill_req_data["stream_options"]
+
+ response = await prefill_client_info["client"].post(
+ api, json=prefill_req_data, headers=headers
+ )
+ response.raise_for_status()
+
+ # Step 2: Stream full output from decoder
+ decode_client_info = get_next_client(request.app, "decode")
+
+ logger.debug(
+ "PD-DISAGG: Prefill=%s:%d, Decode=%s:%d",
+ prefill_client_info["host"],
+ prefill_client_info["port"],
+ decode_client_info["host"],
+ decode_client_info["port"],
+ )
+
+ async def generate_stream():
+ async for chunk in stream_service_response(
+ decode_client_info, api, req_data, request_id
+ ):
+ yield chunk
+
+ return StreamingResponse(generate_stream(), media_type="application/json")
+
+ else:
+ # === PD mixed mode: Directly forward the entire stream using round-robin ===
+ worker_client_info = get_next_client(request.app, "worker")
+
+ logger.debug(
+ "PD-MIXED: Forwarding to %s:%d",
+ worker_client_info["host"],
+ worker_client_info["port"],
+ )
+
+ async def generate_stream():
+ async with worker_client_info["client"].stream(
+ "POST", api, json=req_data, headers=headers
+ ) as resp:
+ resp.raise_for_status()
+ async for chunk in resp.aiter_bytes():
+ yield chunk
+
+ return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
- print("Error occurred in disagg prefill proxy server" f" - {api} endpoint")
+ print(f"Error in proxy server - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@@ -231,12 +320,19 @@ async def handle_chat_completions(request: Request):
@app.get("/healthcheck")
async def healthcheck():
- """Simple endpoint to check if the server is running."""
- return {
- "status": "ok",
- "prefill_instances": len(app.state.prefill_clients),
- "decode_instances": len(app.state.decode_clients),
- }
+ if global_args.pd_disaggregation:
+ return {
+ "status": "ok",
+ "mode": "pd-disaggregation",
+ "prefill_instances": len(app.state.prefill_clients),
+ "decode_instances": len(app.state.decode_clients),
+ }
+ else:
+ return {
+ "status": "ok",
+ "mode": "pd-mixed",
+ "worker_instances": len(app.state.worker_clients),
+ }
if __name__ == "__main__":
diff --git a/ucm/shared/CMakeLists.txt b/ucm/shared/CMakeLists.txt
new file mode 100644
index 000000000..1f73d1e8c
--- /dev/null
+++ b/ucm/shared/CMakeLists.txt
@@ -0,0 +1,5 @@
+add_subdirectory(vendor)
+add_subdirectory(infra)
+add_subdirectory(trans)
+add_subdirectory(metrics)
+add_subdirectory(test)
diff --git a/ucm/shared/__init__.py b/ucm/shared/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/ucm/shared/infra/CMakeLists.txt b/ucm/shared/infra/CMakeLists.txt
new file mode 100644
index 000000000..ba4345dce
--- /dev/null
+++ b/ucm/shared/infra/CMakeLists.txt
@@ -0,0 +1,22 @@
+file(GLOB_RECURSE UCMINFRA_STATUS_SOURCE_FILES "status/*.*")
+add_library(infra_status OBJECT ${UCMINFRA_STATUS_SOURCE_FILES})
+target_include_directories(infra_status PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
+target_link_libraries(infra_status PUBLIC fmt)
+
+file(GLOB UCMINFRA_LOGGER_SOURCE_FILES "logger/*.*")
+file(GLOB_RECURSE UCMINFRA_LOGGER_DETAIL_SOURCE_FILES "logger/${LOGGER_BACKEND}/*.cc")
+add_library(infra_logger OBJECT ${UCMINFRA_LOGGER_SOURCE_FILES} ${UCMINFRA_LOGGER_DETAIL_SOURCE_FILES})
+target_include_directories(infra_logger PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
+target_link_libraries(infra_logger PUBLIC fmt spdlog)
+
+file(GLOB_RECURSE UCMINFRA_TEMPLATE_SOURCE_FILES "template/*.*")
+add_library(infra_template OBJECT ${UCMINFRA_TEMPLATE_SOURCE_FILES})
+target_include_directories(infra_template PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
+
+file(GLOB_RECURSE UCMINFRA_THREAD_SOURCE_FILES "thread/*.*")
+add_library(infra_thread OBJECT ${UCMINFRA_THREAD_SOURCE_FILES})
+target_include_directories(infra_thread PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
+
+file(GLOB_RECURSE UCMINFRA_TIME_SOURCE_FILES "time/*.*")
+add_library(infra_time OBJECT ${UCMINFRA_TIME_SOURCE_FILES})
+target_include_directories(infra_time PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
diff --git a/ucm/store/infra/logger/flux/flux_logger.cc b/ucm/shared/infra/logger/flux/flux_logger.cc
similarity index 100%
rename from ucm/store/infra/logger/flux/flux_logger.cc
rename to ucm/shared/infra/logger/flux/flux_logger.cc
diff --git a/ucm/store/infra/logger/logger.h b/ucm/shared/infra/logger/logger.h
similarity index 97%
rename from ucm/store/infra/logger/logger.h
rename to ucm/shared/infra/logger/logger.h
index f27dd23df..516b9e663 100644
--- a/ucm/store/infra/logger/logger.h
+++ b/ucm/shared/infra/logger/logger.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
* */
-#ifndef UNIFIEDCACHE_LOGGER_H
-#define UNIFIEDCACHE_LOGGER_H
+#ifndef UNIFIEDCACHE_INFRA_LOGGER_H
+#define UNIFIEDCACHE_INFRA_LOGGER_H
#include
#include
diff --git a/ucm/store/infra/logger/spdlog/spdlog_logger.cc b/ucm/shared/infra/logger/spdlog/spdlog_logger.cc
similarity index 100%
rename from ucm/store/infra/logger/spdlog/spdlog_logger.cc
rename to ucm/shared/infra/logger/spdlog/spdlog_logger.cc
diff --git a/ucm/shared/infra/status/status.h b/ucm/shared/infra/status/status.h
new file mode 100644
index 000000000..3711de842
--- /dev/null
+++ b/ucm/shared/infra/status/status.h
@@ -0,0 +1,90 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#ifndef UNIFIEDCACHE_INFRA_STATUS_H
+#define UNIFIEDCACHE_INFRA_STATUS_H
+
+#include
+#include
+#include
+
+namespace UC {
+
+template
+static inline constexpr int32_t __MakeStatusCode()
+{
+ return -50000 - i;
+}
+
+class Status {
+ static constexpr int32_t OK_ = 0;
+ static constexpr int32_t ERROR_ = -1;
+ static constexpr int32_t EPARAM_ = __MakeStatusCode<0>();
+ static constexpr int32_t EOOM_ = __MakeStatusCode<1>();
+ static constexpr int32_t EOSERROR_ = __MakeStatusCode<2>();
+ static constexpr int32_t EDUPLICATE_ = __MakeStatusCode<3>();
+ static constexpr int32_t ERETRY_ = __MakeStatusCode<4>();
+ static constexpr int32_t ENOOBJ_ = __MakeStatusCode<5>();
+ static constexpr int32_t ESERIALIZE_ = __MakeStatusCode<6>();
+ static constexpr int32_t EDESERIALIZE_ = __MakeStatusCode<7>();
+ static constexpr int32_t EUNSUPPORTED_ = __MakeStatusCode<8>();
+ static constexpr int32_t ENOSPACE_ = __MakeStatusCode<9>();
+ int32_t code_;
+ std::string message_;
+ explicit Status(int32_t code) : code_(code) {}
+
+public:
+ bool operator==(const Status& other) const noexcept { return code_ == other.code_; }
+ bool operator!=(const Status& other) const noexcept { return !(*this == other); }
+ int32_t Underlying() const { return code_; }
+ std::string ToString() const
+ {
+ auto str = std::to_string(code_);
+ if (message_.empty()) { return str; }
+ return fmt::format("{}, {}", str, message_);
+ }
+ constexpr bool Success() const noexcept { return code_ == OK_; }
+ constexpr bool Failure() const noexcept { return !Success(); }
+
+public:
+ Status(int32_t code, std::string message) : code_{code}, message_{std::move(message)} {}
+ static Status OK() { return Status{OK_}; }
+ static Status Error(std::string message) { return {ERROR_, std::move(message)}; }
+ static Status Error() { return Status{ERROR_}; }
+ static Status InvalidParam() { return Status{EPARAM_}; }
+ static Status OutOfMemory() { return Status{EOOM_}; }
+ static Status OsApiError() { return Status{EOSERROR_}; }
+ static Status DuplicateKey() { return Status{EDUPLICATE_}; }
+ static Status Retry() { return Status{ERETRY_}; }
+ static Status NotFound() { return Status{ENOOBJ_}; }
+ static Status SerializeFailed() { return Status{ESERIALIZE_}; }
+ static Status DeserializeFailed() { return Status{EDESERIALIZE_}; }
+ static Status Unsupported() { return Status{EUNSUPPORTED_}; }
+ static Status NoSpace() { return Status{ENOSPACE_}; }
+};
+
+inline std::string format_as(const Status& status) { return status.ToString(); }
+
+} // namespace UC
+
+#endif
diff --git a/ucm/store/infra/template/hashset.h b/ucm/shared/infra/template/hashset.h
similarity index 98%
rename from ucm/store/infra/template/hashset.h
rename to ucm/shared/infra/template/hashset.h
index b09692bc1..102f69b62 100644
--- a/ucm/store/infra/template/hashset.h
+++ b/ucm/shared/infra/template/hashset.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
* */
-#ifndef UNIFIEDCACHE_HASHSET_H
-#define UNIFIEDCACHE_HASHSET_H
+#ifndef UNIFIEDCACHE_INFRA_HASHSET_H
+#define UNIFIEDCACHE_INFRA_HASHSET_H
#include
#include
diff --git a/ucm/store/infra/template/singleton.h b/ucm/shared/infra/template/singleton.h
similarity index 94%
rename from ucm/store/infra/template/singleton.h
rename to ucm/shared/infra/template/singleton.h
index fda4957b5..f667288ee 100644
--- a/ucm/store/infra/template/singleton.h
+++ b/ucm/shared/infra/template/singleton.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
* */
-#ifndef UNIFIEDCACHE_SINGLETON_H
-#define UNIFIEDCACHE_SINGLETON_H
+#ifndef UNIFIEDCACHE_INFRA_SINGLETON_H
+#define UNIFIEDCACHE_INFRA_SINGLETON_H
namespace UC {
diff --git a/ucm/shared/infra/template/timer.h b/ucm/shared/infra/template/timer.h
new file mode 100644
index 000000000..0c9db149d
--- /dev/null
+++ b/ucm/shared/infra/template/timer.h
@@ -0,0 +1,90 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#ifndef UNIFIEDCACHE_INFRA_TIMER_H
+#define UNIFIEDCACHE_INFRA_TIMER_H
+
+#include
+#include
+#include
+#include
+#include
+
+namespace UC {
+
+template
+class Timer {
+public:
+ Timer(const std::chrono::seconds& interval, Callable&& callable)
+ : interval_(interval), callable_(callable), running_(false)
+ {
+ }
+ ~Timer()
+ {
+ {
+ std::lock_guard lg(this->mutex_);
+ this->running_ = false;
+ this->cv_.notify_one();
+ }
+ if (this->thread_.joinable()) { this->thread_.join(); }
+ }
+ bool Start()
+ {
+ {
+ std::lock_guard lg(this->mutex_);
+ if (this->running_) { return true; }
+ }
+ try {
+ this->running_ = true;
+ this->thread_ = std::thread(&Timer::Runner, this);
+ return true;
+ } catch (...) {
+ return false;
+ }
+ }
+
+private:
+ void Runner()
+ {
+ while (this->running_) {
+ {
+ std::unique_lock lg(this->mutex_);
+ this->cv_.wait_for(lg, this->interval_, [this] { return !this->running_; });
+ if (!this->running_) { break; }
+ }
+ this->callable_();
+ }
+ }
+
+private:
+ std::chrono::seconds interval_;
+ Callable callable_;
+ std::thread thread_;
+ std::mutex mutex_;
+ std::condition_variable cv_;
+ std::atomic running_;
+};
+
+} // namespace UC
+
+#endif
diff --git a/ucm/shared/infra/template/topn_heap.h b/ucm/shared/infra/template/topn_heap.h
new file mode 100644
index 000000000..737d0b19a
--- /dev/null
+++ b/ucm/shared/infra/template/topn_heap.h
@@ -0,0 +1,121 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+
+#ifndef UNIFIEDCACHE_INFRA_TOP_N_HEAP_H
+#define UNIFIEDCACHE_INFRA_TOP_N_HEAP_H
+
+#include
+#include
+#include
+
+namespace UC {
+
+template >
+class TopNHeap {
+public:
+ using ValueType = T;
+ using SizeType = uint32_t;
+ using ConstRef = const T&;
+
+private:
+ using IndexType = uint32_t;
+ std::vector val_{};
+ std::vector idx_{};
+ SizeType capacity_{0};
+ SizeType size_{0};
+ Compare cmp_{};
+
+public:
+ explicit TopNHeap(const SizeType capacity) noexcept(
+ std::is_nothrow_default_constructible_v)
+ : capacity_{capacity} {
+ val_.resize(capacity);
+ idx_.resize(capacity);
+ }
+ TopNHeap(const TopNHeap&) = delete;
+ TopNHeap(const TopNHeap&&) = delete;
+ TopNHeap& operator=(const TopNHeap&) = delete;
+ TopNHeap& operator=(const TopNHeap&&) = delete;
+ ~TopNHeap() { Clear(); }
+
+ SizeType Size() const noexcept { return size_; }
+ SizeType Capacity() const noexcept { return capacity_; }
+ bool Empty() const noexcept { return size_ == 0; }
+
+ void Push(ConstRef value) noexcept {
+ if (size_ < capacity_) {
+ val_[size_] = value;
+ idx_[size_] = size_;
+ SiftUp(size_);
+ size_++;
+ return;
+ }
+ if (cmp_(val_[idx_.front()], value)) {
+ val_[idx_.front()] = value;
+ SiftDown(0);
+ }
+ }
+ ConstRef Top() const noexcept { return val_[idx_.front()]; }
+ void Pop() noexcept {
+ idx_[0] = idx_[--size_];
+ if (size_) { SiftDown(0); }
+ }
+ void Clear() noexcept { size_ = 0; }
+private:
+ static IndexType Parent(IndexType i) noexcept { return (i - 1) / 2; }
+ static IndexType Left(IndexType i) noexcept { return 2 * i + 1; }
+ static IndexType Right(IndexType i) noexcept { return 2 * i + 2; }
+ void SiftUp(IndexType i) noexcept {
+ auto pos = i;
+ while (pos > 0) {
+ auto p = Parent(pos);
+ if (!cmp_(val_[idx_[pos]], val_[idx_[p]])) { break; }
+ std::swap(idx_[pos], idx_[p]);
+ pos = p;
+ }
+ }
+ void SiftDown(IndexType i) noexcept {
+ auto pos = i;
+ for (;;) {
+ auto l = Left(pos);
+ auto r = Right(pos);
+ auto best = pos;
+ if (l < size_ && cmp_(val_[idx_[l]], val_[idx_[best]])) { best = l; }
+ if (r < size_ && cmp_(val_[idx_[r]], val_[idx_[best]])) { best = r; }
+ if (best == pos) { break; }
+ std::swap(idx_[pos], idx_[best]);
+ pos = best;
+ }
+ }
+};
+
+template >
+class TopNFixedHeap : public TopNHeap {
+public:
+ TopNFixedHeap() : TopNHeap{N} {}
+};
+
+} // namespace UC
+
+#endif
diff --git a/ucm/store/infra/thread/index_pool.h b/ucm/shared/infra/thread/index_pool.h
similarity index 97%
rename from ucm/store/infra/thread/index_pool.h
rename to ucm/shared/infra/thread/index_pool.h
index 225ee8842..4217b7a0e 100644
--- a/ucm/store/infra/thread/index_pool.h
+++ b/ucm/shared/infra/thread/index_pool.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
* */
-#ifndef UNIFIEDCACHE_INDEX_POOL_H
-#define UNIFIEDCACHE_INDEX_POOL_H
+#ifndef UNIFIEDCACHE_INFRA_INDEX_POOL_H
+#define UNIFIEDCACHE_INFRA_INDEX_POOL_H
#include
#include
diff --git a/ucm/store/infra/thread/latch.h b/ucm/shared/infra/thread/latch.h
similarity index 68%
rename from ucm/store/infra/thread/latch.h
rename to ucm/shared/infra/thread/latch.h
index 5d1ccf2bc..fb1dcf583 100644
--- a/ucm/store/infra/thread/latch.h
+++ b/ucm/shared/infra/thread/latch.h
@@ -21,11 +21,12 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
* */
-#ifndef UNIFIEDCACHE_LATCH_H
-#define UNIFIEDCACHE_LATCH_H
+#ifndef UNIFIEDCACHE_INFRA_LATCH_H
+#define UNIFIEDCACHE_INFRA_LATCH_H
#include
#include
+#include
#include
namespace UC {
@@ -34,8 +35,22 @@ class Latch {
public:
explicit Latch(const size_t expected = 0) : counter_{expected} {}
void Up() { ++this->counter_; }
- size_t Done() { return --this->counter_; }
- void Notify() { this->cv_.notify_all(); }
+ void Done(std::function finish) noexcept
+ {
+ auto counter = this->counter_.load(std::memory_order_acquire);
+ while (counter > 0) {
+ auto desired = counter - 1;
+ if (this->counter_.compare_exchange_weak(counter, desired, std::memory_order_acq_rel)) {
+ if (desired == 0) {
+ if (finish) { finish(); }
+ std::lock_guard lg(this->mutex_);
+ this->cv_.notify_all();
+ }
+ return;
+ }
+ counter = this->counter_.load(std::memory_order_acquire);
+ }
+ }
void Wait()
{
std::unique_lock lk(this->mutex_);
@@ -51,4 +66,4 @@ class Latch {
} // namespace UC
-#endif // UNIFIEDCACHE_LATCH_H
+#endif // UNIFIEDCACHE_INFRA_LATCH_H
diff --git a/ucm/store/infra/thread/thread_pool.h b/ucm/shared/infra/thread/thread_pool.h
similarity index 70%
rename from ucm/store/infra/thread/thread_pool.h
rename to ucm/shared/infra/thread/thread_pool.h
index b212f9e0f..baa514ed7 100644
--- a/ucm/store/infra/thread/thread_pool.h
+++ b/ucm/shared/infra/thread/thread_pool.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
* */
-#ifndef UNIFIEDCACHE_THREAD_POOL_H
-#define UNIFIEDCACHE_THREAD_POOL_H
+#ifndef UNIFIEDCACHE_INFRA_THREAD_POOL_H
+#define UNIFIEDCACHE_INFRA_THREAD_POOL_H
#include
#include
@@ -33,11 +33,11 @@
namespace UC {
-template
+template
class ThreadPool {
- using WorkerInitFn = std::function;
- using WorkerFn = std::function;
- using WorkerExitFn = std::function;
+ using WorkerInitFn = std::function;
+ using WorkerFn = std::function;
+ using WorkerExitFn = std::function;
public:
ThreadPool() = default;
@@ -54,14 +54,31 @@ class ThreadPool {
if (w.joinable()) { w.join(); }
}
}
- bool Setup(
- WorkerFn&& fn, WorkerInitFn&& initFn = [] { return true; }, WorkerExitFn&& exitFn = [] {},
- const size_t nWorker = 1) noexcept
+ ThreadPool& SetWorkerFn(WorkerFn&& fn)
{
- this->initFn_ = initFn;
- this->fn_ = fn;
- this->exitFn_ = exitFn;
- std::list> start(nWorker);
+ this->fn_ = std::move(fn);
+ return *this;
+ }
+ ThreadPool& SetWorkerInitFn(WorkerInitFn&& fn)
+ {
+ this->initFn_ = std::move(fn);
+ return *this;
+ }
+ ThreadPool& SetWorkerExitFn(WorkerExitFn&& fn)
+ {
+ this->exitFn_ = std::move(fn);
+ return *this;
+ }
+ ThreadPool& SetNWorker(const size_t nWorker)
+ {
+ this->nWorker_ = nWorker;
+ return *this;
+ }
+ bool Run()
+ {
+ if (this->nWorker_ == 0) { return false; }
+ if (!this->fn_) { return false; }
+ std::list> start(this->nWorker_);
std::list> fut;
for (auto& s : start) {
fut.push_back(s.get_future());
@@ -82,31 +99,33 @@ class ThreadPool {
void Push(Task&& task) noexcept
{
std::unique_lock lk(this->mtx_);
- this->taskQ_.push_back(task);
+ this->taskQ_.push_back(std::move(task));
this->cv_.notify_one();
}
private:
void Worker(std::promise& started) noexcept
{
- auto success = this->initFn_();
+ WorkerArgs args = nullptr;
+ auto success = true;
+ if (this->initFn_) { success = this->initFn_(args); }
started.set_value(success);
while (success) {
- Task task;
std::unique_lock lk(this->mtx_);
this->cv_.wait(lk, [this] { return this->stop_ || !this->taskQ_.empty(); });
if (this->stop_) { break; }
if (this->taskQ_.empty()) { continue; }
- task = std::move(this->taskQ_.front());
+ auto task = std::make_shared(std::move(this->taskQ_.front()));
this->taskQ_.pop_front();
lk.unlock();
- this->fn_(task);
+ this->fn_(*task, args);
}
- this->exitFn_();
+ if (this->exitFn_) { this->exitFn_(args); }
}
private:
bool stop_{false};
+ size_t nWorker_{1};
std::list workers_;
WorkerInitFn initFn_;
WorkerFn fn_;
diff --git a/ucm/store/infra/time/stopwatch.h b/ucm/shared/infra/time/stopwatch.h
similarity index 95%
rename from ucm/store/infra/time/stopwatch.h
rename to ucm/shared/infra/time/stopwatch.h
index 2386f394b..c2a5bb331 100644
--- a/ucm/store/infra/time/stopwatch.h
+++ b/ucm/shared/infra/time/stopwatch.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
* */
-#ifndef UNIFIEDCACHE_STOPWATCH_H
-#define UNIFIEDCACHE_STOPWATCH_H
+#ifndef UNIFIEDCACHE_INFRA_STOPWATCH_H
+#define UNIFIEDCACHE_INFRA_STOPWATCH_H
#include
diff --git a/ucm/shared/metrics/CMakeLists.txt b/ucm/shared/metrics/CMakeLists.txt
new file mode 100644
index 000000000..3933b9f0b
--- /dev/null
+++ b/ucm/shared/metrics/CMakeLists.txt
@@ -0,0 +1,15 @@
+file(GLOB_RECURSE CORE_SRCS CONFIGURE_DEPENDS
+ "${CMAKE_CURRENT_SOURCE_DIR}/cc/stats/*.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/cc/*.cc")
+add_library(monitor_static STATIC ${CORE_SRCS})
+set_property(TARGET monitor_static PROPERTY POSITION_INDEPENDENT_CODE ON)
+target_include_directories(monitor_static PUBLIC
+ $
+ $)
+set_target_properties(monitor_static PROPERTIES OUTPUT_NAME monitor)
+
+file(GLOB_RECURSE BINDINGS_SRCS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/cpy/*.cc")
+pybind11_add_module(ucmmonitor ${BINDINGS_SRCS})
+target_link_libraries(ucmmonitor PRIVATE -Wl,--whole-archive monitor_static -Wl,--no-whole-archive)
+target_include_directories(ucmmonitor PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cc)
+set_target_properties(ucmmonitor PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
\ No newline at end of file
diff --git a/ucm/shared/metrics/cc/stats/conn_stats.cc b/ucm/shared/metrics/cc/stats/conn_stats.cc
new file mode 100644
index 000000000..edf18ac2e
--- /dev/null
+++ b/ucm/shared/metrics/cc/stats/conn_stats.cc
@@ -0,0 +1,83 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#include "conn_stats.h"
+
+namespace UC::Metrics {
+
+ConnStats::ConnStats() = default;
+
+std::string ConnStats::Name() const { return "ConnStats"; }
+
+void ConnStats::Reset()
+{
+ for (auto& v : data_) v.clear();
+}
+
+void ConnStats::Update(const std::unordered_map& params)
+{
+ for (const auto& [k, v] : params) {
+ Key id = KeyFromString(k);
+ if (id == Key::COUNT) continue;
+ EmplaceBack(id, v);
+ }
+}
+
+std::unordered_map> ConnStats::Data()
+{
+ std::unordered_map> result;
+ result["save_requests_num"] = data_[static_cast(Key::save_requests_num)];
+ result["save_blocks_num"] = data_[static_cast(Key::save_blocks_num)];
+ result["save_duration"] = data_[static_cast(Key::save_duration)];
+ result["save_speed"] = data_[static_cast(Key::save_speed)];
+ result["load_requests_num"] = data_[static_cast(Key::load_requests_num)];
+ result["load_blocks_num"] = data_[static_cast(Key::load_blocks_num)];
+ result["load_duration"] = data_[static_cast(Key::load_duration)];
+ result["load_speed"] = data_[static_cast(Key::load_speed)];
+ result["interval_lookup_hit_rates"] =
+ data_[static_cast(Key::interval_lookup_hit_rates)];
+ return result;
+}
+
+Key ConnStats::KeyFromString(const std::string& k)
+{
+ if (k == "save_requests_num") return Key::save_requests_num;
+ if (k == "save_blocks_num") return Key::save_blocks_num;
+ if (k == "save_duration") return Key::save_duration;
+ if (k == "save_speed") return Key::save_speed;
+ if (k == "load_requests_num") return Key::load_requests_num;
+ if (k == "load_blocks_num") return Key::load_blocks_num;
+ if (k == "load_duration") return Key::load_duration;
+ if (k == "load_speed") return Key::load_speed;
+ if (k == "interval_lookup_hit_rates") return Key::interval_lookup_hit_rates;
+ return Key::COUNT;
+}
+
+void ConnStats::EmplaceBack(Key id, double value)
+{
+ data_[static_cast(id)].push_back(value);
+}
+
+static Registrar registrar;
+
+} // namespace UC::Metrics
\ No newline at end of file
diff --git a/ucm/shared/metrics/cc/stats/conn_stats.h b/ucm/shared/metrics/cc/stats/conn_stats.h
new file mode 100644
index 000000000..e8cc94559
--- /dev/null
+++ b/ucm/shared/metrics/cc/stats/conn_stats.h
@@ -0,0 +1,78 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#ifndef UNIFIEDCACHE_CONNSTATS_H
+#define UNIFIEDCACHE_CONNSTATS_H
+
+#include
+#include
+#include
+#include
+#include
+#include "istats.h"
+#include "stats_registry.h"
+
+namespace UC::Metrics {
+
+enum class Key : uint8_t {
+ interval_lookup_hit_rates = 0,
+ save_requests_num,
+ save_blocks_num,
+ save_duration,
+ save_speed,
+ load_requests_num,
+ load_blocks_num,
+ load_duration,
+ load_speed,
+ COUNT
+};
+
+class ConnStats : public IStats {
+public:
+ ConnStats();
+ ~ConnStats() = default;
+
+ std::string Name() const override;
+ void Reset() override;
+ void Update(const std::unordered_map& params) override;
+ std::unordered_map> Data() override;
+
+private:
+ static constexpr std::size_t N = static_cast(Key::COUNT);
+ std::array, N> data_;
+
+ static Key KeyFromString(const std::string& k);
+ void EmplaceBack(Key id, double value);
+};
+
+struct Registrar {
+ Registrar()
+ {
+ StatsRegistry::RegisterStats(
+ "ConnStats", []() -> std::unique_ptr { return std::make_unique(); });
+ }
+};
+
+} // namespace UC::Metrics
+
+#endif // UNIFIEDCACHE_CONNSTATS_H
\ No newline at end of file
diff --git a/ucm/shared/metrics/cc/stats/istats.h b/ucm/shared/metrics/cc/stats/istats.h
new file mode 100644
index 000000000..6e8de7b32
--- /dev/null
+++ b/ucm/shared/metrics/cc/stats/istats.h
@@ -0,0 +1,45 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#ifndef UNIFIEDCACHE_ISTATS_H
+#define UNIFIEDCACHE_ISTATS_H
+
+#include
+#include
+#include
+#include
+
+namespace UC::Metrics {
+
+class IStats {
+public:
+ virtual ~IStats() = default;
+ virtual std::string Name() const = 0;
+ virtual void Update(const std::unordered_map& params) = 0;
+ virtual void Reset() = 0;
+ virtual std::unordered_map> Data() = 0;
+};
+
+} // namespace UC::Metrics
+
+#endif
\ No newline at end of file
diff --git a/ucm/shared/metrics/cc/stats_monitor.cc b/ucm/shared/metrics/cc/stats_monitor.cc
new file mode 100644
index 000000000..2d3d80266
--- /dev/null
+++ b/ucm/shared/metrics/cc/stats_monitor.cc
@@ -0,0 +1,82 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#include "stats_monitor.h"
+#include
+#include
+#include "stats/istats.h"
+#include "stats_registry.h"
+
+namespace UC::Metrics {
+
+StatsMonitor::StatsMonitor()
+{
+ auto& registry = StatsRegistry::GetInstance();
+ for (const auto& name : registry.GetRegisteredStatsNames()) {
+ stats_map_[name] = registry.CreateStats(name);
+ }
+}
+
+void StatsMonitor::CreateStats(const std::string& name)
+{
+ std::lock_guard lock(mutex_);
+ auto& registry = StatsRegistry::GetInstance();
+ stats_map_[name] = registry.CreateStats(name);
+}
+
+std::unordered_map> StatsMonitor::GetStats(const std::string& name)
+{
+ std::lock_guard lock(mutex_);
+ return stats_map_[name]->Data();
+}
+
+void StatsMonitor::ResetStats(const std::string& name)
+{
+ std::lock_guard lock(mutex_);
+ stats_map_[name]->Reset();
+}
+
+std::unordered_map>
+StatsMonitor::GetStatsAndClear(const std::string& name)
+{
+ std::lock_guard lock(mutex_);
+ auto result = stats_map_[name]->Data();
+ stats_map_[name]->Reset();
+ return result;
+}
+
+void StatsMonitor::UpdateStats(const std::string& name,
+ const std::unordered_map& params)
+{
+ std::lock_guard lock(mutex_);
+ auto it = stats_map_.find(name);
+ if (it != stats_map_.end()) { it->second->Update(params); }
+}
+
+void StatsMonitor::ResetAllStats()
+{
+ std::lock_guard lock(mutex_);
+ for (auto& [n, ptr] : stats_map_) { ptr->Reset(); }
+}
+
+} // namespace UC::Metrics
\ No newline at end of file
diff --git a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_manager.h b/ucm/shared/metrics/cc/stats_monitor.h
similarity index 55%
rename from ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_manager.h
rename to ucm/shared/metrics/cc/stats_monitor.h
index d9d6a1976..1545d4b5c 100644
--- a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_manager.h
+++ b/ucm/shared/metrics/cc/stats_monitor.h
@@ -21,39 +21,50 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
* */
-#ifndef UNIFIEDCACHE_TSF_TASK_MANAGER_H
-#define UNIFIEDCACHE_TSF_TASK_MANAGER_H
+#ifndef UNIFIEDCACHE_MONITOR_H
+#define UNIFIEDCACHE_MONITOR_H
#include
+#include
+#include
#include
#include
-#include "tsf_task_queue.h"
+#include "stats/istats.h"
-namespace UC {
+namespace UC::Metrics {
-class TsfTaskManager {
+class StatsMonitor {
public:
- Status Setup(const int32_t deviceId, const size_t streamNumber, const size_t bufferSize,
- const size_t bufferNumber, const size_t timeoutMs, const SpaceLayout* layout);
- Status Submit(std::list& tasks, const size_t size, const size_t number,
- const std::string& brief, size_t& taskId);
- Status Wait(const size_t taskId);
- Status Check(const size_t taskId, bool& finish);
+ static StatsMonitor& GetInstance()
+ {
+ static StatsMonitor inst;
+ return inst;
+ }
-private:
- void Dispatch(std::list& tasks, std::vector>& targets,
- const size_t taskId, std::shared_ptr waiter) const;
+ ~StatsMonitor() = default;
+
+ void CreateStats(const std::string& name);
+
+ std::unordered_map> GetStats(const std::string& name);
+
+ void ResetStats(const std::string& name);
+
+ std::unordered_map> GetStatsAndClear(const std::string& name);
+
+ void UpdateStats(const std::string& name,
+ const std::unordered_map& params);
+
+ void ResetAllStats();
private:
- std::mutex _mutex;
- TsfTaskSet _failureSet;
- std::unordered_map> _waiters;
- std::vector> _queues;
- size_t _qIdx{0};
- size_t _taskIdSeed{0};
- size_t _timeoutMs{0};
+ std::mutex mutex_;
+ std::unordered_map> stats_map_;
+
+ StatsMonitor();
+ StatsMonitor(const StatsMonitor&) = delete;
+ StatsMonitor& operator=(const StatsMonitor&) = delete;
};
-} // namespace UC
+} // namespace UC::Metrics
-#endif
+#endif // UNIFIEDCACHE_MONITOR_H
\ No newline at end of file
diff --git a/ucm/shared/metrics/cc/stats_registry.cc b/ucm/shared/metrics/cc/stats_registry.cc
new file mode 100644
index 000000000..c2551d9ad
--- /dev/null
+++ b/ucm/shared/metrics/cc/stats_registry.cc
@@ -0,0 +1,59 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#include "stats_registry.h"
+
+namespace UC::Metrics {
+
+StatsRegistry& StatsRegistry::GetInstance()
+{
+ static StatsRegistry inst;
+ return inst;
+}
+
+void StatsRegistry::RegisterStats(std::string name, Creator creator)
+{
+ auto& reg = GetInstance();
+ std::lock_guard lk(reg.mutex_);
+ reg.registry_[name] = creator;
+}
+
+std::unique_ptr StatsRegistry::CreateStats(const std::string& name)
+{
+ auto& reg = GetInstance();
+ std::lock_guard lk(reg.mutex_);
+ if (auto it = reg.registry_.find(name); it != reg.registry_.end()) return it->second();
+ return nullptr;
+}
+
+std::vector StatsRegistry::GetRegisteredStatsNames()
+{
+ auto& reg = GetInstance();
+ std::lock_guard lk(reg.mutex_);
+ std::vector names;
+ names.reserve(reg.registry_.size());
+ for (auto& [n, _] : reg.registry_) names.push_back(n);
+ return names;
+}
+
+} // namespace UC::Metrics
\ No newline at end of file
diff --git a/ucm/shared/metrics/cc/stats_registry.h b/ucm/shared/metrics/cc/stats_registry.h
new file mode 100644
index 000000000..c22b6617c
--- /dev/null
+++ b/ucm/shared/metrics/cc/stats_registry.h
@@ -0,0 +1,58 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#ifndef UNIFIEDCACHE_REGISTRY_H
+#define UNIFIEDCACHE_REGISTRY_H
+
+#include
+#include
+#include
+#include "stats/istats.h"
+
+namespace UC::Metrics {
+
+using Creator = std::unique_ptr (*)();
+
+class StatsRegistry {
+public:
+ static StatsRegistry& GetInstance();
+
+ static void RegisterStats(std::string name, Creator creator);
+
+ std::unique_ptr CreateStats(const std::string& name);
+
+ std::vector GetRegisteredStatsNames();
+
+private:
+ StatsRegistry() = default;
+ ~StatsRegistry() = default;
+ StatsRegistry(const StatsRegistry&) = delete;
+ StatsRegistry& operator=(const StatsRegistry&) = delete;
+
+ std::mutex mutex_;
+ std::unordered_map registry_;
+};
+
+} // namespace UC::Metrics
+
+#endif // UNIFIEDCACHE_REGISTRY_H
\ No newline at end of file
diff --git a/ucm/shared/metrics/cpy/metrics.py.cc b/ucm/shared/metrics/cpy/metrics.py.cc
new file mode 100644
index 000000000..10bfc2f97
--- /dev/null
+++ b/ucm/shared/metrics/cpy/metrics.py.cc
@@ -0,0 +1,50 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#include
+#include
+#include "stats_monitor.h"
+
+namespace py = pybind11;
+namespace UC::Metrics {
+
+void bind_monitor(py::module_& m)
+{
+ py::class_(m, "StatsMonitor")
+ .def_static("get_instance", &StatsMonitor::GetInstance, py::return_value_policy::reference)
+ .def("update_stats", &StatsMonitor::UpdateStats)
+ .def("reset_all", &StatsMonitor::ResetAllStats)
+ .def("get_stats", &StatsMonitor::GetStats)
+ .def("get_stats_and_clear", &StatsMonitor::GetStatsAndClear);
+}
+
+} // namespace UC::Metrics
+
+PYBIND11_MODULE(ucmmonitor, module)
+{
+ module.attr("project") = UCM_PROJECT_NAME;
+ module.attr("version") = UCM_PROJECT_VERSION;
+ module.attr("commit_id") = UCM_COMMIT_ID;
+ module.attr("build_type") = UCM_BUILD_TYPE;
+ UC::Metrics::bind_monitor(module);
+}
\ No newline at end of file
diff --git a/ucm/shared/metrics/observability.py b/ucm/shared/metrics/observability.py
new file mode 100644
index 000000000..fb33400ce
--- /dev/null
+++ b/ucm/shared/metrics/observability.py
@@ -0,0 +1,305 @@
+#
+# MIT License
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+
+
+import os
+import threading
+import time
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+import prometheus_client
+import yaml
+
+# Third Party
+from prometheus_client import REGISTRY
+from vllm.distributed.parallel_state import get_world_group
+
+from ucm.logger import init_logger
+from ucm.shared.metrics import ucmmonitor
+
+logger = init_logger(__name__)
+
+
+@dataclass
+class UCMEngineMetadata:
+ """Metadata for UCM engine"""
+
+ model_name: str
+ worker_id: str
+
+
+class PrometheusLogger:
+ _gauge_cls = prometheus_client.Gauge
+ _counter_cls = prometheus_client.Counter
+ _histogram_cls = prometheus_client.Histogram
+
+ def __init__(self, metadata: UCMEngineMetadata, config: Dict[str, Any]):
+ # Ensure PROMETHEUS_MULTIPROC_DIR is set before any metric registration
+ prometheus_config = config.get("prometheus", {})
+ multiproc_dir = prometheus_config.get("multiproc_dir", "/vllm-workspace")
+ if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
+ os.environ["PROMETHEUS_MULTIPROC_DIR"] = multiproc_dir
+ if not os.path.exists(multiproc_dir):
+ os.makedirs(multiproc_dir, exist_ok=True)
+
+ self.metadata = metadata
+ self.config = config
+ self.labels = self._metadata_to_labels(metadata)
+ labelnames = list(self.labels.keys())
+
+ # Initialize metrics based on configuration
+ self._init_metrics_from_config(labelnames, prometheus_config)
+
+ def _init_metrics_from_config(
+ self, labelnames: List[str], prometheus_config: Dict[str, Any]
+ ):
+ """Initialize metrics based on configuration"""
+ enabled = prometheus_config.get("enabled_metrics", {})
+
+ # Get metric name prefix from config (e.g., "ucm:")
+ # If not specified, use empty string
+ metric_prefix = prometheus_config.get("metric_prefix", "ucm:")
+
+ # Store metric mapping: metric_name -> (metric_type, attribute_name, stats_field_name)
+ # This mapping will be used in log_prometheus to dynamically log metrics
+ self.metric_mappings: Dict[str, Dict[str, str]] = {}
+
+ # Initialize counters
+ if enabled.get("counters", True):
+ counters = prometheus_config.get("counters", [])
+ for counter_cfg in counters:
+ name = counter_cfg.get("name")
+ doc = counter_cfg.get("documentation", "")
+ # Prometheus metric name with prefix
+ prometheus_name = f"{metric_prefix}{name}" if metric_prefix else name
+ # Internal attribute name for storing the metric object
+ attr_name = f"counter_{name}"
+
+ if not hasattr(self, attr_name):
+ setattr(
+ self,
+ attr_name,
+ self._counter_cls(
+ name=prometheus_name,
+ documentation=doc,
+ labelnames=labelnames,
+ ),
+ )
+ # Store mapping for dynamic logging
+ self.metric_mappings[name] = {
+ "type": "counter",
+ "attr": attr_name,
+ }
+
+ # Initialize gauges
+ if enabled.get("gauges", True):
+ gauges = prometheus_config.get("gauges", [])
+ for gauge_cfg in gauges:
+ name = gauge_cfg.get("name")
+ doc = gauge_cfg.get("documentation", "")
+ multiprocess_mode = gauge_cfg.get("multiprocess_mode", "live")
+ # Prometheus metric name with prefix
+ prometheus_name = f"{metric_prefix}{name}" if metric_prefix else name
+ # Internal attribute name
+ attr_name = f"gauge_{name}"
+
+ if not hasattr(self, attr_name):
+ setattr(
+ self,
+ attr_name,
+ self._gauge_cls(
+ name=prometheus_name,
+ documentation=doc,
+ labelnames=labelnames,
+ multiprocess_mode=multiprocess_mode,
+ ),
+ )
+ # Store mapping for dynamic logging
+ self.metric_mappings[name] = {
+ "type": "gauge",
+ "attr": attr_name,
+ }
+
+ # Initialize histograms
+ if enabled.get("histograms", True):
+ histograms = prometheus_config.get("histograms", [])
+ for hist_cfg in histograms:
+ name = hist_cfg.get("name")
+ doc = hist_cfg.get("documentation", "")
+ buckets = hist_cfg.get("buckets", [])
+ # Prometheus metric name with prefix
+ prometheus_name = f"{metric_prefix}{name}" if metric_prefix else name
+ # Internal attribute name
+ attr_name = f"histogram_{name}"
+
+ if not hasattr(self, attr_name):
+ setattr(
+ self,
+ attr_name,
+ self._histogram_cls(
+ name=prometheus_name,
+ documentation=doc,
+ labelnames=labelnames,
+ buckets=buckets,
+ ),
+ )
+ # Store mapping for dynamic logging
+ self.metric_mappings[name] = {
+ "type": "histogram",
+ "attr": attr_name,
+ }
+
+ def _log_gauge(self, gauge, data: Union[int, float]) -> None:
+ # Convenience function for logging to gauge.
+ gauge.labels(**self.labels).set(data)
+
+ def _log_counter(self, counter, data: Union[int, float]) -> None:
+ # Convenience function for logging to counter.
+ # Prevent ValueError from negative increment
+ if data < 0:
+ return
+ counter.labels(**self.labels).inc(data)
+
+ def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None:
+ # Convenience function for logging to histogram.
+ for value in data:
+ histogram.labels(**self.labels).observe(value)
+
+ def log_prometheus(self, stats: Any):
+ """Log metrics to Prometheus based on configuration file"""
+ # Dynamically log metrics based on what's configured in YAML
+ for stat_name, value in stats.items():
+ try:
+ metric_mapped = self.metric_mappings[stat_name]
+ if metric_mapped is None:
+ logger.warning(f"Stat {stat_name} not initialized.")
+ continue
+ metric_obj = getattr(self, metric_mapped["attr"], None)
+ metric_type = metric_mapped["type"]
+
+ # Log based on metric type
+ if metric_type == "counter":
+ self._log_counter(metric_obj, value)
+ elif metric_type == "gauge":
+ self._log_gauge(metric_obj, value)
+ elif metric_type == "histogram":
+ # Histograms expect a list
+ if not isinstance(value, list):
+ if value:
+ value = [value]
+ else:
+ value = []
+ self._log_histogram(metric_obj, value)
+ except Exception as e:
+ logger.warning(f"Failed to log metric {stat_name}: {e}")
+
+ @staticmethod
+ def _metadata_to_labels(metadata: UCMEngineMetadata):
+ return {
+ "model_name": metadata.model_name,
+ "worker_id": metadata.worker_id,
+ }
+
+ _instance = None
+
+ @staticmethod
+ def GetOrCreate(
+ metadata: UCMEngineMetadata,
+ config_path: str = "",
+ ) -> "PrometheusLogger":
+ if PrometheusLogger._instance is None:
+ PrometheusLogger._instance = PrometheusLogger(metadata, config_path)
+ # assert PrometheusLogger._instance.metadata == metadata, \
+ # "PrometheusLogger instance already created with different metadata"
+ if PrometheusLogger._instance.metadata != metadata:
+ logger.error(
+ "PrometheusLogger instance already created with"
+ "different metadata. This should not happen except "
+ "in test"
+ )
+ return PrometheusLogger._instance
+
+ @staticmethod
+ def GetInstance() -> "PrometheusLogger":
+ assert (
+ PrometheusLogger._instance is not None
+ ), "PrometheusLogger instance not created yet"
+ return PrometheusLogger._instance
+
+ @staticmethod
+ def GetInstanceOrNone() -> Optional["PrometheusLogger"]:
+ """
+ Returns the singleton instance of PrometheusLogger if it exists,
+ otherwise returns None.
+ """
+ return PrometheusLogger._instance
+
+
+class UCMStatsLogger:
+ def __init__(self, model_name: str, rank: int, config_path: str = ""):
+ # Create metadata
+ self.metadata = UCMEngineMetadata(
+ model_name=str(model_name), worker_id=str(rank)
+ )
+ # Load configuration
+ config = self._load_config(config_path)
+ self.log_interval = config.get("log_interval", 10)
+
+ self.monitor = ucmmonitor.StatsMonitor.get_instance()
+ self.prometheus_logger = PrometheusLogger.GetOrCreate(self.metadata, config)
+ self.is_running = True
+
+ self.thread = threading.Thread(target=self.log_worker, daemon=True)
+ self.thread.start()
+
+ def _load_config(self, config_path: str) -> Dict[str, Any]:
+ """Load configuration from YAML file"""
+ try:
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+ if config is None:
+ logger.warning(
+ f"Config file {config_path} is empty, using defaults"
+ )
+ return {}
+ return config
+ except FileNotFoundError:
+ logger.warning(f"Config file {config_path} not found, using defaults")
+ return {}
+ except yaml.YAMLError as e:
+ logger.error(f"Error parsing YAML config file {config_path}: {e}")
+ return {}
+
+ def log_worker(self):
+ while self.is_running:
+ # Use UCMStatsMonitor.get_states_and_clear() from external import
+ stats = self.monitor.get_stats_and_clear("ConnStats")
+ self.prometheus_logger.log_prometheus(stats)
+ time.sleep(self.log_interval)
+
+ def shutdown(self):
+ self.is_running = False
+ self.thread.join()
diff --git a/ucm/shared/metrics/test/test.py b/ucm/shared/metrics/test/test.py
new file mode 100644
index 000000000..246e6f880
--- /dev/null
+++ b/ucm/shared/metrics/test/test.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+#
+# MIT License
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+
+
+import os
+import sys
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from ucm.shared.metrics import ucmmonitor
+
+# import monitor
+
+mon = ucmmonitor.StatsMonitor.get_instance()
+mon.update_stats(
+ "ConnStats",
+ {
+ "save_duration": 1.2,
+ "save_speed": 300.5,
+ "load_duration": 0.8,
+ "load_speed": 450.0,
+ "interval_lookup_hit_rates": 0.95,
+ },
+)
+mon.update_stats(
+ "ConnStats",
+ {
+ "save_duration": 1.2,
+ "save_speed": 300.5,
+ "load_duration": 0.8,
+ "load_speed": 450.0,
+ "interval_lookup_hit_rates": 0.95,
+ },
+)
+
+data = mon.get_stats("ConnStats")
+print(data)
diff --git a/ucm/shared/test/CMakeLists.txt b/ucm/shared/test/CMakeLists.txt
new file mode 100644
index 000000000..07241d814
--- /dev/null
+++ b/ucm/shared/test/CMakeLists.txt
@@ -0,0 +1,11 @@
+if(BUILD_UNIT_TESTS)
+ include(GoogleTest)
+ file(GLOB_RECURSE UCMSHARED_TEST_SOURCE_FILES "./case/*.cc")
+ add_executable(ucmshared.test ${UCMSHARED_TEST_SOURCE_FILES})
+ target_include_directories(ucmshared.test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/case)
+ target_link_libraries(ucmshared.test PRIVATE
+ trans
+ gtest_main gtest mockcpp
+ )
+ gtest_discover_tests(ucmshared.test)
+endif()
diff --git a/ucm/store/test/case/infra/hashset_test.cc b/ucm/shared/test/case/infra/hashset_test.cc
similarity index 100%
rename from ucm/store/test/case/infra/hashset_test.cc
rename to ucm/shared/test/case/infra/hashset_test.cc
diff --git a/ucm/shared/test/case/trans/trans_test.cc b/ucm/shared/test/case/trans/trans_test.cc
new file mode 100644
index 000000000..d8bef64b3
--- /dev/null
+++ b/ucm/shared/test/case/trans/trans_test.cc
@@ -0,0 +1,141 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#include
+#include "trans/device.h"
+
+class UCTransUnitTest : public ::testing::Test {};
+
+TEST_F(UCTransUnitTest, CopyDataWithCE)
+{
+ const auto ok = UC::Status::OK();
+ constexpr int32_t deviceId = 0;
+ constexpr size_t size = 36 * 1024;
+ constexpr size_t number = 64 * 61;
+ UC::Trans::Device device;
+ ASSERT_EQ(device.Setup(deviceId), ok);
+ auto buffer = device.MakeBuffer();
+ auto stream = device.MakeStream();
+ auto hPtr1 = buffer->MakeHostBuffer(size * number);
+ ASSERT_NE(hPtr1, nullptr);
+ ASSERT_EQ(buffer->MakeDeviceBuffers(size, number), ok);
+ std::vector> ptrHolder;
+ ptrHolder.reserve(number);
+ void* dPtrArr[number];
+ for (size_t i = 0; i < number; i++) {
+ *(size_t*)(((char*)hPtr1.get()) + size * i) = i;
+ auto ptr = buffer->GetDeviceBuffer(size);
+ dPtrArr[i] = ptr.get();
+ ptrHolder.emplace_back(ptr);
+ }
+ auto hPtr2 = buffer->MakeHostBuffer(size * number);
+ ASSERT_NE(hPtr2, nullptr);
+ ASSERT_EQ(stream->HostToDeviceAsync(hPtr1.get(), dPtrArr, size, number), ok);
+ ASSERT_EQ(stream->DeviceToHostAsync(dPtrArr, hPtr2.get(), size, number), ok);
+ ASSERT_EQ(stream->Synchronized(), ok);
+ for (size_t i = 0; i < number; i++) {
+ ASSERT_EQ(*(size_t*)(((char*)hPtr2.get()) + size * i), i);
+ }
+}
+
+TEST_F(UCTransUnitTest, CopyDataWithSM)
+{
+ const auto ok = UC::Status::OK();
+ constexpr int32_t deviceId = 0;
+ constexpr size_t size = 36 * 1024;
+ constexpr size_t number = 64 * 61;
+ UC::Trans::Device device;
+ ASSERT_EQ(device.Setup(deviceId), ok);
+ auto buffer = device.MakeBuffer();
+ auto stream = device.MakeSMStream();
+ if (!stream) { return; }
+ auto hPtr1 = buffer->MakeHostBuffer(size * number);
+ ASSERT_NE(hPtr1, nullptr);
+ ASSERT_EQ(buffer->MakeDeviceBuffers(size, number), ok);
+ std::vector> ptrHolder;
+ ptrHolder.reserve(number);
+ void* dPtrArr[number];
+ for (size_t i = 0; i < number; i++) {
+ *(size_t*)(((char*)hPtr1.get()) + size * i) = i;
+ auto ptr = buffer->GetDeviceBuffer(size);
+ dPtrArr[i] = ptr.get();
+ ptrHolder.emplace_back(ptr);
+ }
+ auto dPtrArrOnDev = buffer->MakeDeviceBuffer(sizeof(dPtrArr));
+ ASSERT_EQ(stream->HostToDevice((void*)dPtrArr, dPtrArrOnDev.get(), sizeof(dPtrArr)), ok);
+ auto hPtr2 = buffer->MakeHostBuffer(size * number);
+ ASSERT_NE(hPtr2, nullptr);
+ ASSERT_EQ(stream->HostToDeviceAsync(hPtr1.get(), (void**)dPtrArrOnDev.get(), size, number), ok);
+ ASSERT_EQ(stream->DeviceToHostAsync((void**)dPtrArrOnDev.get(), hPtr2.get(), size, number), ok);
+ ASSERT_EQ(stream->Synchronized(), ok);
+ for (size_t i = 0; i < number; i++) {
+ ASSERT_EQ(*(size_t*)(((char*)hPtr2.get()) + size * i), i);
+ }
+}
+
+TEST_F(UCTransUnitTest, CopyDataBatchWithSM)
+{
+ const auto ok = UC::Status::OK();
+ constexpr int32_t deviceId = 0;
+ constexpr size_t size = 36 * 1024;
+ constexpr size_t number = 64 * 61;
+ UC::Trans::Device device;
+ ASSERT_EQ(device.Setup(deviceId), ok);
+ auto stream = device.MakeSMStream();
+ if (!stream) { return; }
+ auto bDev = device.MakeBuffer();
+ auto bHost1 = device.MakeBuffer();
+ auto bHost2 = device.MakeBuffer();
+ ASSERT_EQ(bDev->MakeDeviceBuffers(size, number), ok);
+ ASSERT_EQ(bHost1->MakeHostBuffers(size, number), ok);
+ ASSERT_EQ(bHost2->MakeHostBuffers(size, number), ok);
+ std::vector> devPtrHolder, host1PtrHolder, host2PtrHolder;
+ void *dPtrArr[number], *h1PtrArr[number], *h2PtrArr[number];
+ for (size_t i = 0; i < number; i++) {
+ auto d = bDev->GetDeviceBuffer(size);
+ auto h1 = bHost1->GetHostBuffer(size);
+ auto h2 = bHost2->GetHostBuffer(size);
+ dPtrArr[i] = d.get();
+ h1PtrArr[i] = h1.get();
+ *(size_t*)h1PtrArr[i] = i;
+ h2PtrArr[i] = h2.get();
+ devPtrHolder.emplace_back(d);
+ host1PtrHolder.emplace_back(h1);
+ host2PtrHolder.emplace_back(h2);
+ }
+ constexpr const auto arrSize = sizeof(void*) * number;
+ auto dPtrArrOnDev = bDev->MakeDeviceBuffer(arrSize);
+ auto h1PtrArrOnDev = bHost1->MakeDeviceBuffer(arrSize);
+ auto h2PtrArrOnDev = bHost2->MakeDeviceBuffer(arrSize);
+ ASSERT_EQ(stream->HostToDeviceAsync((void*)dPtrArr, dPtrArrOnDev.get(), arrSize), ok);
+ ASSERT_EQ(stream->HostToDeviceAsync((void*)h1PtrArr, h1PtrArrOnDev.get(), arrSize), ok);
+ ASSERT_EQ(stream->HostToDeviceAsync((void*)h2PtrArr, h2PtrArrOnDev.get(), arrSize), ok);
+ auto src = (void**)h1PtrArrOnDev.get();
+ auto dst = (void**)dPtrArrOnDev.get();
+ ASSERT_EQ(stream->HostToDeviceAsync(src, dst, size, number), ok);
+ src = (void**)dPtrArrOnDev.get();
+ dst = (void**)h2PtrArrOnDev.get();
+ ASSERT_EQ(stream->DeviceToHostAsync(src, dst, size, number), ok);
+ ASSERT_EQ(stream->Synchronized().Underlying(), ok.Underlying());
+ for (size_t i = 0; i < number; i++) { ASSERT_EQ(*(size_t*)h2PtrArr[i], i); }
+}
diff --git a/ucm/shared/test/example/trans/trans_on_cuda_example.py b/ucm/shared/test/example/trans/trans_on_cuda_example.py
new file mode 100644
index 000000000..01ffd864d
--- /dev/null
+++ b/ucm/shared/test/example/trans/trans_on_cuda_example.py
@@ -0,0 +1,265 @@
+# -*- coding: utf-8 -*-
+#
+# MIT License
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+import time
+from functools import wraps
+
+import cupy
+import numpy as np
+
+from ucm.shared.trans import ucmtrans
+
+
+def test_wrap(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ print(f"========>> Running in {func.__name__}:")
+ result = func(*args, **kwargs)
+ print()
+ return result
+
+ return wrapper
+
+
+def make_host_memory(size, number, dtype, fill=False):
+ element_size = np.dtype(dtype).itemsize
+ num_elements = size // element_size
+ host = cupy.cuda.alloc_pinned_memory(size * number)
+ host_np = np.frombuffer(host, dtype=dtype, count=num_elements)
+ if fill:
+ fixed_len = min(1024, number)
+ host_np[:fixed_len] = np.arange(fixed_len, dtype=dtype)
+ print("make:", host_np.shape, host_np.itemsize, host_np)
+ return host
+
+
+def make_batch_host_memory(size, number, dtype, fill=False):
+ element_size = np.dtype(dtype).itemsize
+ num_elements = size // element_size
+ host = []
+ for i in range(number):
+ pinned_mem = cupy.cuda.alloc_pinned_memory(size)
+ np_array = np.frombuffer(pinned_mem, dtype=dtype, count=num_elements)
+ if fill:
+ value = np.uint64(1023 + i)
+ np_array[0] = value
+ np_array[-1] = value
+ host.append(pinned_mem)
+ if i == 0:
+ print("make:", np_array.shape, np_array.itemsize, np_array)
+ return host
+
+
+def compare(host1, host2, size, dtype, show_detail=True):
+ element_size = np.dtype(dtype).itemsize
+ num_elements = size // element_size
+ host1_np = np.frombuffer(host1, dtype=dtype, count=num_elements)
+ host2_np = np.frombuffer(host2, dtype=dtype, count=num_elements)
+ if show_detail:
+ print("compare[1]:", host1_np.shape, host1_np.itemsize, host1_np)
+ print("compare[2]:", host2_np.shape, host2_np.itemsize, host2_np)
+ return np.array_equal(host1_np, host2_np)
+
+
+@test_wrap
+def trans_with_ce(d, size, number, dtype):
+ s = d.MakeStream()
+ host1 = make_host_memory(size, number, dtype, True)
+ device = [cupy.empty(size, dtype=np.uint8) for _ in range(number)]
+ device_ptr = np.array([d.data.ptr for d in device], dtype=np.uint64)
+ host2 = make_host_memory(size, number, dtype)
+ tp = time.perf_counter()
+ s.HostToDeviceScatter(host1.ptr, device_ptr, size, number)
+ s.DeviceToHostGather(device_ptr, host2.ptr, size, number)
+ cost = time.perf_counter() - tp
+ print(f"cost: {cost}s")
+ print(f"bandwidth: {size * number / cost / 1e9}GB/s")
+ assert compare(host1, host2, size, dtype)
+
+
+@test_wrap
+def trans_with_sm(d, size, number, dtype):
+ s = d.MakeSMStream()
+ host1 = make_host_memory(size, number, dtype, True)
+ device = [cupy.empty(size, dtype=np.uint8) for _ in range(number)]
+ device_ptr = np.array([d.data.ptr for d in device], dtype=np.uint64)
+ device_ptr_cupy = cupy.empty(number, dtype=np.uint64)
+ device_ptr_cupy.set(device_ptr)
+ host2 = make_host_memory(size, number, dtype)
+ tp = time.perf_counter()
+ s.HostToDeviceScatter(host1.ptr, device_ptr_cupy.data.ptr, size, number)
+ s.DeviceToHostGather(device_ptr_cupy.data.ptr, host2.ptr, size, number)
+ cost = time.perf_counter() - tp
+ print(f"cost: {cost}s")
+ print(f"bandwidth: {size * number / cost / 1e9}GB/s")
+ assert compare(host1, host2, size, dtype)
+
+
+@test_wrap
+def trans_with_ce_async(d, size, number, dtype):
+ s = d.MakeStream()
+ host1 = make_host_memory(size, number, dtype, True)
+ device = [cupy.empty(size, dtype=np.uint8) for _ in range(number)]
+ device_ptr = np.array([d.data.ptr for d in device], dtype=np.uint64)
+ host2 = make_host_memory(size, number, dtype)
+ tp = time.perf_counter()
+ s.HostToDeviceScatterAsync(host1.ptr, device_ptr, size, number)
+ s.DeviceToHostGatherAsync(device_ptr, host2.ptr, size, number)
+ s.Synchronized()
+ cost = time.perf_counter() - tp
+ print(f"cost: {cost}s")
+ print(f"bandwidth: {size * number / cost / 1e9}GB/s")
+ assert compare(host1, host2, size, dtype)
+
+
+@test_wrap
+def trans_with_sm_async(d, size, number, dtype):
+ s = d.MakeSMStream()
+ host1 = make_host_memory(size, number, dtype, True)
+ device = [cupy.empty(size, dtype=np.uint8) for _ in range(number)]
+ device_ptr = np.array([d.data.ptr for d in device], dtype=np.uint64)
+ device_ptr_cupy = cupy.empty(number, dtype=np.uint64)
+ device_ptr_cupy.set(device_ptr)
+ host2 = make_host_memory(size, number, dtype)
+ tp = time.perf_counter()
+ s.HostToDeviceScatterAsync(host1.ptr, device_ptr_cupy.data.ptr, size, number)
+ s.DeviceToHostGatherAsync(device_ptr_cupy.data.ptr, host2.ptr, size, number)
+ s.Synchronized()
+ cost = time.perf_counter() - tp
+ print(f"cost: {cost}s")
+ print(f"bandwidth: {size * number / cost / 1e9}GB/s")
+ assert compare(host1, host2, size, dtype)
+
+
+@test_wrap
+def trans_batch_with_ce(d, size, number, dtype):
+ s = d.MakeStream()
+ host1 = make_batch_host_memory(size, number, dtype, True)
+ host1_ptr = np.array([h.ptr for h in host1], dtype=np.uint64)
+ device = [cupy.empty(size, dtype=np.uint8) for _ in range(number)]
+ device_ptr = np.array([d.data.ptr for d in device], dtype=np.uint64)
+ host2 = make_batch_host_memory(size, number, dtype)
+ host2_ptr = np.array([h.ptr for h in host2], dtype=np.uint64)
+ tp = time.perf_counter()
+ s.HostToDeviceBatch(host1_ptr, device_ptr, size, number)
+ s.DeviceToHostBatch(device_ptr, host2_ptr, size, number)
+ cost = time.perf_counter() - tp
+ print(f"cost: {cost}s")
+ print(f"bandwidth: {size * number / cost / 1e9}GB/s")
+ for h1, h2 in zip(host1, host2):
+ assert compare(h1, h2, size, dtype, False)
+
+
+@test_wrap
+def trans_batch_with_sm(dev, size, number, dtype):
+ s = dev.MakeSMStream()
+ h1 = make_batch_host_memory(size, number, dtype, True)
+ h1_ptr = np.array([h.ptr for h in h1], dtype=np.uint64)
+ h1_ptr_cupy = cupy.empty(number, dtype=np.uint64)
+ h1_ptr_cupy.set(h1_ptr)
+ d = [cupy.empty(size, dtype=np.uint8) for _ in range(number)]
+ d_ptr = np.array([d.data.ptr for d in d], dtype=np.uint64)
+ d_ptr_cupy = cupy.empty(number, dtype=np.uint64)
+ d_ptr_cupy.set(d_ptr)
+ h2 = make_batch_host_memory(size, number, dtype)
+ h2_ptr = np.array([h.ptr for h in h2], dtype=np.uint64)
+ h2_ptr_cupy = cupy.empty(number, dtype=np.uint64)
+ h2_ptr_cupy.set(h2_ptr)
+ tp = time.perf_counter()
+ s.HostToDeviceBatch(h1_ptr_cupy.data.ptr, d_ptr_cupy.data.ptr, size, number)
+ s.DeviceToHostBatch(d_ptr_cupy.data.ptr, h2_ptr_cupy.data.ptr, size, number)
+ cost = time.perf_counter() - tp
+ print(f"cost: {cost}s")
+ print(f"bandwidth: {size * number / cost / 1e9}GB/s")
+ for x, y in zip(h1, h2):
+ assert compare(x, y, size, dtype, False)
+
+
+@test_wrap
+def trans_batch_with_ce_async(d, size, number, dtype):
+ s = d.MakeStream()
+ host1 = make_batch_host_memory(size, number, dtype, True)
+ host1_ptr = np.array([h.ptr for h in host1], dtype=np.uint64)
+ device = [cupy.empty(size, dtype=np.uint8) for _ in range(number)]
+ device_ptr = np.array([d.data.ptr for d in device], dtype=np.uint64)
+ host2 = make_batch_host_memory(size, number, dtype)
+ host2_ptr = np.array([h.ptr for h in host2], dtype=np.uint64)
+ tp = time.perf_counter()
+ s.HostToDeviceBatchAsync(host1_ptr, device_ptr, size, number)
+ s.DeviceToHostBatchAsync(device_ptr, host2_ptr, size, number)
+ s.Synchronized()
+ cost = time.perf_counter() - tp
+ print(f"cost: {cost}s")
+ print(f"bandwidth: {size * number / cost / 1e9}GB/s")
+ for h1, h2 in zip(host1, host2):
+ assert compare(h1, h2, size, dtype, False)
+
+
+@test_wrap
+def trans_batch_with_sm_async(dev, size, number, dtype):
+ s = dev.MakeSMStream()
+ h1 = make_batch_host_memory(size, number, dtype, True)
+ h1_ptr = np.array([h.ptr for h in h1], dtype=np.uint64)
+ h1_ptr_cupy = cupy.empty(number, dtype=np.uint64)
+ h1_ptr_cupy.set(h1_ptr)
+ d = [cupy.empty(size, dtype=np.uint8) for _ in range(number)]
+ d_ptr = np.array([d.data.ptr for d in d], dtype=np.uint64)
+ d_ptr_cupy = cupy.empty(number, dtype=np.uint64)
+ d_ptr_cupy.set(d_ptr)
+ h2 = make_batch_host_memory(size, number, dtype)
+ h2_ptr = np.array([h.ptr for h in h2], dtype=np.uint64)
+ h2_ptr_cupy = cupy.empty(number, dtype=np.uint64)
+ h2_ptr_cupy.set(h2_ptr)
+ tp = time.perf_counter()
+ s.HostToDeviceBatchAsync(h1_ptr_cupy.data.ptr, d_ptr_cupy.data.ptr, size, number)
+ s.DeviceToHostBatchAsync(d_ptr_cupy.data.ptr, h2_ptr_cupy.data.ptr, size, number)
+ s.Synchronized()
+ cost = time.perf_counter() - tp
+ print(f"cost: {cost}s")
+ print(f"bandwidth: {size * number / cost / 1e9}GB/s")
+ for x, y in zip(h1, h2):
+ assert compare(x, y, size, dtype, False)
+
+
+def main():
+ device_id = 0
+ size = 36 * 1024
+ number = 61 * 64
+ dtype = np.float16
+ print(f"ucmtrans: {ucmtrans.commit_id}-{ucmtrans.build_type}")
+ cupy.cuda.Device(device_id).use()
+ d = ucmtrans.Device()
+ d.Setup(device_id)
+ trans_with_ce(d, size, number, dtype)
+ trans_with_sm(d, size, number, dtype)
+ trans_with_ce_async(d, size, number, dtype)
+ trans_with_sm_async(d, size, number, dtype)
+ trans_batch_with_ce(d, size, number, dtype)
+ trans_batch_with_sm(d, size, number, dtype)
+ trans_batch_with_ce_async(d, size, number, dtype)
+ trans_batch_with_sm_async(d, size, number, dtype)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ucm/shared/trans/CMakeLists.txt b/ucm/shared/trans/CMakeLists.txt
new file mode 100644
index 000000000..57a1bd0aa
--- /dev/null
+++ b/ucm/shared/trans/CMakeLists.txt
@@ -0,0 +1,16 @@
+if(RUNTIME_ENVIRONMENT STREQUAL "ascend")
+ add_subdirectory(ascend)
+endif()
+if(RUNTIME_ENVIRONMENT STREQUAL "cuda")
+ add_subdirectory(cuda)
+endif()
+if(RUNTIME_ENVIRONMENT STREQUAL "simu")
+ add_subdirectory(simu)
+endif()
+target_include_directories(trans PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..)
+target_link_libraries(trans PUBLIC infra_status)
+
+file(GLOB_RECURSE UCMTRANS_CPY_SOURCE_FILES "./cpy/*.cc")
+pybind11_add_module(ucmtrans ${UCMTRANS_CPY_SOURCE_FILES})
+target_link_libraries(ucmtrans PRIVATE trans)
+set_target_properties(ucmtrans PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
diff --git a/ucm/shared/trans/__init__.py b/ucm/shared/trans/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/ucm/shared/trans/ascend/CMakeLists.txt b/ucm/shared/trans/ascend/CMakeLists.txt
new file mode 100644
index 000000000..5a1c3670c
--- /dev/null
+++ b/ucm/shared/trans/ascend/CMakeLists.txt
@@ -0,0 +1,15 @@
+set(ASCEND_ROOT "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Path to Ascend root directory")
+add_library(Ascend::ascendcl UNKNOWN IMPORTED)
+set_target_properties(Ascend::ascendcl PROPERTIES
+ INTERFACE_INCLUDE_DIRECTORIES "${ASCEND_ROOT}/include"
+ IMPORTED_LOCATION "${ASCEND_ROOT}/lib64/libascendcl.so"
+)
+add_library(trans STATIC
+ ascend_device.cc
+ ascend_buffer.cc
+ ascend_stream.cc
+)
+target_link_libraries(trans PUBLIC
+ fmt
+ Ascend::ascendcl
+)
diff --git a/ucm/store/test/case/nfsstore/tsf_task_waiter_test.cc b/ucm/shared/trans/ascend/ascend_buffer.cc
similarity index 53%
rename from ucm/store/test/case/nfsstore/tsf_task_waiter_test.cc
rename to ucm/shared/trans/ascend/ascend_buffer.cc
index bcf11a44a..cb748bda2 100644
--- a/ucm/store/test/case/nfsstore/tsf_task_waiter_test.cc
+++ b/ucm/shared/trans/ascend/ascend_buffer.cc
@@ -21,42 +21,36 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
* */
-#include
-#include
-#include
-#include "tsf_task/tsf_task_waiter.h"
+#include "ascend_buffer.h"
+#include
-class UCTsfTaskWaiterTest : public ::testing::Test {};
+namespace UC::Trans {
-TEST_F(UCTsfTaskWaiterTest, TaskTimeout)
+std::shared_ptr Trans::AscendBuffer::MakeDeviceBuffer(size_t size)
{
- UC::TsfTaskWaiter waiter{1, 1024, 1, "xxx"};
- auto fut = std::async([&] {
- std::this_thread::sleep_for(std::chrono::milliseconds(10));
- waiter.Done();
- });
- ASSERT_FALSE(waiter.Wait(1));
- fut.get();
- ASSERT_TRUE(waiter.Finish());
+ void* device = nullptr;
+ auto ret = aclrtMalloc(&device, size, ACL_MEM_TYPE_HIGH_BAND_WIDTH);
+ if (ret == ACL_SUCCESS) { return std::shared_ptr(device, aclrtFree); }
+ return nullptr;
}
-TEST_F(UCTsfTaskWaiterTest, TaskSuccess)
+std::shared_ptr Trans::AscendBuffer::MakeHostBuffer(size_t size)
{
- UC::TsfTaskWaiter waiter{1, 1024, 1, "xxx"};
- auto fut = std::async([&] { waiter.Done(); });
- ASSERT_TRUE(waiter.Wait(1000));
- ASSERT_TRUE(waiter.Finish());
- fut.get();
+ void* host = nullptr;
+ auto ret = aclrtMallocHost(&host, size);
+ if (ret == ACL_SUCCESS) { return std::shared_ptr(host, aclrtFreeHost); }
+ return nullptr;
}
-TEST_F(UCTsfTaskWaiterTest, TaskTimeoutButSuccess)
+Status Buffer::RegisterHostBuffer(void* host, size_t size, void** pDevice)
{
- UC::TsfTaskWaiter waiter{1, 1024, 1, "xxx"};
- auto fut = std::async([&] {
- std::this_thread::sleep_for(std::chrono::milliseconds(10));
- waiter.Done();
- });
- fut.get();
- ASSERT_TRUE(waiter.Finish());
- ASSERT_TRUE(waiter.Wait(1));
+ void* device = nullptr;
+ auto ret = aclrtHostRegister(host, size, ACL_HOST_REGISTER_MAPPED, &device);
+ if (ret != ACL_SUCCESS) [[unlikely]] { return Status{ret, std::to_string(ret)}; }
+ if (pDevice) { *pDevice = device; }
+ return Status::OK();
}
+
+void Buffer::UnregisterHostBuffer(void* host) { aclrtHostUnregister(host); }
+
+} // namespace UC::Trans
diff --git a/ucm/shared/trans/ascend/ascend_buffer.h b/ucm/shared/trans/ascend/ascend_buffer.h
new file mode 100644
index 000000000..3f64ce237
--- /dev/null
+++ b/ucm/shared/trans/ascend/ascend_buffer.h
@@ -0,0 +1,39 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#ifndef UNIFIEDCACHE_TRANS_ASCEND_BUFFER_H
+#define UNIFIEDCACHE_TRANS_ASCEND_BUFFER_H
+
+#include "trans/detail/reserved_buffer.h"
+
+namespace UC::Trans {
+
+class AscendBuffer : public ReservedBuffer {
+public:
+ std::shared_ptr MakeDeviceBuffer(size_t size) override;
+ std::shared_ptr MakeHostBuffer(size_t size) override;
+};
+
+} // namespace UC::Trans
+
+#endif
diff --git a/ucm/shared/trans/ascend/ascend_device.cc b/ucm/shared/trans/ascend/ascend_device.cc
new file mode 100644
index 000000000..bfc92f987
--- /dev/null
+++ b/ucm/shared/trans/ascend/ascend_device.cc
@@ -0,0 +1,62 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#include
+#include "ascend_buffer.h"
+#include "ascend_stream.h"
+#include "trans/device.h"
+
+namespace UC::Trans {
+
+Status Device::Setup(int32_t deviceId)
+{
+ if (deviceId < 0) { return Status::Error(fmt::format("invalid device id({})", deviceId)); }
+ auto ret = aclrtSetDevice(deviceId);
+ if (ret == ACL_SUCCESS) { return Status::OK(); }
+ return Status{ret, std::to_string(ret)};
+}
+
+std::unique_ptr Device::MakeStream()
+{
+ std::unique_ptr stream = nullptr;
+ try {
+ stream = std::make_unique();
+ } catch (...) {
+ return nullptr;
+ }
+ if (stream->Setup().Success()) { return stream; }
+ return nullptr;
+}
+
+std::unique_ptr Device::MakeSMStream() { return nullptr; }
+
+std::unique_ptr Device::MakeBuffer()
+{
+ try {
+ return std::make_unique();
+ } catch (...) {
+ return nullptr;
+ }
+}
+
+} // namespace UC::Trans
diff --git a/ucm/shared/trans/ascend/ascend_stream.cc b/ucm/shared/trans/ascend/ascend_stream.cc
new file mode 100644
index 000000000..104612716
--- /dev/null
+++ b/ucm/shared/trans/ascend/ascend_stream.cc
@@ -0,0 +1,177 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#include "ascend_stream.h"
+
+namespace UC::Trans {
+
+AscendStream::~AscendStream()
+{
+ if (cbThread_.joinable()) {
+ auto tid = cbThread_.native_handle();
+ (void)aclrtUnSubscribeReport(tid, stream_);
+ stop_ = true;
+ cbThread_.join();
+ }
+ if (stream_) {
+ (void)aclrtDestroyStream(stream_);
+ stream_ = nullptr;
+ }
+}
+
+Status AscendStream::Setup()
+{
+ auto ret = aclrtCreateStream(&stream_);
+ if (ret != ACL_SUCCESS) [[unlikely]] { return Status{ret, std::to_string(ret)}; }
+ cbThread_ = std::thread([this] {
+ while (!this->stop_) { (void)aclrtProcessReport(10); }
+ });
+ auto tid = cbThread_.native_handle();
+ ret = aclrtSubscribeReport(tid, stream_);
+ if (ret != ACL_SUCCESS) [[unlikely]] { return Status{ret, std::to_string(ret)}; }
+ return Status::OK();
+}
+
+Status AscendStream::DeviceToHost(void* device, void* host, size_t size)
+{
+ auto ret = aclrtMemcpy(host, size, device, size, ACL_MEMCPY_DEVICE_TO_HOST);
+ if (ret == ACL_SUCCESS) { return Status::OK(); }
+ return Status{ret, std::to_string(ret)};
+}
+
+Status AscendStream::DeviceToHost(void* device[], void* host[], size_t size, size_t number)
+{
+ auto s = DeviceToHostAsync(device, host, size, number);
+ if (s.Failure()) [[unlikely]] { return s; }
+ return Synchronized();
+}
+
+Status AscendStream::DeviceToHost(void* device[], void* host, size_t size, size_t number)
+{
+ auto s = DeviceToHostAsync(device, host, size, number);
+ if (s.Failure()) [[unlikely]] { return s; }
+ return Synchronized();
+}
+
+Status AscendStream::DeviceToHostAsync(void* device, void* host, size_t size)
+{
+ auto ret = aclrtMemcpyAsync(host, size, device, size, ACL_MEMCPY_DEVICE_TO_HOST, stream_);
+ if (ret == ACL_SUCCESS) { return Status::OK(); }
+ return Status{ret, std::to_string(ret)};
+}
+
+Status AscendStream::DeviceToHostAsync(void* device[], void* host[], size_t size, size_t number)
+{
+ for (size_t i = 0; i < number; i++) {
+ auto s = DeviceToHostAsync(device[i], host[i], size);
+ if (s.Failure()) [[unlikely]] { return s; }
+ }
+ return Status::OK();
+}
+
+Status AscendStream::DeviceToHostAsync(void* device[], void* host, size_t size, size_t number)
+{
+ for (size_t i = 0; i < number; i++) {
+ auto pHost = (void*)(((int8_t*)host) + size * i);
+ auto s = DeviceToHostAsync(device[i], pHost, size);
+ if (s.Failure()) [[unlikely]] { return s; }
+ }
+ return Status::OK();
+}
+
+Status AscendStream::HostToDevice(void* host, void* device, size_t size)
+{
+ auto ret = aclrtMemcpy(device, size, host, size, ACL_MEMCPY_HOST_TO_DEVICE);
+ if (ret == ACL_SUCCESS) { return Status::OK(); }
+ return Status{ret, std::to_string(ret)};
+}
+
+Status AscendStream::HostToDevice(void* host[], void* device[], size_t size, size_t number)
+{
+ auto s = HostToDeviceAsync(host, device, size, number);
+ if (s.Failure()) [[unlikely]] { return s; }
+ return Synchronized();
+}
+
+Status AscendStream::HostToDevice(void* host, void* device[], size_t size, size_t number)
+{
+ auto s = HostToDeviceAsync(host, device, size, number);
+ if (s.Failure()) [[unlikely]] { return s; }
+ return Synchronized();
+}
+
+Status AscendStream::HostToDeviceAsync(void* host, void* device, size_t size)
+{
+ auto ret = aclrtMemcpyAsync(device, size, host, size, ACL_MEMCPY_HOST_TO_DEVICE, stream_);
+ if (ret == ACL_SUCCESS) { return Status::OK(); }
+ return Status{ret, std::to_string(ret)};
+}
+
+Status AscendStream::HostToDeviceAsync(void* host[], void* device[], size_t size, size_t number)
+{
+ for (size_t i = 0; i < number; i++) {
+ auto s = HostToDeviceAsync(host[i], device[i], size);
+ if (s.Failure()) [[unlikely]] { return s; }
+ }
+ return Status::OK();
+}
+
+Status AscendStream::HostToDeviceAsync(void* host, void* device[], size_t size, size_t number)
+{
+ for (size_t i = 0; i < number; i++) {
+ auto pHost = (void*)(((int8_t*)host) + size * i);
+ auto s = HostToDeviceAsync(pHost, device[i], size);
+ if (s.Failure()) [[unlikely]] { return s; }
+ }
+ return Status::OK();
+}
+
+using Closure = std::function;
+
+static void Trampoline(void* data)
+{
+ auto c = static_cast(data);
+ (*c)(true);
+ delete c;
+}
+
+Status Trans::AscendStream::AppendCallback(std::function cb)
+{
+ auto c = new (std::nothrow) Closure{std::move(cb)};
+ if (!c) [[unlikely]] { return Status::Error("out of memory for appending callback"); }
+ auto ret = aclrtLaunchCallback(Trampoline, (void*)c, ACL_CALLBACK_NO_BLOCK, stream_);
+ if (ret != ACL_SUCCESS) [[unlikely]] {
+ delete c;
+ return Status{ret, std::to_string(ret)};
+ }
+ return Status::OK();
+}
+
+Status AscendStream::Synchronized()
+{
+ auto ret = aclrtSynchronizeStream(stream_);
+ if (ret == ACL_SUCCESS) { return Status::OK(); }
+ return Status{ret, std::to_string(ret)};
+}
+
+} // namespace UC::Trans
diff --git a/ucm/shared/trans/ascend/ascend_stream.h b/ucm/shared/trans/ascend/ascend_stream.h
new file mode 100644
index 000000000..8ae53d595
--- /dev/null
+++ b/ucm/shared/trans/ascend/ascend_stream.h
@@ -0,0 +1,64 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#ifndef UNIFIEDCACHE_TRANS_ASCEND_STREAM_H
+#define UNIFIEDCACHE_TRANS_ASCEND_STREAM_H
+
+#include
+#include
+#include
+#include "trans/stream.h"
+
+namespace UC::Trans {
+
+class AscendStream : public Stream {
+protected:
+ aclrtStream stream_{nullptr};
+ std::atomic_bool stop_{false};
+ std::thread cbThread_;
+
+public:
+ ~AscendStream() override;
+ Status Setup() override;
+
+ Status DeviceToHost(void* device, void* host, size_t size) override;
+ Status DeviceToHost(void* device[], void* host[], size_t size, size_t number) override;
+ Status DeviceToHost(void* device[], void* host, size_t size, size_t number) override;
+ Status DeviceToHostAsync(void* device, void* host, size_t size) override;
+ Status DeviceToHostAsync(void* device[], void* host[], size_t size, size_t number) override;
+ Status DeviceToHostAsync(void* device[], void* host, size_t size, size_t number) override;
+
+ Status HostToDevice(void* host, void* device, size_t size) override;
+ Status HostToDevice(void* host[], void* device[], size_t size, size_t number) override;
+ Status HostToDevice(void* host, void* device[], size_t size, size_t number) override;
+ Status HostToDeviceAsync(void* host, void* device, size_t size) override;
+ Status HostToDeviceAsync(void* host[], void* device[], size_t size, size_t number) override;
+ Status HostToDeviceAsync(void* host, void* device[], size_t size, size_t number) override;
+
+ Status AppendCallback(std::function cb) override;
+ Status Synchronized() override;
+};
+
+} // namespace UC::Trans
+
+#endif
diff --git a/ucm/shared/trans/buffer.h b/ucm/shared/trans/buffer.h
new file mode 100644
index 000000000..a73752513
--- /dev/null
+++ b/ucm/shared/trans/buffer.h
@@ -0,0 +1,50 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#ifndef UNIFIEDCACHE_TRANS_BUFFER_H
+#define UNIFIEDCACHE_TRANS_BUFFER_H
+
+#include
+#include "status/status.h"
+
+namespace UC::Trans {
+
+class Buffer {
+public:
+ virtual ~Buffer() = default;
+
+ virtual std::shared_ptr MakeDeviceBuffer(size_t size) = 0;
+ virtual Status MakeDeviceBuffers(size_t size, size_t number) = 0;
+ virtual std::shared_ptr GetDeviceBuffer(size_t size) = 0;
+
+ virtual std::shared_ptr MakeHostBuffer(size_t size) = 0;
+ virtual Status MakeHostBuffers(size_t size, size_t number) = 0;
+ virtual std::shared_ptr GetHostBuffer(size_t size) = 0;
+
+ static Status RegisterHostBuffer(void* host, size_t size, void** pDevice = nullptr);
+ static void UnregisterHostBuffer(void* host);
+};
+
+} // namespace UC::Trans
+
+#endif
diff --git a/ucm/shared/trans/cpy/trans.py.cc b/ucm/shared/trans/cpy/trans.py.cc
new file mode 100644
index 000000000..952fc8e69
--- /dev/null
+++ b/ucm/shared/trans/cpy/trans.py.cc
@@ -0,0 +1,192 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#include
+#include
+#include "trans/device.h"
+
+namespace py = pybind11;
+
+namespace UC::Trans {
+
+using Ptr = uintptr_t;
+using PtrArray = py::array_t;
+
+inline void ThrowIfFailed(const Status& s)
+{
+ if (s.Failure()) [[unlikely]] { throw std::runtime_error{s.ToString()}; }
+}
+
+inline void DeviceToHost(Stream& self, Ptr src, Ptr dst, size_t size)
+{
+ ThrowIfFailed(self.DeviceToHost((void*)src, (void*)dst, size));
+}
+
+inline void DeviceToHostBatch(Stream& self, py::object src, py::object dst, size_t size,
+ size_t number)
+{
+ if (py::isinstance(src)) {
+ auto device = static_cast(src.cast().request().ptr);
+ auto host = static_cast(dst.cast().request().ptr);
+ ThrowIfFailed(self.DeviceToHost(device, host, size, number));
+ } else {
+ auto device = static_cast((void*)src.cast());
+ auto host = static_cast((void*)dst.cast());
+ ThrowIfFailed(self.DeviceToHost(device, host, size, number));
+ }
+}
+
+inline void DeviceToHostGather(Stream& self, py::object src, Ptr dst, size_t size, size_t number)
+{
+ if (py::isinstance(src)) {
+ auto device = static_cast(src.cast().request().ptr);
+ ThrowIfFailed(self.DeviceToHost(device, (void*)dst, size, number));
+ } else {
+ auto device = static_cast((void*)src.cast());
+ ThrowIfFailed(self.DeviceToHost(device, (void*)dst, size, number));
+ }
+}
+
+inline void DeviceToHostAsync(Stream& self, Ptr src, Ptr dst, size_t size)
+{
+ ThrowIfFailed(self.DeviceToHostAsync((void*)src, (void*)dst, size));
+}
+
+inline void DeviceToHostBatchAsync(Stream& self, py::object src, py::object dst, size_t size,
+ size_t number)
+{
+ if (py::isinstance(src)) {
+ auto device = static_cast(src.cast().request().ptr);
+ auto host = static_cast(dst.cast().request().ptr);
+ ThrowIfFailed(self.DeviceToHostAsync(device, host, size, number));
+ } else {
+ auto device = static_cast((void*)src.cast());
+ auto host = static_cast((void*)dst.cast());
+ ThrowIfFailed(self.DeviceToHostAsync(device, host, size, number));
+ }
+}
+
+inline void DeviceToHostGatherAsync(Stream& self, py::object src, Ptr dst, size_t size,
+ size_t number)
+{
+ if (py::isinstance(src)) {
+ auto device = static_cast(src.cast().request().ptr);
+ ThrowIfFailed(self.DeviceToHostAsync(device, (void*)dst, size, number));
+ } else {
+ auto device = static_cast((void*)src.cast());
+ ThrowIfFailed(self.DeviceToHostAsync(device, (void*)dst, size, number));
+ }
+}
+
+inline void HostToDevice(Stream& self, Ptr src, Ptr dst, size_t size)
+{
+ ThrowIfFailed(self.HostToDevice((void*)src, (void*)dst, size));
+}
+
+inline void HostToDeviceBatch(Stream& self, py::object src, py::object dst, size_t size,
+ size_t number)
+{
+ if (py::isinstance(src)) {
+ auto host = static_cast(src.cast().request().ptr);
+ auto device = static_cast(dst.cast().request().ptr);
+ ThrowIfFailed(self.HostToDevice(host, device, size, number));
+ } else {
+ auto host = static_cast((void*)src.cast());
+ auto device = static_cast((void*)dst.cast());
+ ThrowIfFailed(self.HostToDevice(host, device, size, number));
+ }
+}
+
+inline void HostToDeviceScatter(Stream& self, Ptr src, py::object dst, size_t size, size_t number)
+{
+ if (py::isinstance(dst)) {
+ auto device = static_cast(dst.cast().request().ptr);
+ ThrowIfFailed(self.HostToDevice((void*)src, device, size, number));
+ } else {
+ auto device = static_cast((void*)dst.cast());
+ ThrowIfFailed(self.HostToDevice((void*)src, device, size, number));
+ }
+}
+
+inline void HostToDeviceAsync(Stream& self, Ptr src, Ptr dst, size_t size)
+{
+ ThrowIfFailed(self.HostToDeviceAsync((void*)src, (void*)dst, size));
+}
+
+inline void HostToDeviceBatchAsync(Stream& self, py::object src, py::object dst, size_t size,
+ size_t number)
+{
+ if (py::isinstance(src)) {
+ auto host = static_cast(src.cast().request().ptr);
+ auto device = static_cast(dst.cast().request().ptr);
+ ThrowIfFailed(self.HostToDeviceAsync(host, device, size, number));
+ } else {
+ auto host = static_cast((void*)src.cast());
+ auto device = static_cast((void*)dst.cast());
+ ThrowIfFailed(self.HostToDeviceAsync(host, device, size, number));
+ }
+}
+
+inline void HostToDeviceScatterAsync(Stream& self, Ptr src, py::object dst, size_t size,
+ size_t number)
+{
+ if (py::isinstance(dst)) {
+ auto device = static_cast(dst.cast().request().ptr);
+ ThrowIfFailed(self.HostToDeviceAsync((void*)src, device, size, number));
+ } else {
+ auto device = static_cast((void*)dst.cast());
+ ThrowIfFailed(self.HostToDeviceAsync((void*)src, device, size, number));
+ }
+}
+
+} // namespace UC::Trans
+
+PYBIND11_MODULE(ucmtrans, m)
+{
+ using namespace UC::Trans;
+ m.attr("project") = UCM_PROJECT_NAME;
+ m.attr("version") = UCM_PROJECT_VERSION;
+ m.attr("commit_id") = UCM_COMMIT_ID;
+ m.attr("build_type") = UCM_BUILD_TYPE;
+
+ auto s = py::class_>(m, "Stream");
+ s.def("DeviceToHost", &DeviceToHost);
+ s.def("DeviceToHostBatch", &DeviceToHostBatch);
+ s.def("DeviceToHostGather", &DeviceToHostGather);
+ s.def("DeviceToHostAsync", &DeviceToHostAsync);
+ s.def("DeviceToHostBatchAsync", &DeviceToHostBatchAsync);
+ s.def("DeviceToHostGatherAsync", &DeviceToHostGatherAsync);
+ s.def("HostToDevice", &HostToDevice);
+ s.def("HostToDeviceBatch", &HostToDeviceBatch);
+ s.def("HostToDeviceScatter", &HostToDeviceScatter);
+ s.def("HostToDeviceAsync", &HostToDeviceAsync);
+ s.def("HostToDeviceBatchAsync", &HostToDeviceBatchAsync);
+ s.def("HostToDeviceScatterAsync", &HostToDeviceScatterAsync);
+ s.def("Synchronized", [](Stream& self) { ThrowIfFailed(self.Synchronized()); });
+
+ auto d = py::class_(m, "Device");
+ d.def(py::init<>());
+ d.def("Setup", [](Device& self, int32_t deviceId) { ThrowIfFailed(self.Setup(deviceId)); });
+ d.def("MakeStream", &Device::MakeStream);
+ d.def("MakeSMStream", &Device::MakeSMStream);
+}
diff --git a/ucm/shared/trans/cuda/CMakeLists.txt b/ucm/shared/trans/cuda/CMakeLists.txt
new file mode 100644
index 000000000..b98de9985
--- /dev/null
+++ b/ucm/shared/trans/cuda/CMakeLists.txt
@@ -0,0 +1,23 @@
+set(CUDA_ROOT "/usr/local/cuda/" CACHE PATH "Path to CUDA root directory")
+set(CMAKE_CUDA_COMPILER ${CUDA_ROOT}/bin/nvcc)
+set(CMAKE_CUDA_ARCHITECTURES 75 80 86 89 90)
+enable_language(CUDA)
+add_library(kernel OBJECT cuda_sm_kernel.cu)
+target_compile_options(kernel PRIVATE
+ --diag-suppress=128 --diag-suppress=2417 --diag-suppress=2597
+ -Wall -fPIC
+)
+add_library(trans STATIC
+ cuda_device.cc
+ cuda_buffer.cc
+ cuda_stream.cc
+ cuda_sm_stream.cc
+)
+target_include_directories(trans PUBLIC ${CUDA_ROOT}/include)
+target_link_directories(trans PUBLIC ${CUDA_ROOT}/lib64)
+target_link_libraries(trans PUBLIC
+ fmt
+ cudart
+ nvidia-ml
+ kernel
+)
diff --git a/ucm/shared/trans/cuda/cuda_buffer.cc b/ucm/shared/trans/cuda/cuda_buffer.cc
new file mode 100644
index 000000000..11f25a7e2
--- /dev/null
+++ b/ucm/shared/trans/cuda/cuda_buffer.cc
@@ -0,0 +1,58 @@
+/**
+ * MIT License
+ *
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ * */
+#include "cuda_buffer.h"
+#include
+
+namespace UC::Trans {
+
+std::shared_ptr CudaBuffer::MakeDeviceBuffer(size_t size)
+{
+ void* device = nullptr;
+ auto ret = cudaMalloc(&device, size);
+ if (ret == cudaSuccess) { return std::shared_ptr(device, cudaFree); }
+ return nullptr;
+}
+
+std::shared_ptr CudaBuffer::MakeHostBuffer(size_t size)
+{
+ void* host = nullptr;
+ auto ret = cudaMallocHost(&host, size);
+ if (ret == cudaSuccess) { return std::shared_ptr