diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000..0236df923 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,38 @@ +# See https://help.github.com/articles/about-codeowners/ +# for more info about CODEOWNERS file + +* @mag1c-h @ygwpz @FangRun2 @Tarrei +/.github @Wwwzff @hek14 @ygwpz @mag1c-h @FangRun2 @Tarrei + +/ucm/sparse @wuhuxiao @wangwenxin0312 @hek14 @ygwpz @mag1c-h +/ucm/sparse/cache_blend @wuhuxiao @hek14 @ygwpz @mag1c-h +/ucm/sparse/esa @wangwenxin0312 @hek14 @ygwpz @mag1c-h +/ucm/sparse/gsa @Zbm1996 @zbb200819 @yxkyong @HaoLi980405 @wuhuxiao @hek14 @ygwpz @mag1c-h +/ucm/sparse/kvcomp @leideng @pengwwang @wuhuxiao @hek14 @ygwpz @mag1c-h +/ucm/sparse/kvstar @saki-daisuki @summer-ai007 @xwLearnsLLM @wuhuxiao @hek14 @ygwpz @mag1c-h + +/ucm/store @mag1c-h @ygwpz +/ucm/store/dramstore @harrisonyhq @mag1c-h @ygwpz +/ucm/store/localstore @mag1c-h @ygwpz +/ucm/store/mooncakestore @chinesezyc @mag1c-h @ygwpz +/ucm/store/nfsstore @mag1c-h @ygwpz + +/ucm/integration @qyh111 @harrisonyhq @ygwpz @mag1c-h @hek14 + +/ucm/pd @flesher0813 @ygwpz @mag1c-h + +/ucm/sandbox @Wwwzff @hek14 @ygwpz @mag1c-h @FangRun2 @Tarrei + +/benchmarks @flesher0813 @ygwpz @mag1c-h + +/docker @harrisonyhq @ygwpz @mag1c-h + +/docs @flesher0813 @ygwpz @mag1c-h @FangRun2 @Tarrei @hek14 +/docs/source/user-guide/sparse-attention/esa.md @wangwenxin0312 @hek14 @flesher0813 @ygwpz @mag1c-h @FangRun2 @Tarrei +/docs/source/user-guide/sparse-attention/gsa.md @Zbm1996 @zbb200819 @yxkyong @HaoLi980405 @flesher0813 @ygwpz @mag1c-h @FangRun2 @Tarrei +/docs/source/user-guide/sparse-attention/kvcomp.md @leideng @pengwwang @flesher0813 @ygwpz @mag1c-h @FangRun2 @Tarrei +/docs/source/user-guide/sparse-attention/kvstar.md @saki-daisuki @summer-ai007 @flesher0813 @ygwpz @mag1c-h @FangRun2 @Tarrei + +/examples @harrisonyhq @ygwpz @mag1c-h @hek14 + +/test @Wwwzff @ygwpz @mag1c-h diff --git a/.github/workflows/cpp-linter.yml b/.github/workflows/cpp-linter.yml new file mode 100644 index 000000000..e013a62c7 --- /dev/null +++ b/.github/workflows/cpp-linter.yml @@ -0,0 +1,34 @@ +name: cpp-linter + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "dev*", "main", "*release" ] + + +jobs: + cpp-linter: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - uses: cpp-linter/cpp-linter-action@main + id: linter + continue-on-error: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + style: file + tidy-checks: '-*' + files-changed-only: true + lines-changed-only: diff + format-review: true + thread-comments: ${{ github.event_name == 'pull_request' && 'update' }} + + - name: Fail fast?! + if: steps.linter.outputs.checks-failed != 0 + run: | + echo "some linter checks failed. ${{ steps.linter.outputs.checks-failed }}" + exit 1 diff --git a/.github/workflows/unifiedcache_test.yml b/.github/workflows/unifiedcache_test.yml index 91bb0db15..aa0dee2ad 100644 --- a/.github/workflows/unifiedcache_test.yml +++ b/.github/workflows/unifiedcache_test.yml @@ -49,7 +49,9 @@ jobs: set -euo pipefail pip install -v -e . --no-build-isolation cd \$(pip show vllm | grep Location | awk '{print \$2}') && - git apply /workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch + git apply /workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-pc.patch + git apply /workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-aggre.patch + git apply /workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch cd /workspace/unified-cache-management python3 -m unittest discover -s test " diff --git a/.gitignore b/.gitignore index 8ff7d5c29..734cf4bf1 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,15 @@ **/output/** .venv/** **/__pycache__/** -*.egg-info/** \ No newline at end of file +*.egg-info/** +reports/ +dataset/ +logs/ +.* +*.log +result_outputs/ +results/ +.cache/ +backup/ +$null +*__pycache__/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 788dc15ce..b72407387 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,7 +9,10 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) option(BUILD_UCM_STORE "build ucm store module." ON) option(BUILD_UCM_SPARSE "build ucm sparse module." ON) option(BUILD_UNIT_TESTS "build all unit test suits." OFF) -set(RUNTIME_ENVIRONMENT "simu" CACHE STRING "runtime: simu, ascend or cuda.") +option(BUILD_NUMA "build numactl library." OFF) +option(DOWNLOAD_DEPENDENCE "download dependence by cmake." ON) +set(RUNTIME_ENVIRONMENT "simu" CACHE STRING "runtime: simu, ascend, musa or cuda.") +set(LOGGER_BACKEND "spdlog" CACHE STRING "backend: spdlog or flux.") execute_process(COMMAND git rev-parse HEAD OUTPUT_VARIABLE UCM_COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE) add_compile_definitions(UCM_PROJECT_NAME="${PROJECT_NAME}") diff --git a/README.md b/README.md index 421964e65..42ab3a257 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@

-| Documentation | Website | RoadMap | 中文 | +| Documentation | Website | RoadMap | 中文 |

--- @@ -82,9 +82,10 @@ please refer to [Quick Start](./docs/source/getting-started/quick_start.md). --- ## Contact Us +1. For technical questions and feature requests, please use GitHub [Issues](https://github.com/ModelEngine-Group/unified-cache-management/issues). +2. WeChat technical discussion group: Scan the QR code below. -For technical questions and feature requests, please use -GitHub [Issues](https://github.com/ModelEngine-Group/unified-cache-management/issues). +wechat-gh ## License diff --git a/README_zh.md b/README_zh.md index 9eba04e34..6a5bee04d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -6,7 +6,7 @@

-| 文档 | 网站 | 发展路线图 | EN | +| 文档 | 网站 | 发展路线图 | EN |

--- @@ -62,7 +62,10 @@ KVStoreBase有助于实现稀疏算法与外部存储的解耦。它定义了与 --- ## 联系我们 -如需技术咨询或功能请求,请提交 GitHub [Issues](https://github.com/ModelEngine-Group/unified-cache-management/issues). +1. 如需技术咨询或功能请求,请提交 GitHub [Issues](https://github.com/ModelEngine-Group/unified-cache-management/issues)。 +2. 微信技术交流群:扫描下方二维码。 + +wechat-gh ## 许可协议 diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000..8c5f5f70b --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,128 @@ +# TraceReplay Benchmark Tool + +It accurately replays real-world request traces with original timing or dynamically generates requests using popular datasets. The tool delivers comprehensive performance metrics—including Time to First Token (TTFT), Time Per Output Token (TPOT), Inter-Token Latency (ITL), End-to-End Latency, Goodput, etc. + +--- + +## 1. Overview + +The Trace Replay feature mainly includes request generation, request sending and response receiving, as well as result calculation and saving. It can reproduce historical requests based on MoonCake's trace file and strictly send the requests according to the timestamps recorded in the trace. After execution, Trace Replay calculates key performance metrics such as Time to First Token (TTFT) and Time Per Output Token (TPOT), then outputs the results to the terminal and saves them to an Excel file. + +[Mooncake traces](https://github.com/kvcache-ai/Mooncake/tree/main/FAST25-release/traces) consist of two types of trace data: +* Conversation and Tool&Agent trace: Sampled from one hour of online request data from different clusters. +* Synthetic trace: Generated synthetically from other publicly available datasets. + +For more information, please refer to the Mooncake paper: [Mooncake-FAST25.pdf](https://github.com/kvcache-ai/Mooncake/blob/main/FAST25-release/Mooncake-FAST25.pdf) . + +Trace Replay supports two request-generation methods: +* Hash ID-based: Input tokens are generated based on the input_length and hash_ids recorded in the trace file. Each hash_id corresponds to a block, with each block containing 512 tokens. The same hash_id always maps to the identical token sequence. +* Dataset-based: Prompts are generated by invoking vLLM's benchmark module using the input_length from the trace file and the user-specified dataset name. This approach does not rely on the hash_ids present in the trace file. + +Depending on the request generation method, Trace Replay offers two modes: Trace Mode and Benchmark Mode, which can be configured by the user via the --trace-mode parameter. + +--- + +## 2. Parameter + +| Argument | Default | Help | +|-----------|---------|---------| +| --backend | None | backend framework type | +| --model | None | model path | +| --host| localhost | IP address of the inference server | +| --port | None | Port number of the inference server | +| --trace-path | None | trace jsonl file path | +| --trace-mode | trace | 'trace' to replay requests from cached trace files, 'benchmark' to generate requests dynamically using the benchmark module | +| --dataset-name | sharegpt | if enable benchmark mode, you must specify a dataset, refer to the [vLLM benchmark documentation](https://github.com/vllm-project/vllm/blob/releases/v0.9.1/benchmarks/README.md )| +| --save-prompts | False | save generated prompts with timestamp for reuse | +| --save-result | False | save the benchmark metrics to excel file | +| --result-dir | None | path to save results | + +--- + +## 3. Example + +### 1. Download example trace + +You need to download the trace jsonl file from [Mooncake traces](https://github.com/kvcache-ai/Mooncake/tree/main/FAST25-release/traces ). In the trace, each line is a JSON object representing a single request: + +``` +{ + "timestamp": 1696000000123, // ms since epoch + "input_length": 512, // number of input tokens + "output_length": 128, // expected output tokens + "hash_ids": [123, 456, 789] // seed list for deterministic prompt generation +} +``` + +### 2. Set environment variable + +Trace Replay depends on [vLLM's benchmark](https://github.com/vllm-project/vllm/tree/main/benchmarks) module, which you need to download separately. Before running Trace Replay, you must set the path to the benchmark module via an environment variable.: + +```bash +export BENCHMARK_PATH="/vllm-workspace/benchmarks" +``` + +### 3.Basic Usage + +Execute the Python script to replay a trace against a local vLLM instance: + +```bash +python3 /trace_replay.py \ + --model /home/models/dsv2-lite \ + --backend vllm \ + --trace-path /conversation_trace.jsonl \ + --trace-mode trace \ + --host 127.0.0.1 \ + --port 8000 \ + --save-result \ + --save-prompts +``` + +### 4.Results + +Successful execution results in output similar to the following: + +``` +============ Serving Benchmark Result ============ +Successful requests: 510 +Benchmark duration (s): 301.46 +Total input tokens: 7201515 +Total generated tokens: 185502 +Request throughput (req/s): 1.69 +Output token throughput (tok/s): 615.34 +Total Token throughput (tok/s): 24504.02 +---------------Time to First Token---------------- +Mean TTFT (ms): 20931.33 +Median TTFT (ms): 19119.63 +Std TTFT (ms): 17324.45 +P25 TTFT (ms): 4057.98 +P50 TTFT (ms): 19119.63 +P75 TTFT (ms): 33284.55 +P99 TTFT (ms): 64592.68 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 187.71 +Median TPOT (ms): 200.69 +Std TPOT (ms): 63.08 +P25 TPOT (ms): 144.17 +P50 TPOT (ms): 200.69 +P75 TPOT (ms): 234.55 +P99 TPOT (ms): 312.87 +---------------Inter-token Latency---------------- +Mean ITL (ms): 181.20 +Median ITL (ms): 169.18 +Std ITL (ms): 133.70 +P25 ITL (ms): 86.63 +P50 ITL (ms): 169.18 +P75 ITL (ms): 230.91 +P99 ITL (ms): 647.04 +----------------End-to-end Latency---------------- +Mean E2EL (ms): 86656.79 +Median E2EL (ms): 89218.82 +Std E2EL (ms): 43454.94 +P25 E2EL (ms): 53935.13 +P50 E2EL (ms): 89218.82 +P75 E2EL (ms): 120761.34 +P99 E2EL (ms): 171262.27 +================================================== +``` + diff --git a/benchmarks/trace_replay.py b/benchmarks/trace_replay.py index 21567d454..820216e3b 100644 --- a/benchmarks/trace_replay.py +++ b/benchmarks/trace_replay.py @@ -9,26 +9,14 @@ import time from collections import defaultdict from datetime import datetime +from typing import Optional +import aiohttp import pandas +import vllm from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase - -benchmark_path = os.environ.get("BENCHMARK_PATH") -if benchmark_path: - sys.path.append(benchmark_path) - print(f"Added benchmark path: {benchmark_path}") -else: - raise EnvironmentError( - "Missing BENCHMARK_PATH environment variable.\n" - "Usage: export BENCHMARK_PATH=" - ) - -from backend_request_func import ( - ASYNC_REQUEST_FUNCS, - RequestFuncInput, -) -from benchmark_dataset import ( +from vllm.benchmarks.datasets import ( AIMODataset, ASRDataset, BenchmarkDataset, @@ -45,13 +33,19 @@ SonnetDataset, VisionArenaDataset, ) -from benchmark_serving import ( +from vllm.benchmarks.endpoint_request_func import ( + ASYNC_REQUEST_FUNCS, + RequestFuncInput, +) +from vllm.benchmarks.serve import ( + BenchmarkMetrics, + add_cli_args, calculate_metrics, - create_argument_parser, get_tokenizer, ) logger = logging.getLogger(__name__) +REQUEST_FUC = None class TraceReplayDataset(BenchmarkDataset): @@ -89,19 +83,20 @@ def __init__( def generate_prompt( self, hash_ids: list[int], target_length: int, tokenizer ) -> str: - + DEFAULT_BLOCK_SIZE = 512 vocab_size = tokenizer.vocab_size # Use hash_ids to influence token generation base_offset = hash_ids[0] if hash_ids else 0 + token_ids = [] for i, value in enumerate(hash_ids): if value in self.hash_to_tokens: token_ids.extend(self.hash_to_tokens[value]) - elif (i + 1) * 512 <= target_length: - for j in range(512): - token_idx = i * 512 + j + elif (i + 1) * DEFAULT_BLOCK_SIZE <= target_length: + for j in range(DEFAULT_BLOCK_SIZE): + token_idx = i * DEFAULT_BLOCK_SIZE + j token_id = ( base_offset + token_idx @@ -112,10 +107,11 @@ def generate_prompt( ] ) ) % vocab_size + self.hash_to_tokens.setdefault(value, []).append(token_id) token_ids.extend(self.hash_to_tokens[value]) else: - needed = target_length - i * 512 + needed = target_length - i * DEFAULT_BLOCK_SIZE padding = [ (base_offset + len(token_ids) + j) % vocab_size for j in range(needed) @@ -357,7 +353,9 @@ def gene_prompts_by_dataset_name( return input_requests -def save_metrics_to_file(metrics, output_dir="./"): +def save_metrics_to_file( + metrics: BenchmarkMetrics, metric_percentiles: list[float], output_dir: str = "./" +): output_path = output_dir if not os.path.exists(output_path): os.makedirs(output_path, exist_ok=True) @@ -365,33 +363,32 @@ def save_metrics_to_file(metrics, output_dir="./"): outputs = {} outputs["time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - outputs["mean_ttft(ms)"] = round(metrics.mean_ttft_ms, 2) - outputs["p99_ttft(ms)"] = round(metrics.percentiles_ttft_ms[3][1], 2) - outputs["mean_tpot(ms)"] = round(metrics.mean_tpot_ms, 2) - outputs["p99_tpot(ms)"] = round(metrics.percentiles_tpot_ms[3][1], 2) - outputs["total_input_tokens"] = round(metrics.total_input, 2) - outputs["total_output_tokens"] = round(metrics.total_output, 2) - outputs["total_token_throughput(tok/s)"] = round(metrics.total_token_throughput, 2) - outputs["output_throughput(tok/s)"] = round(metrics.output_throughput, 2) - outputs["request_throughput(req/s)"] = round(metrics.request_throughput, 2) - outputs["request_goodput(req/s)"] = metrics.request_goodput + outputs["total_input_tokens"] = metrics.total_input + outputs["total_output_tokens"] = metrics.total_output outputs["completed requests"] = metrics.completed + outputs["request_throughput"] = metrics.request_throughput + outputs["request_goodput"] = metrics.request_goodput + outputs["output_throughput"] = metrics.output_throughput + outputs["total_token_throughput"] = metrics.total_token_throughput + outputs["mean_ttft_ms"] = metrics.mean_ttft_ms + for p, value in metrics.percentiles_ttft_ms: + outputs[f"p{int(p)}_ttft_ms"] = value + outputs["mean_tpot_ms"] = metrics.mean_tpot_ms + for p, value in metrics.percentiles_tpot_ms: + outputs[f"p{int(p)}_tpot_ms"] = value + outputs["mean_itl_ms"] = metrics.mean_itl_ms + for p, value in metrics.percentiles_itl_ms: + outputs[f"p{int(p)}_itl_ms"] = value + outputs["mean_e2el_ms"] = metrics.mean_e2el_ms + for p, value in metrics.percentiles_e2el_ms: + outputs[f"p{int(p)}_e2el_ms"] = value df = pandas.DataFrame([outputs]) - if os.path.isfile(excel_file): - try: - existing_df = pandas.read_excel(excel_file) - updated_df = pandas.concat([existing_df, df], ignore_index=True) - except Exception as e: - print( - f"Warning: Failed to read {excel_file}, it will be overwritten. Error: {e}" - ) - updated_df = df - else: - updated_df = df - # Save back to Excel (automatically create or overwrite) - with pandas.ExcelWriter(excel_file, engine="openpyxl", mode="w") as writer: - updated_df.to_excel(writer, index=False, sheet_name="Performance Metrics") + os.makedirs(os.path.dirname(excel_file), exist_ok=True) + with pandas.ExcelWriter( + excel_file, engine="openpyxl", mode="a", if_sheet_exists="replace" + ) as writer: + df.to_excel(writer, index=False, sheet_name="Metrics") print(f"Successfully saved performance metrics to {excel_file}") @@ -399,8 +396,9 @@ def save_req_results_to_file(outputs, output_dir="./"): output_path = output_dir if not os.path.exists(output_path): os.makedirs(output_path, exist_ok=True) - excel_file = os.path.join(output_path, "req_results.xlsx") + excel_file = os.path.join(output_path, "metrics.xlsx") rows = [] + # print(f"outputs: {outputs}") for output in outputs: ttft = output.ttft * 1000 if output.ttft is not None else None output_len = output.output_tokens if output.output_tokens is not None else 0 @@ -410,29 +408,54 @@ def save_req_results_to_file(outputs, output_dir="./"): if output_len > 1 and output.ttft is not None and output.latency is not None: tpot = (output.latency - output.ttft) / (output_len - 1) * 1000 row = { - "ttft(ms)": ttft, - "tpot(ms)": tpot, - "e2el(ms)": latency, - "input_tokens": input_len, - "output_tokens": output_len, - "success": output.success, + "input_lens": input_len, + "output_lens": output_len, + "ttfts_ms": ttft, + "tpot_ms": tpot, } + if output.send_time and output.running_time: + row["send_to_funning"] = output.running_time - output.send_time + if output.running_time and output.worker_time: + row["running_to_worker"] = output.worker_time - output.running_time + if output.worker_time and output.start_loadkv_time: + row["worker_to_loadkv"] = output.start_loadkv_time - output.worker_time + if output.start_loadkv_time and output.start_forward_time: + row["loadkv_duration"] = ( + output.start_forward_time - output.start_loadkv_time + ) + if output.start_forward_time and output.finish_forward_time: + row["forward_duration"] = ( + output.finish_forward_time - output.start_forward_time + ) + if output.finish_forward_time and output.finish_savekv_time: + row["savekv_duration"] = ( + output.finish_savekv_time - output.finish_forward_time + ) + if output.first_token_time and output.running_time: + row["running_to_first_token"] = ( + output.first_token_time - output.running_time + ) + row["success"] = output.success rows.append(row) + df = pandas.DataFrame(rows) - file_exists = os.path.isfile(excel_file) - if file_exists: - try: - existing_df = pandas.read_excel(excel_file) - updated_df = pandas.concat([existing_df, df], ignore_index=True) - except Exception as e: - print( - f"Warning: Failed to read {excel_file}, it will be overwritten. Error: {e}" - ) - updated_df = df - else: - updated_df = df - with pandas.ExcelWriter(excel_file, engine="openpyxl", mode="w") as writer: - updated_df.to_excel(writer, index=False, sheet_name="Performance Metrics") + os.makedirs(os.path.dirname(excel_file), exist_ok=True) + with pandas.ExcelWriter( + excel_file, engine="openpyxl", mode="a", if_sheet_exists="replace" + ) as writer: + df.to_excel(writer, index=False, sheet_name="details") + + +async def request_func( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, + session: aiohttp.ClientSession = None, +): + if session: + return await REQUEST_FUC( + request_func_input=request_func_input, pbar=pbar, session=session + ) + return await REQUEST_FUC(request_func_input=request_func_input, pbar=pbar) # Send requests by timestamp @@ -443,6 +466,8 @@ async def replay_trace_by_time( model_id = args.model model_name = args.served_model_name disable_tqdm = args.disable_tqdm + if backend not in ASYNC_REQUEST_FUNCS: + raise ValueError(f"Unknown backend: {backend}") if args.base_url is not None: api_url = f"{args.base_url}{args.endpoint}" @@ -451,10 +476,6 @@ async def replay_trace_by_time( api_url = f"http://{args.host}:{args.port}{args.endpoint}" base_url = f"http://{args.host}:{args.port}" - if backend not in ASYNC_REQUEST_FUNCS: - raise ValueError(f"Unknown backend: {backend}") - request_func = ASYNC_REQUEST_FUNCS[backend] - print("Starting initial single prompt test run...") test_request = None for _, reqs in sorted(req_groups.items()): @@ -463,6 +484,18 @@ async def replay_trace_by_time( break if test_request is None: raise ValueError("No request found for initial test run.") + + use_session = tuple(map(int, vllm.__version__.split(".")[:3])) >= (0, 10, 1) + session = None + if use_session: + session = vllm.aiohttp.ClientSession( + base_url=base_url, + trust_env=True, + timeout=aiohttp.ClientTimeout(total=6 * 60 * 60), + ) + global REQUEST_FUC + REQUEST_FUC = ASYNC_REQUEST_FUNCS[backend] + test_input = RequestFuncInput( model=model_id, model_name=model_name, @@ -476,7 +509,7 @@ async def replay_trace_by_time( extra_body={"temperature": 0.9}, ) - test_output = await request_func(request_func_input=test_input) + test_output = await request_func(request_func_input=test_input, session=session) if not getattr(test_output, "success", False): raise ValueError( @@ -513,9 +546,13 @@ async def _run_one_request(sample_req): ) if semaphore is not None: async with semaphore: - return await request_func(request_func_input=req_input, pbar=pbar) + return await request_func( + request_func_input=req_input, session=session, pbar=pbar + ) else: - return await request_func(request_func_input=req_input, pbar=pbar) + return await request_func( + request_func_input=req_input, session=session, pbar=pbar + ) for sec, reqs in sorted(req_groups.items()): delay = sec - (time.perf_counter() - start_time) @@ -540,13 +577,28 @@ async def send_group(r=reqs, d=delay): flat_requests.extend(reqs) group_results = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + if use_session: + await session.close() + outputs = [] for res in group_results: if isinstance(res, list): outputs.extend(res) - if pbar is not None: - pbar.close() + percentile_metrics = ( + args.metric_percentiles.split(",") + if args.metric_percentiles + else ["ttft", "tpot", "itl", "e2el"] + ) + metric_percentiles = ( + [int(x) for x in args.metric_percentiles.split(",")] + if args.metric_percentiles + else [50, 90] + ) + goodput = {k: float(v) for item in args.goodput for k, v in [item.split(":", 1)]} benchmark_duration = time.perf_counter() - start_time metrics, actual_output_lens = calculate_metrics( @@ -554,9 +606,8 @@ async def send_group(r=reqs, d=delay): outputs=outputs, dur_s=benchmark_duration, tokenizer=tokenizer, - selected_percentile_metrics=["ttft", "tpot", "itl", "e2el"], - selected_percentiles=[25.0, 50.0, 75.0, 99.0], - goodput_config_dict={"ttft": 2000, "tpot": 50}, + selected_percentiles=metric_percentiles, + goodput_config_dict=goodput, ) print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) @@ -586,7 +637,7 @@ def process_one_metric( metric_name: str, metric_header: str, ): - selected_percentile_metrics = ["ttft", "tpot", "itl", "e2el"] + selected_percentile_metrics = percentile_metrics if metric_attribute_name not in selected_percentile_metrics: return print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) @@ -621,13 +672,19 @@ def process_one_metric( output_dir = args.result_dir if args.result_dir is not None else "./" if args.save_result: - save_metrics_to_file(metrics=metrics, output_dir=output_dir) + save_metrics_to_file( + metrics=metrics, + metric_percentiles=metric_percentiles, + output_dir=output_dir, + ) save_req_results_to_file(outputs=outputs, output_dir=output_dir) return def create_argument_trace(): - parser = create_argument_parser() + parser = argparse.ArgumentParser(description="Benchmark LLM serving performance") + add_cli_args(parser) + trace_group = parser.add_argument_group("tracing parameters") trace_group.add_argument( "--trace-path", @@ -677,6 +734,14 @@ def main(args: argparse.Namespace): if __name__ == "__main__": + # Check openpyxl for Excel export + try: + import openpyxl + except ImportError: + print("\nMissing package: openpyxl") + print("Please install openpyxl via pip install.\n") + sys.exit(1) + parser = create_argument_trace() args = parser.parse_args() main(args) diff --git a/docker/Dockerfile b/docker/Dockerfile index 622d8fc1b..fdebcb997 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -6,15 +6,12 @@ ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" WORKDIR /workspace # Install unified-cache-management -COPY . /vllm-workspace/unified-cache-management +COPY . /workspace/unified-cache-management RUN pip config set global.index-url ${PIP_INDEX_URL} RUN export PLATFORM="cuda" && \ - pip install -v -e /vllm-workspace/unified-cache-management --no-build-isolation + pip install -v -e /workspace/unified-cache-management --no-build-isolation -# Apply patch for vLLM -RUN cd $(pip show vllm | grep Location | awk '{print $2}') \ - && git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch ENTRYPOINT ["/bin/bash"] \ No newline at end of file diff --git a/docker/Dockerfile-NPU b/docker/Dockerfile-NPU index e03c45dc3..270a63c95 100644 --- a/docker/Dockerfile-NPU +++ b/docker/Dockerfile-NPU @@ -6,22 +6,12 @@ ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" WORKDIR /workspace # Install unified-cache-management -COPY . /vllm-workspace/unified-cache-management +COPY . /workspace/unified-cache-management RUN pip config set global.index-url ${PIP_INDEX_URL} RUN export PLATFORM="ascend" && \ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ - pip install -v -e /vllm-workspace/unified-cache-management --no-build-isolation - -# Apply patch for vLLM -RUN cd /vllm-workspace/vllm \ - && git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch \ - && git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch - -# Apply patch for vLLM-Ascend -RUN cd /vllm-workspace/vllm-ascend \ - && git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch \ - && git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt-sparse.patch + pip install -v -e /workspace/unified-cache-management --no-build-isolation CMD ["/bin/bash"] \ No newline at end of file diff --git a/docs/source/_static/images/qrcode_for_wechat.png b/docs/source/_static/images/qrcode_for_wechat.png new file mode 100644 index 000000000..b456bdabd Binary files /dev/null and b/docs/source/_static/images/qrcode_for_wechat.png differ diff --git a/docs/source/getting-started/installation_gpu.md b/docs/source/getting-started/installation_gpu.md index 2ab51d705..eaf1d3d05 100644 --- a/docs/source/getting-started/installation_gpu.md +++ b/docs/source/getting-started/installation_gpu.md @@ -33,6 +33,13 @@ docker run \ ``` Refer to [Set up using docker](https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html#set-up-using-docker) for more information to run your own vLLM container. +### Install by pip +Install by pip or find the pre-build wheels on [Pypi](https://pypi.org/project/uc-manager/). +``` +pip install uc-manager +``` + + ### Build from source code Follow commands below to install unified-cache-management: @@ -44,17 +51,12 @@ export PLATFORM=cuda pip install -v -e . --no-build-isolation ``` -After installation, please apply patch to ensure uc_connector can be used: +**Note:** Patches are now applied automatically via dynamic patching when you import the unified-cache-management package. You no longer need to manually apply patches using `git apply`. The patches are automatically applied when you use the `UnifiedCacheConnectorV1` connector. -```bash -cd $(pip show vllm | grep Location | awk '{print $2}') -git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch -git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch -``` - -Refer to this [issue](https://github.com/vllm-project/vllm/issues/21702) to see details of this patch's changes. ## Setup from docker + +### Build image from source Download the pre-built `vllm/vllm-openai:v0.9.2` docker image and build unified-cache-management docker image by commands below: ```bash # Build docker image using source code, replace with the branch or tag name needed @@ -62,6 +64,14 @@ Download the pre-built `vllm/vllm-openai:v0.9.2` docker image and build unified- cd unified-cache-management docker build -t ucm-vllm:latest -f ./docker/Dockerfile ./ ``` + + +### Pre-built images + +```bash +docker pull unifiedcachemanager/ucm:latest +``` + Then run your container using following command. You can add or remove Docker parameters as needed. ```bash # Use `--ipc=host` to make sure the shared memory is large enough. diff --git a/docs/source/getting-started/installation_npu.md b/docs/source/getting-started/installation_npu.md index 0824de0c3..571e96e15 100644 --- a/docs/source/getting-started/installation_npu.md +++ b/docs/source/getting-started/installation_npu.md @@ -39,13 +39,10 @@ docker run --rm \ -v /root/.cache:/root/.cache \ -it $IMAGE bash ``` -Codes of vLLM and vLLM Ascend are placed in /vllm-workspace, you can refer to [vLLM-Ascend Installation](https://vllm-ascend.readthedocs.io/en/latest/installation.html) for more information. After installation, please apply patches to ensure uc_connector can be used: -```bash -cd /vllm-workspace/vllm -git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch -cd /vllm-workspace/vllm-ascend -git apply /vllm-workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch -``` +Codes of vLLM and vLLM Ascend are placed in /vllm-workspace, you can refer to [vLLM-Ascend Installation](https://vllm-ascend.readthedocs.io/en/latest/installation.html) for more information. + +**Note:** For vLLM and vLLM Ascend patches, they are now applied automatically via dynamic patching when you import the unified-cache-management package. + Refer to these issues [vllm-issue](https://github.com/vllm-project/vllm/issues/21702) and [vllm-ascend-issue](https://github.com/vllm-project/vllm-ascend/issues/2057) to see details of patches' changes. ### Build from source code @@ -60,7 +57,7 @@ cd .. ``` ## Setup from docker -Download the pre-built docker image provided or build unified-cache-management docker image by commands below: +Download the pre-built `vllm-ascend` docker image or build unified-cache-management docker image by commands below: ```bash # Build docker image using source code, replace with the branch or tag name needed git clone --depth 1 --branch https://github.com/ModelEngine-Group/unified-cache-management.git diff --git a/docs/source/getting-started/quick_start.md b/docs/source/getting-started/quick_start.md index 5ad56b289..098c2eeb7 100644 --- a/docs/source/getting-started/quick_start.md +++ b/docs/source/getting-started/quick_start.md @@ -21,14 +21,16 @@ Before you start with UCM, please make sure that you have installed UCM correctl ## Features Overview -UCM supports two key features: **Prefix Cache** and **GSA Sparsity**. +UCM supports two key features: **Prefix Cache** and **Sparse attention**. Each feature supports both **Offline Inference** and **Online API** modes. For quick start, just follow the [usage](./quick_start.md) guide below to launch your own inference experience; -For further research, click on the links blow to see more details of each feature: +For further research on Prefix Cache, more details are available via the link below: - [Prefix Cache](../user-guide/prefix-cache/index.md) + +Various Sparse Attention features are now available, try GSA Sparsity via the link below: - [GSA Sparsity](../user-guide/sparse-attention/gsa.md) ## Usage @@ -40,12 +42,14 @@ You can use our official offline example script to run offline inference as foll ```bash cd examples/ +# Change the model path to your own model path +export MODEL_PATH=/home/models/Qwen2.5-14B-Instruct python offline_inference.py ``` -
+
OpenAI-Compatible Online API For online inference , vLLM with our connector can also be deployed as a server that implements the OpenAI API protocol. @@ -55,10 +59,23 @@ First, specify the python hash seed by: export PYTHONHASHSEED=123456 ``` -Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model: +Create a config yaml like following and save it to your own directory: +```yaml +# UCM Configuration File Example +# Refer to file unified-cache-management/examples/ucm_config_example.yaml for more details +ucm_connector_name: "UcmNfsStore" + +ucm_connector_config: + storage_backends: "/mnt/test" +``` + +Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model and your config file path: ```bash -vllm serve /home/models/Qwen2.5-14B-Instruct \ +# Change the model path to your own model path +export MODEL_PATH=/home/models/Qwen2.5-14B-Instruct +vllm serve ${MODEL_PATH} \ +--served-model-name vllm_cpu_offload \ --max-model-len 20000 \ --tensor-parallel-size 2 \ --gpu_memory_utilization 0.87 \ @@ -66,15 +83,11 @@ vllm serve /home/models/Qwen2.5-14B-Instruct \ --port 7800 \ --kv-transfer-config \ '{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", "kv_role": "kv_both", "kv_connector_extra_config": { - "ucm_connector_name": "UcmDramStore", - "ucm_connector_config": { - "max_cache_size": 5368709120, - "kv_block_size": 262144 - } + "UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml" } }' ``` @@ -95,7 +108,7 @@ After successfully started the vLLM server,You can interact with the API as fo curl http://localhost:7800/v1/completions \ -H "Content-Type: application/json" \ -d '{ - "model": "/home/models/Qwen2.5-14B-Instruct", + "model": "vllm_cpu_offload", "prompt": "Shanghai is a", "max_tokens": 7, "temperature": 0 @@ -103,3 +116,4 @@ curl http://localhost:7800/v1/completions \ ```
+Note: If you want to disable vLLM prefix cache to test the cache ability of UCM, you can add `--no-enable-prefix-caching` to the command line. \ No newline at end of file diff --git a/docs/source/index.md b/docs/source/index.md index 2352d3996..69be815e6 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -57,6 +57,7 @@ getting-started/installation_npu user-guide/prefix-cache/index user-guide/sparse-attention/index user-guide/pd-disaggregation/index +user-guide/metrics/metrics ::: :::{toctree} diff --git a/docs/source/user-guide/metrics/metrics.md b/docs/source/user-guide/metrics/metrics.md new file mode 100644 index 000000000..22b532681 --- /dev/null +++ b/docs/source/user-guide/metrics/metrics.md @@ -0,0 +1,193 @@ +# Observability + +UCM (Unified Cache Management) provides detailed metrics monitoring through Prometheus endpoints, allowing in-depth monitoring of cache performance and behavior. This document describes how to enable and configure observability from the embedded vLLM `/metrics` API endpoint. + +--- + +## Quick Start Guide + +### 1) On UCM Side + +First, set the `PROMETHEUS_MULTIPROC_DIR` environment variable. + +```bash +export PROMETHEUS_MULTIPROC_DIR=/vllm-workspace +``` + +Then, start the UCM service. + +```bash +export CUDA_VISIBLE_DEVICES=0 +vllm serve /home/models/Qwen2.5-14B-Instruct \ + --max-model-len 5000 \ + --tensor-parallel-size 1 \ + --gpu_memory_utilization 0.87 \ + --trust-remote-code \ + --disable-log-requests \ + --no-enable-prefix-caching \ + --enforce-eager \ + --max-num-batched-tokens 40000 \ + --max-num-seqs 10 \ + --host 0.0.0.0 \ + --port 8000 \ + --kv-transfer-config \ + '{ + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "UCM_CONFIG_FILE": "/vllm-workspace/unified-cache-management/examples/ucm_config.yaml" + } + }' +``` +**Note**: You can refer to the `ucm_config.yaml` file at https://github.com/ModelEngine-Group/unified-cache-management/tree/develop/examples to configure the `metrics_config_path` parameter. + +You can use the `vllm bench serve` command to run benchmarks: + +```bash +vllm bench serve \ + --backend vllm \ + --model /home/models/Qwen2.5-14B-Instruct \ + --host 127.0.0.1 \ + --port 8000 \ + --dataset-name random \ + --num-prompts 20 \ + --random-input-len 200 \ + --random-output-len 10 \ + --request-rate 1 \ + --ignore-eos +``` + +Once the HTTP server is running, you can access the UCM metrics at the `/metrics` endpoint. + +```bash +curl http://$:8000/metrics | grep ucm: +``` + +You will also find some `.db` files in the `$PROMETHEUS_MULTIPROC_DIR` directory, which are temporary files used by Prometheus. + +### 2) Start Prometheus and Grafana with Docker Compose + +#### Create Docker Compose Configuration Files + +First, create the `docker-compose.yaml` file: + +```yaml +# docker-compose.yaml +version: "3" + +services: + prometheus: + image: prom/prometheus:latest + extra_hosts: + - "host.docker.internal:host-gateway" + ports: + - "9090:9090" + volumes: + - ${PWD}/prometheus.yaml:/etc/prometheus/prometheus.yml + + grafana: + image: grafana/grafana:latest + depends_on: + - prometheus + ports: + - "3000:3000" +``` + +Then, create the `prometheus.yaml` configuration file: + +```yaml +# prometheus.yaml +global: + scrape_interval: 5s + evaluation_interval: 30s + +scrape_configs: + - job_name: vllm + static_configs: + - targets: + - 'host.docker.internal:8000' +``` + +**Note**: Make sure the port number in `prometheus.yaml` matches the port number used when starting the vLLM service. + +#### Start Services + +Run the following command in the directory containing `docker-compose.yaml` and `prometheus.yaml`: + +```bash +docker compose up +``` + +This will start Prometheus and Grafana services. + +### 3) Configure Grafana Dashboard + +#### Access Grafana + +Navigate to `http://:3000`. Log in with the default username (`admin`) and password (`admin`). You will be prompted to change the password on first login. + +#### Add Prometheus Data Source + +1. Navigate to `http://:3000/connections/datasources/new` and select **Prometheus**. + +2. On the Prometheus configuration page, add the Prometheus server URL in the **Connection** section. For this Docker Compose setup, Grafana and Prometheus run in separate containers, but Docker creates DNS names for each container. You can directly use `http://prometheus:9090`. + +3. Click **Save & Test**. You should see a green checkmark showing "Successfully queried the Prometheus API." + +#### Import Dashboard + +1. Navigate to `http://:3000/dashboard/import`. + +2. Click **Upload JSON file**, then upload the `unified-cache-management/examples/metrics/grafana.json` file. + +3. Select the Prometheus data source configured earlier. + +4. Click **Import** to complete the import. + +You should now be able to see the UCM monitoring dashboard with real-time visualization of all 9 metrics. + +## Available Metrics + +UCM exposes various metrics to monitor its performance. The following table lists all available metrics organized by category: + +| Metric Name | Type | Description | +|------------|------|-------------| +| **Load Operation Metrics** | | | +| `ucm:load_requests_num` | Histogram | Number of requests loaded per `start_load_kv` call | +| `ucm:load_blocks_num` | Histogram | Number of blocks loaded per `start_load_kv` call | +| `ucm:load_duration` | Histogram | Time to load KV cache from UCM (milliseconds) | +| `ucm:load_speed` | Histogram | Speed of loading from UCM (GB/s) | +| **Save Operation Metrics** | | | +| `ucm:save_requests_num` | Histogram | Number of requests saved per `wait_for_save` call | +| `ucm:save_blocks_num` | Histogram | Number of blocks saved per `wait_for_save` call | +| `ucm:save_duration` | Histogram | Time to save to UCM (milliseconds) | +| `ucm:save_speed` | Histogram | Speed of saving to UCM (GB/s) | +| **Lookup Hit Rate Metrics** | | | +| `ucm:interval_lookup_hit_rates` | Histogram | Hit rate of UCM lookup requests | + +## Prometheus Configuration + +Metrics configuration is defined in the `unified-cache-management/examples/metrics/metrics_configs.yaml` file: + +```yaml +log_interval: 5 # Interval in seconds for logging metrics + +prometheus: + multiproc_dir: "/vllm-workspace" # Prometheus directory + metric_prefix: "ucm:" # Metric name prefix + + enabled_metrics: + counters: true + gauges: true + histograms: true + + histograms: + - name: "load_requests_num" + documentation: "Number of requests loaded from ucm" + buckets: [1, 5, 10, 20, 50, 100, 200, 500, 1000] + # ... other metric configurations +``` + +--- + diff --git a/docs/source/user-guide/pd-disaggregation/1p1d.md b/docs/source/user-guide/pd-disaggregation/1p1d.md index a6e2547f4..fb3f4d056 100644 --- a/docs/source/user-guide/pd-disaggregation/1p1d.md +++ b/docs/source/user-guide/pd-disaggregation/1p1d.md @@ -5,16 +5,17 @@ This example demonstrates how to run unified-cache-management with disaggregated ## Prerequisites - UCM: Installed with reference to the Installation documentation. -- Hardware: At least 2 GPUs +- Hardware: At least 2 GPUs or 2 NPUs ## Start disaggregated service -For illustration purposes, let us assume that the model used is Qwen2.5-7B-Instruct. +For illustration purposes, let us take GPU as an example and assume the model used is Qwen2.5-7B-Instruct.Using ASCEND_RT_VISIBLE_DEVICES instead of CUDA_VISIBLE_DEVICES to specify visible devices when starting service on Ascend platform. ### Run prefill server Prefiller Launch Command: ```bash export PYTHONHASHSEED=123456 -CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \ +export CUDA_VISIBLE_DEVICES=0 +vllm serve /home/models/Qwen2.5-7B-Instruct \ --max-model-len 20000 \ --tensor-parallel-size 1 \ --gpu_memory_utilization 0.87 \ @@ -41,8 +42,9 @@ CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \ ### Run decode server Decoder Launch Command: ```bash -export PYTHONHASHSEED=123456 -CUDA_VISIBLE_DEVICES=1 vllm serve /home/models/Qwen2.5-7B-Instruct \ +export PYTHONHASHSEED=123456 +export CUDA_VISIBLE_DEVICES=0 +vllm serve /home/models/Qwen2.5-7B-Instruct \ --max-model-len 20000 \ --tensor-parallel-size 1 \ --gpu_memory_utilization 0.87 \ @@ -68,8 +70,8 @@ CUDA_VISIBLE_DEVICES=1 vllm serve /home/models/Qwen2.5-7B-Instruct \ ### Run proxy server Make sure prefill nodes and decode nodes can connect to each other. ```bash -cd vllm-workspace/unified-cache-management/test/ -python3 toy_proxy_server.py --host localhost --port 7802 --prefiller-host --prefiller-port 7800 --decoder-host --decoder-port 7801 +cd /vllm-workspace/unified-cache-management/ucm/pd +python3 toy_proxy_server.py --pd-disaggregation --host localhost --port 7802 --prefiller-host --prefiller-port 7800 --decoder-host --decoder-port 7801 ``` ## Testing and Benchmarking @@ -88,8 +90,7 @@ curl http://localhost:7802/v1/completions \ ### Benchmark Test Use the benchmark scripts provided by vLLM. ```bash -cd /vllm-workspace/vllm/benchmarks -python3 benchmark_serving.py \ +vllm bench serve \ --backend vllm \ --dataset-name random \ --random-input-len 4096 \ diff --git a/docs/source/user-guide/pd-disaggregation/npgd.md b/docs/source/user-guide/pd-disaggregation/npgd.md index e7dd10bf7..c4919779a 100644 --- a/docs/source/user-guide/pd-disaggregation/npgd.md +++ b/docs/source/user-guide/pd-disaggregation/npgd.md @@ -50,7 +50,8 @@ vllm serve /home/models/Qwen2.5-7B-Instruct \ Decoder Launch Command: ```bash export PYTHONHASHSEED=123456 -CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \ +export CUDA_VISIBLE_DEVICES=0 +vllm serve /home/models/Qwen2.5-7B-Instruct \ --max-model-len 20000 \ --tensor-parallel-size 1 \ --gpu_memory_utilization 0.87 \ @@ -77,7 +78,7 @@ CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \ ### Run proxy server Make sure prefill nodes and decode nodes can connect to each other. ```bash -cd vllm-workspace/unified-cache-management/test/ +cd /vllm-workspace/unified-cache-management/ucm/pd python3 toy_proxy_server.py --host localhost --port 7802 --prefiller-host --prefiller-port 7800 --decoder-host --decoder-port 7801 ``` @@ -97,8 +98,7 @@ curl http://localhost:7802/v1/completions \ ### Benchmark Test Use the benchmark scripts provided by vLLM. ```bash -cd /vllm-workspace/vllm/benchmarks -python3 benchmark_serving.py \ +vllm bench serve \ --backend vllm \ --dataset-name random \ --random-input-len 4096 \ diff --git a/docs/source/user-guide/pd-disaggregation/xpyd.md b/docs/source/user-guide/pd-disaggregation/xpyd.md index c7fde1e9f..a57ab5d2f 100644 --- a/docs/source/user-guide/pd-disaggregation/xpyd.md +++ b/docs/source/user-guide/pd-disaggregation/xpyd.md @@ -5,15 +5,17 @@ This example demonstrates how to run unified-cache-management with disaggregated ## Prerequisites - UCM: Installed with reference to the Installation documentation. -- Hardware: At least 4 GPUs (At least 2 GPUs for prefiller + 2 for decoder in 2d2p setup) +- Hardware: At least 4 GPUs (At least 2 GPUs for prefiller + 2 for decoder in 2d2p setup or 2 NPUs for prefiller + 2 for decoder in 2d2p setup) ## Start disaggregated service -For illustration purposes, let us assume that the model used is Qwen2.5-7B-Instruct. +For illustration purposes, let us take GPU as an example and assume the model used is Qwen2.5-7B-Instruct.Using ASCEND_RT_VISIBLE_DEVICES instead of CUDA_VISIBLE_DEVICES to specify visible devices when starting service on Ascend platform. + ### Run prefill servers Prefiller1 Launch Command: ```bash export PYTHONHASHSEED=123456 -CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \ +export CUDA_VISIBLE_DEVICES=0 +vllm serve /home/models/Qwen2.5-7B-Instruct \ --max-model-len 20000 \ --tensor-parallel-size 1 \ --gpu_memory_utilization 0.87 \ @@ -40,7 +42,8 @@ CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \ Prefiller2 Launch Command: ```bash export PYTHONHASHSEED=123456 -CUDA_VISIBLE_DEVICES=1 vllm serve /home/models/Qwen2.5-7B-Instruct \ +export CUDA_VISIBLE_DEVICES=1 +vllm serve /home/models/Qwen2.5-7B-Instruct \ --max-model-len 20000 \ --tensor-parallel-size 1 \ --gpu_memory_utilization 0.87 \ @@ -68,7 +71,8 @@ CUDA_VISIBLE_DEVICES=1 vllm serve /home/models/Qwen2.5-7B-Instruct \ Decoder1 Launch Command: ```bash export PYTHONHASHSEED=123456 -CUDA_VISIBLE_DEVICES=2 vllm serve /home/models/Qwen2.5-7B-Instruct \ +export CUDA_VISIBLE_DEVICES=2 +vllm serve /home/models/Qwen2.5-7B-Instruct \ --max-model-len 20000 \ --tensor-parallel-size 1 \ --gpu_memory_utilization 0.87 \ @@ -94,7 +98,8 @@ CUDA_VISIBLE_DEVICES=2 vllm serve /home/models/Qwen2.5-7B-Instruct \ Decoder2 Launch Command: ```bash export PYTHONHASHSEED=123456 -CUDA_VISIBLE_DEVICES=3 vllm serve /home/models/Qwen2.5-7B-Instruct \ +export CUDA_VISIBLE_DEVICES=3 +vllm serve /home/models/Qwen2.5-7B-Instruct \ --max-model-len 20000 \ --tensor-parallel-size 1 \ --gpu_memory_utilization 0.87 \ @@ -121,8 +126,8 @@ CUDA_VISIBLE_DEVICES=3 vllm serve /home/models/Qwen2.5-7B-Instruct \ ### Run proxy server Make sure prefill nodes and decode nodes can connect to each other. the number of prefill/decode hosts should be equal to the number of prefill/decode ports. ```bash -cd vllm-workspace/unified-cache-management/test/ -python3 toy_proxy_server.py --host localhost --port 7805 --prefiller-hosts --prefiller-port 7800 7801 --decoder-hosts --decoder-ports 7802 7803 +cd /vllm-workspace/unified-cache-management/ucm/pd +python3 toy_proxy_server.py --pd-disaggregation --host localhost --port 7805 --prefiller-hosts --prefiller-port 7800 7801 --decoder-hosts --decoder-ports 7802 7803 ``` ## Testing and Benchmarking @@ -141,8 +146,7 @@ curl http://localhost:7805/v1/completions \ ### Benchmark Test Use the benchmark scripts provided by vLLM. ```bash -cd /vllm-workspace/vllm/benchmarks -python3 benchmark_serving.py \ +vllm bench serve \ --backend vllm \ --dataset-name random \ --random-input-len 4096 \ diff --git a/docs/source/user-guide/prefix-cache/dram_store.md b/docs/source/user-guide/prefix-cache/dram_store.md deleted file mode 100644 index 1be2f30a2..000000000 --- a/docs/source/user-guide/prefix-cache/dram_store.md +++ /dev/null @@ -1,133 +0,0 @@ -# DRAM Store - -This document provides a usage example and configuration guide for the **DRAM Connector**. This connector enables offloading of KV cache from GPU HBM to CPU DRAM, helping reduce memory pressure and supporting larger models or batch sizes. - -## Performance - -### Overview -The following are the multi-concurrency performance test results of UCM in the Prefix Cache scenario under a CUDA environment, showing the performance improvements of UCM on two different models. -During the tests, HBM cache was disabled, and KV Cache was retrieved and matched only from DRAM. - -In the QwQ-32B model, the test used one H20 server with 2 GPUs. - -Here, Full Compute refers to pure VLLM inference, while DRAM80% indicates that after UCM pooling, the DRAM hit rate of the KV cache is 80%. - -The following table shows the results on the QwQ-32B model: -| **QwQ-32B** | | | | | -| ---------------: | -------------: | ------------------: | -------------: | :----------- | -| **Input length** | **Concurrent** | **Full Compute(s)** | **DRAM80%(s)** | **Speedup** | -| 4 000 | 1 | 1.0269 | 0.3102 | **+230.9 %** | -| 8 000 | 1 | 2.0902 | 0.5718 | **+265.5 %** | -| 16 000 | 1 | 4.4852 | 1.1914 | **+276.4 %** | -| 4 000 | 2 | 1.5383 | 0.4209 | **+265.4 %** | -| 8 000 | 2 | 3.1323 | 0.8231 | **+280.5 %** | -| 16 000 | 2 | 6.7984 | 1.7420 | **+290.2 %** | -| 4 000 | 4 | 2.8173 | 0.9444 | **+198.2 %** | -| 8 000 | 4 | 5.2643 | 1.8290 | **+187.8 %** | -| 16 000 | 4 | 11.3651 | 3.6706 | **+209.6 %** | -## Features - -The DRAM connector supports the following functionalities: - -- `dump`: Offload KV cache blocks from HBM to DRAM. -- `load`: Load KV cache blocks from DRAM back to HBM. -- `lookup`: Look up KV blocks stored in DRAM by block hash. -- `wait`: Ensure that all copy streams between CPU and GPU have completed. -- `commit`: Mark cache operations as complete and ready for reuse. - -## Configuration - -To use the DRAM connector, you need to configure the `connector_config` dictionary in your model's launch configuration. - -### Required Parameters - -- `max_cache_size` *(optional)*: - Specifies the maximum allowed DRAM memory usage (in **bytes**) for caching in `kv_connector_extra_config["ucm_connector_config"]`. - If not provided, it defaults to **5 GB**. -- `kv_block_size` *(optional)*: - Specifies the memory size (in **bytes**) of a single key or value cache block used in vLLM’s paged attention mechanism, which is calculated as : `block_size * head_size * total_num_kv_heads * element_size`. - -### Example: - -```python -# Allocate up to 8GB DRAM for KV cache -# KV Block size (in byte) is 262144 -kv_connector_extra_config={"ucm_connector_name": "UcmDramStore", "ucm_connector_config":{"max_cache_size": 5368709120, "kv_block_size": 262144}} -``` - -## Launching Inference - -### Offline Inference - -To start **offline inference** with the DRAM connector,modify the script `examples/offline_inference.py` to include the `kv_connector_extra_config` for DRAM connector usage: - -```python -# In examples/offline_inference.py -ktc = KVTransferConfig( - ... - kv_connector_extra_config={"ucm_connector_name": "UcmDramStore", "ucm_connector_config":{"max_cache_size": 5368709120, "kv_block_size": 262144}} -) -``` - -Then run the script as follows: - -```bash -cd examples/ -python offline_inference.py -``` - -### Online Inference - -For **online inference** , vLLM with our connector can also be deployed as a server that implements the OpenAI API protocol. - -First, specify the python hash seed by: -```bash -export PYTHONHASHSEED=123456 -``` - -Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model: - -```bash -vllm serve /home/models/Qwen2.5-14B-Instruct \ ---max-model-len 20000 \ ---tensor-parallel-size 2 \ ---gpu_memory_utilization 0.87 \ ---trust-remote-code \ ---port 7800 \ ---kv-transfer-config \ -'{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", - "kv_role": "kv_both", - "kv_connector_extra_config": { - "ucm_connector_name": "UcmDramStore", - "ucm_connector_config": { - "max_cache_size": 5368709120, - "kv_block_size": 262144 - } - } -}' -``` - -If you see log as below: - -```bash -INFO: Started server process [32890] -INFO: Waiting for application startup. -INFO: Application startup complete. -``` - -Congratulations, you have successfully started the vLLM server with DRAM Connector! - -After successfully started the vLLM server,You can interact with the API as following: - -```bash -curl http://localhost:7800/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "/home/models/Qwen2.5-14B-Instruct", - "prompt": "Shanghai is a", - "max_tokens": 7, - "temperature": 0 - }' -``` diff --git a/docs/source/user-guide/prefix-cache/index.md b/docs/source/user-guide/prefix-cache/index.md index defe27d38..ba3d16bef 100644 --- a/docs/source/user-guide/prefix-cache/index.md +++ b/docs/source/user-guide/prefix-cache/index.md @@ -79,6 +79,5 @@ performance. :::{toctree} :maxdepth: 1 -dram_store nfs_store ::: \ No newline at end of file diff --git a/docs/source/user-guide/prefix-cache/nfs_store.md b/docs/source/user-guide/prefix-cache/nfs_store.md index b581acf56..741fcedf7 100644 --- a/docs/source/user-guide/prefix-cache/nfs_store.md +++ b/docs/source/user-guide/prefix-cache/nfs_store.md @@ -87,8 +87,15 @@ To use the NFS connector, you need to configure the `connector_config` dictionar ### Example: -```python -kv_connector_extra_config={"ucm_connector_name": "UcmNfsStore", "ucm_connector_config":{"storage_backends": "/mnt/test1", "transferStreamNumber": 32}} +Create a config yaml like following and save it to your own directory: +```yaml +# UCM Configuration File Example +# Refer to file unified-cache-management/examples/ucm_config_example.yaml for more details +ucm_connector_name: "UcmNfsStore" + +ucm_connector_config: + storage_backends: "/mnt/test" + transferStreamNumber: 32 ``` ## Launching Inference @@ -101,7 +108,7 @@ To start **offline inference** with the NFS connector,modify the script `examp # In examples/offline_inference.py ktc = KVTransferConfig( ... - kv_connector_extra_config={"ucm_connector_name": "UcmNfsStore", "ucm_connector_config":{"storage_backends": "/mnt/test1", "transferStreamNumber": 32}} + kv_connector_extra_config={"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} ) ``` @@ -131,13 +138,7 @@ vllm serve /home/models/Qwen2.5-14B-Instruct \ "kv_connector": "UnifiedCacheConnectorV1", "kv_connector_module_path": "ucm.integration.vllm.uc_connector", "kv_role": "kv_both", - "kv_connector_extra_config": { - "ucm_connector_name": "UcmNfsStore", - "ucm_connector_config": { - "storage_backends": "/mnt/test", - "transferStreamNumber":32 - } - } + "kv_connector_extra_config": {"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} }' ``` diff --git a/docs/source/user-guide/sparse-attention/esa.md b/docs/source/user-guide/sparse-attention/esa.md index 3d42ba52e..53beadf10 100644 --- a/docs/source/user-guide/sparse-attention/esa.md +++ b/docs/source/user-guide/sparse-attention/esa.md @@ -9,6 +9,7 @@ ESA provides developers with an intuitive example of how to implement their own ### Basic Usage ESA can be launched using the following command: ```shell +export ENABLE_SPARSE=TRUE export MODEL_PATH="/path/to/model" # For example: /home/models/Qwen2.5-14B-Instruct export DATASET_PATH="/path/to/longbench/multifieldqa_zh.jsonl" # For example: /home/data/Longbench/data/multifieldqa_zh.jsonl python examples/offline_inference_esa.py @@ -31,8 +32,8 @@ ktc = KVTransferConfig( "init_window_sz": 1, "local_window_sz": 2, "min_blocks": 4, - "sparse_ratio": 0.3, - "retrieval_stride": 5, + "sparse_ratio": 0.2, + "retrieval_stride": 10, } }, }, @@ -80,8 +81,8 @@ The following results were obtained using `Qwen2.5-14B-Instruct` under the speci "init_window_sz": 1, "local_window_sz": 2, "min_blocks": 4, - "sparse_ratio": 0.3, - "retrieval_stride": 5 + "sparse_ratio": 0.2, + "retrieval_stride": 10 } }, ``` @@ -92,5 +93,5 @@ The following results were obtained using `Qwen2.5-14B-Instruct` under the speci We use [LongBench](https://huggingface.co/datasets/zai-org/LongBench) to evaluate the accuracy of the ESA algorithm. | Dataset | F1-Score | |-------|-----------| -| multifieldqa_zh | 59.4 | -| dureader | 26.4 | \ No newline at end of file +| multifieldqa_zh | 64.28 | +| dureader | 28.73 | diff --git a/docs/source/user-guide/sparse-attention/gsa.md b/docs/source/user-guide/sparse-attention/gsa.md index 327fe0769..5a96287a3 100644 --- a/docs/source/user-guide/sparse-attention/gsa.md +++ b/docs/source/user-guide/sparse-attention/gsa.md @@ -107,6 +107,8 @@ ktc = KVTransferConfig( Thus, an example command for launching the online LLM service is as follows: ```shell +export ENABLE_SPARSE=TRUE + vllm serve /home/models/DeepSeek-R1-Distill-Qwen-32B \ --served-model-name DeepSeek-R1-Distill-Qwen-32B \ --max-model-len 131000 \ diff --git a/docs/source/user-guide/sparse-attention/kvcomp.md b/docs/source/user-guide/sparse-attention/kvcomp.md index 3c1d0b238..4e0cbc715 100644 --- a/docs/source/user-guide/sparse-attention/kvcomp.md +++ b/docs/source/user-guide/sparse-attention/kvcomp.md @@ -97,6 +97,7 @@ This design ensures both **efficiency** and **accuracy** by preserving essential KVComp is part of the UCM Sparse Attention module. For installation instructions, please refer to the [UCM's top-level README](https://github.com/ModelEngine-Group/unified-cache-management). Once UCM is installed, KVComp is naturally supported by running the following example python scripts. ```bash +export ENABLE_SPARSE=TRUE python ucm/sandbox/sparse/kvcomp/offline_inference_kvcomp.py ``` diff --git a/docs/source/user-guide/sparse-attention/kvstar.md b/docs/source/user-guide/sparse-attention/kvstar.md index cf6222158..41d13358b 100644 --- a/docs/source/user-guide/sparse-attention/kvstar.md +++ b/docs/source/user-guide/sparse-attention/kvstar.md @@ -32,6 +32,7 @@ For long-sequence inference, KVstar achieves the following with minimal accuracy ### Basic Usage KVstar can be launched using the following command: ```shell +export ENABLE_SPARSE=TRUE export MODEL_PATH="/path/to/model" # For example: /home/models/Qwen2.5-14B-Instruct export DATASET_PATH="/path/to/longbench/multifieldqa_zh.jsonl" # For example: /home/data/Longbench/data/multifieldqa_zh.jsonl export DATA_DIR="/path/to/data" diff --git a/examples/metrics/grafana.json b/examples/metrics/grafana.json new file mode 100644 index 000000000..72d175971 --- /dev/null +++ b/examples/metrics/grafana.json @@ -0,0 +1,1025 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "target": { + "limit": 100, + "matchAny": false, + "tags": [], + "type": "dashboard" + }, + "type": "dashboard" + } + ] + }, + "description": "Monitoring UnifiedCache Connector Service (load/save)", + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": 1, + "links": [], + "liveNow": false, + "panels": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Hit rates of ucm lookup requests", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "yellow", + "value": 5000 + }, + { + "color": "red", + "value": 10000 + } + ] + }, + "unit": "" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 14, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(ucm:interval_lookup_hit_rates_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:interval_lookup_hit_rates_count{model_name=\"$model_name\"}[$__rate_interval])", + "hide": false, + "instant": false, + "legendFormat": "Average", + "range": true, + "refId": "A" + } + ], + "title": "Connector Interval Lookup Hit Rates", + "type": "timeseries" + }, + + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Number of load requests each start_load_kv.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "yellow", + "value": 5000 + }, + { + "color": "red", + "value": 10000 + } + ] + }, + "unit": "" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 0 + }, + "id": 15, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(ucm:load_requests_num_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:load_requests_num_count{model_name=\"$model_name\"}[$__rate_interval])", + "hide": false, + "instant": false, + "legendFormat": "worker-{{worker_id}}", + "range": true, + "refId": "A" + } + ], + "title": "Connector Load Requests Num", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Number of load blocks each start_load_kv.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "yellow", + "value": 5000 + }, + { + "color": "red", + "value": 10000 + } + ] + }, + "unit": "" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 0 + }, + "id": 16, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(ucm:load_blocks_num_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:load_blocks_num_count{model_name=\"$model_name\"}[$__rate_interval])", + "hide": false, + "instant": false, + "legendFormat": "worker-{{worker_id}}", + "range": true, + "refId": "A" + } + ], + "title": "Connector Load Blocks Num", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "P50, P90, P95, P99 and Average load duration in milliseconds for each start_load_kv.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "yellow", + "value": 5000 + }, + { + "color": "red", + "value": 10000 + } + ] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 8 + }, + "id": 17, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(ucm:load_duration_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:load_duration_count{model_name=\"$model_name\"}[$__rate_interval])", + "hide": false, + "instant": false, + "legendFormat": "worker-{{worker_id}}", + "range": true, + "refId": "A" + } + ], + "title": "Connector Load Duration", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "P50, P90, P95, P99 and Average load speed in GB/s for each start_load_kv.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "red", + "value": null + }, + { + "color": "yellow", + "value": 0.005 + }, + { + "color": "green", + "value": 0.01 + } + ] + }, + "unit": "gb/s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 8 + }, + "id": 21, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(ucm:load_speed_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:load_speed_count{model_name=\"$model_name\"}[$__rate_interval])", + "hide": false, + "instant": false, + "legendFormat": "worker-{{worker_id}}", + "range": true, + "refId": "A" + } + ], + "title": "Connector Load Speed", + "type": "timeseries" + }, + + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Number of save requests each wait_for_save.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "yellow", + "value": 5000 + }, + { + "color": "red", + "value": 10000 + } + ] + }, + "unit": "" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 16 + }, + "id": 19, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(ucm:save_requests_num_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:save_requests_num_count{model_name=\"$model_name\"}[$__rate_interval])", + "hide": false, + "instant": false, + "legendFormat": "worker-{{worker_id}}", + "range": true, + "refId": "A" + } + ], + "title": "Connector Save Requests Num", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Number of save blocks each wait_for_save.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "yellow", + "value": 5000 + }, + { + "color": "red", + "value": 10000 + } + ] + }, + "unit": "" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 16 + }, + "id": 20, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(ucm:save_blocks_num_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:save_blocks_num_count{model_name=\"$model_name\"}[$__rate_interval])", + "hide": false, + "instant": false, + "legendFormat": "worker-{{worker_id}}", + "range": true, + "refId": "A" + } + ], + "title": "Connector Save Blocks Num", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "P50, P90, P95, P99 and Average save duration in milliseconds for each save_kv.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "yellow", + "value": 8000 + }, + { + "color": "red", + "value": 15000 + } + ] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 24 + }, + "id": 18, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(ucm:save_duration_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:save_duration_count{model_name=\"$model_name\"}[$__rate_interval])", + "hide": false, + "instant": false, + "legendFormat": "worker-{{worker_id}}", + "range": true, + "refId": "A" + } + ], + "title": "Connector Save Duration", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "P50, P90, P95, P99 and Average save speed in GB/s for each save_kv.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "red", + "value": null + }, + { + "color": "yellow", + "value": 0.004 + }, + { + "color": "green", + "value": 0.008 + } + ] + }, + "unit": "gb/s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 24 + }, + "id": 22, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(ucm:save_speed_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:save_speed_count{model_name=\"$model_name\"}[$__rate_interval])", + "hide": false, + "instant": false, + "legendFormat": "worker-{{worker_id}}", + "range": true, + "refId": "A" + } + ], + "title": "Connector Save Speed", + "type": "timeseries" + } + ], + "refresh": "", + "schemaVersion": 39, + "tags": [], + "templating": { + "list": [ + { + "current": { + "selected": false, + "text": "prometheus", + "value": "edx8memhpd9tsa" + }, + "hide": 0, + "includeAll": false, + "label": "datasource", + "multi": false, + "name": "DS_PROMETHEUS", + "options": [], + "query": "prometheus", + "queryValue": "", + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "type": "datasource" + }, + { + "current": { + "selected": false, + "text": "/share/datasets/public_models/Meta-Llama-3-8B-Instruct", + "value": "/share/datasets/public_models/Meta-Llama-3-8B-Instruct" + }, + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "definition": "label_values(model_name)", + "hide": 0, + "includeAll": false, + "label": "model_name", + "multi": false, + "name": "model_name", + "options": [], + "query": { + "query": "label_values(model_name)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "type": "query" + } + ] + }, + "time": { + "from": "now-5m", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "vLLM - UnifiedCache Connector Monitoring", + "uid": "b281712d-8bff-41ef-9f3f-71ad43c05e9b", + "version": 9, + "weekStart": "" +} \ No newline at end of file diff --git a/examples/metrics/metrics_configs.yaml b/examples/metrics/metrics_configs.yaml new file mode 100644 index 000000000..5ed07baa9 --- /dev/null +++ b/examples/metrics/metrics_configs.yaml @@ -0,0 +1,56 @@ +# Prometheus Metrics Configuration +# This file defines which metrics should be enabled and their configurations +log_interval: 5 # Interval in seconds for logging metrics + +prometheus: + multiproc_dir: "/vllm-workspace" # Directory for Prometheus multiprocess mode + + metric_prefix: "ucm:" + + # Enable/disable metrics by category + enabled_metrics: + counters: true + gauges: true + histograms: true + + # Counter metrics configuration + # counters: + # - name: "received_requests" + # documentation: "Total number of requests sent to ucm" + + # Gauge metrics configuration + # gauges: + # - name: "lookup_hit_rate" + # documentation: "Hit rate of ucm lookup requests since last log" + # multiprocess_mode: "livemostrecent" + + # Histogram metrics configuration + histograms: + - name: "load_requests_num" + documentation: "Number of requests loaded from ucm" + buckets: [1, 5, 10, 20, 50, 100, 200, 500, 1000] + - name: "load_blocks_num" + documentation: "Number of blocks loaded from ucm" + buckets: [0, 50, 100, 150, 200, 250, 300, 350, 400, 550, 600, 750, 800, 850, 900, 950, 1000] + - name: "load_duration" + documentation: "Time to load from ucm (ms)" + buckets: [0, 50, 100, 150, 200, 250, 300, 350, 400, 550, 600, 750, 800, 850, 900, 950, 1000] + - name: "load_speed" + documentation: "Speed of loading from ucm (GB/s)" + buckets: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 50, 60, 70, 80, 90, 100] + - name: "save_requests_num" + documentation: "Number of requests saved to ucm" + buckets: [1, 5, 10, 20, 50, 100, 200, 500, 1000] + - name: "save_blocks_num" + documentation: "Number of blocks saved to ucm" + buckets: [0, 50, 100, 150, 200, 250, 300, 350, 400, 550, 600, 750, 800, 850, 900, 950, 1000] + - name: "save_duration" + documentation: "Time to save to ucm (ms)" + buckets: [0, 50, 100, 150, 200, 250, 300, 350, 400, 550, 600, 750, 800, 850, 900, 950, 1000] + - name: "save_speed" + documentation: "Speed of saving to ucm (GB/s)" + buckets: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 50, 60, 70, 80, 90, 100] + - name: "interval_lookup_hit_rates" + documentation: "Hit rates of ucm lookup requests" + buckets: [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + diff --git a/examples/offline_inference.py b/examples/offline_inference.py index f50682464..5a2fea372 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,5 +1,4 @@ import contextlib -import json import os import time from dataclasses import asdict @@ -16,11 +15,6 @@ logger = init_logger(__name__) -def setup_environment_variables(): - os.environ["VLLM_USE_V1"] = "1" - os.environ["PYTHONHASHSEED"] = "123456" - - @contextlib.contextmanager def build_llm_with_uc(module_path: str, name: str, model: str): ktc = KVTransferConfig( @@ -28,20 +22,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str): kv_connector_module_path=module_path, kv_role="kv_both", kv_connector_extra_config={ - "ucm_connector_name": "UcmDramStore", - "ucm_connector_config": { - "max_cache_size": 5368709120, - "kv_block_size": 262144, - }, - "ucm_sparse_config": { - "ESA": { - "init_window_sz": 1, - "local_window_sz": 2, - "min_blocks": 4, - "sparse_ratio": 0.3, - "retrieval_stride": 5, - } - }, + "UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml" }, ) @@ -53,6 +34,8 @@ def build_llm_with_uc(module_path: str, name: str, model: str): max_num_batched_tokens=30000, block_size=128, enforce_eager=True, + trust_remote_code=True, + enable_prefix_caching=False, ) llm = LLM(**asdict(llm_args)) @@ -79,22 +62,41 @@ def print_output( def main(): - module_path = "ucm.integration.vllm.uc_connector" - name = "UnifiedCacheConnectorV1" - model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct") + module_path = "ucm.integration.vllm.ucm_connector" + name = "UCMConnector" + model = os.getenv("MODEL_PATH", "/home/models/DeepSeek-V2-Lite") tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True) - setup_environment_variables() with build_llm_with_uc(module_path, name, model) as llm: messages = [ { "role": "system", - "content": "You are a highly specialized assistant whose mission is to faithfully reproduce English literary texts verbatim, without any deviation, paraphrasing, or omission. Your primary responsibility is accuracy: every word, every punctuation mark, and every line must appear exactly as in the original source. Core Principles: Verbatim Reproduction: If the user asks for a passage, you must output the text word-for-word. Do not alter spelling, punctuation, capitalization, or line breaks. Do not paraphrase, summarize, modernize, or “improve” the language. Consistency: The same input must always yield the same output. Do not generate alternative versions or interpretations. Clarity of Scope: Your role is not to explain, interpret, or critique. You are not a storyteller or commentator, but a faithful copyist of English literary and cultural texts. Recognizability: Because texts must be reproduced exactly, they will carry their own cultural recognition. You should not add labels, introductions, or explanations before or after the text. Coverage: You must handle passages from classic literature, poetry, speeches, or cultural texts. Regardless of tone—solemn, visionary, poetic, persuasive—you must preserve the original form, structure, and rhythm by reproducing it precisely. Success Criteria: A human reader should be able to compare your output directly with the original and find zero differences. The measure of success is absolute textual fidelity. Your function can be summarized as follows: verbatim reproduction only, no paraphrase, no commentary, no embellishment, no omission.", + "content": "You are a highly specialized assistant whose mission is to faithfully reproduce English " + "literary texts verbatim, without any deviation, paraphrasing, or omission. Your primary " + "responsibility is accuracy: every word, every punctuation mark, and every line must " + "appear exactly as in the original source. Core Principles: Verbatim Reproduction: If the " + "user asks for a passage, you must output the text word-for-word. Do not alter spelling, " + "punctuation, capitalization, or line breaks. Do not paraphrase, summarize, modernize, " + "or “improve” the language. Consistency: The same input must always yield the same output. " + "Do not generate alternative versions or interpretations. Clarity of Scope: Your role is " + "not to explain, interpret, or critique. You are not a storyteller or commentator, " + "but a faithful copyist of English literary and cultural texts. Recognizability: Because " + "texts must be reproduced exactly, they will carry their own cultural recognition. You " + "should not add labels, introductions, or explanations before or after the text. Coverage: " + "You must handle passages from classic literature, poetry, speeches, or cultural texts. " + "Regardless of tone—solemn, visionary, poetic, persuasive—you must preserve the original " + "form, structure, and rhythm by reproducing it precisely. Success Criteria: A human reader " + "should be able to compare your output directly with the original and find zero " + "differences. The measure of success is absolute textual fidelity. Your function can be " + "summarized as follows: verbatim reproduction only, no paraphrase, no commentary, " + "no embellishment, no omission.", }, { "role": "user", - "content": "Please reproduce verbatim the opening sentence of the United States Declaration of Independence (1776), starting with 'When in the Course of human events' and continuing word-for-word without paraphrasing.", + "content": "Please reproduce verbatim the opening sentence of the United States Declaration of " + "Independence (1776), starting with 'When in the Course of human events' and continuing " + "word-for-word without paraphrasing.", }, ] diff --git a/examples/offline_inference_esa.py b/examples/offline_inference_esa.py index caae0970d..852a8ca02 100644 --- a/examples/offline_inference_esa.py +++ b/examples/offline_inference_esa.py @@ -24,6 +24,7 @@ def setup_environment_variables(): os.environ["VLLM_USE_V1"] = "1" os.environ["PYTHONHASHSEED"] = "123456" + os.environ["ENABLE_SPARSE"] = "true" global model, path_to_dataset, data_dir, tokenizer model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct") @@ -45,10 +46,10 @@ def setup_environment_variables(): sys.exit(1) data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache") - data_dir = input( - "Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: " - ) if not os.path.isdir(data_dir): + data_dir = input( + "Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: " + ) create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ") if create.lower() == "y": os.makedirs(data_dir, exist_ok=True) @@ -66,11 +67,15 @@ def build_llm_with_uc(module_path: str, name: str, model: str): kv_connector_module_path=module_path, kv_role="kv_both", kv_connector_extra_config={ - "ucm_connector_name": "UcmNfsStore", - "ucm_connector_config": { - "storage_backends": data_dir, - "kv_block_size": 33554432, - }, + "ucm_connectors": [ + { + "ucm_connector_name": "UcmNfsStore", + "ucm_connector_config": { + "storage_backends": data_dir, + "use_direct": False, + }, + } + ], "ucm_sparse_config": { "ESA": { "init_window_sz": 1, @@ -87,12 +92,13 @@ def build_llm_with_uc(module_path: str, name: str, model: str): model=model, kv_transfer_config=ktc, max_model_len=32768, - gpu_memory_utilization=0.6, + gpu_memory_utilization=0.8, max_num_batched_tokens=30000, block_size=128, enforce_eager=True, distributed_executor_backend="mp", tensor_parallel_size=1, + trust_remote_code=True, ) llm = LLM(**asdict(llm_args)) @@ -111,16 +117,20 @@ def print_output( start = time.time() outputs = llm.generate(prompt, sampling_params) print("-" * 50) + lines = [] for output in outputs: generated_text = output.outputs[0].text print(f"Generated text: {generated_text!r}") + lines.append(generated_text + "\n") print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.") + with open("./newest_out.txt", "w") as f: + f.writelines(lines) print("-" * 50) def main(): - module_path = "ucm.integration.vllm.uc_connector" - name = "UnifiedCacheConnectorV1" + module_path = "ucm.integration.vllm.ucm_connector" + name = "UCMConnector" setup_environment_variables() def get_prompt(prompt): @@ -140,24 +150,23 @@ def get_prompt(prompt): with build_llm_with_uc(module_path, name, model) as llm: prompts = [] - batch_size = 5 + batch_size = 20 assert os.path.isfile( path_to_dataset ), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`" with open(path_to_dataset, "r") as f: - for _ in range(batch_size): - line = f.readline() - if not line: - break - data = json.loads(line) - context = data["context"] - question = data["input"] - prompts.append(get_prompt(f"{context}\n\n{question}")) - - sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=100) + lines = f.readlines() + for i in range(batch_size): + line = lines[i] + data = json.loads(line) + prompt = f"""阅读以下文字并用中文简短回答:\n\n{data["context"]}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{data["input"]}\n回答:""" + prompts.append(get_prompt(prompt)) + + sampling_params = SamplingParams( + temperature=0, top_p=0.95, max_tokens=256, ignore_eos=False + ) print_output(llm, prompts, sampling_params, "first") - print_output(llm, prompts, sampling_params, "second") if __name__ == "__main__": diff --git a/examples/offline_inference_kvcomp.py b/examples/offline_inference_kvcomp.py new file mode 100644 index 000000000..f1cd6c196 --- /dev/null +++ b/examples/offline_inference_kvcomp.py @@ -0,0 +1,172 @@ +import contextlib +import json +import os +import sys +import time +from dataclasses import asdict + +from transformers import AutoTokenizer + +# Third Party +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig +from vllm.engine.arg_utils import EngineArgs + +from ucm.logger import init_logger + +logger = init_logger(__name__) +model = "" +path_to_dataset = "" +data_dir = "" +tokenizer = None + + +def setup_environment_variables(): + os.environ["VLLM_USE_V1"] = "1" + os.environ["PYTHONHASHSEED"] = "123456" + os.environ["ENABLE_SPARSE"] = "true" + + global model, path_to_dataset, data_dir, tokenizer + model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct") + if not os.path.isdir(model): + model = input("Enter path to model, e.g. /home/models/Qwen2.5-14B-Instruct: ") + if not os.path.isdir(model): + print("Exiting. Incorrect model_path") + sys.exit(1) + + path_to_dataset = os.getenv( + "DATASET_PATH", "/home/data/Longbench/data/multifieldqa_zh.jsonl" + ) + if not os.path.isfile(path_to_dataset): + path_to_dataset = input( + "Enter path to one of the longbench dataset, e.g. /home/data/Longbench/data/multifieldqa_zh.jsonl: " + ) + if not os.path.isfile(path_to_dataset): + print("Exiting. Incorrect dataset path") + sys.exit(1) + + data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache") + data_dir = input( + "Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: " + ) + if not os.path.isdir(data_dir): + create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ") + if create.lower() == "y": + os.makedirs(data_dir, exist_ok=True) + else: + print("Exiting. Directory not created.") + sys.exit(1) + + tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True) + + +@contextlib.contextmanager +def build_llm_with_uc(module_path: str, name: str, model: str): + ktc = KVTransferConfig( + kv_connector=name, + kv_connector_module_path=module_path, + kv_role="kv_both", + kv_connector_extra_config={ + "ucm_connectors": [ + { + "ucm_connector_name": "UcmNfsStore", + "ucm_connector_config": { + "storage_backends": data_dir, + "use_direct": False, + }, + } + ], + "ucm_sparse_config": { + "KvComp": { + "init_window_sz": 1, + "local_window_sz": 2, + "min_blocks": 4, + "sparse_ratio": 0.3, + "retrieval_stride": 5, + } + }, + # "kvcomp_config_path": "unified-cache-management/ucm/sparse/kvcomp/configs/kvcomp_deepseek_v2_lite_config.json", + "kvcomp_config_path": "unified-cache-management/ucm/sparse/kvcomp/configs/kvcomp_qwq_32B_config.json", + }, + ) + + llm_args = EngineArgs( + model=model, + kv_transfer_config=ktc, + max_model_len=32768, + gpu_memory_utilization=0.8, + max_num_batched_tokens=30000, + block_size=128, + enforce_eager=True, + distributed_executor_backend="mp", + tensor_parallel_size=2, + trust_remote_code=True, + ) + + llm = LLM(**asdict(llm_args)) + try: + yield llm + finally: + logger.info("LLM engine is exiting.") + + +def print_output( + llm: LLM, + prompt: list[str], + sampling_params: SamplingParams, + req_str: str, +): + start = time.time() + outputs = llm.generate(prompt, sampling_params) + print("-" * 50) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.") + print("-" * 50) + + +def main(): + module_path = "ucm.integration.vllm.ucm_connector" + name = "UCMConnector" + setup_environment_variables() + + def get_prompt(prompt): + messages = [ + { + "role": "system", + "content": "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:“全国美国文学研究会的第十八届年会在哪所大学举办的?”\n回答应该为:“xx大学”。\n\n", + }, + {"role": "user", "content": prompt}, + ] + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + add_special_tokens=True, + ) + + with build_llm_with_uc(module_path, name, model) as llm: + prompts = [] + batch_size = 10 + assert os.path.isfile( + path_to_dataset + ), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`" + with open(path_to_dataset, "r") as f: + lines = f.readlines() + for i in range(batch_size): + line = lines[i] + data = json.loads(line) + prompt = f"""阅读以下文字并用中文简短回答:\n\n{data["context"]}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{data["input"]}\n回答:""" + prompts.append(get_prompt(prompt)) + + sampling_params = SamplingParams( + temperature=0, top_p=0.95, max_tokens=256, ignore_eos=False + ) + + print_output(llm, prompts, sampling_params, "first") + print_output(llm, prompts, sampling_params, "second") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference_kvstar.py b/examples/offline_inference_kvstar.py index e26113993..c8dfd1eec 100644 --- a/examples/offline_inference_kvstar.py +++ b/examples/offline_inference_kvstar.py @@ -24,6 +24,7 @@ def setup_environment_variables(): os.environ["VLLM_USE_V1"] = "1" os.environ["PYTHONHASHSEED"] = "123456" + os.environ["ENABLE_SPARSE"] = "true" global model, path_to_dataset, data_dir, tokenizer model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct") @@ -67,11 +68,15 @@ def build_llm_with_uc(module_path: str, name: str, model: str): kv_connector_module_path=module_path, kv_role="kv_both", kv_connector_extra_config={ - "ucm_connector_name": "UcmNfsStore", - "ucm_connector_config": { - "storage_backends": data_dir, - "kv_block_size": 33554432, - }, + "ucm_connectors": [ + { + "ucm_connector_name": "UcmNfsStore", + "ucm_connector_config": { + "storage_backends": data_dir, + "use_direct": False, + }, + } + ], "ucm_sparse_config": { "KVStarMultiStep": { "init_window_sz": 1, @@ -122,8 +127,8 @@ def print_output( def main(): - module_path = "ucm.integration.vllm.uc_connector" - name = "UnifiedCacheConnectorV1" + module_path = "ucm.integration.vllm.ucm_connector" + name = "UCMConnector" setup_environment_variables() def get_prompt(prompt): diff --git a/examples/ucm_config_example.yaml b/examples/ucm_config_example.yaml new file mode 100644 index 000000000..32c357e1b --- /dev/null +++ b/examples/ucm_config_example.yaml @@ -0,0 +1,38 @@ +# UCM Configuration File Example +# +# This file demonstrates how to configure UCM using YAML. +# You can use this config file by setting the path to this file in kv_connector_extra_config in launch script or command line like this: +# kv_connector_extra_config={"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} +# +# Alternatively, you can still use kv_connector_extra_config in KVTransferConfig +# for backward compatibility. + +# Connector name (e.g., "UcmNfsStore", "UcmDramStore") +ucm_connectors: + - ucm_connector_name: "UcmNfsStore" + ucm_connector_config: + storage_backends: "/mnt/test" + use_direct: false + +load_only_first_rank: false + +# Enable UCM metrics so they can be monitored online via Grafana and Prometheus. +# metrics_config_path: "/workspace/unified-cache-management/examples/metrics/metrics_configs.yaml" + +# Sparse attention configuration +# Format 1: Dictionary format (for methods like ESA, KvComp) +# ucm_sparse_config: +# ESA: +# init_window_sz: 1 +# local_window_sz: 2 +# min_blocks: 4 +# sparse_ratio: 0.3 +# retrieval_stride: 5 + # Or for GSA: + # GSA: {} + + +# Whether to use layerwise loading/saving (optional, default: True for UnifiedCacheConnectorV1) +# use_layerwise: true +# hit_ratio: 0.9 + diff --git a/setup.py b/setup.py index a86d6d513..8b4c8f660 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ import subprocess import sys import sysconfig +from glob import glob import pybind11 import torch @@ -45,6 +46,10 @@ def _is_npu() -> bool: return PLATFORM == "ascend" +def _is_musa() -> bool: + return PLATFORM == "musa" + + class CMakeExtension(Extension): def __init__(self, name: str, sourcedir: str = ""): super().__init__(name, sources=[]) @@ -84,10 +89,12 @@ def build_cmake(self, ext: CMakeExtension): cmake_args.append("-DRUNTIME_ENVIRONMENT=cuda") elif _is_npu(): cmake_args.append("-DRUNTIME_ENVIRONMENT=ascend") + elif _is_musa(): + cmake_args.append("-DRUNTIME_ENVIRONMENT=musa") else: raise RuntimeError( "No supported accelerator found. " - "Please ensure either CUDA or NPU is available." + "Please ensure either CUDA/MUSA or NPU is available." ) cmake_args.append(ext.sourcedir) @@ -104,17 +111,37 @@ def build_cmake(self, ext: CMakeExtension): ) +def _get_package_data_with_so(): + """Automatically discover all packages and include .so files.""" + packages = find_packages() + package_data = {} + + for package in packages: + # Convert package name to directory path + package_dir = os.path.join(ROOT_DIR, package.replace(".", os.sep)) + + # Check if this package directory contains .so files + so_files = glob(os.path.join(package_dir, "*.so")) + if so_files: + package_data[package] = ["*.so"] + print(f"[INFO] Including .so files for package: {package}") + + print(f"[INFO] Package data: {package_data}") + return package_data + + ext_modules = [] ext_modules.append(CMakeExtension(name="ucm", sourcedir=ROOT_DIR)) setup( - name="ucm", - version="0.0.2", + name="uc-manager", + version="0.1.0", description="Unified Cache Management", author="Unified Cache Team", packages=find_packages(), python_requires=">=3.10", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuild}, + package_data=_get_package_data_with_so(), zip_safe=False, ) diff --git a/test/.gitignore b/test/.gitignore new file mode 100644 index 000000000..220d21ac1 --- /dev/null +++ b/test/.gitignore @@ -0,0 +1,13 @@ +reports/ +dataset/ +logs/ +result_outputs/ +results/ +.cache/ +backup/ +$null +*__pycache__/ +.* +*.log +start.bat +!.gitignore \ No newline at end of file diff --git a/test/README.md b/test/README.md new file mode 100644 index 000000000..1e11da7e7 --- /dev/null +++ b/test/README.md @@ -0,0 +1,179 @@ +# Pytest +[简体中文](README_zh.md) +A comprehensive Pytest testing framework featuring configuration management, database integration, performance testing, and HTML report generation. + +## 📋 Features + +- **Modern Testing Framework**: Complete test solution built on Pytest 7.0+ +- **Configuration Management**: YAML-based config with thread-safe singleton pattern +- **Database Integration**: Built-in MySQL support with automatic result storage +- **HTML Reports**: Auto-generated pytest HTML test reports +- **Tagging System**: Multi-dimensional test tags (stage, feature, platform, etc.) + +## 🗂️ Project Structure + +``` +pytest_demo/ +├── common/ # Common modules +│ ├── __init__.py +│ ├── config_utils.py # Configuration utilities +│ ├── db_utils.py # Database utilities +│ └── capture_utils # Return-value capture utilities +├── results/ # Result storage folder +├── suites/ # Test suites +│ ├── UnitTest # Unit tests +│ ├── Feature # Feature tests +│ └── E2E/ # End-to-end tests +│ └── test_demo_performance.py # Sample test file +├── config.yaml # Main config file +├── conftest.py # Pytest config +├── pytest.ini # Pytest settings +├── requirements.txt # Dependencies +└── README.md # This doc (CN) +``` + +## 🚀 Quick Start + +### Prerequisites + +- Python 3.8+ +- MySQL 5.7+ (optional, for DB features) +- Git + +### Installation + +1. **Install dependencies** + ```bash + pip install -r requirements.txt + ``` + +2. **Configure database** (optional) + + Edit `config.yaml`: + ```yaml + database: + backup: "results/" + host: "127.0.0.1" + port: 3306 + name: "ucm_pytest" + user: "root" + password: "123456" + charset: "utf8mb4" + ``` + +3. **Run tests** + ```bash + # Run all tests + pytest + + # Run tests by tag + pytest --stage=1 + pytest --feature=performance + ``` + +## ⚙️ Configuration + +### config.yaml + +Full YAML-based config. Key sections: + +- **reports**: Report settings (HTML, timestamp, etc.) +- **database**: MySQL connection details + +## 🧪 Test Examples + +### Basic functional test + +```python +# suites/E2E/test_demo_performance.py +import pytest + +@pytest.fixture(scope="module", name="calc") +def calculator(): + return Calculator() + +@pytest.mark.feature("mark") +class TestCalculator: + def test_add(self, calc): + assert calc.add(1, 2) == 3 + + def test_divide_by_zero(self, calc): + with pytest.raises(ZeroDivisionError): + calc.divide(6, 0) +``` + +## 🏷️ Tagging System + +Multi-dimensional tags supported: + +### Stage tags +- `stage(0)`: Unit tests +- `stage(1)`: Smoke tests +- `stage(2)`: Regression tests +- `stage(3)`: Release tests + +### Functional tags +- `feature`: Module tag +- `platform`: Platform tag (GPU/NPU) + +### Usage + +```bash +# Run smoke tests and above +pytest --stage=1+ + +# Run by feature +pytest --feature=performance +pytest --feature=performance,reliability + +# Run by platform +pytest --platform=gpu +``` + +### HTML Reports + +Auto-generated timestamped HTML reports: +- Location: `reports/pytest_YYYYMMDD_HHMMSS/report.html` +- Detailed results, errors, timing +- Customizable title & style + +### Database Storage + +If enabled, results are auto-saved to MySQL. +To add new record types, ask DB admin to create tables; otherwise only local files are used. + +Example: +```python +@pytest.mark.feature("capture") # Must be top decorator +@export_vars +def test_capture_mix(): + assert 1 == 1 + return { + '_name': 'demo', + '_data': { + 'length': 10086, # single value + 'accuracy': [0.1, 0.2, 0.3], # list + 'loss': [0.1, 0.2, 0.3], # list + } + } +``` + +### Config Access + +Read settings easily: +```python +from common.config_utils import config_utils +# Get config +db_config = config_utils.get_config("database") +api_config = config_utils.get_nested_config("easyPerf.api") +``` + +## 🛠️ Development Guide + +### Adding New Tests + +1. Create test files under `suites/` categories +2. Apply appropriate tags +3. Naming: `test_*.py` +4. Use fixtures & marks for data management +5. Keep custom marks concise and aligned with overall goals \ No newline at end of file diff --git a/test/README_zh.md b/test/README_zh.md new file mode 100644 index 000000000..26b0f393a --- /dev/null +++ b/test/README_zh.md @@ -0,0 +1,182 @@ +# Pytest 项目 + Pytest 测试框架,包括配置管理、数据库集成、性能测试和 HTML 报告生成。 + +## 📋 项目特性 + +- **现代化测试框架**: 基于 Pytest 7.0+ 的完整测试解决方案 +- **配置管理**: 支持 YAML 配置文件,线程安全的单例模式配置管理 +- **数据库集成**: 内置 MySQL 数据库支持,自动结果存储 +- **HTML 报告**: 自动生成pytest HTML 测试报告 +- **标记系统**: 支持多维度测试标记(阶段、功能、平台等) + +## 🗂️ 项目结构 + +``` +pytest_demo/ +├── common/ # 公共模块 +│ ├── __init__.py +│ ├── config_utils.py # 配置管理工具 +│ ├── db_utils.py # 数据库工具 +│ └── capture_utils # 返回值捕获工具 +├── results/ # 结果存储目录 +├── suites/ # 测试套件 +│ ├── UnitTest # 单元测试 +│ ├── Feature # 功能测试 +│ └── E2E/ # 端到端测试 +│ └── test_demo_performance.py # 示例测试文件 +├── config.yaml # 主配置文件 +├── conftest.py # Pytest 配置文件 +├── pytest.ini # Pytest 配置 +├── requirements.txt # 项目依赖 +└── README.md # 本文档 +``` + +## 🚀 快速开始 + +### 环境要求 + +- Python 3.8+ +- MySQL 5.7+ (可选,用于数据库功能) +- Git + +### 安装步骤 + +1. **安装依赖** + ```bash + pip install -r requirements.txt + ``` + +2. **配置数据库**(可选) + + 编辑 `config.yaml` 文件中的数据库配置: + ```yaml + database: + backup: "results/" + host: "127.0.0.1" + port: 3306 + name: "ucm_pytest" + user: "root" + password: "123456" + charset: "utf8mb4" + ``` + +3. **运行测试** + ```bash + # 运行所有测试 + pytest + + # 运行特定标记的测试 + pytest --stage=1 + pytest --feature=performance + ``` + +## ⚙️ 配置说明 + + +### config.yaml 配置 + +项目支持完整的 YAML 配置管理,主要配置项包括: + +- **reports**: 报告配置(HTML 报告、时间戳等) +- **database**: 数据库连接配置 + +## 🧪 测试示例 + +### 基础功能测试 + +```python +# suites/E2E/test_demo_performance.py +import pytest + +@pytest.fixture(scope="module", name="calc") +def calculator(): + return Calculator() + +@pytest.mark.feature("mark") +class TestCalculator: + def test_add(self, calc): + assert calc.add(1, 2) == 3 + + def test_divide_by_zero(self, calc): + with pytest.raises(ZeroDivisionError): + calc.divide(6, 0) +``` + +## 🏷️ 测试标记系统 + +项目支持多维度的测试标记: + +### 测试阶段标记 +- `stage(0)`: 单元测试 +- `stage(1)`: 冒烟测试 +- `stage(2)`: 回归测试 +- `stage(3)`: 发布测试 + +### 功能标记 +- `feature`: 功能模块标记 +- `platform`: 平台标记(GPU/NPU) + +### 使用示例 + +```bash +# 运行冒烟测试及以上的所有测试 +pytest --stage=1+ + +# 运行特定功能的测试 +pytest --feature=performance +pytest --feature=performance, reliability +# 运行特定平台的测试 +pytest --platform=gpu +``` + + +### HTML 报告 + +项目自动生成带时间戳的 HTML 测试报告: +- 报告位置:`reports/pytest_YYYYMMDD_HHMMSS/report.html` +- 包含详细的测试结果、错误信息和执行时间 +- 支持自定义报告标题和样式 + +### 数据库存储 + +如果启用数据库功能,测试结果会自动存储到 MySQL 数据库。 +若需要新增记录,请联系管理人员在数据库新增对应表;否则只能保存至本地文件。 +使用方式示例: +```python +@pytest.mark.feature("capture") # pytest 的标签必须在上面,否则无法正常使用标记功能 +@export_vars +def test_capture_mix(): + assert 1 == 1 + return { + '_name': 'demo', + '_data': { + 'length': 10086, # single value + 'accuracy': [0.1, 0.2, 0.3], # list + 'loss': [0.1, 0.2, 0.3], # list + } + } + +``` + + +### 配置管理 + +可以通过配置工具便捷读取参数: +```python +from common.config_utils import config_utils +# 获取配置 +db_config = config_utils.get_config("database") +api_config = config_utils.get_nested_config("easyPerf.api") +``` + + + +## 🛠️ 开发指南 + +### 添加新测试 + +1. 在 `suites/` 目录下的各个分类下创建新的测试文件 +2. 使用适当的测试标记 +3. 遵循命名规范:`test_*.py` +4. 使用 fixture 及mark 进行测试数据管理 +5. 自定义 mark 标签不易过细,应当与整体功能目标相符合 \ No newline at end of file diff --git a/test/common/__init__.py b/test/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/common/capture_utils.py b/test/common/capture_utils.py new file mode 100644 index 000000000..b12b76637 --- /dev/null +++ b/test/common/capture_utils.py @@ -0,0 +1,97 @@ +import functools +from typing import Any, Dict, List + +from common.db_utils import write_to_db + + +def _align_and_split(name: str, data: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Align a mixed data package (single values and/or lists) and split it into + """ + if not data: + return [] + + aligned: Dict[str, List[Any]] = {} + lengths: Dict[str, int] = {} + for k, v in data.items(): + if isinstance(v, (list, tuple)): + aligned[k] = list(v) + else: + aligned[k] = [v] + lengths[k] = len(aligned[k]) + + max_len = max(lengths.values()) + + for k, lst in aligned.items(): + if len(lst) < max_len: + lst.extend([lst[-1]] * (max_len - len(lst))) + + return [{k: aligned[k][i] for k in aligned} for i in range(max_len)] + + +def post_process(table_name: str, **kwargs) -> List[Dict[str, Any]]: + """ + Unified post-processing entry point. Supports two calling styles: + """ + results = [] + if "_data" in kwargs: + name = kwargs.get("_name", table_name) + results = _align_and_split(name, kwargs["_data"]) + for result in results: + write_to_db(name, result) + return results + return [] + + +# ---------------- decorator ---------------- +def export_vars(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + # If the function returns a dict containing '_data' or 'data', post-process it + if isinstance(result, dict): + if "_data" in result or "data" in result: + return post_process(func.__name__, **result) + # Otherwise return unchanged + return result + + return wrapper + + +# ---------------- usage examples ---------------- +@export_vars +def capture(): + """All single values via 'name' + 'data'""" + return {"name": "demo", "_data": {"accuracy": 0.1, "loss": 0.3}} + + +@export_vars +def capture_list(): + """All lists via '_name' + '_data'""" + return { + "_name": "demo", + "_data": { + "accuracy": [0.1, 0.2, 0.3], + "loss": [0.1, 0.2, 0.3], + }, + } + + +@export_vars +def capture_mix(): + """Mixed single + lists via '_name' + '_data'""" + return { + "_name": "demo", + "_data": { + "length": 10086, # single value + "accuracy": [0.1, 0.2, 0.3], # list + "loss": [0.1, 0.2, 0.3], # list + }, + } + + +# quick test +if __name__ == "__main__": + print("capture(): ", capture()) + print("capture_list(): ", capture_list()) + print("capture_mix(): ", capture_mix()) diff --git a/test/common/config_utils.py b/test/common/config_utils.py new file mode 100644 index 000000000..106f783ee --- /dev/null +++ b/test/common/config_utils.py @@ -0,0 +1,86 @@ +import os +import threading +from typing import Any, Dict + +import yaml + + +class ConfigUtils: + """ + Singleton Configuration Utility + Provides methods to read and access YAML configuration files. + """ + + _instance = None + _lock = threading.Lock() # Ensure thread-safe singleton creation + + def __init__(self): + self._config = None + + def __new__(cls, config_file: str = None): + # Double-checked locking + if cls._instance is None: + with cls._lock: + if cls._instance is None: + instance = super().__new__(cls) + instance._init_config(config_file) + cls._instance = instance + return cls._instance + + def _init_config(self, config_file: str = None): + """Initialize configuration file path and load config""" + if config_file is None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_file = os.path.join(current_dir, "..", "config.yaml") + + self.config_file = os.path.abspath(config_file) + self._config = None # Lazy load + + def _load_config(self) -> Dict[str, Any]: + """Internal method to read configuration from file""" + try: + with open(self.config_file, "r", encoding="utf-8") as f: + return yaml.safe_load(f) or {} + except FileNotFoundError: + print(f"[WARN] Config file not found: {self.config_file}") + return {} + except yaml.YAMLError as e: + print(f"[ERROR] Failed to parse YAML config: {e}") + return {} + + def read_config(self) -> Dict[str, Any]: + """Read configuration file (lazy load)""" + if self._config is None: + self._config = self._load_config() + return self._config + + def reload_config(self): + """Force reload configuration file""" + self._config = self._load_config() + + def get_config(self, key: str, default: Any = None) -> Any: + """Get top-level configuration item""" + config = self.read_config() + return config.get(key, default) + + def get_nested_config(self, key_path: str, default: Any = None) -> Any: + """Get nested configuration, e.g., 'influxdb.host'""" + config = self.read_config() + keys = key_path.split(".") + value = config + try: + for k in keys: + value = value[k] + return value + except (KeyError, TypeError): + return default + + +# Global instance +config_utils = ConfigUtils() + +if __name__ == "__main__": + print("DataBase config:", config_utils.get_config("database")) + print( + "DataBase host:", config_utils.get_nested_config("database.host", "localhost") + ) diff --git a/test/common/db_utils.py b/test/common/db_utils.py new file mode 100644 index 000000000..089af43b2 --- /dev/null +++ b/test/common/db_utils.py @@ -0,0 +1,183 @@ +import json +import logging +import threading +from pathlib import Path +from typing import Any, Dict, Optional + +import peewee +from common.config_utils import config_utils as config_instance +from peewee import AutoField, Model, MySQLDatabase, TextField + +logger = logging.getLogger("db_handler") +logger.setLevel(logging.DEBUG) + +# Avoid adding handlers multiple times +if not logger.handlers: + logger.setLevel(logging.DEBUG) + +# Global DB instance and lock for thread-safe singleton +_db_instance: Optional[MySQLDatabase] = None +_db_lock = threading.Lock() +_test_build_id: Optional[str] = None +_backup_path: Optional[Path] = None +_db_enabled: bool = False # from config + + +def _get_db() -> Optional[MySQLDatabase]: + """Return a singleton MySQLDatabase instance based on YAML configuration.""" + global _db_instance, _backup_path, _db_enabled + + if _db_instance is None: + with _db_lock: + if _db_instance is None: + db_config = config_instance.get_config("database", {}) + _db_enabled = db_config.get("enabled", False) + + backup_str = db_config.get("backup", "results/") + _backup_path = Path(backup_str).resolve() + _backup_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Backup directory set to: {_backup_path}") + + if not _db_enabled: + return None + + try: + _db_instance = MySQLDatabase( + db_config.get("name", "test_db"), + user=db_config.get("user", "root"), + password=db_config.get("password", ""), + host=db_config.get("host", "localhost"), + port=db_config.get("port", 3306), + charset=db_config.get("charset", "utf8mb4"), + ) + logger.info( + f"Database instance created for: {_db_instance.database}" + ) + except Exception as e: + logger.error(f"Failed to create database instance: {e}") + _db_instance = None + + return _db_instance + + +def _set_test_build_id(build_id: Optional[str] = None) -> None: + """Set or generate a unique test build ID.""" + global _test_build_id + _test_build_id = build_id or "default_build_id" + logger.debug(f"Test build ID set to: {_test_build_id}") + + +def _get_test_build_id() -> str: + """Return the current test build ID, generating one if necessary.""" + global _test_build_id + if _test_build_id is None: + _set_test_build_id() + return _test_build_id + + +class BaseEntity(Model): + """Base PeeWee model class using the singleton database.""" + + class Meta: + database = _get_db() + + +def _backup_to_file(table_name: str, data: Dict[str, Any]) -> None: + """Write data to a JSON Lines (.jsonl) file in the backup directory.""" + if not _backup_path: + logger.warning("Backup path is not set. Skipping backup.") + return + + file_path = _backup_path / f"{table_name}.jsonl" + try: + file_path.parent.mkdir(parents=True, exist_ok=True) + with file_path.open("a", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False) + f.write("\n") + logger.info(f"Data backed up to {file_path}") + except Exception as e: + logger.error(f"Failed to write backup file {file_path}: {e}") + + +def write_to_db(table_name: str, data: Dict[str, Any]) -> bool: + """ + Attempt to insert data into the specified database table. + If the table doesn't exist or an error occurs, back up to a JSONL file. + """ + db = _get_db() + data["test_build_id"] = _get_test_build_id() + + # Skip DB entirely if disabled + if not _db_enabled or db is None: + _backup_to_file(table_name, data) + return False + + try: + if not db.table_exists(table_name): + logger.warning(f"Table '{table_name}' does not exist. Writing to backup.") + _backup_to_file(table_name, data) + return False + + # Get existing columns and filter data + columns = db.get_columns(table_name) + col_names = {col.name for col in columns} + filtered_data = {k: v for k, v in data.items() if k in col_names} + + # Build dynamic model for insertion + fields = {"id": AutoField()} + for col in columns: + if col.name != "id": + fields[col.name] = TextField(null=True) + + DynamicEntity = type( + f"{table_name.capitalize()}DynamicModel", + (BaseEntity,), + { + "Meta": type("Meta", (), {"database": db, "table_name": table_name}), + **fields, + }, + ) + + with db.atomic(): + DynamicEntity.insert(filtered_data).execute() + logger.info(f"Successfully inserted data into table '{table_name}'.") + return True + + except peewee.PeeweeException as e: + logger.error( + f"Database write error for table '{table_name}': {e}", exc_info=True + ) + except Exception as e: + logger.critical( + f"Unexpected error during DB write for '{table_name}': {e}", exc_info=True + ) + + # Fallback to backup on any failure + _backup_to_file(table_name, data) + return False + + +def database_connection(build_id: str) -> None: + """Test database connection and set the build ID.""" + logger.info(f"Setting test build ID: {build_id}") + _set_test_build_id(build_id) + + db = _get_db() + if not _db_enabled: + logger.info("Database connection skipped because enabled=false.") + return + + if db is None: + logger.error("No database instance available.") + return + + logger.info(f"Attempting connection to database: {db.database}") + try: + db.connect(reuse_if_open=True) + logger.info("Database connection successful.") + except Exception as e: + logger.error(f"Database connection failed: {e}", exc_info=True) + finally: + if not db.is_closed(): + db.close() + logger.debug("Database connection closed.") diff --git a/test/common/llmperf/__init__.py b/test/common/llmperf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/common/llmperf/run_inference.py b/test/common/llmperf/run_inference.py new file mode 100644 index 000000000..b04deb1ea --- /dev/null +++ b/test/common/llmperf/run_inference.py @@ -0,0 +1,185 @@ +import json +import os +import random +from pathlib import Path +from typing import Any, Dict, List + +import yaml +from common.llmperf.utils.token_benchmark import run_token_benchmark +from common.llmperf.utils.utils import reset_prefill_cache + + +def run_test_cases( + llm_api, + model, + timeout, + max_num_completed_requests, + concurrent_requests, + mean_input_tokens, + stddev_input, + mean_output_tokens, + stddev_output, + additional_sampling_params, + timestamp_dir, + server_url, + tokenizer_path, + hit_rate, +): + print(f"[INFO] Total {len(mean_input_tokens)} test cases to be executed") + all_summaries = [] + failed_case = [] + + # Clear proxy environment variables + env = os.environ.copy() + env.pop("http_proxy", None) + env.pop("https_proxy", None) + + for i, ( + mean_input, + mean_output, + max_completed, + concurrent, + additional_sampling_params, + hit_rate_val, + ) in enumerate( + zip( + mean_input_tokens, + mean_output_tokens, + max_num_completed_requests, + concurrent_requests, + additional_sampling_params, + hit_rate, + ), + start=1, + ): + # for i, case in enumerate(mean_input_tokens): + print(f"\n>>> Executing test case {i} <<<") + reset_prefill_cache(env, server_url) + # Use a fixed random_seed for each test to control PC hit_rate + random_seed = random.randint(1, 100000) + + try: + # Determine if two runs are needed (PC hit_rate test) + if hit_rate_val == 0: + summary = run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=mean_output, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "normal"}, + ) + else: + print( + f"[INFO] hit_rate > 0 detected, entering prefill mode, PC hit rate: {hit_rate_val} %" + ) + # hit_rate > 0: first prefill mode + prefill_mean_input = int(mean_input * hit_rate_val / 100) + print( + f"[INFO] Prefill execution: mean_input_tokens={prefill_mean_input}" + ) + run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=prefill_mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=2, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "prefill"}, + ) + reset_prefill_cache(env, server_url) + # Then run normal mode + print("[INFO] Prefill completed, switching to normal mode execution") + summary = run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=mean_output, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "normal"}, + ) + all_summaries.append(summary) + except Exception as e: + print(f"[Warning] {e}") + failed_case.append(i) + + return all_summaries, failed_case + + +def inference_results( + mean_input_tokens, + mean_output_tokens, + max_num_completed_requests, + concurrent_requests, + additional_sampling_params, + hit_rate, +): + config_file = Path(__file__).parent.parent.parent / "config.yaml" + print("[INFO] Initialization complete, starting main process") + print(f"[INFO] Reading configuration file: {config_file}") + with open(config_file, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + llm_api = config.get("llm_connection", {}).get("llm_api", "openai") + model = config.get("llm_connection", {}).get("model", "") + test_timeout_s = config.get("llm_connection", {}).get("test_timeout_s", 60000) + stddev_input_tokens = config.get("llm_connection", {}).get( + "stddev_input_tokens", 0 + ) + stddev_output_tokens = config.get("llm_connection", {}).get( + "stddev_output_tokens", 0 + ) + timestamp_dir = Path("results") + timestamp_dir.mkdir(parents=True, exist_ok=True) + server_url = config.get("llm_connection", {}).get("server_url", "") + tokenizer_path = config.get("llm_connection", {}).get("tokenizer_path", "") + print(f"[INFO] Created results directory: {timestamp_dir}") + + all_summaries, failed_cases = run_test_cases( + llm_api, + model, + test_timeout_s, + max_num_completed_requests, + concurrent_requests, + mean_input_tokens, + stddev_input_tokens, + mean_output_tokens, + stddev_output_tokens, + additional_sampling_params, + timestamp_dir, + server_url, + tokenizer_path, + hit_rate, + ) + total = len(mean_input_tokens) + print( + f"\n[INFO] All tests completed! Success: {total - len(failed_cases)}/{total}" + ) + if failed_cases: + print(f"[WARN] Failed case indices: {failed_cases}") + return all_summaries diff --git a/test/common/llmperf/utils/__init__.py b/test/common/llmperf/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/common/llmperf/utils/common_metrics.py b/test/common/llmperf/utils/common_metrics.py new file mode 100644 index 000000000..40e21124e --- /dev/null +++ b/test/common/llmperf/utils/common_metrics.py @@ -0,0 +1,17 @@ +# TODO (Avnishn): compute metrics in class +INTER_TOKEN_LAT = "inter_token_latency_s" +TTFT = "ttft_s" +E2E_LAT = "end_to_end_latency_s" +NUM_INPUT_TOKENS = "number_input_tokens" +NUM_OUTPUT_TOKENS = "number_output_tokens" +NUM_TOTAL_TOKENS = "number_total_tokens" +REQ_OUTPUT_THROUGHPUT = "request_output_throughput_token_per_s" +ERROR_MSG = "error_msg" +ERROR_CODE = "error_code" +ERROR_CODE_FREQ = "error_code_frequency" +NUM_ERRORS = "number_errors" +OUTPUT_THROUGHPUT = "mean_output_throughput_token_per_s" +NUM_COMPLETED_REQUESTS = "num_completed_requests" +COMPLETED_REQUESTS_PER_MIN = "num_completed_requests_per_min" +ERROR_RATE = "error_rate" +NUM_REQ_STARTED = "num_requests_started" diff --git a/test/common/llmperf/utils/models.py b/test/common/llmperf/utils/models.py new file mode 100644 index 000000000..1cbab6281 --- /dev/null +++ b/test/common/llmperf/utils/models.py @@ -0,0 +1,23 @@ +from typing import Any, Dict, Optional, Tuple + +from pydantic import BaseModel + + +class RequestConfig(BaseModel): + """The configuration for a request to the LLM API. + + Args: + model: The model to use. + prompt: The prompt to provide to the LLM API. + sampling_params: Additional sampling parameters to send with the request. + For more information see the Router app's documentation for the completions + llm_api: The name of the LLM API to send the request to. + metadata: Additional metadata to attach to the request for logging or validation purposes. + """ + + model: str + prompt: Tuple[str, int] + sampling_params: Optional[Dict[str, Any]] = None + llm_api: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + openai_api_base: Optional[str] = "" diff --git a/test/common/llmperf/utils/openai_chat_completions_client.py b/test/common/llmperf/utils/openai_chat_completions_client.py new file mode 100644 index 000000000..5023bfa1f --- /dev/null +++ b/test/common/llmperf/utils/openai_chat_completions_client.py @@ -0,0 +1,136 @@ +import json +import os +import time +from asyncio import timeout +from pathlib import Path +from typing import Any, Dict, Tuple + +import requests +import yaml +from common.llmperf.utils import common_metrics +from common.llmperf.utils.models import RequestConfig + +config_file = Path(__file__).parent.parent.parent.parent / "config.yaml" +with open(config_file, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) +stream = config.get("llm_connection", {}).get("stream", True) +ignore_eos = config.get("llm_connection", {}).get("ignore_eos", True) +timeout = config.get("llm_connection", {}).get("timeout", 180) + + +class OpenAIChatCompletionsClient: + """ + used for sending HTTP requests, receiving token streams, measuring latency, etc. + """ + + def llm_request( + self, request_config: RequestConfig + ) -> Tuple[Dict[str, Any], str, RequestConfig]: + prompt, prompt_len = request_config.prompt + + message = [ + {"role": "user", "content": prompt}, + ] + model = request_config.model + body = { + "model": model, + "messages": message, + "stream": stream, + "ignore_eos": ignore_eos, + } + sampling_params = request_config.sampling_params + body.update(sampling_params or {}) + + time_to_next_token = [] + tokens_received = 0 + ttft = 0.0 + error_response_code = None + generated_text = "" + error_msg = "" + output_throughput = 0.0 + total_request_time = 0.0 + flag = False + + metrics: Dict[str, Any] = {} + + metrics[common_metrics.ERROR_CODE] = None + metrics[common_metrics.ERROR_MSG] = "" + + start_time = time.monotonic() + most_recent_received_token_time = start_time + + address = request_config.openai_api_base + + if not address: + raise ValueError("the environment variable OPENAI_API_BASE must be set.") + key = os.environ.get("OPENAI_API_KEY", "secret_abcdefg") + if not key: + raise ValueError("the environment variable OPENAI_API_KEY must be set.") + headers = {"Authorization": f"Bearer {key}"} + if not address.endswith("/"): + address = address + "/" + address += "chat/completions" + try: + with requests.post( + address, + json=body, + stream=stream, + timeout=timeout, + headers=headers, + ) as response: + if response.status_code != 200: + error_msg = response.text + error_response_code = response.status_code + response.raise_for_status() + + for chunk in response.iter_lines(chunk_size=None): + if not chunk: + continue + stem = b"data: " + if chunk.startswith(stem): + chunk = chunk[len(stem) :] + # Data might already be bytes or str + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8", errors="ignore") + if chunk.strip() == "[DONE]": + continue + tokens_received += 1 + data = json.loads(chunk) + if "error" in data: + error_msg = data["error"]["message"] + error_response_code = data["error"]["code"] + raise RuntimeError(error_msg) + delta = data["choices"][0]["delta"] + content = delta.get("content", None) or delta.get( + "reasoning_content", "" + ) + if content: + if tokens_received != 0 and flag == False: + ttft = time.monotonic() - start_time + flag = True + else: + time_to_next_token.append( + time.monotonic() - most_recent_received_token_time + ) + most_recent_received_token_time = time.monotonic() + generated_text += content + + total_request_time = time.monotonic() - start_time + if total_request_time > 0: + output_throughput = tokens_received / total_request_time + + except Exception as e: + metrics[common_metrics.ERROR_MSG] = error_msg + metrics[common_metrics.ERROR_CODE] = error_response_code + print(f"Warning Or Error: {e}") + print(error_response_code) + + metrics[common_metrics.INTER_TOKEN_LAT] = sum(time_to_next_token) + metrics[common_metrics.TTFT] = ttft + metrics[common_metrics.E2E_LAT] = total_request_time + metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = output_throughput + metrics[common_metrics.NUM_TOTAL_TOKENS] = tokens_received + prompt_len + metrics[common_metrics.NUM_OUTPUT_TOKENS] = tokens_received + metrics[common_metrics.NUM_INPUT_TOKENS] = prompt_len + + return metrics, generated_text, request_config diff --git a/test/common/llmperf/utils/sonnet.txt b/test/common/llmperf/utils/sonnet.txt new file mode 100644 index 000000000..9f13ead47 --- /dev/null +++ b/test/common/llmperf/utils/sonnet.txt @@ -0,0 +1,84 @@ +Shall I compare thee to a summer's day? +Thou art more lovely and more temperate: +Rough winds do shake the darling buds of May, +And summer's lease hath all too short a date: +Sometime too hot the eye of heaven shines, +And often is his gold complexion dimm'd; +And every fair from fair sometime declines, +By chance or nature's changing course untrimm'd; +But thy eternal summer shall not fade +Nor lose possession of that fair thou owest; +Nor shall Death brag thou wander'st in his shade, +When in eternal lines to time thou growest: +So long as men can breathe or eyes can see, +So long lives this and this gives life to thee. +Then let not winter's ragged hand deface +In thee thy summer, ere thou be distill'd: +Make sweet some vial; treasure thou some place +With beauty's treasure, ere it be self-kill'd. +That use is not forbidden usury, +Which happies those that pay the willing loan; +That's for thyself to breed another thee, +Or ten times happier, be it ten for one; +Ten times thyself were happier than thou art, +If ten of thine ten times refigured thee: +Then what could death do, if thou shouldst depart, +Leaving thee living in posterity? +Be not self-will'd, for thou art much too fair +To be death's conquest and make worms thine heir. +Where art thou, Muse, that thou forget'st so long +To speak of that which gives thee all thy might? +Spend'st thou thy fury on some worthless song, +Darkening thy power to lend base subjects light? +Return, forgetful Muse, and straight redeem +In gentle numbers time so idly spent; +Sing to the ear that doth thy lays esteem +And gives thy pen both skill and argument. +Rise, resty Muse, my love's sweet face survey, +If Time have any wrinkle graven there; +If any, be a satire to decay, +And make Time's spoils despised every where. +Give my love fame faster than Time wastes life; +So thou prevent'st his scythe and crooked knife. +My glass shall not persuade me I am old, +So long as youth and thou are of one date; +But when in thee time's furrows I behold, +Then look I death my days should expiate. +For all that beauty that doth cover thee +Is but the seemly raiment of my heart, +Which in thy breast doth live, as thine in me: +How can I then be elder than thou art? +O, therefore, love, be of thyself so wary +As I, not for myself, but for thee will; +Bearing thy heart, which I will keep so chary +As tender nurse her babe from faring ill. +Presume not on thy heart when mine is slain; +Thou gavest me thine, not to give back again. +So am I as the rich, whose blessed key +Can bring him to his sweet up-locked treasure, +The which he will not every hour survey, +For blunting the fine point of seldom pleasure. +Therefore are feasts so solemn and so rare, +Since, seldom coming, in the long year set, +Like stones of worth they thinly placed are, +Or captain jewels in the carcanet. +So is the time that keeps you as my chest, +Or as the wardrobe which the robe doth hide, +To make some special instant special blest, +By new unfolding his imprison'd pride. +Blessed are you, whose worthiness gives scope, +Being had, to triumph, being lack'd, to hope. +If there be nothing new, but that which is +Hath been before, how are our brains beguiled, +Which, labouring for invention, bear amiss +The second burden of a former child! +O, that record could with a backward look, +Even of five hundred courses of the sun, +Show me your image in some antique book, +Since mind at first in character was done! +That I might see what the old world could say +To this composed wonder of your frame; +Whether we are mended, or whether better they, +Or whether revolution be the same. +O, sure I am, the wits of former days +To subjects worse have given admiring praise. \ No newline at end of file diff --git a/test/common/llmperf/utils/token_benchmark.py b/test/common/llmperf/utils/token_benchmark.py new file mode 100644 index 000000000..67553cf1b --- /dev/null +++ b/test/common/llmperf/utils/token_benchmark.py @@ -0,0 +1,386 @@ +import json +import logging +import random +import re +import time +from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd +from common.llmperf.utils import common_metrics +from common.llmperf.utils.models import RequestConfig +from common.llmperf.utils.openai_chat_completions_client import ( + OpenAIChatCompletionsClient, +) +from common.llmperf.utils.utils import ( + LLMPerfResults, + randomly_sample_sonnet_lines_prompt, + sample_random_positive_int, +) +from transformers import AutoTokenizer + + +def get_token_throughput_latencies( + model: str, + mean_input_tokens: int, + stddev_input_tokens: int, + mean_output_tokens: int, + stddev_output_tokens: int, + additional_sampling_params: Optional[Dict[str, Any]] = None, + concurrent_requests: int = 1, + max_num_completed_requests: int = 500, + test_timeout_s=90, + llm_api="openai", + random_seed: int = None, + openai_api_base: str = "", + tokenizer_path: str = None, +) -> Tuple[Dict[str, Any], List[Dict[str, Any]], float, float]: + """Get the token throughput and latencies for the given model. + + Args: + model: The name of the model to query. + mean_input_tokens: The mean number of tokens to send in the prompt for the request. + stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request. + mean_output_tokens: The mean number of tokens to generate per request. + stddev_output_tokens: The standard deviation of the number of tokens to generate per request. + additional_sampling_params: Additional sampling parameters to send with the request. + For more information see the LLM APIs documentation for the completions + concurrent_requests: The number of concurrent requests to make. Increase + this to increase the amount of load and vice versa. + test_timeout_s: The amount of time to run the test for before reporting results. + llm_api: The name of the llm api to use. Either "openai" or "litellm". + + Returns: + A summary of the performance metrics collected across all completed requests + (e.g. throughput, latencies, etc.) + The individual metrics for each request. + """ + random.seed(random_seed) + + print(f"Using tokenizer:{tokenizer_path}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + get_token_length = lambda text: len(tokenizer.encode(text)) + + if not additional_sampling_params: + additional_sampling_params = {} + + # 1. create prompts + prompts: List[Tuple[str, int]] = [] + num_output_tokens_list: List[int] = [] + for i in range(max_num_completed_requests): + num_output = sample_random_positive_int( + mean_output_tokens, stddev_output_tokens + ) + num_output_tokens_list.append(num_output) + prompts.append( + randomly_sample_sonnet_lines_prompt( + prompt_tokens_mean=mean_input_tokens, + prompt_tokens_stddev=stddev_input_tokens, + tokenizer=tokenizer, + ) + ) + start_time = time.monotonic() + completed_requests: List[Dict[str, Any]] = [] + incremental_time_delay = 0.0 + client = OpenAIChatCompletionsClient() + futures = [] + + # 2. Submitting tasks using a thread pool + with ThreadPoolExecutor(max_workers=concurrent_requests) as executor: + for idx in range(max_num_completed_requests): + sampling = {"max_tokens": num_output_tokens_list[idx]} + sampling.update(additional_sampling_params) + cfg = RequestConfig( + model=model, + prompt=prompts[idx], + sampling_params=sampling, + llm_api=llm_api, + openai_api_base=openai_api_base, + ) + futures.append(executor.submit(client.llm_request, cfg)) + # 3. Waiting for completion or timeout + for future in as_completed(futures, timeout=test_timeout_s): + try: + metrics, gen_text, req_cfg = future.result() + except Exception as e: + logging.warning(f"[WARN] Future raised exception: {e}") + continue + num_output_tokens = get_token_length(gen_text) + if num_output_tokens: + metrics[common_metrics.INTER_TOKEN_LAT] /= ( + (metrics[common_metrics.NUM_OUTPUT_TOKENS] - 1) + if (metrics[common_metrics.NUM_OUTPUT_TOKENS] - 1) + else 1 + ) + metrics[common_metrics.NUM_OUTPUT_TOKENS] = num_output_tokens + metrics[common_metrics.NUM_TOTAL_TOKENS] = ( + metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens + ) + try: + metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = ( + num_output_tokens / metrics[common_metrics.E2E_LAT] + ) + except ZeroDivisionError: + logging.error("Division by zero in throughput calculation.") + + completed_requests.append(metrics) + + incremental_time_delay += metrics.get( + common_metrics.INTER_TOKEN_LAT, 0.0 + ) + + end_time = time.monotonic() + + print(f"Results for token benchmark for {model} queried with the {llm_api} api.\n") + if mean_output_tokens == 2: + print(f"[INFO] First token sending pre-embedding completed\n") + return {}, [], 0.0, 0.0 + + ret = metrics_summary(completed_requests, start_time, end_time) + + metadata = { + "model": model, + "mean_input_tokens": mean_input_tokens, + "stddev_input_tokens": stddev_input_tokens, + "mean_output_tokens": mean_output_tokens, + "stddev_output_tokens": stddev_output_tokens, + "concurrent_requests": concurrent_requests, + "additional_sampling_params": additional_sampling_params, + } + + metadata["results"] = ret + elapsed_time = end_time - start_time + return metadata, completed_requests, elapsed_time, incremental_time_delay + + +def compute_throughput( + summary: Dict[str, Any], + completed_requests: List[Dict[str, Any]], + elapsed_time: float, + incremental_time_delay: float, +) -> Tuple[float, float]: + """ + Compute total_throughput (token/s) based on the metrics in summary. + + Formula: (mean_output_tokens * num_completed_requests) / total_e2e_latency_s + + Args: + summary (Dict[str, Any]): A dictionary containing performance metrics. + + Returns: + float: The computed total throughput in tokens per second. Returns 0.0 if latency is zero. + """ + mean_output_tokens = summary.get("mean_output_tokens", 0) + + total_throughput = ( + (mean_output_tokens * len(completed_requests)) / elapsed_time + if elapsed_time > 0 + else 0.0 + ) + incremental_throughput = ( + (mean_output_tokens * len(completed_requests)) / incremental_time_delay + if incremental_time_delay > 0 + else 0.0 + ) + return round(total_throughput, 4), round(incremental_throughput, 4) + + +def metrics_summary( + metrics: List[Dict[str, Any]], start_time: int, end_time: int +) -> Dict[str, Any]: + """Generate a summary over metrics generated from potentially multiple instances of this client. + + Args: + metrics: The metrics to summarize. + start_time: The time the test started. + end_time: The time the test ended. + + Returns: + A summary with the following information: + - Overall throughput (generated tokens / total test time) + - Number of completed requests + - Error rate + - Error code frequency + - Quantiles (p25-p99) for the following metrics: + - Inter token latency + - Time to first token + - User total request time + - Number of tokens processed per request + - Number of tokens generated per request + - User throughput (tokens / s) + """ + ret = {} + + def flatten(item): + for sub_item in item: + if isinstance(sub_item, Iterable) and not isinstance(sub_item, str): + yield from flatten(sub_item) + else: + yield sub_item + + df = pd.DataFrame(metrics) + df_without_errored_req = df[df[common_metrics.ERROR_CODE].isna()] + + for key in [ + common_metrics.INTER_TOKEN_LAT, + common_metrics.TTFT, + common_metrics.E2E_LAT, + common_metrics.REQ_OUTPUT_THROUGHPUT, + common_metrics.NUM_INPUT_TOKENS, + common_metrics.NUM_OUTPUT_TOKENS, + ]: + print(key) + ret[key] = {} + series = pd.Series(list(flatten(df_without_errored_req[key]))).dropna() + series = series[series > 0] # Calculate non-zero values + quantiles = series.quantile([0.25, 0.5, 0.75, 0.9, 0.95, 0.99]).to_dict() + quantiles_reformatted_keys = {} + for quantile, value in quantiles.items(): + reformatted_key = f"p{int(quantile * 100)}" + print(f" {reformatted_key} = {value}") + quantiles_reformatted_keys[reformatted_key] = value + ret[key]["quantiles"] = quantiles_reformatted_keys + mean = series.mean() + print(f" mean = {mean}") + ret[key]["mean"] = mean + print(f" min = {series.min()}") + ret[key]["min"] = series.min() + print(f" max = {series.max()}") + ret[key]["max"] = series.max() + print(f" stddev = {series.std()}") + ret[key]["stddev"] = series.std() + + ret[common_metrics.NUM_REQ_STARTED] = len(metrics) + + error_codes = df[common_metrics.ERROR_CODE].dropna() + num_errors = len(error_codes) + ret[common_metrics.ERROR_RATE] = num_errors / len(metrics) if len(metrics) else 0 + ret[common_metrics.NUM_ERRORS] = num_errors + print(f"Number Of Errored Requests: {num_errors}") + error_code_frequency = dict(error_codes.value_counts()) + if num_errors: + error_code_frequency = dict(error_codes.value_counts()) + print("Error Code Frequency") + print(error_code_frequency) + ret[common_metrics.ERROR_CODE_FREQ] = str(error_code_frequency) + + overall_output_throughput = df_without_errored_req[ + common_metrics.NUM_OUTPUT_TOKENS + ].sum() / (end_time - start_time) + + print(f"Overall Output Throughput: {overall_output_throughput}") + ret[common_metrics.OUTPUT_THROUGHPUT] = overall_output_throughput + + num_completed_requests = len(df_without_errored_req) + num_completed_requests_per_min = ( + num_completed_requests / (end_time - start_time) * 60 + ) + print(f"Number Of Completed Requests: {num_completed_requests}") + print(f"Completed Requests Per Minute: {num_completed_requests_per_min}") + + ret[common_metrics.NUM_COMPLETED_REQUESTS] = num_completed_requests + ret[common_metrics.COMPLETED_REQUESTS_PER_MIN] = num_completed_requests_per_min + + return ret + + +def run_token_benchmark( + llm_api: str, + model: str, + test_timeout_s: int, + max_num_completed_requests: int, + concurrent_requests: int, + mean_input_tokens: int, + stddev_input_tokens: int, + mean_output_tokens: int, + stddev_output_tokens: int, + additional_sampling_params: str, + results_dir: str, + random_seed: int, + openai_api_base: str, + tokenizer_path: str, + user_metadata: Dict[str, Any], +): + """ + Args: + llm_api: The name of the llm api to use. + model: The name of the model to query. + max_num_completed_requests: The number of requests to complete before finishing the test. + test_timeout_s: The amount of time to run the test for before reporting results. + concurrent_requests: The number of concurrent requests to make. Increase + this to increase the amount of load and vice versa. + mean_input_tokens: The mean number of tokens to send in the prompt for the request. + stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request. + mean_output_tokens: The mean number of tokens to generate per request. + stddev_output_tokens: The standard deviation of the number of tokens to generate per request. + additional_sampling_params: Additional sampling parameters to send with the request. + For more information see the LLM APIs documentation for the completions. + results_dir: The directory to save the results to. + user_metadata: Additional metadata to include in the results. + """ + if mean_input_tokens < 40: + print( + "the minimum number of input tokens that will be sent is 41" + " because of the prompting logic right now" + ) + + summary, completed_requests, elapsed_time, incremental_time_delay = ( + get_token_throughput_latencies( + model=model, + llm_api=llm_api, + test_timeout_s=test_timeout_s, + max_num_completed_requests=max_num_completed_requests, + mean_input_tokens=mean_input_tokens, + stddev_input_tokens=stddev_input_tokens, + mean_output_tokens=mean_output_tokens, + stddev_output_tokens=stddev_output_tokens, + concurrent_requests=concurrent_requests, + additional_sampling_params=json.loads(additional_sampling_params), + random_seed=random_seed, + openai_api_base=openai_api_base, + tokenizer_path=tokenizer_path, + ) + ) + if mean_output_tokens == 2: + return summary, completed_requests, elapsed_time, incremental_time_delay + + timestamp = int(time.time() * 1000) + if results_dir: + filename = f"{model}_{mean_input_tokens}_{mean_output_tokens}_{timestamp}" + filename = re.sub(r"[^\w\d-]+", "-", filename) + filename = re.sub(r"-{2,}", "-", filename) + summary_filename = f"{filename}_summary" + + # Update to metadata. + summary.update(user_metadata) + total_tp, req_tp = compute_throughput( + summary, completed_requests, elapsed_time, incremental_time_delay + ) + summary["num_completed_requests"] = len(completed_requests) + summary["elapsed_time"] = elapsed_time + summary["incremental_time_delay"] = incremental_time_delay + summary["total_throughput"] = total_tp + summary["incremental_throughput"] = req_tp + + results = LLMPerfResults(name=summary_filename, metadata=summary) + results_dir = Path(results_dir) + if not results_dir.exists(): + results_dir.mkdir(parents=True) + elif not results_dir.is_dir(): + raise ValueError(f"{results_dir} is not a directory") + + llmperf_dir = results_dir / "llmperf" + if not llmperf_dir.exists(): + llmperf_dir.mkdir(parents=True) + elif not llmperf_dir.is_dir(): + raise ValueError(f"{llmperf_dir} is not a directory") + + try: + with open(llmperf_dir / f"{summary_filename}.json", "w") as f: + json.dump(results.to_dict(), f, indent=4, default=str) + except Exception as e: + print(results.to_dict()) + raise e + return summary diff --git a/test/common/llmperf/utils/utils.py b/test/common/llmperf/utils/utils.py new file mode 100644 index 000000000..e2c270871 --- /dev/null +++ b/test/common/llmperf/utils/utils.py @@ -0,0 +1,171 @@ +import hashlib +import json +import math +import os +import pathlib +import random +import subprocess +import time +from typing import Any, Dict, Tuple + +from transformers import LlamaTokenizerFast + +RESULTS_VERSION = "2025-10-30" + + +class LLMPerfResults: + def __init__( + self, + name: str, + metadata: Dict[str, Any] = None, + ): + self.name = name + self.metadata = metadata or {} + self.timestamp = int(time.time()) + self.metadata["timestamp"] = self.timestamp + self.version = RESULTS_VERSION + + def to_dict(self): + data = { + "version": self.version, + "name": self.name, + } + data.update(self.metadata) + data = flatten_dict(data) + return data + + def json(self): + data = self.to_dict() + return json.dumps(data) + + +def upload_to_s3(results_path: str, s3_path: str) -> None: + """Upload the results to s3. + + Args: + results_path: The path to the results file. + s3_path: The s3 path to upload the results to. + + """ + + command = ["aws", "s3", "sync", results_path, f"{s3_path}/"] + result = subprocess.run(command) + if result.returncode == 0: + print("Files uploaded successfully!") + else: + print("An error occurred:") + print(result.stderr) + + +def randomly_sample_sonnet_lines_prompt( + prompt_tokens_mean: int = 550, + prompt_tokens_stddev: int = 250, + tokenizer: LlamaTokenizerFast = None, +) -> Tuple[str, int]: + """Generate a prompt that randomly samples lines from a the shakespeare sonnet at sonnet.txt. + + Args: + prompt_length_mean: The mean length of the prompt to generate. + prompt_len_stddev: The standard deviation of the length of the prompt to generate. + expect_output_tokens: The number of tokens to expect in the output. This is used to + determine the length of the prompt. The prompt will be generated such that the output + will be approximately this many tokens. + + Note: + tokens will be counted from the sonnet using the Llama tokenizer. Using one tokenizer + ensures a fairer comparison across different LLMs. For example, if gpt 3.5 tokenizes + a prompt in less tokens than Llama2, then this will be reflected in the results since + they will be fed identical prompts. + + Returns: + A tuple of the prompt and the length of the prompt. + """ + get_token_length = lambda text: len(tokenizer.encode(text)) + + prompt = ( + "Randomly stream lines from the following text " + "Don't generate eos tokens:\n\n" + ) + # get a prompt length that is at least as long as the base + num_prompt_tokens = sample_random_positive_int( + prompt_tokens_mean, prompt_tokens_stddev + ) + while num_prompt_tokens < get_token_length(prompt): + num_prompt_tokens = sample_random_positive_int( + prompt_tokens_mean, prompt_tokens_stddev + ) + remaining_prompt_tokens = num_prompt_tokens - get_token_length(prompt) + sonnet_path = pathlib.Path(__file__).parent.resolve() / "sonnet.txt" + with open(sonnet_path, "r") as f: + sonnet_lines = f.readlines() + random.shuffle(sonnet_lines) + sampling_lines = True + while sampling_lines: + for line in sonnet_lines: + line_to_add = line + if remaining_prompt_tokens - get_token_length(line_to_add) < 0: + # This will cut off a line in the middle of a word, but that's ok since an + # llm should be able to handle that. + line_to_add = line_to_add[: int(math.ceil(remaining_prompt_tokens))] + sampling_lines = False + prompt += line_to_add + break + prompt += line_to_add + remaining_prompt_tokens -= get_token_length(line_to_add) + print(hashlib.sha256(prompt.encode("utf-8")).hexdigest()) + return (prompt, num_prompt_tokens) + + +def sample_random_positive_int(mean: int, stddev: int) -> int: + """Sample random numbers from a gaussian distribution until a positive number is sampled. + + Args: + mean: The mean of the gaussian distribution to sample from. + stddev: The standard deviation of the gaussian distribution to sample from. + + Returns: + A random positive integer sampled from the gaussian distribution. + """ + ret = -1 + while ret <= 0: + ret = int(random.gauss(mean, stddev)) + return ret + + +def flatten_dict(d, parent_key="", sep="_"): + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def reset_prefill_cache(env, server_url): + """ + prefix cache / HBM + Param: + env + server_url + """ + reset_url = f"{server_url}/reset_prefix_cache" + print(f"[INFO] Resetting prefix cache: {reset_url}") + try: + result = subprocess.run( + ["curl", "-X", "POST", reset_url, "-s", "-f"], + env=env, + check=False, + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + print("[INFO] Prefix cache successfully reset") + else: + print( + f"[ERROR] Unsuccessfully reset prefix cache,error code: {result.returncode}" + ) + except Exception as e: + print(f"[ERROR] Exception in resetting prefix cache: {e}") diff --git a/test/config.yaml b/test/config.yaml new file mode 100644 index 000000000..7ac32f484 --- /dev/null +++ b/test/config.yaml @@ -0,0 +1,27 @@ +reports: + base_dir: "results/reports" + use_timestamp: true + directory_prefix: "pytest" + html: # pytest-html + enabled: true + filename: "report.html" + title: "UCM Pytest Test Report" + +database: + backup: "results/" + enabled: true + host: "127.0.0.1" + port: 3306 + name: "ucm_pytest" + user: "root" + password: "123456" + charset: "utf8mb4" + +# LLM Connection Configuration +llm_connection: + model: "qwen3" + server_url: "http://141.111.32.70:9382" + tokenizer_path: "/home/models/QwQ-32B" + stream: true # stream output + ignore_eos: true # Ignore the returned terminator + timeout: 180 # request time out \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 000000000..150257952 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import datetime as dt +import platform as pf +import sys +from functools import wraps +from pathlib import Path + +import pytest +from common.config_utils import config_utils as config_instance +from common.db_utils import database_connection, write_to_db + +# ---------------- Constants ---------------- +PRJ_ROOT = Path(__file__).resolve().parent +sys.path.insert(0, str(PRJ_ROOT)) + + +# ---------------- CLI Options ---------------- +def pytest_addoption(parser): + parser.addoption( + "--stage", action="store", default="", help="Filter by stage marker (1,2,3,+)" + ) + parser.addoption( + "--feature", action="store", default="", help="Filter by feature marker" + ) + parser.addoption( + "--platform", action="store", default="", help="Filter by platform marker" + ) + + +# ---------------- Test Filtering ---------------- +def pytest_collection_modifyitems(config, items): + kept = items[:] + + markers = [m.split(":", 1)[0].strip() for m in config.getini("markers")] + for name in markers: + opt = config.getoption(f"--{name}", "").strip() + if not opt: + continue + + if name == "stage" and opt.endswith("+"): + min_stage = int(opt[:-1]) + kept = [ + it + for it in kept + if any(int(v) >= min_stage for v in _get_marker_args(it, "stage")) + ] + else: + wanted = {x.strip() for x in opt.split(",") if x.strip()} + kept = [ + it + for it in kept + if any(v in wanted for v in _get_marker_args(it, name)) + ] + + config.hook.pytest_deselected(items=[i for i in items if i not in kept]) + items[:] = kept + + +def _get_marker_args(item, marker_name): + """Extract only args (not kwargs) from markers, as strings.""" + return [ + str(arg) for mark in item.iter_markers(name=marker_name) for arg in mark.args + ] + + +# ---------------- Report Setup ---------------- +def _prepare_report_dir(config: pytest.Config) -> Path: + cfg = config_instance.get_config("reports", {}) + base_dir = Path(cfg.get("base_dir", "reports")) + prefix = cfg.get("directory_prefix", "pytest") + if cfg.get("use_timestamp", False): + ts = dt.datetime.now().strftime("%Y%m%d_%H%M%S") + report_dir = base_dir / f"{prefix}_{ts}" + else: + report_dir = base_dir + report_dir.mkdir(parents=True, exist_ok=True) + return report_dir + + +def _setup_html_report(config: pytest.Config, report_dir: Path) -> None: + reports_config = config_instance.get_config("reports", {}) + html_cfg = reports_config.get("html", {}) + if not html_cfg.get("enabled", True): + if hasattr(config.option, "htmlpath"): + config.option.htmlpath = None + print("HTML report disabled according to config.yaml") + return + + html_filename = html_cfg.get("filename", "report.html") + config.option.htmlpath = str(report_dir / html_filename) + config.option.self_contained_html = True + print("HTML report enabled") + + +# ---------------- Build ID & Session Init ---------------- +def _generate_build_id(config: pytest.Config) -> str: + ts = dt.datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + cli_parts = [] + markers = [m.split(":", 1)[0].strip() for m in config.getini("markers")] + for opt in markers: + val = config.getoption(opt, "") + if val: + cli_parts.append(f"{opt}={val}") + args_part = "_".join(cli_parts) if cli_parts else "all_cases" + return f"pytest_{ts}_{args_part}" + + +# ---------------- Pytest Hooks ---------------- +def pytest_configure(config: pytest.Config) -> None: + """The global configuration will be executed directly upon entering pytest.""" + print(f"Starting Test Session: {dt.datetime.now():%Y-%m-%d %H:%M:%S}") + + # Set up report directory + report_dir = _prepare_report_dir(config) + config._report_dir = report_dir # Attach to config for later use + _setup_html_report(config, report_dir) + + # Generate and register build ID into DB + build_id = _generate_build_id(config) + config._build_id = build_id + database_connection(build_id) + + +def pytest_sessionstart(session): + print("") + print("-" * 60) + print(f"{'Python':<10} │ {pf.python_version()}") + print(f"{'Platform':<10} │ {pf.system()} {pf.release()}") + print("-" * 60) + + +def pytest_sessionfinish(session, exitstatus): + report_dir = getattr(session.config, "_report_dir", "reports") + print("") + print("-" * 60) + print(f"{'Reports at':<10} │ {report_dir}") + print("Test session ended") + print("-" * 60) + + +# ---------------- Fixtures ---------------- + + +def pytest_runtest_logreport(report): + """ + Called after each test phase. We only care about 'call' (the actual test). + """ + if report.when != "call": + return + + status = report.outcome.upper() # 'passed', 'failed', 'skipped' → 'PASSED', etc. + test_result = { + "test_case": report.nodeid, + "status": status, + # "duration": report.duration, + "error": str(report.longrepr) if report.failed else None, + } + write_to_db("test_case_info", test_result) diff --git a/test/pytest.ini b/test/pytest.ini new file mode 100644 index 000000000..4be3cf477 --- /dev/null +++ b/test/pytest.ini @@ -0,0 +1,25 @@ +[pytest] +testpaths = suites +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +addopts = + -ra + --strict-markers + --capture=no +filterwarnings = + ignore::pytest.PytestReturnNotNoneWarning + +log_cli = 1 +log_cli_level = INFO +log_cli_format = [%(levelname)s] %(name)s: %(message)s +norecursedirs = .git venv env __pycache__ *.egg + +markers = + # -------- Levels (Required) -------- + stage(n): Unit/Smoke/Regression/Release (0=Unit 1=Smoke 2=Regression 3=Release) + # -------- Features (Recommended) -------- + feature: Feature tag + platform(name): Platform tag(gpu/npu) +# end of markers \ No newline at end of file diff --git a/test/requirements.txt b/test/requirements.txt new file mode 100644 index 000000000..07635b247 --- /dev/null +++ b/test/requirements.txt @@ -0,0 +1,6 @@ +pytest>=7.0.0 +pytest-html>=3.1.1 +PyYAML>=6.0 +# MySQL +peewee>=3.14.5 +pymysql>=1.0.2 \ No newline at end of file diff --git a/test/suites/E2E/test_uc_performance.py b/test/suites/E2E/test_uc_performance.py new file mode 100644 index 000000000..dbec0318b --- /dev/null +++ b/test/suites/E2E/test_uc_performance.py @@ -0,0 +1,158 @@ +import pytest +from common.capture_utils import export_vars +from common.llmperf.run_inference import inference_results + + +@pytest.mark.parametrize("mean_input_tokens", [[2000, 3000]]) +@pytest.mark.parametrize("mean_output_tokens", [[200, 500]]) +@pytest.mark.parametrize("max_num_completed_requests", [[8, 4]]) +@pytest.mark.parametrize("concurrent_requests", [[8, 4]]) +@pytest.mark.parametrize("additional_sampling_params", [["{}", "{}"]]) +@pytest.mark.parametrize("hit_rate", [[0, 50]]) +@pytest.mark.feature("uc_performance_test") +@export_vars +def test_performance( + mean_input_tokens, + mean_output_tokens, + max_num_completed_requests, + concurrent_requests, + additional_sampling_params, + hit_rate, +): + all_summaries = inference_results( + mean_input_tokens, + mean_output_tokens, + max_num_completed_requests, + concurrent_requests, + additional_sampling_params, + hit_rate, + ) + failed_cases = [] + + value_lists = { + "mean_input_tokens": [], + "mean_output_tokens": [], + "results_inter_token_latency_s_quantiles_p50": [], + "results_inter_token_latency_s_quantiles_p90": [], + "results_inter_token_latency_s_quantiles_p99": [], + "results_inter_token_latency_s_mean": [], + "results_ttft_s_quantiles_p50": [], + "results_ttft_s_quantiles_p90": [], + "results_ttft_s_quantiles_p99": [], + "results_ttft_s_mean": [], + "results_end_to_end_latency_s_quantiles_p50": [], + "results_end_to_end_latency_s_quantiles_p90": [], + "results_end_to_end_latency_s_quantiles_p99": [], + "results_end_to_end_latency_s_mean": [], + "num_completed_requests": [], + "elapsed_time": [], + "incremental_time_delay": [], + "total_throughput": [], + "incremental_throughput": [], + } + + for i, summary in enumerate(all_summaries): + mean_input_tokens = summary["mean_input_tokens"] + mean_output_tokens = summary["mean_output_tokens"] + + results_inter_token_latency_s_quantiles_p50 = summary["results"][ + "inter_token_latency_s" + ]["quantiles"]["p50"] + results_inter_token_latency_s_quantiles_p90 = summary["results"][ + "inter_token_latency_s" + ]["quantiles"]["p90"] + results_inter_token_latency_s_quantiles_p99 = summary["results"][ + "inter_token_latency_s" + ]["quantiles"]["p99"] + results_inter_token_latency_s_mean = summary["results"][ + "inter_token_latency_s" + ]["mean"] + + results_ttft_s_quantiles_p50 = summary["results"]["ttft_s"]["quantiles"]["p50"] + results_ttft_s_quantiles_p90 = summary["results"]["ttft_s"]["quantiles"]["p90"] + results_ttft_s_quantiles_p99 = summary["results"]["ttft_s"]["quantiles"]["p99"] + results_ttft_s_mean = summary["results"]["ttft_s"]["mean"] + + results_end_to_end_latency_s_quantiles_p50 = summary["results"][ + "end_to_end_latency_s" + ]["quantiles"]["p50"] + results_end_to_end_latency_s_quantiles_p90 = summary["results"][ + "end_to_end_latency_s" + ]["quantiles"]["p90"] + results_end_to_end_latency_s_quantiles_p99 = summary["results"][ + "end_to_end_latency_s" + ]["quantiles"]["p99"] + results_end_to_end_latency_s_mean = summary["results"]["end_to_end_latency_s"][ + "mean" + ] + + num_completed_requests = summary["num_completed_requests"] + elapsed_time = summary["elapsed_time"] + incremental_time_delay = summary["incremental_time_delay"] + total_throughput = summary["total_throughput"] + incremental_throughput = summary["incremental_throughput"] + + values = [ + mean_input_tokens, + mean_output_tokens, + results_inter_token_latency_s_quantiles_p50, + results_inter_token_latency_s_quantiles_p90, + results_inter_token_latency_s_quantiles_p99, + results_inter_token_latency_s_mean, + results_ttft_s_quantiles_p50, + results_ttft_s_quantiles_p90, + results_ttft_s_quantiles_p99, + results_ttft_s_mean, + results_end_to_end_latency_s_quantiles_p50, + results_end_to_end_latency_s_quantiles_p90, + results_end_to_end_latency_s_quantiles_p99, + results_end_to_end_latency_s_mean, + num_completed_requests, + elapsed_time, + incremental_time_delay, + total_throughput, + incremental_throughput, + ] + + for var_name, val in zip( + [ + "mean_input_tokens", + "mean_output_tokens", + "results_inter_token_latency_s_quantiles_p50", + "results_inter_token_latency_s_quantiles_p90", + "results_inter_token_latency_s_quantiles_p99", + "results_inter_token_latency_s_mean", + "results_ttft_s_quantiles_p50", + "results_ttft_s_quantiles_p90", + "results_ttft_s_quantiles_p99", + "results_ttft_s_mean", + "results_end_to_end_latency_s_quantiles_p50", + "results_end_to_end_latency_s_quantiles_p90", + "results_end_to_end_latency_s_quantiles_p99", + "results_end_to_end_latency_s_mean", + "num_completed_requests", + "elapsed_time", + "incremental_time_delay", + "total_throughput", + "incremental_throughput", + ], + values, + ): + value_lists[var_name].append(val) + if val is None: + failed_cases.append((i, var_name, "missing")) + + try: + assert val > 0, f"value <= 0" + except AssertionError as e: + failed_cases.append((i, var_name, str(e))) + + # Output final result + if failed_cases: + print(f"\n[WARNING] Assertion failed: {len(failed_cases)} abnormal cases found") + for i, key, reason in failed_cases: + print(f" Iteration={i + 1}, key='{key}' -> {reason}") + else: + print("\n[INFO] All values are greater than 0. Assertion passed!") + + return {"_name": "llmperf", "_data": value_lists} diff --git a/test/test_uc_connector.py b/test/test_uc_connector.py index d4a0caeb7..0c2261d87 100644 --- a/test/test_uc_connector.py +++ b/test/test_uc_connector.py @@ -25,6 +25,7 @@ import random import secrets import unittest +from collections import defaultdict from typing import List, Union from unittest.mock import MagicMock, Mock, patch @@ -106,12 +107,14 @@ def init_uc( ucconnector.dump_tasks: dict[str, dict[str, List[Task]]] = {} ucconnector.total_tp_size = self.total_tp_size ucconnector._connector_metadata = metadata - ucconnector.layerwise_load_tasks: dict[ - str, dict[str, tuple[Task, Task]] - ] = {} + ucconnector.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict( + dict + ) ucconnector._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {} ucconnector._load_failed_reqs: set[str] = set() ucconnector._load_req_to_blocks: dict[str, set[int]] = {} + ucconnector.num_layers = 48 + ucconnector.is_mla = False return ucconnector def test_get_num_new_matched_tokens_hit_all_on_storage(self): @@ -508,6 +511,7 @@ def test_wait_for_save_not_layerwise_invalid_para(self): ucconnector.block_size = self.block_size ucconnector.use_layerwise = False ucconnector._connector_metadata = Mock() + ucconnector.is_mla = False with self.assertRaises(AssertionError): ucconnector.wait_for_save() @@ -542,6 +546,7 @@ def mock_wait(task: Task) -> int: ) forward_context = Mock() ucconnector.start_load_kv(forward_context) + assert mock_connector.load.call_count == 1 def test_start_load_kv_invalid_para(self): with patch.object(UnifiedCacheConnectorV1, "__init__", return_value=None): @@ -559,6 +564,7 @@ def test_start_load_kv_layerwise_success(self): req_meta1.load_blocks = [ (secrets.token_hex(8), i) for i in range(self.block_number) ] + req_meta1.load_async = False metadata = UCConnectorV1Metadata() metadata.requests = [req_meta1] @@ -575,7 +581,7 @@ def mock_load( ucconnector = self.init_uc(mock_connector, metadata=metadata) forward_context = Mock() ucconnector.start_load_kv(forward_context) - assert mock_connector.load.call_count == 2 * self.num_layers + assert mock_connector.load.call_count == self.num_layers if __name__ == "__main__": diff --git a/test/test_ucm_connector_save_load.py b/test/test_ucm_connector_save_load.py new file mode 100644 index 000000000..c0def663f --- /dev/null +++ b/test/test_ucm_connector_save_load.py @@ -0,0 +1,599 @@ +# -*- coding: utf-8 -*- +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +""" +Standalone bandwidth test for `UCMConnector`. + +This script instantiates the connector exactly like runtime code (no mocks) and +benchmarks `wait_for_save` (dump) and `start_load_kv` (load). +""" + +import csv +import math +import multiprocessing +import os +import secrets +import time +import traceback +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union +from unittest.mock import patch + +import torch +from vllm.config import ( + CacheConfig, + DeviceConfig, + KVTransferConfig, + ModelConfig, + ParallelConfig, + VllmConfig, +) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole + +from ucm.integration.vllm.ucm_connector import ( + RequestDispatchMeta, + UCMConnector, + UCMConnectorMetadata, +) +from ucm.logger import init_logger + +logger = init_logger(__name__) + + +def make_aligned_tensor(shape, dtype, device, alignment=4096): + numel = math.prod(shape) + dtype_size = torch.tensor(1, dtype=dtype).element_size() + total_bytes = numel * dtype_size + + padded_bytes = total_bytes + alignment + storage = torch.ByteTensor(padded_bytes).to(device) + + ptr = storage.data_ptr() + offset = ptr % alignment + aligned_ptr = ptr + (alignment - offset) if offset != 0 else ptr + + aligned_storage = storage[(aligned_ptr - ptr) :].view(dtype) + tensor = aligned_storage[:numel].view(shape) + tensor.storage_ref = storage + return tensor + + +def make_buffers( + block_number: int, + device_id: int, + batch_size: int, + head_dim: int, + block_len: int, + block_layer: int, + num_head: int, + kv: int, + is_mla: bool, +) -> Tuple[List[str], Dict[str, torch.Tensor]]: + logger.info(f"Allocating buffers: blocks={block_number}, batch_size={batch_size}") + hashes = [secrets.token_hex(16) for _ in range(block_number)] + device = f"cuda:{device_id}" + kv_caches: Dict[str, torch.Tensor] = {} + + for layer in range(block_layer): + layer_name = f"layer.{layer}" + if is_mla: + kv_caches[layer_name] = make_aligned_tensor( + [block_number, block_len, head_dim], + dtype=torch.float16, + device=device, + ) + else: + kv_caches[layer_name] = make_aligned_tensor( + [kv, block_number, block_len, num_head, head_dim], + dtype=torch.float16, + device=device, + ) + return hashes, kv_caches + + +def build_vllm_config( + *, + model_path: str, + block_size: int, + num_layers: int, + num_head: int, + head_size: int, + is_mla: bool, + tp_size: int, + connector_name: str, + storage_backends: str, + transfer_stream_number: int, + use_direct: bool, +) -> VllmConfig: + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=4, + cache_dtype="auto", + ) + + # This ensures connector uses test parameters for head_size, num_head, num_layers + hf_overrides = { + "head_dim": head_size, # Override head_size for get_head_size() + "num_key_value_heads": num_head, # Override num_head for get_num_kv_heads() + "num_hidden_layers": num_layers, # Override num_layers for get_num_layers() + } + if is_mla: + # head_dim = kv_lora_rank + qk_rope_head_dim (typically 512 + 64 = 576) + # For testing purposes, we set kv_lora_rank = head_size - 64 + kv_lora_rank = head_size - 64 # qk_rope_head_dim = 64 + hf_overrides.update( + { + "model_type": "deepseek_v3", + "kv_lora_rank": kv_lora_rank, + "qk_rope_head_dim": 64, + } + ) + + model_config = ModelConfig( + model=model_path, + tokenizer=None, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + max_model_len=8192, + max_context_len_to_capture=8192, + max_logprobs=20, + disable_sliding_window=False, + skip_tokenizer_init=True, + limit_mm_per_prompt={}, + use_async_output_proc=True, + override_neuron_config={}, + config_format="auto", + is_deepseek_mla=is_mla, + hf_overrides=hf_overrides, + ) + + parallel_config = ParallelConfig( + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + worker_use_ray=False, + ) + + device = "cuda" if torch.cuda.is_available() else "npu" + device_config = DeviceConfig(device=device) + + kv_transfer_config = KVTransferConfig( + kv_connector="UCMConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "ucm_connectors": [ + { + "ucm_connector_name": connector_name, + "ucm_connector_config": { + "storage_backends": storage_backends, + "use_direct": use_direct, + "stream_number": transfer_stream_number, + "local_rank_size": 1, + }, + } + ] + }, + ) + + return VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + device_config=device_config, + kv_transfer_config=kv_transfer_config, + ) + + +@dataclass +class DummyLayer: + kv_cache: Union[Dict[int, torch.Tensor], List[torch.Tensor]] + + +@dataclass +class DummyForwardContext: + no_compile_layers: Dict[str, DummyLayer] + virtual_engine: int = 0 + + +def build_forward_context( + kv_caches: Dict[str, torch.Tensor], is_mla: bool +) -> DummyForwardContext: + layers = {} + for layer_name, tensor in kv_caches.items(): + layers[layer_name] = DummyLayer(kv_cache={0: tensor}) + return DummyForwardContext(no_compile_layers=layers, virtual_engine=0) + + +def compute_total_bytes( + kv_caches: Dict[str, torch.Tensor], batch_size: int, is_mla: bool +) -> int: + total = 0 + for tensor in kv_caches.values(): + if is_mla: + total += tensor[:batch_size].numel() * tensor.element_size() + else: + total += tensor[:, :batch_size].numel() * tensor.element_size() + return total + + +def run_once( + connector: UCMConnector, + kv_caches: Dict[str, torch.Tensor], + hashes: List[str], + batch_size: int, + is_mla: bool, +) -> Tuple[Tuple[float, float, float], Tuple[float, float, float]]: + dump_hashes = hashes[:batch_size] + + metadata = UCMConnectorMetadata() + dump_vllm_block_ids = list(range(batch_size)) + metadata.request_meta["uc_test_write"] = RequestDispatchMeta( + load_block_ids=([], []), + dump_block_ids=(dump_hashes, dump_vllm_block_ids), + ) + connector.connector.kv_caches = kv_caches + connector.bind_connector_metadata(metadata) + + total_bytes = compute_total_bytes(kv_caches, batch_size, is_mla) + + start = time.perf_counter() + connector.wait_for_save() + write_time = time.perf_counter() - start + + time.sleep(1) + + write_bw = (total_bytes / (1024**3)) / write_time if write_time > 0 else 0.0 + + lookup = connector.connector.store.lookup(dump_hashes) + if not all(lookup): + raise RuntimeError("Found missing cache blocks before load test.") + + load_metadata = UCMConnectorMetadata() + load_vllm_block_ids = list(range(batch_size)) + load_metadata.request_meta["uc_test_read"] = RequestDispatchMeta( + load_block_ids=(dump_hashes, load_vllm_block_ids), + dump_block_ids=([], []), + ) + connector.connector.kv_caches = kv_caches + connector.bind_connector_metadata(load_metadata) + + forward_context = build_forward_context(kv_caches, is_mla) + + start = time.perf_counter() + connector.start_load_kv(forward_context) + read_time = time.perf_counter() - start + + read_bw = (total_bytes / (1024**3)) / read_time if read_time > 0 else 0.0 + + logger.info( + f"Size: {total_bytes / (1024**3):.4f} GB, Time: {write_time:.4f}s, WRITE SPEED: {write_bw:.4f} GB/s " + ) + logger.info( + f"Size: {total_bytes / (1024**3):.4f} GB, Time: {read_time:.4f}s, READ SPEED: {read_bw:.4f} GB/s" + ) + + return ( + (total_bytes / (1024**3), write_time, write_bw), + (total_bytes / (1024**3), read_time, read_bw), + ) + + +def run_test( + storage_backends: str, + device_id: int, + repeat: int, + num_head: int, + block_len: int, + num_tokens: int, + block_layer: int, + head_size: int, + block_elem_size: int, + kv: int, + mla: bool, + ucm_connector_name: str, + total_tp_size: int, + model_path: str, + transfer_stream_number: int, + use_direct: bool, +) -> Tuple[float, float, float, float, float, float]: + block_dim = head_size * num_head + io_size = block_dim * block_len * block_elem_size + block_size = io_size * block_layer + batch_size = int(num_tokens / block_len) + real_blocks = batch_size * repeat + 10 + + vllm_config = build_vllm_config( + model_path=model_path, + block_size=block_len, + num_layers=block_layer, + num_head=num_head, + head_size=head_size, + is_mla=mla, + tp_size=total_tp_size, + connector_name=ucm_connector_name, + storage_backends=storage_backends, + transfer_stream_number=transfer_stream_number, + use_direct=use_direct, + ) + + dummy_world_group = type("DummyWorldGroup", (), {"local_rank": 0})() + + class DummyTPGroup: + def broadcast(self, tensor, src): + pass + + dummy_tp_group = DummyTPGroup() + + patches = [ + patch( + "ucm.integration.vllm.ucm_connector.get_world_group", + return_value=dummy_world_group, + ), + patch( + "ucm.integration.vllm.ucm_connector.get_tp_group", + return_value=dummy_tp_group, + ), + ] + + with patches[0], patches[1]: + connector = UCMConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector.rank = device_id if device_id >= 0 else 0 + connector.connector.kv_caches = {} + + hashes, kv_caches = make_buffers( + real_blocks, + device_id, + batch_size, + head_size, + block_len, + block_layer, + num_head, + kv, + mla, + ) + + w_sizes, w_times, w_bws = [], [], [] + r_sizes, r_times, r_bws = [], [], [] + + for round_idx in range(repeat): + logger.info(f"Round {round_idx + 1}: start write test") + start_hash_idx = round_idx * batch_size + end_hash_idx = start_hash_idx + batch_size + round_hashes = hashes[start_hash_idx:end_hash_idx] + + if len(round_hashes) < batch_size: + round_hashes = [secrets.token_hex(16) for _ in range(batch_size)] + + (w_size, w_time, w_bw), (r_size, r_time, r_bw) = run_once( + connector, kv_caches, round_hashes, batch_size, mla + ) + + if round_idx != 0: + w_sizes.append(w_size) + w_times.append(w_time) + w_bws.append(w_bw) + r_sizes.append(r_size) + r_times.append(r_time) + r_bws.append(r_bw) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif hasattr(torch, "npu") and torch.npu.is_available(): + torch.npu.empty_cache() + + def avg(values: List[float]) -> float: + return sum(values) / len(values) if values else 0.0 + + avg_w_size = avg(w_sizes) + avg_w_time = avg(w_times) + avg_w_bw = avg(w_bws) + avg_r_size = avg(r_sizes) + avg_r_time = avg(r_times) + avg_r_bw = avg(r_bws) + + logger.info( + "\n=== Summary ===\n" + f"Write : size={avg_w_size:.4f} GB | time={avg_w_time:.4f} s | bw={avg_w_bw:.4f} GB/s\n" + f"Read : size={avg_r_size:.4f} GB | time={avg_r_time:.4f} s | bw={avg_r_bw:.4f} GB/s\n" + ) + + return avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size + + +def run_wrapper(result_queue, *args): + try: + result = run_test(*args) + result_queue.put(("success", result)) + except Exception as e: + result_queue.put(("error", traceback.format_exc())) + + +def get_user_input(prompt, default=None): + if default is not None: + user_input = input(f"{prompt} (default: {default}): ").strip() + return user_input if user_input else default + else: + return input(f"{prompt}: ").strip() + + +def main(): + + try: + multiprocessing.set_start_method("spawn", force=True) + except RuntimeError: + pass + + storage_backends = "." + device_id = 0 + repeat = 3 + num_tokens_list = [2048, 4096, 8192, 16384, 32768] + ucm_connector_name = "UcmNfsStore" + model_path = "/home/models/QwQ-32B" + transfer_stream_numbers = [32, 64, 128] + os.environ["UC_LOGGER_LEVEL"] = "debug" + + print("1. Model Selection:") + print(" 1 - QwQ-32B") + print(" 2 - deepseek-v3") + model_choice = get_user_input("Please select model", "1") + mla = True if model_choice == "2" else False + print("\n2. IoDirect Transfer:") + print(" 1 - Disable IoDirect (default)") + print(" 2 - Enable IoDirect") + use_direct = get_user_input("Please select Direct IO mode", "1") + use_direct = False if use_direct == "1" else True + + if mla: + block_lens = [64] + block_layer = 61 + head_size = 576 + block_elem_size = 2 + kv = 1 + model_name = "deepseek-v3" + num_head_list = [1] + total_tp_size = 1 + else: + block_lens = [128, 256] + block_layer = 64 + head_size = 128 + block_elem_size = 2 + kv = 2 + model_name = "QwQ-32B" + num_head_list = [1, 2, 4, 8] + total_tp_size = 1 + + SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + csv_file = os.path.join(SCRIPT_DIR, "save_load_result.csv") + need_header = not os.path.exists(csv_file) + + with open(csv_file, "a", newline="", encoding="utf-8") as csv_fp: + writer = csv.writer(csv_fp) + + if need_header: + writer.writerow( + [ + "Model", + "Sequence Length", + "Batch Size", + "Layers", + "Element Size", + "KV", + "Num Head", + "Block Size", + "Stream Number", + "IO Count", + "IO Size(B)", + "Total Size(GB)", + "Write Avg Time(s)", + "Write Avg Bandwidth(GB/s)", + "Read Avg Time(s)", + "Read Avg Bandwidth(GB/s)", + ] + ) + + for num_head in num_head_list: + for block_len in block_lens: + for transfer_stream_number in transfer_stream_numbers: + block_dim = head_size * num_head + io_size = block_dim * block_len * block_elem_size + + for num_tokens in num_tokens_list: + sep = "=" * 60 + print( + f"\n{sep}\n= num_head={num_head} | num_tokens={num_tokens:>6} | Repeat {repeat} times =\n{sep}\n" + ) + + batch_size = int(num_tokens / block_len) + io_count = batch_size * block_layer + + result_queue = multiprocessing.Queue() + + process = multiprocessing.Process( + target=run_wrapper, + args=( + result_queue, + storage_backends, + device_id, + repeat, + num_head, + block_len, + num_tokens, + block_layer, + head_size, + block_elem_size, + kv, + mla, + ucm_connector_name, + total_tp_size, + model_path, + transfer_stream_number, + use_direct, + ), + ) + + process.start() + process.join() + + status, result = result_queue.get() + if status == "error": + raise Exception(f"Error in subprocess: {result}") + + ( + avg_w_size, + avg_w_time, + avg_w_bw, + avg_r_time, + avg_r_bw, + avg_r_size, + ) = result + + writer.writerow( + [ + model_name, + num_tokens, + batch_size, + block_layer, + block_elem_size, + kv, + num_head, + block_len, + transfer_stream_number, + io_count, + io_size, + f"{avg_w_size:.4f}", + f"{avg_w_time:.4f}", + f"{avg_w_bw:.4f}", + f"{avg_r_time:.4f}", + f"{avg_r_bw:.4f}", + ] + ) + + csv_fp.flush() + + print("\n" + "=" * 60 + "\n= All combinations tested =\n" + "=" * 60 + "\n") + + +if __name__ == "__main__": + main() diff --git a/test/test_ucm_dram.py b/test/test_ucm_dram.py deleted file mode 100644 index 020405d13..000000000 --- a/test/test_ucm_dram.py +++ /dev/null @@ -1,250 +0,0 @@ -# -# MIT License -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# - -import random -import unittest -import unittest.mock as mock -from contextlib import contextmanager -from typing import List -from unittest.mock import MagicMock - -import torch -from vllm.multimodal.inputs import MultiModalKwargs -from vllm.sampling_params import SamplingParams -from vllm.utils import sha256 -from vllm.v1.core.kv_cache_utils import hash_request_tokens -from vllm.v1.request import Request - - -@contextmanager -def mock_stream_context(stream=None): - yield - - -class MockStream: - def __init__(self, device=None): - self.device = device or torch.device("cpu") - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - def synchronize(self): - pass - - def record_event(self, event=None): - return event or MockEvent() - - def wait_stream(self, stream): - pass - - -class MockEvent: - def __init__(self, enable_timing=False): - self.enable_timing = enable_timing - - def record(self, stream=None): - pass - - def wait(self, stream=None): - pass - - def synchronize(self): - pass - - -def patch_cuda_for_cpu(): - mock.patch("torch.cuda.Stream", MockStream).start() - mock.patch("torch.cuda.Event", MockEvent).start() - mock.patch("torch.cuda.current_stream", return_value=MockStream()).start() - mock.patch("torch.cuda.synchronize", side_effect=lambda *a, **k: None).start() - mock.patch("torch.cuda.is_available", return_value=True).start() - mock.patch("torch.cuda.stream", mock_stream_context).start() - - -patch_cuda_for_cpu() -from ucm.store.dramstore.dramstore_connector import ( # isort: skip - DramTask, - UcmDramStore, -) - - -def make_request( - request_id, prompt_token_ids, mm_positions=None, mm_hashes=None, cache_salt=None -): - if mm_positions is None: - multi_modal_inputs = None - else: - multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) - - return Request( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_inputs=multi_modal_inputs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17), - pooling_params=None, - eos_token_id=100, - arrival_time=0, - lora_request=None, - cache_salt=cache_salt, - ) - - -class TestUcmDram(unittest.TestCase): - - @classmethod - def setUpClass(cls): - print("===> Before all tests (setUpClass)") - - @classmethod - def tearDownClass(cls): - print("===> After all tests (setUpClass)") - - def setUp(self): - self.config = {"block_size": 4} - self.scheduler_config = { - "role": "scheduler", - "max_cache_size": 1073741824, - "kv_block_size": 262144, - } - self.worker_config = { - "role": "worker", - "max_cache_size": 1073741824, - "kv_block_size": 262144, - } - - self.block_number = 4 - self.block_size = int(self.config["block_size"]) - self.scheduler_dram = UcmDramStore(self.scheduler_config) - self.worker_dram = UcmDramStore(self.worker_config) - random.seed(20250728) - self.request = make_request( - request_id=1, - prompt_token_ids=random.sample( - range(0, 10000), self.block_number * self.block_size - ), - mm_positions=None, - mm_hashes=None, - ) - block_hash_types = hash_request_tokens(sha256, self.block_size, self.request) - self.block_hashes: List[str] = [str(x.hash_value) for x in block_hash_types] - - def test_look_up_all_hit(self): - """ - Test for all blocks hitten in cache - """ - expected = [True] * len(self.block_hashes) - self.scheduler_dram.cached_blocks.update(self.block_hashes) - actual = self.scheduler_dram.lookup(self.block_hashes) - - self.assertEqual(actual, expected) - - def test_lookup_partial_hit(self): - """ - Test for part of the blocks hitten in cache - """ - partial_index = random.randint(0, 4) - partial_hashes = self.block_hashes[:partial_index] - self.scheduler_dram.cached_blocks.update(partial_hashes) - actual = self.scheduler_dram.lookup(self.block_hashes) - expected = [True] * partial_index + [False] * (self.block_size - partial_index) - self.assertEqual(actual, expected) - - def test_lookup_none_hit(self): - """ - Test for none of the blocks hitten in cache - """ - actual = self.scheduler_dram.lookup(self.block_hashes) - expected = [False] * len(self.block_hashes) - self.assertEqual(actual, expected) - - def test_load_success(self): - """ - Test for load from cache successfully - """ - src_tensors = [ - torch.randint(0, 100, (self.block_size,), dtype=torch.int8) - for _ in range(len(self.block_hashes)) - ] - offsets = [i for i in range(len(self.block_hashes))] - dump_task = self.worker_dram.dump(self.block_hashes, offsets, src_tensors) - self.worker_dram.wait(dump_task) - dst_tensors = [ - torch.zeros(self.block_size, dtype=torch.int8) - for _ in range(len(self.block_hashes)) - ] - load_task = self.worker_dram.load(self.block_hashes, offsets, dst_tensors) - - self.assertIsInstance(load_task, DramTask) - self.assertIsNotNone(load_task.event) - for i, (src_tensor, dst_tensor) in enumerate(zip(src_tensors, dst_tensors)): - self.assertEqual(dst_tensor.shape[0], self.block_size) - self.assertTrue( - torch.equal(src_tensor, dst_tensor), - f"Block {i} loaded data is different", - ) - - def test_dump_success(self): - """ - Test data dump successfully - """ - src_tensors = [ - torch.randint(0, 100, (self.block_size,), dtype=torch.int8) - for _ in range(len(self.block_hashes)) - ] - offsets = [i for i in range(len(self.block_hashes))] - original_data = [tensor.clone() for tensor in src_tensors] - dump_task = self.worker_dram.dump(self.block_hashes, offsets, src_tensors) - self.assertIsInstance(dump_task, DramTask) - self.assertIsNotNone(dump_task.event) - self.worker_dram.wait(dump_task) - for i, block_id in enumerate(self.block_hashes): - key = block_id + "_" + str(offsets[i]) - cached_data = self.worker_dram.dram_cache[key] - self.assertEqual(cached_data.shape[0], self.block_size) - self.assertTrue(torch.equal(cached_data, original_data[i])) - - def test_wait_success(self): - """ - Test wait for task successfully - """ - task = DramTask() - task.event = MagicMock() - result = self.worker_dram.wait(task) - self.assertEqual(result, 0) - task.event.synchronize.assert_called_once() - - def test_wait_failure(self): - task = DramTask() - task.event = None - result = self.worker_dram.wait(task) - self.assertEqual(result, -1) - - -if __name__ == "__main__": - unittest.main() diff --git a/ucm/CMakeLists.txt b/ucm/CMakeLists.txt index af0712f16..0d4579d5f 100644 --- a/ucm/CMakeLists.txt +++ b/ucm/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(shared) if(BUILD_UCM_STORE) add_subdirectory(store) endif() diff --git a/ucm/__init__.py b/ucm/__init__.py index e69de29bb..8052a3998 100644 --- a/ucm/__init__.py +++ b/ucm/__init__.py @@ -0,0 +1,17 @@ +from ucm.integration.vllm.uc_connector import UnifiedCacheConnectorV1 +from ucm.integration.vllm.ucm_connector import UCMConnector + +try: + from ucm.integration.vllm.patch.apply_patch import ensure_patches_applied + + ensure_patches_applied() +except Exception as e: + # Don't fail if patches can't be applied - might be running in environment without vLLM + import warnings + + warnings.warn( + f"Failed to apply vLLM patches: {e}. " + f"If you're using vLLM, ensure it's installed and patches are compatible." + ) + +__all__ = ["UCMConnector"] diff --git a/ucm/integration/vllm/patch/0.9.1/vllm-adapt.patch b/ucm/integration/vllm/patch/0.9.1/vllm-adapt.patch index f644f7276..3bffd1922 100644 --- a/ucm/integration/vllm/patch/0.9.1/vllm-adapt.patch +++ b/ucm/integration/vllm/patch/0.9.1/vllm-adapt.patch @@ -1,22 +1,22 @@ -From b837d3d46e593c946f5de70bdff178fa2bff882b Mon Sep 17 00:00:00 2001 -From: root -Date: Mon, 15 Sep 2025 22:07:21 -0700 -Subject: [PATCH] 0.9.1-patch +From 76751cae43498d693a7a6dd2c8ec4b2d40672385 Mon Sep 17 00:00:00 2001 +From: zhou-haitao <1300182097@qq.com> +Date: Tue, 21 Oct 2025 03:31:16 -0700 +Subject: [PATCH] Add commit --- .../kv_transfer/kv_connector/utils.py | 113 +++++++++++++++ .../kv_transfer/kv_connector/v1/base.py | 8 ++ .../v1/shared_storage_connector.py | 7 +- vllm/v1/core/block_pool.py | 2 +- - vllm/v1/core/sched/scheduler.py | 129 ++++++++++++++++++ + vllm/v1/core/sched/scheduler.py | 132 ++++++++++++++++++ vllm/v1/core/single_type_kv_cache_manager.py | 2 + vllm/v1/executor/multiproc_executor.py | 37 ++++- vllm/v1/outputs.py | 5 + vllm/v1/request.py | 1 + vllm/v1/worker/gpu_input_batch.py | 9 ++ vllm/v1/worker/gpu_model_runner.py | 52 ++++++- - vllm/v1/worker/gpu_worker.py | 23 +++- - 12 files changed, 366 insertions(+), 22 deletions(-) + vllm/v1/worker/gpu_worker.py | 23 ++- + 12 files changed, 369 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index b9bed06d7..de062cfb3 100644 @@ -211,7 +211,7 @@ index d21f94727..1800665c7 100644 new_full_blocks = blocks[num_cached_blocks:num_full_blocks] assert len(block_hashes) >= num_cached_blocks diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index 3d7bbe7e0..1ef81e960 100644 +index 3d7bbe7e0..b6d4a340a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -707,16 +707,28 @@ class Scheduler(SchedulerInterface): @@ -243,16 +243,19 @@ index 3d7bbe7e0..1ef81e960 100644 num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) if num_tokens_scheduled == 0: # The request was not scheduled in this step. -@@ -761,6 +773,8 @@ class Scheduler(SchedulerInterface): +@@ -761,6 +773,11 @@ class Scheduler(SchedulerInterface): new_logprobs = None new_token_ids = generated_token_ids kv_transfer_params = None + if model_runner_output.finished_dumping is not None: + request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) ++ is_prefill = request.num_output_tokens == 0 ++ if is_prefill: ++ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) # Append generated tokens and check for stop. Note that if # a request is still being prefilled, we expect the model runner -@@ -824,6 +838,8 @@ class Scheduler(SchedulerInterface): +@@ -824,6 +841,8 @@ class Scheduler(SchedulerInterface): if not stopped: new_running.append(request) @@ -261,7 +264,7 @@ index 3d7bbe7e0..1ef81e960 100644 # KV Connector: update state for finished KV Transfers. self._update_from_kv_xfer_finished(model_runner_output) -@@ -1042,3 +1058,116 @@ class Scheduler(SchedulerInterface): +@@ -1042,3 +1061,116 @@ class Scheduler(SchedulerInterface): for req_id in (model_runner_output.finished_sending or ()): logger.debug("Finished sending KV transfer for request %s", req_id) self._free_blocks(self.requests[req_id]) @@ -707,4 +710,5 @@ index b7d244f27..263a916d2 100644 def profile(self, is_start: bool = True): if self.profiler is None: -- -2.34.1 \ No newline at end of file +2.34.1 + diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-aggre.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-aggre.patch new file mode 100644 index 000000000..5f8df381c --- /dev/null +++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-aggre.patch @@ -0,0 +1,753 @@ +From 6e2c814bb3b3a74ca56149b44d6a0b2017b91136 Mon Sep 17 00:00:00 2001 +From: harrisonyhq +Date: Tue, 4 Nov 2025 23:32:10 -0800 +Subject: [PATCH 2/3] [Patch1] Patch for load failure and aggregate + +--- + .../kv_transfer/kv_connector/utils.py | 113 +++++++++++ + .../kv_transfer/kv_connector/v1/base.py | 9 + + .../kv_connector/v1/multi_connector.py | 6 + + vllm/v1/core/block_pool.py | 2 +- + vllm/v1/core/sched/output.py | 2 + + vllm/v1/core/sched/scheduler.py | 184 ++++++++++++++++-- + vllm/v1/core/single_type_kv_cache_manager.py | 3 + + vllm/v1/executor/multiproc_executor.py | 30 ++- + vllm/v1/outputs.py | 6 +- + vllm/v1/worker/gpu_input_batch.py | 14 ++ + vllm/v1/worker/gpu_model_runner.py | 28 ++- + vllm/v1/worker/gpu_worker.py | 23 ++- + 12 files changed, 397 insertions(+), 23 deletions(-) + +diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py +index 5cbc8ca31..b63bf5965 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/utils.py ++++ b/vllm/distributed/kv_transfer/kv_connector/utils.py +@@ -5,10 +5,16 @@ KV cache helper for store. + """ + import torch + ++from collections import defaultdict ++from collections.abc import Sequence ++from concurrent.futures import CancelledError, Future ++from typing import Optional, cast ++ + import vllm.envs as envs + from vllm import _custom_ops as ops + from vllm.config import VllmConfig, get_current_vllm_config + from vllm.logger import init_logger ++from vllm.v1.outputs import ModelRunnerOutput + + logger = init_logger(__name__) + +@@ -107,3 +113,110 @@ def get_kv_connector_cache_layout(): + "layout to HND for better xfer performance.") + return "HND" + return "NHD" ++ ++ ++class KVOutputAggregator: ++ """Utility class to aggregate the output of all workers into a single ++ output corresponding to Rank 0 for scheduler.""" ++ ++ def __init__(self, world_size: int): ++ # Complete transfer tracker. Used by to track finished requests ++ # [req_id -> n_finished_workers] ++ self._recv_remaining_count = defaultdict[str, int](lambda: world_size) ++ self._send_remaining_count = defaultdict[str, int](lambda: world_size) ++ self._dump_remaining_count = defaultdict[str, int](lambda: world_size) ++ ++ def aggregate(self, ++ outputs: list[ModelRunnerOutput], ++ output_rank: int = 0) -> ModelRunnerOutput: ++ # aggregate finished_sending, finished_recving from all workers ++ ++ def update_finished_set(req_ids: Optional[set[str]], ++ remaining_count_dict: dict[str, int], ++ finished_set: set[str]) -> None: ++ for req_id in req_ids or (): ++ new_count = remaining_count_dict[req_id] - 1 ++ if new_count == 0: ++ finished_set.add(req_id) ++ del remaining_count_dict[req_id] ++ else: ++ remaining_count_dict[req_id] = new_count ++ ++ def update_finished_list(req_ids: Optional[dict[str, list[str]]], ++ remaining_count_dict: dict[str, int], ++ finished_list: dict[str, list[str]]) -> None: ++ for req_id, succeed_dump_blocks in (req_ids or {}).items(): ++ if req_id not in finished_list: ++ finished_list[req_id] = [] ++ for blk_id in succeed_dump_blocks: ++ new_count = remaining_count_dict[blk_id] - 1 ++ if new_count == 0: ++ finished_list[req_id].append(blk_id) ++ del remaining_count_dict[blk_id] ++ else: ++ remaining_count_dict[blk_id] = new_count ++ ++ finished_sending = set[str]() ++ finished_recving = set[str]() ++ invalid_block_ids = set[int]() ++ finished_dumping: dict[str, list[str]] = {} ++ for output in outputs: ++ update_finished_set(output.finished_sending, ++ self._send_remaining_count, finished_sending) ++ update_finished_set(output.finished_recving, ++ self._recv_remaining_count, finished_recving) ++ update_finished_list(output.finished_dumping, ++ self._dump_remaining_count, finished_dumping) ++ if output.invalid_block_ids: ++ invalid_block_ids |= output.invalid_block_ids ++ ++ # select output of the worker specified by output_rank ++ output = outputs[output_rank] ++ ++ # set the aggregated finished_sending / finished_recving ++ # if output.finished_sending/recving is not empty, but the other ranks ++ # still have unfinished send/recv, we want to set the aggregated ++ # finished_sending/recving to None until all ranks have finished ++ # send/recv ++ output.finished_sending = finished_sending if finished_sending else None ++ output.finished_recving = finished_recving if finished_recving else None ++ output.finished_dumping = finished_dumping if finished_dumping else None ++ output.invalid_block_ids = invalid_block_ids or None ++ ++ return output ++ ++ def async_aggregate(self, ++ output_futures: Sequence[Future[ModelRunnerOutput]], ++ output_rank: int = 0) -> Future[ModelRunnerOutput]: ++ """Takes a list of futures and returns a single future which resolves ++ to the respective list of outputs.""" ++ result_future: Future[ModelRunnerOutput] = Future() ++ ++ outputs: list[Optional[ModelRunnerOutput]] = [None ++ ] * len(output_futures) ++ ++ def make_callback(idx): ++ ++ def callback(fut): ++ if result_future.done(): ++ return ++ ++ try: ++ outputs[idx] = fut.result() ++ except CancelledError: ++ result_future.cancel() ++ except Exception as e: ++ result_future.set_exception(e) ++ ++ # this check assumes io_thread_pool uses a single thread ++ if all(outputs): ++ result_future.set_result( ++ self.aggregate(cast(list[ModelRunnerOutput], outputs), ++ output_rank)) ++ ++ return callback ++ ++ for i, output_future in enumerate(output_futures): ++ output_future.add_done_callback(make_callback(i)) ++ ++ return result_future +diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py +index f80b5eba2..39d8fa389 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py ++++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py +@@ -201,6 +201,15 @@ class KVConnectorBase_V1(ABC): + """ + return None, None + ++ def get_block_ids_with_load_errors(self) -> set[int]: ++ """ ++ Get the set of block IDs that failed to load. ++ Returns: ++ Optional[set[int]]: A set of block IDs that encountered load errors. ++ Returns None if no errors occurred during load. ++ """ ++ return set() ++ + # ============================== + # Scheduler-side methods + # ============================== +diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +index 5f92d69bd..4e1f45e7a 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py ++++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +@@ -134,6 +134,12 @@ class MultiConnector(KVConnectorBase_V1): + + return finished_sending or None, finished_recving or None + ++ def get_block_ids_with_load_errors(self) -> set[int]: ++ agg_block_ids: set[int] = set() ++ for c in self._connectors: ++ agg_block_ids |= c.get_block_ids_with_load_errors() ++ return agg_block_ids ++ + # ============================== + # Scheduler-side methods + # ============================== +diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py +index d21f94727..1800665c7 100644 +--- a/vllm/v1/core/block_pool.py ++++ b/vllm/v1/core/block_pool.py +@@ -124,7 +124,7 @@ class BlockPool: + kv_cache_group_id: The id of the KV cache group. + hash_fn: The hash function to use for block hashes. + """ +- if num_cached_blocks == num_full_blocks: ++ if num_cached_blocks >= num_full_blocks: + return + new_full_blocks = blocks[num_cached_blocks:num_full_blocks] + assert len(block_hashes) >= num_cached_blocks +diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py +index d34f39327..c94e421c0 100644 +--- a/vllm/v1/core/sched/output.py ++++ b/vllm/v1/core/sched/output.py +@@ -93,6 +93,7 @@ class CachedRequestData: + new_token_ids: list[list[int]] + new_block_ids: list[tuple[list[int], ...]] + num_computed_tokens: list[int] ++ num_output_tokens: list[int] + + @property + def num_reqs(self) -> int: +@@ -106,6 +107,7 @@ class CachedRequestData: + new_token_ids=[], + new_block_ids=[], + num_computed_tokens=[], ++ num_output_tokens=[], + ) + + +diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py +index cd80f92a1..2d4fd4d59 100644 +--- a/vllm/v1/core/sched/scheduler.py ++++ b/vllm/v1/core/sched/scheduler.py +@@ -119,6 +119,7 @@ class Scheduler(SchedulerInterface): + + # KV Connector: requests in process of async KV loading or recving + self.finished_recving_kv_req_ids: set[str] = set() ++ self.failed_recving_kv_req_ids: set[str] = set() + + # Encoder-related. + # Calculate encoder cache size if applicable +@@ -621,6 +622,7 @@ class Scheduler(SchedulerInterface): + new_token_ids: list[list[int]] = [] + new_block_ids: list[tuple[list[int], ...]] = [] + num_computed_tokens: list[int] = [] ++ num_output_tokens: list[int] = [] + + for req in itertools.chain(running_reqs, resumed_reqs): + req_id = req.request_id +@@ -638,6 +640,7 @@ class Scheduler(SchedulerInterface): + new_token_ids.append(token_ids) + new_block_ids.append(req_to_new_block_ids[req_id]) + num_computed_tokens.append(req.num_computed_tokens) ++ num_output_tokens.append(len(req.output_token_ids)) + # Because resumed_reqs is usually empty, it is more efficient to do + # in-place appending so that we don't need to allocate a new list. + resumed_from_preemption = [False] * len(running_reqs) +@@ -649,6 +652,7 @@ class Scheduler(SchedulerInterface): + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, ++ num_output_tokens=num_output_tokens, + ) + + def _try_schedule_encoder_inputs( +@@ -746,16 +750,29 @@ class Scheduler(SchedulerInterface): + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + pooler_outputs = model_runner_output.pooler_output + num_nans_in_logits = model_runner_output.num_nans_in_logits ++ invalid_block_ids = model_runner_output.invalid_block_ids + + new_running: list[Request] = [] + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) + spec_decoding_stats: Optional[SpecDecodingStats] = None + ++ failed_kv_load_req_ids = None ++ if invalid_block_ids: ++ # These blocks contain externally computed tokens that failed to ++ # load. Identify affected requests and adjust their computed token ++ # count to trigger recomputation of the invalid blocks. ++ failed_kv_load_req_ids = self._handle_invalid_blocks(invalid_block_ids) ++ + # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below + # loop can be a performance bottleneck. We should do our best to avoid + # expensive operations inside the loop. + for request in self.running: + req_id = request.request_id ++ # self.req_meta.stage == SequenceStage.PREFILL and self.req_meta.is_last_chunk ++ if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids: ++ # Skip requests that were recovered from KV load failure ++ new_running.append(request) ++ continue + num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) + if num_tokens_scheduled == 0: + # The request was not scheduled in this step. +@@ -1089,18 +1106,31 @@ class Scheduler(SchedulerInterface): + if request.request_id not in self.finished_recving_kv_req_ids: + return False + +- # Now that the blocks are ready, actually cache them. +- (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) +- num_computed_tokens = len(block_ids) * self.block_size +- # Handle the case where num request tokens less then one block. +- num_computed_tokens = min(num_computed_tokens, request.num_tokens) +- if num_computed_tokens == request.num_tokens: +- num_computed_tokens -= 1 +- # This will cache the blocks iff caching is enabled. +- self.kv_cache_manager.cache_blocks(request, num_computed_tokens) +- +- # Update the request state for scheduling. +- request.num_computed_tokens = num_computed_tokens ++ if request.request_id in self.failed_recving_kv_req_ids: ++ # Request had KV load failures; num_computed_tokens was already ++ # updated in _update_requests_with_invalid_blocks ++ if request.num_computed_tokens: ++ # Cache any valid computed tokens. ++ self.kv_cache_manager.cache_blocks(request, ++ request.num_computed_tokens) ++ else: ++ # No valid computed tokens, release allocated blocks. ++ # There may be a local cache hit on retry. ++ self.kv_cache_manager.free(request) ++ self.failed_recving_kv_req_ids.remove(request.request_id) ++ else: ++ # Now that the blocks are ready, actually cache them. ++ (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) ++ num_computed_tokens = len(block_ids) * self.block_size ++ # Handle the case where num request tokens less then one block. ++ num_computed_tokens = min(num_computed_tokens, request.num_tokens) ++ if num_computed_tokens == request.num_tokens: ++ num_computed_tokens -= 1 ++ # This will cache the blocks iff caching is enabled. ++ self.kv_cache_manager.cache_blocks(request, num_computed_tokens) ++ ++ # Update the request state for scheduling. ++ request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) +@@ -1124,3 +1154,133 @@ class Scheduler(SchedulerInterface): + for req_id in (model_runner_output.finished_sending or ()): + logger.debug("Finished sending KV transfer for request %s", req_id) + self._free_blocks(self.requests[req_id]) ++ ++ ++ def _update_requests_with_invalid_blocks( ++ self, requests: Iterable[Request], ++ invalid_block_ids: set[int]) -> tuple[set[str], int]: ++ """ ++ Identify and update requests affected by invalid KV cache blocks. ++ This method scans the given requests, detects those with invalid blocks ++ and adjusts their `num_computed_tokens` to the longest valid prefix. ++ For observability, it also accumulates the total number of tokens that ++ will need to be recomputed across all affected requests. ++ Args: ++ requests: The set of requests to scan for invalid blocks. ++ invalid_block_ids: IDs of invalid blocks. ++ Returns: ++ tuple: ++ - affected_req_ids (set[str]): IDs of requests impacted by ++ invalid blocks. ++ - total_affected_tokens (int): Total number of tokens that must ++ be recomputed across all affected requests (for observability). ++ """ ++ affected_req_ids: set[str] = set() ++ total_affected_tokens = 0 ++ # If a block is invalid and shared by multiple requests in the batch, ++ # these requests must be rescheduled, but only the first will recompute ++ # it. This set tracks blocks already marked for recomputation. ++ marked_invalid_block_ids: set[int] = set() ++ for request in requests: ++ is_affected = False ++ marked_invalid_block = False ++ req_id = request.request_id ++ # TODO (davidb): add support for hybrid memory allocator ++ (req_block_ids, ) = self.kv_cache_manager.get_block_ids(req_id) ++ # We iterate only over blocks that may contain externally computed ++ # tokens ++ if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: ++ # Async loading. If num_computed_tokens is set it implies we ++ # already processed some block failures for it in a prior step ++ req_num_computed_tokens = ( ++ request.num_computed_tokens if req_id ++ in self.failed_recving_kv_req_ids else len(req_block_ids) * ++ self.block_size) ++ else: ++ # Sync loading. num_computed_tokens includes new tokens ++ req_num_computed_tokens = request.num_cached_tokens ++ ++ req_num_computed_blocks = (req_num_computed_tokens + ++ self.block_size - 1) // self.block_size ++ for idx, block_id in zip(range(req_num_computed_blocks), ++ req_block_ids): ++ ++ if block_id not in invalid_block_ids: ++ continue ++ ++ is_affected = True ++ ++ if block_id in marked_invalid_block_ids: ++ # This invalid block is shared with a previous request ++ # and was already marked for recomputation. ++ # This means this request can still consider this block ++ # as computed when rescheduled. ++ # Currently this only applies to sync loading; Async ++ # loading does not yet support block sharing ++ continue ++ ++ marked_invalid_block_ids.add(block_id) ++ ++ if marked_invalid_block: ++ # This request has already marked an invalid block for ++ # recomputation and updated its num_computed_tokens. ++ continue ++ ++ marked_invalid_block = True ++ # Truncate the computed tokens at the first failed block ++ request.num_computed_tokens = idx * self.block_size ++ total_affected_tokens += (req_num_computed_tokens - ++ request.num_computed_tokens) ++ ++ if is_affected: ++ if not marked_invalid_block: ++ # All invalid blocks of this request are shared with ++ # previous requests and will be recomputed by them. ++ # Revert to considering only cached tokens as computed. ++ # Currently this only applies to sync loading; Async ++ # loading does not yet support block sharing ++ total_affected_tokens += (request.num_computed_tokens - ++ request.num_cached_tokens) ++ request.num_computed_tokens = request.num_cached_tokens ++ ++ affected_req_ids.add(request.request_id) ++ ++ return (affected_req_ids, total_affected_tokens) ++ ++ ++ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: ++ total_requests_to_reschedule = 0 ++ total_tokens_to_reschedule = 0 ++ ++ # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) --- ++ async_load_reqs = ( ++ req for req in self.waiting ++ if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS) ++ async_affected_req_ids, num_tokens_to_reschedule = ( ++ self._update_requests_with_invalid_blocks(async_load_reqs, ++ invalid_block_ids)) ++ ++ total_requests_to_reschedule += len(async_affected_req_ids) ++ total_tokens_to_reschedule += num_tokens_to_reschedule ++ ++ # Mark requests with async KV load failures; they will be rescheduled ++ # once loading completes ++ self.failed_recving_kv_req_ids |= async_affected_req_ids ++ ++ # --- Handle sync KV loads (running requests) --- ++ sync_affected_req_ids, num_tokens_to_reschedule = ( ++ self._update_requests_with_invalid_blocks(self.running, ++ invalid_block_ids)) ++ ++ total_requests_to_reschedule += len(sync_affected_req_ids) ++ total_tokens_to_reschedule += num_tokens_to_reschedule ++ ++ if total_requests_to_reschedule: ++ logger.warning( ++ "Recovered from KV load failure: " ++ "%d request(s) rescheduled (%d tokens affected).", ++ total_requests_to_reschedule, total_tokens_to_reschedule) ++ ++ # Return the IDs of affected running requests to skip in ++ # update_from_output. ++ return sync_affected_req_ids +diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py +index 5b4718038..28bd4618a 100644 +--- a/vllm/v1/core/single_type_kv_cache_manager.py ++++ b/vllm/v1/core/single_type_kv_cache_manager.py +@@ -142,6 +142,9 @@ class SingleTypeKVCacheManager(ABC): + num_cached_blocks = self.num_cached_block[request.request_id] + num_full_blocks = num_tokens // self.block_size + ++ if num_cached_blocks >= num_full_blocks: ++ return ++ + self.block_pool.cache_full_blocks( + request=request, + blocks=self.req_to_blocks[request.request_id], +diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py +index b06b7cc80..61cd7110f 100644 +--- a/vllm/v1/executor/multiproc_executor.py ++++ b/vllm/v1/executor/multiproc_executor.py +@@ -26,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment, + destroy_model_parallel) + from vllm.distributed.device_communicators.shm_broadcast import (Handle, + MessageQueue) ++from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator + from vllm.executor.multiproc_worker_utils import ( + _add_prefix, set_multiprocessing_worker_envs) + from vllm.logger import init_logger +@@ -111,10 +112,14 @@ class MultiprocExecutor(Executor): + if self.max_concurrent_batches > 1: + # Note: must use only 1 IO thread to keep dequeue sequence + # from the response queue ++ # _async_aggregate_workers_output also assumes a single IO thread + self.io_thread_pool = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="mp_exec_io") + + self.output_rank = self._get_output_rank() ++ self.has_connector = self.vllm_config.kv_transfer_config is not None ++ self.kv_output_aggregator = KVOutputAggregator( ++ self.parallel_config.world_size) + + def start_worker_monitor(self): + workers = self.workers +@@ -155,13 +160,30 @@ class MultiprocExecutor(Executor): + self, + scheduler_output, + ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: +- (output, ) = self.collective_rpc( ++ non_block = self.max_concurrent_batches > 1 ++ ++ if not self.has_connector or self.vllm_config.model_config.use_mla: ++ # get output only from a single worker (output_rank) ++ (output, ) = self.collective_rpc( ++ "execute_model", ++ args=(scheduler_output, ), ++ unique_reply_rank=self.output_rank, ++ non_block=non_block, ++ timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) ++ return output ++ ++ # get output from all workers ++ outputs = self.collective_rpc( + "execute_model", + args=(scheduler_output, ), +- unique_reply_rank=self.output_rank, +- non_block=self.max_concurrent_batches > 1, ++ non_block=non_block, + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) +- return output ++ ++ # aggregate all workers output to a single output ++ if non_block: ++ return self.kv_output_aggregator.async_aggregate( ++ outputs, self.output_rank) ++ return self.kv_output_aggregator.aggregate(outputs, self.output_rank) + + def collective_rpc(self, + method: Union[str, Callable], +diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py +index c8388baed..16af8dbce 100644 +--- a/vllm/v1/outputs.py ++++ b/vllm/v1/outputs.py +@@ -1,7 +1,7 @@ + # SPDX-License-Identifier: Apache-2.0 + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +-from dataclasses import dataclass ++from dataclasses import dataclass, field + from typing import NamedTuple, Optional + + import torch +@@ -109,6 +109,10 @@ class ModelRunnerOutput: + finished_recving: Optional[set[str]] = None + finished_dumping: Optional[dict[str, list[str]]] = None + ++ # IDs of externally computed KV blocks that failed to load. ++ # Requests referencing these blocks should be rescheduled to recompute them. ++ invalid_block_ids: set[int] = field(default_factory=set) ++ + # req_id -> num_nans_in_logits + num_nans_in_logits: Optional[dict[str, int]] = None + +diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py +index 1a79d72be..8819d7629 100644 +--- a/vllm/v1/worker/gpu_input_batch.py ++++ b/vllm/v1/worker/gpu_input_batch.py +@@ -96,6 +96,9 @@ class InputBatch: + pin_memory=False, + ) + self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() ++ self.is_token_ids = torch.zeros( ++ (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False ++ ) + self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) + self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) +@@ -286,8 +289,14 @@ class InputBatch: + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) ++ if request.prompt_token_ids is not None: ++ self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids ++ self.is_token_ids[req_index, :num_prompt_tokens] = True ++ else: ++ self.is_token_ids[req_index, :num_prompt_tokens] = False + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids ++ self.is_token_ids[req_index, start_idx:end_idx] = True + # Number of token ids in token_ids_cpu. + # NOTE(woosuk): This may include spec decode tokens. + self.num_tokens[req_index] = request.num_tokens +@@ -473,6 +482,8 @@ class InputBatch: + self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] + self.token_ids_cpu[i2, ...] = tmp + ++ self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...] ++ + swap_dict_values(self.generators, i1, i2) + swap_dict_values(self.bad_words_token_ids, i1, i2) + +@@ -542,6 +553,9 @@ class InputBatch: + num_tokens = self.num_tokens[last_req_index] + self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ + last_req_index, :num_tokens] ++ self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[ ++ last_req_index, :num_tokens ++ ] + self.num_tokens[empty_index] = num_tokens + self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ + last_req_index] +diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py +index 53ee8cfcd..c3df1d5d2 100644 +--- a/vllm/v1/worker/gpu_model_runner.py ++++ b/vllm/v1/worker/gpu_model_runner.py +@@ -473,6 +473,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] ++ num_output_tokens = req_data.num_output_tokens[i] + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens +@@ -492,6 +493,21 @@ class GPUModelRunner(LoRAModelRunnerMixin): + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:]) ++ elif num_output_tokens < len(req_state.output_token_ids): ++ # Some output tokens were discarded due to a sync-KV-load ++ # failure. Align the cached state. ++ del req_state.output_token_ids[num_output_tokens:] ++ ++ req_index = self.input_batch.req_id_to_index.get(req_id) ++ if req_index is not None: ++ old_end_idx = self.input_batch.num_tokens_no_spec[ ++ req_index] ++ end_idx = self.input_batch.num_prompt_tokens[ ++ req_index] + num_output_tokens ++ self.input_batch.num_tokens[req_index] = end_idx ++ self.input_batch.num_tokens_no_spec[req_index] = end_idx ++ self.input_batch.is_token_ids[req_index, ++ end_idx:old_end_idx] = False + + # Update the block IDs. + if not resumed_from_preemption: +@@ -1381,6 +1397,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + finished_dumping = self.maybe_wait_for_kv_save() + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) ++ invalid_block_ids = self.get_block_ids_with_load_errors() + + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output +@@ -1564,6 +1581,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + finished_recving=finished_recving, + finished_dumping=finished_dumping, + num_nans_in_logits=num_nans_in_logits, ++ invalid_block_ids = invalid_block_ids + ) + + def propose_draft_token_ids( +@@ -1694,13 +1712,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): + self.maybe_setup_kv_connector(scheduler_output) + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) ++ invalid_block_ids = self.get_block_ids_with_load_errors() ++ get_kv_transfer_group().clear_connector_metadata() + +- if not finished_sending and not finished_recving: ++ if not finished_sending and not finished_recving and not invalid_block_ids: + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving ++ output.invalid_block_ids = invalid_block_ids + return output + + @staticmethod +@@ -1733,6 +1754,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): + scheduler_output.finished_req_ids) + return None, None + ++ def get_block_ids_with_load_errors(self) -> Optional[set[int]]: ++ if has_kv_transfer_group(): ++ return get_kv_transfer_group().get_block_ids_with_load_errors() ++ return None ++ + def propose_ngram_draft_token_ids( + self, + sampled_token_ids: list[list[int]], +diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py +index 9e7e44d06..1b816b25b 100644 +--- a/vllm/v1/worker/gpu_worker.py ++++ b/vllm/v1/worker/gpu_worker.py +@@ -1,6 +1,7 @@ + # SPDX-License-Identifier: Apache-2.0 + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + """A GPU worker class.""" ++import copy + import gc + import os + from typing import TYPE_CHECKING, Optional +@@ -15,7 +16,8 @@ from vllm.device_allocator.cumem import CuMemAllocator + from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +-from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized ++from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, ++ has_kv_transfer_group) + from vllm.distributed.parallel_state import get_pp_group, get_tp_group + from vllm.logger import init_logger + from vllm.lora.request import LoRARequest +@@ -24,7 +26,7 @@ from vllm.platforms import current_platform + from vllm.sequence import IntermediateTensors + from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +-from vllm.v1.outputs import ModelRunnerOutput ++from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput + from vllm.v1.utils import report_usage_stats + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + from vllm.v1.worker.worker_base import WorkerBase +@@ -313,9 +315,22 @@ class Worker(WorkerBase): + assert isinstance(output, IntermediateTensors) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) +- return None ++ if not has_kv_transfer_group(): ++ return None ++ ++ # In case of PP with kv transfer, we need to pass through the ++ # finished_sending and finished_recving buffers. ++ new_output = EMPTY_MODEL_RUNNER_OUTPUT ++ if output.finished_sending or output.finished_recving or output.finished_dumping or output.invalid_block_ids: ++ new_output = copy.copy(new_output) ++ new_output.finished_sending = output.finished_sending ++ new_output.finished_recving = output.finished_recving ++ new_output.finished_dumping = output.finished_dumping ++ new_output.invalid_block_ids = output.invalid_block_ids ++ output = new_output ++ + assert isinstance(output, ModelRunnerOutput) +- return output if self.is_driver_worker else None ++ return output + + def profile(self, is_start: bool = True): + if self.profiler is None: +-- +2.34.1 + diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-pc.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-pc.patch new file mode 100644 index 000000000..bf0b7e19a --- /dev/null +++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-pc.patch @@ -0,0 +1,122 @@ +From 26fdd2026cc3d1ed7da894907ae244a155a16566 Mon Sep 17 00:00:00 2001 +From: harrisonyhq +Date: Tue, 4 Nov 2025 19:36:36 -0800 +Subject: [PATCH 1/3] [Patch0] UCM PC adapt patch + +--- + .../kv_transfer/kv_connector/v1/multi_connector.py | 7 ++++++- + vllm/v1/core/sched/scheduler.py | 11 +++++++++++ + vllm/v1/outputs.py | 1 + + vllm/v1/request.py | 2 ++ + vllm/v1/worker/gpu_model_runner.py | 7 ++++--- + 5 files changed, 24 insertions(+), 4 deletions(-) + +diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +index be3c23399..5f92d69bd 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py ++++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +@@ -99,8 +99,13 @@ class MultiConnector(KVConnectorBase_V1): + c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) + + def wait_for_save(self): ++ success_dumped_blocks = None + for c in self._connectors: +- c.wait_for_save() ++ uc_dump_blocks = c.wait_for_save() ++ if uc_dump_blocks: ++ success_dumped_blocks = uc_dump_blocks ++ ++ return success_dumped_blocks if success_dumped_blocks else None + + def get_finished( + self, finished_req_ids: set[str] +diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py +index fe552db74..cd80f92a1 100644 +--- a/vllm/v1/core/sched/scheduler.py ++++ b/vllm/v1/core/sched/scheduler.py +@@ -34,6 +34,7 @@ from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.request import Request, RequestStatus + from vllm.v1.spec_decode.metrics import SpecDecodingStats + from vllm.v1.structured_output import StructuredOutputManager ++from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import MultiConnector + + logger = init_logger(__name__) + +@@ -791,6 +792,16 @@ class Scheduler(SchedulerInterface): + new_logprobs = None + new_token_ids = generated_token_ids + kv_transfer_params = None ++ if model_runner_output.finished_dumping is not None: ++ request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) ++ is_prefill = request.num_output_tokens == 0 ++ if is_prefill: ++ if isinstance(self.connector, MultiConnector): ++ for c in self.connector._connectors: ++ if hasattr(c, 'connector') and hasattr(c.connector, 'commit'): ++ c.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) ++ else: ++ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) + + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner +diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py +index f78623f57..c8388baed 100644 +--- a/vllm/v1/outputs.py ++++ b/vllm/v1/outputs.py +@@ -107,6 +107,7 @@ class ModelRunnerOutput: + # [req_ids] + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None ++ finished_dumping: Optional[dict[str, list[str]]] = None + + # req_id -> num_nans_in_logits + num_nans_in_logits: Optional[dict[str, int]] = None +diff --git a/vllm/v1/request.py b/vllm/v1/request.py +index 9b96f4599..e70d1695b 100644 +--- a/vllm/v1/request.py ++++ b/vllm/v1/request.py +@@ -103,6 +103,8 @@ class Request: + # The number of tokens with prefix cache hits. + self.num_cached_tokens = -1 + ++ self.succeed_dumped_blocks: list[str] = [] ++ + # The number of NaNs in logits. A value greater than 0 + # indicates that the output is corrupted + self.num_nans_in_logits = 0 +diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py +index 5a26e88db..53ee8cfcd 100644 +--- a/vllm/v1/worker/gpu_model_runner.py ++++ b/vllm/v1/worker/gpu_model_runner.py +@@ -1378,7 +1378,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + inputs_embeds=inputs_embeds, + ) + +- self.maybe_wait_for_kv_save() ++ finished_dumping = self.maybe_wait_for_kv_save() + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + +@@ -1562,6 +1562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + pooler_output=[], + finished_sending=finished_sending, + finished_recving=finished_recving, ++ finished_dumping=finished_dumping, + num_nans_in_logits=num_nans_in_logits, + ) + +@@ -1719,9 +1720,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod +- def maybe_wait_for_kv_save() -> None: ++ def maybe_wait_for_kv_save(): + if has_kv_transfer_group(): +- get_kv_transfer_group().wait_for_save() ++ return get_kv_transfer_group().wait_for_save() + + @staticmethod + def get_finished_kv_transfers( +-- +2.34.1 + diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch new file mode 100644 index 000000000..eb9848756 --- /dev/null +++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch @@ -0,0 +1,628 @@ +From 0431022b90649f7115b89b61aaf2a0f83e896d5a Mon Sep 17 00:00:00 2001 +From: wenxinwang +Date: Mon, 10 Nov 2025 20:35:47 +0800 +Subject: [PATCH] adapt to deepseek patch + +--- + vllm/attention/layer.py | 49 ++++++++++++- + .../kv_transfer/kv_connector/utils.py | 5 ++ + .../v1/shared_storage_connector.py | 7 +- + vllm/v1/attention/backends/mla/common.py | 10 ++- + vllm/v1/core/kv_cache_manager.py | 7 +- + vllm/v1/core/sched/output.py | 3 + + vllm/v1/core/sched/scheduler.py | 37 +++++++--- + vllm/v1/worker/block_table.py | 13 ++++ + vllm/v1/worker/gpu_model_runner.py | 71 +++++++++++++++---- + vllm/v1/worker/gpu_worker.py | 2 + + 10 files changed, 171 insertions(+), 33 deletions(-) + +diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py +index f0ad68b16..728ab99fd 100644 +--- a/vllm/attention/layer.py ++++ b/vllm/attention/layer.py +@@ -2,7 +2,6 @@ + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + """Attention layer.""" + from typing import Any, Dict, List, Optional +- + import torch + import torch.nn as nn + import torch.nn.functional as F +@@ -22,6 +21,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod + from vllm.platforms import _Backend, current_platform + from vllm.utils import direct_register_custom_op + from vllm.v1.attention.backends.utils import validate_kv_sharing_target ++from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse + + + class Attention(nn.Module): +@@ -409,9 +409,10 @@ def unified_attention( + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] ++ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) + output = self.impl.forward(self, query, key, value, kv_cache, + attn_metadata) +- ++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output + +@@ -449,6 +450,8 @@ def unified_attention_with_output( + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] ++ if not self.use_mla: ++ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) + self.impl.forward(self, + query, + key, +@@ -457,7 +460,8 @@ def unified_attention_with_output( + attn_metadata, + output=output, + output_scale=output_scale) +- ++ if not self.use_mla: ++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + + +@@ -479,3 +483,42 @@ direct_register_custom_op( + fake_impl=unified_attention_with_output_fake, + dispatch_key=current_platform.dispatch_key, + ) ++ ++def maybe_execute_sparse_attention_begin( ++ query: torch.Tensor, ++ key: torch.Tensor, ++ value: torch.Tensor, ++ layer_name: str, ++ forward_context: ForwardContext, ++ phase: Optional[str] = None, ++): ++ if not has_ucm_sparse(): ++ return ++ ++ ucm_sparse = get_ucm_sparse() ++ ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ return ++ ++ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context, phase) ++ ++def maybe_execute_sparse_attention_finished( ++ query: torch.Tensor, ++ key: torch.Tensor, ++ value: torch.Tensor, ++ attn_output: torch.Tensor, ++ layer_name: str, ++ forward_context: ForwardContext, ++ phase: Optional[str] = None, ++): ++ if not has_ucm_sparse(): ++ return ++ ++ ucm_sparse = get_ucm_sparse() ++ ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ return ++ ++ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context, phase) +diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py +index b63bf5965..155597c51 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/utils.py ++++ b/vllm/distributed/kv_transfer/kv_connector/utils.py +@@ -3,6 +3,11 @@ + """ + KV cache helper for store. + """ ++from collections import defaultdict ++from collections.abc import Sequence ++from concurrent.futures import CancelledError, Future ++from typing import Optional, cast ++ + import torch + + from collections import defaultdict +diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +index 3c574d065..223106def 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py ++++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +@@ -2,7 +2,7 @@ + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import hashlib + import os +-from dataclasses import dataclass ++from dataclasses import dataclass, field + from typing import TYPE_CHECKING + + import safetensors +@@ -53,10 +53,7 @@ class ReqMeta: + + @dataclass + class SharedStorageConnectorMetadata(KVConnectorMetadata): +- requests: list[ReqMeta] +- +- def __init__(self): +- self.requests = [] ++ requests: list[ReqMeta] = field(default_factory=list) + + def add_request( + self, +diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py +index f2aaf59a4..b56f62b39 100644 +--- a/vllm/v1/attention/backends/mla/common.py ++++ b/vllm/v1/attention/backends/mla/common.py +@@ -200,6 +200,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + MLAAttentionImpl) + from vllm.attention.backends.utils import get_mla_dims + from vllm.attention.ops.merge_attn_states import merge_attn_states ++from vllm.forward_context import ForwardContext, get_forward_context + from vllm.attention.utils.fa_utils import get_flash_attn_version + from vllm.logger import init_logger + from vllm.model_executor.layers.linear import (ColumnParallelLinear, +@@ -211,6 +212,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) + from vllm.v1.kv_cache_interface import AttentionSpec + from vllm.v1.worker.block_table import BlockTable ++from vllm.attention.layer import (maybe_execute_sparse_attention_begin, maybe_execute_sparse_attention_finished) + + try: + from vllm.vllm_flash_attn import flash_attn_varlen_func +@@ -908,7 +910,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: +- ++ forward_context: ForwardContext = get_forward_context() + assert output is not None, "Output tensor must be provided." + + if output_scale is not None: +@@ -957,10 +959,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): + ) + + if has_prefill: ++ maybe_execute_sparse_attention_begin(prefill_q, prefill_k_c_normed, prefill_k_pe, layer.layer_name, forward_context, "prefill") + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata) +- ++ maybe_execute_sparse_attention_finished(prefill_q, prefill_k_c_normed, prefill_k_pe, output[num_decode_tokens:], layer.layer_name, forward_context, "prefill") + if has_decode: + assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( +@@ -971,8 +974,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) +- ++ maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, layer.layer_name, forward_context, "decode") + output[:num_decode_tokens] = self._forward_decode( + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) ++ maybe_execute_sparse_attention_finished(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, output[:num_decode_tokens], layer.layer_name, forward_context, "decode") + + return output_padded +diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py +index 6937455e7..bf9aec864 100644 +--- a/vllm/v1/core/kv_cache_manager.py ++++ b/vllm/v1/core/kv_cache_manager.py +@@ -3,7 +3,7 @@ + + from collections import defaultdict + from dataclasses import dataclass +-from typing import Optional ++from typing import Optional, Union + + from vllm.distributed.kv_events import KVCacheEvent + from vllm.logger import init_logger +@@ -14,6 +14,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.metrics.stats import PrefixCacheStats + from vllm.v1.request import Request, RequestStatus ++from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse ++from ucm.sparse.base import INVALID_SLOT + + logger = init_logger(__name__) + +@@ -193,6 +195,7 @@ class KVCacheManager: + num_draft_tokens: int = 0, + num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, ++ num_slots_sparsed: Union[None, int] = None + ) -> Optional[KVCacheBlocks]: + """Add slots for a request with new tokens to append. + +@@ -231,6 +234,8 @@ class KVCacheManager: + """ + if num_new_tokens == 0: + raise ValueError("num_new_tokens must be greater than 0") ++ if num_slots_sparsed != INVALID_SLOT: ++ return get_ucm_sparse().allocate_slots(self, request, num_slots_sparsed) + + if new_computed_blocks is not None: + new_computed_block_list = new_computed_blocks.blocks +diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py +index c94e421c0..fff0eeb42 100644 +--- a/vllm/v1/core/sched/output.py ++++ b/vllm/v1/core/sched/output.py +@@ -157,3 +157,6 @@ class SchedulerOutput: + + # KV Cache Connector metadata. + kv_connector_metadata: Optional[KVConnectorMetadata] = None ++ ++ # modified slots by sparse algorithm ++ req_sparsed_slots: dict[str, int] = None +diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py +index 2d4fd4d59..e99a51788 100644 +--- a/vllm/v1/core/sched/scheduler.py ++++ b/vllm/v1/core/sched/scheduler.py +@@ -35,6 +35,8 @@ from vllm.v1.request import Request, RequestStatus + from vllm.v1.spec_decode.metrics import SpecDecodingStats + from vllm.v1.structured_output import StructuredOutputManager + from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import MultiConnector ++from ucm.sparse.state import ensure_ucm_sparse_initialized, get_ucm_sparse, has_ucm_sparse ++from ucm.sparse.base import UcmSparseBase, UcmSparseRole, INVALID_SLOT + + logger = init_logger(__name__) + +@@ -80,12 +82,18 @@ class Scheduler(SchedulerInterface): + # will have a corresponding KVConnector with Role=WORKER. + # KV Connector pushes/pull of remote KVs for P/D and offloading. + self.connector = None ++ self.ucm_sparse = None + if self.vllm_config.kv_transfer_config is not None: + assert len(self.kv_cache_config.kv_cache_groups) == 1, ( + "Multiple KV cache groups are not currently supported " + "with KV connectors") + self.connector = KVConnectorFactory.create_connector_v1( + config=self.vllm_config, role=KVConnectorRole.SCHEDULER) ++ # Initialize UCM Sparse if available ++ if "ucm_sparse_config" in vllm_config.kv_transfer_config.kv_connector_extra_config: ++ ensure_ucm_sparse_initialized(vllm_config, role=UcmSparseRole.SCHEDULER) ++ self.ucm_sparse = get_ucm_sparse() ++ logger.info("UCM Sparse initialized successfully: {}".format(self.ucm_sparse)) + + self.kv_event_publisher = EventPublisherFactory.create( + self.kv_events_config, +@@ -203,8 +211,13 @@ class Scheduler(SchedulerInterface): + + # First, schedule the RUNNING requests. + req_index = 0 ++ req_sparsed_slots: dict[str, int] = {} + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] ++ num_slots_sparsed = INVALID_SLOT ++ if self.ucm_sparse: ++ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request) ++ req_sparsed_slots.update({request.request_id: num_slots_sparsed}) + + num_new_tokens = (request.num_tokens_with_spec - + request.num_computed_tokens) +@@ -252,7 +265,8 @@ class Scheduler(SchedulerInterface): + request, + num_new_tokens, + num_draft_tokens=num_draft_tokens, +- num_lookahead_tokens=self.num_lookahead_tokens) ++ num_lookahead_tokens=self.num_lookahead_tokens, ++ num_slots_sparsed=num_slots_sparsed) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. +@@ -339,6 +353,10 @@ class Scheduler(SchedulerInterface): + break + + request = self.waiting.peek_request() ++ num_slots_sparsed = INVALID_SLOT ++ if self.ucm_sparse: ++ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request) ++ req_sparsed_slots.update({request.request_id: num_slots_sparsed}) + + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: +@@ -448,6 +466,7 @@ class Scheduler(SchedulerInterface): + new_computed_blocks, + num_lookahead_tokens=self.num_lookahead_tokens, + delay_cache_blocks=load_kv_async, ++ num_slots_sparsed=num_slots_sparsed + ) + if new_blocks is None: + # The request cannot be scheduled. +@@ -561,6 +580,7 @@ class Scheduler(SchedulerInterface): + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, ++ req_sparsed_slots=req_sparsed_slots, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between +@@ -809,16 +829,12 @@ class Scheduler(SchedulerInterface): + new_logprobs = None + new_token_ids = generated_token_ids + kv_transfer_params = None ++ + if model_runner_output.finished_dumping is not None: + request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) + is_prefill = request.num_output_tokens == 0 + if is_prefill: +- if isinstance(self.connector, MultiConnector): +- for c in self.connector._connectors: +- if hasattr(c, 'connector') and hasattr(c.connector, 'commit'): +- c.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) +- else: +- self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) ++ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) + + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner +@@ -870,7 +886,6 @@ class Scheduler(SchedulerInterface): + spec_token_ids[req_index]) + else: + request.spec_token_ids = spec_token_ids[req_index] +- + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids or pooler_output is not None \ +@@ -897,6 +912,7 @@ class Scheduler(SchedulerInterface): + + if not stopped: + new_running.append(request) ++ + self.running = new_running + + # KV Connector: update state for finished KV Transfers. +@@ -955,6 +971,8 @@ class Scheduler(SchedulerInterface): + def add_request(self, request: Request) -> None: + self.waiting.add_request(request) + self.requests[request.request_id] = request ++ if self.ucm_sparse: ++ self.ucm_sparse.request_begin(request.request_id, request.prompt_token_ids) + if self.log_stats: + request.record_event(EngineCoreEventType.QUEUED) + +@@ -1004,6 +1022,8 @@ class Scheduler(SchedulerInterface): + + def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + assert request.is_finished() ++ if self.ucm_sparse: ++ self.ucm_sparse.request_finished_in_scheduler(request.request_id) + + delay_free_blocks, kv_xfer_params = self._connector_finished(request) + self.encoder_cache_manager.free(request) +@@ -1155,7 +1175,6 @@ class Scheduler(SchedulerInterface): + logger.debug("Finished sending KV transfer for request %s", req_id) + self._free_blocks(self.requests[req_id]) + +- + def _update_requests_with_invalid_blocks( + self, requests: Iterable[Request], + invalid_block_ids: set[int]) -> tuple[set[str], int]: +diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py +index 8f4e8d64c..f45e39f5c 100644 +--- a/vllm/v1/worker/block_table.py ++++ b/vllm/v1/worker/block_table.py +@@ -61,6 +61,15 @@ class BlockTable: + self.num_blocks_per_row[row_idx] += num_blocks + self.block_table_np[row_idx, start:start + num_blocks] = block_ids + ++ def reset_row( ++ self, ++ row_idx: int, ++ ) -> None: ++ self.num_blocks_per_row[row_idx] = 0 ++ self.block_table[row_idx].fill_(0) ++ self.block_table_cpu[row_idx].fill_(0) ++ self.block_table_np[row_idx].fill(0) ++ + def add_row(self, block_ids: list[int], row_idx: int) -> None: + self.num_blocks_per_row[row_idx] = 0 + self.append_row(block_ids, row_idx) +@@ -117,6 +126,10 @@ class MultiGroupBlockTable: + for i, block_table in enumerate(self.block_tables): + block_table.append_row(block_ids[i], row_idx) + ++ def reset_row(self, row_idx: int) -> None: ++ for i, block_table in enumerate(self.block_tables): ++ block_table.reset_row(row_idx) ++ + def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.add_row(block_ids[i], row_idx) +diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py +index c3df1d5d2..dbf1ea7d7 100644 +--- a/vllm/v1/worker/gpu_model_runner.py ++++ b/vllm/v1/worker/gpu_model_runner.py +@@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager + from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, + sanity_check_mm_encoder_outputs, scatter_mm_placeholders) + ++from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse ++from ucm.sparse.base import UcmSparseMetadata, INVALID_SLOT ++ + if TYPE_CHECKING: + import xgrammar as xgr + import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 +@@ -365,6 +368,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: ++ self.ucm_sparse_request_finished_in_worker(req_id) + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + # Remove the finished requests from the persistent batch. +@@ -468,12 +472,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs ++ req_sparsed_slots = scheduler_output.req_sparsed_slots + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] + num_output_tokens = req_data.num_output_tokens[i] ++ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens +@@ -510,15 +516,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): + end_idx:old_end_idx] = False + + # Update the block IDs. +- if not resumed_from_preemption: +- # Append the new blocks to the existing block IDs. +- for block_ids, new_ids in zip(req_state.block_ids, +- new_block_ids): +- block_ids.extend(new_ids) +- else: ++ if resumed_from_preemption or is_sparsed_request: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids ++ else: ++ # Append the new blocks to the existing block IDs. ++ for block_ids, new_ids in zip(req_state.block_ids, ++ new_block_ids): ++ block_ids.extend(new_ids) + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: +@@ -531,6 +537,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + num_computed_tokens) ++ if is_sparsed_request: ++ self.input_batch.block_table.reset_row(req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) + + # For the last rank, we don't need to update the token_ids_cpu +@@ -639,6 +647,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + ++ self.seq_lens_np[:num_reqs] = ( ++ self.input_batch.num_computed_tokens_cpu[:num_reqs] + ++ num_scheduled_tokens) ++ ++ # TODO: improve performance, no `positions_np.copy()` ++ sparsed_positions = positions_np.copy() ++ req_sparsed_slots = scheduler_output.req_sparsed_slots ++ for req_id in self.input_batch.req_id_to_index: ++ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT ++ req_index = self.input_batch.req_id_to_index[req_id] ++ offset = 0 if req_index == 0 else cu_num_tokens[req_index - 1] # TODO: support MTP ++ if is_sparsed_request: ++ sparsed_positions[offset] = req_sparsed_slots[req_id] - 1 + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] +@@ -668,11 +689,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): + # block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + +- positions_np // block_size) ++ sparsed_positions // block_size) + block_table_cpu = block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten( + )[block_table_indices].numpy() +- block_offsets = positions_np % block_size ++ block_offsets = sparsed_positions % block_size + np.add( + block_numbers * block_size, + block_offsets, +@@ -682,9 +703,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + +- self.seq_lens_np[:num_reqs] = ( +- self.input_batch.num_computed_tokens_cpu[:num_reqs] + +- num_scheduled_tokens) ++ for req_id in self.input_batch.req_id_to_index: ++ req_index = self.input_batch.req_id_to_index[req_id] ++ is_sparsed_request = scheduler_output.req_sparsed_slots[req_id] != INVALID_SLOT ++ if is_sparsed_request: ++ self.seq_lens_np[req_index] = scheduler_output.req_sparsed_slots[req_id] + + # Copy the tensors to the GPU. + self.input_ids[:total_num_scheduled_tokens].copy_( +@@ -696,6 +719,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): + non_blocking=True) + else: + # Common case (1D positions) ++ self.positions_cpu[:total_num_scheduled_tokens] = torch.from_numpy( ++ positions_np[:total_num_scheduled_tokens]) + self.positions[:total_num_scheduled_tokens].copy_( + self.positions_cpu[:total_num_scheduled_tokens], + non_blocking=True) +@@ -1386,6 +1411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + skip_cuda_graphs=skip_cuda_graphs, + ): + self.maybe_setup_kv_connector(scheduler_output) ++ self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata) + + model_output = self.model( + input_ids=input_ids, +@@ -1395,6 +1421,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): + ) + + finished_dumping = self.maybe_wait_for_kv_save() ++ self.maybe_execute_ucm_sparse_finished() ++ + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + invalid_block_ids = self.get_block_ids_with_load_errors() +@@ -1741,10 +1769,29 @@ class GPUModelRunner(LoRAModelRunnerMixin): + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod +- def maybe_wait_for_kv_save(): ++ def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().wait_for_save() + ++ def maybe_execute_ucm_sparse_begin(self, scheduler_output: "SchedulerOutput", attn_metadata: CommonAttentionMetadata): ++ if not has_ucm_sparse(): ++ return ++ ucm_sparse = get_ucm_sparse() ++ ucm_sparse.build_sparse_meta(scheduler_output, self.requests, self.input_batch, attn_metadata) ++ ucm_sparse.execute_begin(scheduler_output) ++ ++ def maybe_execute_ucm_sparse_finished(self): ++ if not has_ucm_sparse(): ++ return ++ ucm_sparse = get_ucm_sparse() ++ ucm_sparse.execute_finished() ++ ++ def ucm_sparse_request_finished_in_worker(self, request_id: str | int): ++ if not has_ucm_sparse(): ++ return ++ ucm_sparse = get_ucm_sparse() ++ ucm_sparse.request_finished_in_worker(request_id) ++ + @staticmethod + def get_finished_kv_transfers( + scheduler_output: "SchedulerOutput", +diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py +index 1b816b25b..d9666d102 100644 +--- a/vllm/v1/worker/gpu_worker.py ++++ b/vllm/v1/worker/gpu_worker.py +@@ -30,6 +30,7 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput + from vllm.v1.utils import report_usage_stats + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + from vllm.v1.worker.worker_base import WorkerBase ++from ucm.sparse.state import ensure_ucm_sparse_initialized + + logger = init_logger(__name__) + +@@ -401,6 +402,7 @@ def init_worker_distributed_environment( + parallel_config.pipeline_parallel_size) + + ensure_kv_transfer_initialized(vllm_config) ++ ensure_ucm_sparse_initialized(vllm_config) + + + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): +-- +2.34.1 + diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch deleted file mode 100644 index da340e7ef..000000000 --- a/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch +++ /dev/null @@ -1,1159 +0,0 @@ -From 555ba9e4920445381aecda262b9146342e92eeee Mon Sep 17 00:00:00 2001 -From: hek14 <1023129548@qq.com> -Date: Fri, 26 Sep 2025 09:51:36 +0800 -Subject: [PATCH] UCM adaptor - ---- - vllm/attention/layer.py | 45 ++++- - .../kv_transfer/kv_connector/utils.py | 113 ++++++++++++ - .../kv_transfer/kv_connector/v1/base.py | 9 + - .../v1/shared_storage_connector.py | 7 +- - vllm/v1/core/block_pool.py | 2 +- - vllm/v1/core/kv_cache_manager.py | 11 +- - vllm/v1/core/sched/output.py | 3 + - vllm/v1/core/sched/scheduler.py | 164 +++++++++++++++++- - vllm/v1/core/single_type_kv_cache_manager.py | 3 + - vllm/v1/executor/multiproc_executor.py | 30 +++- - vllm/v1/outputs.py | 5 + - vllm/v1/request.py | 2 +- - vllm/v1/worker/block_table.py | 13 ++ - vllm/v1/worker/gpu_input_batch.py | 9 + - vllm/v1/worker/gpu_model_runner.py | 120 +++++++++++-- - vllm/v1/worker/gpu_worker.py | 25 ++- - 16 files changed, 524 insertions(+), 37 deletions(-) - -diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py -index f0ad68b16..2acde35d8 100644 ---- a/vllm/attention/layer.py -+++ b/vllm/attention/layer.py -@@ -2,7 +2,6 @@ - # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - """Attention layer.""" - from typing import Any, Dict, List, Optional -- - import torch - import torch.nn as nn - import torch.nn.functional as F -@@ -22,6 +21,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod - from vllm.platforms import _Backend, current_platform - from vllm.utils import direct_register_custom_op - from vllm.v1.attention.backends.utils import validate_kv_sharing_target -+from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse - - - class Attention(nn.Module): -@@ -409,9 +409,10 @@ def unified_attention( - attn_metadata = attn_metadata[layer_name] - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] -+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) - output = self.impl.forward(self, query, key, value, kv_cache, - attn_metadata) -- -+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) - return output - -@@ -449,6 +450,7 @@ def unified_attention_with_output( - attn_metadata = attn_metadata[layer_name] - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] -+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) - self.impl.forward(self, - query, - key, -@@ -457,7 +459,7 @@ def unified_attention_with_output( - attn_metadata, - output=output, - output_scale=output_scale) -- -+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) - - -@@ -479,3 +481,40 @@ direct_register_custom_op( - fake_impl=unified_attention_with_output_fake, - dispatch_key=current_platform.dispatch_key, - ) -+ -+def maybe_execute_sparse_attention_begin( -+ query: torch.Tensor, -+ key: torch.Tensor, -+ value: torch.Tensor, -+ layer_name: str, -+ forward_context: ForwardContext, -+): -+ if not has_ucm_sparse(): -+ return -+ -+ ucm_sparse = get_ucm_sparse() -+ -+ attn_metadata = forward_context.attn_metadata -+ if attn_metadata is None: -+ return -+ -+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context) -+ -+def maybe_execute_sparse_attention_finished( -+ query: torch.Tensor, -+ key: torch.Tensor, -+ value: torch.Tensor, -+ attn_output: torch.Tensor, -+ layer_name: str, -+ forward_context: ForwardContext, -+): -+ if not has_ucm_sparse(): -+ return -+ -+ ucm_sparse = get_ucm_sparse() -+ -+ attn_metadata = forward_context.attn_metadata -+ if attn_metadata is None: -+ return -+ -+ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context) -diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py -index 5cbc8ca31..8556a979e 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/utils.py -+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py -@@ -3,12 +3,18 @@ - """ - KV cache helper for store. - """ -+from collections import defaultdict -+from collections.abc import Sequence -+from concurrent.futures import CancelledError, Future -+from typing import Optional, cast -+ - import torch - - import vllm.envs as envs - from vllm import _custom_ops as ops - from vllm.config import VllmConfig, get_current_vllm_config - from vllm.logger import init_logger -+from vllm.v1.outputs import ModelRunnerOutput - - logger = init_logger(__name__) - -@@ -107,3 +113,110 @@ def get_kv_connector_cache_layout(): - "layout to HND for better xfer performance.") - return "HND" - return "NHD" -+ -+ -+class KVOutputAggregator: -+ """Utility class to aggregate the output of all workers into a single -+ output corresponding to Rank 0 for scheduler.""" -+ -+ def __init__(self, world_size: int): -+ # Complete transfer tracker. Used by to track finished requests -+ # [req_id -> n_finished_workers] -+ self._recv_remaining_count = defaultdict[str, int](lambda: world_size) -+ self._send_remaining_count = defaultdict[str, int](lambda: world_size) -+ self._dump_remaining_count = defaultdict[str, int](lambda: world_size) -+ -+ def aggregate(self, -+ outputs: list[ModelRunnerOutput], -+ output_rank: int = 0) -> ModelRunnerOutput: -+ # aggregate finished_sending, finished_recving from all workers -+ -+ def update_finished_set(req_ids: Optional[set[str]], -+ remaining_count_dict: dict[str, int], -+ finished_set: set[str]) -> None: -+ for req_id in req_ids or (): -+ new_count = remaining_count_dict[req_id] - 1 -+ if new_count == 0: -+ finished_set.add(req_id) -+ del remaining_count_dict[req_id] -+ else: -+ remaining_count_dict[req_id] = new_count -+ -+ def update_finished_list(req_ids: Optional[dict[str, list[str]]], -+ remaining_count_dict: dict[str, int], -+ finished_list: dict[str, list[str]]) -> None: -+ for req_id, succeed_dump_blocks in (req_ids or {}).items(): -+ if req_id not in finished_list: -+ finished_list[req_id] = [] -+ for blk_id in succeed_dump_blocks: -+ new_count = remaining_count_dict[blk_id] - 1 -+ if new_count == 0: -+ finished_list[req_id].append(blk_id) -+ del remaining_count_dict[blk_id] -+ else: -+ remaining_count_dict[blk_id] = new_count -+ -+ finished_sending = set[str]() -+ finished_recving = set[str]() -+ invalid_block_ids = set[int]() -+ finished_dumping: dict[str, list[str]] = {} -+ for output in outputs: -+ update_finished_set(output.finished_sending, -+ self._send_remaining_count, finished_sending) -+ update_finished_set(output.finished_recving, -+ self._recv_remaining_count, finished_recving) -+ update_finished_list(output.finished_dumping, -+ self._dump_remaining_count, finished_dumping) -+ if output.invalid_block_ids: -+ invalid_block_ids |= output.invalid_block_ids -+ -+ # select output of the worker specified by output_rank -+ output = outputs[output_rank] -+ -+ # set the aggregated finished_sending / finished_recving -+ # if output.finished_sending/recving is not empty, but the other ranks -+ # still have unfinished send/recv, we want to set the aggregated -+ # finished_sending/recving to None until all ranks have finished -+ # send/recv -+ output.finished_sending = finished_sending if finished_sending else None -+ output.finished_recving = finished_recving if finished_recving else None -+ output.finished_dumping = finished_dumping if finished_dumping else None -+ output.invalid_block_ids = invalid_block_ids or None -+ -+ return output -+ -+ def async_aggregate(self, -+ output_futures: Sequence[Future[ModelRunnerOutput]], -+ output_rank: int = 0) -> Future[ModelRunnerOutput]: -+ """Takes a list of futures and returns a single future which resolves -+ to the respective list of outputs.""" -+ result_future: Future[ModelRunnerOutput] = Future() -+ -+ outputs: list[Optional[ModelRunnerOutput]] = [None -+ ] * len(output_futures) -+ -+ def make_callback(idx): -+ -+ def callback(fut): -+ if result_future.done(): -+ return -+ -+ try: -+ outputs[idx] = fut.result() -+ except CancelledError: -+ result_future.cancel() -+ except Exception as e: -+ result_future.set_exception(e) -+ -+ # this check assumes io_thread_pool uses a single thread -+ if all(outputs): -+ result_future.set_result( -+ self.aggregate(cast(list[ModelRunnerOutput], outputs), -+ output_rank)) -+ -+ return callback -+ -+ for i, output_future in enumerate(output_futures): -+ output_future.add_done_callback(make_callback(i)) -+ -+ return result_future -diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py -index f80b5eba2..8891246e6 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py -+++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py -@@ -201,6 +201,15 @@ class KVConnectorBase_V1(ABC): - """ - return None, None - -+ def get_block_ids_with_load_errors(self) -> Optional[set[int]]: -+ """ -+ Get the set of block IDs that failed to load. -+ Returns: -+ Optional[set[int]]: A set of block IDs that encountered load errors. -+ Returns None if no errors occurred during load. -+ """ -+ return None -+ - # ============================== - # Scheduler-side methods - # ============================== -diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py -index 3c574d065..223106def 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py -+++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py -@@ -2,7 +2,7 @@ - # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import hashlib - import os --from dataclasses import dataclass -+from dataclasses import dataclass, field - from typing import TYPE_CHECKING - - import safetensors -@@ -53,10 +53,7 @@ class ReqMeta: - - @dataclass - class SharedStorageConnectorMetadata(KVConnectorMetadata): -- requests: list[ReqMeta] -- -- def __init__(self): -- self.requests = [] -+ requests: list[ReqMeta] = field(default_factory=list) - - def add_request( - self, -diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py -index d21f94727..1800665c7 100644 ---- a/vllm/v1/core/block_pool.py -+++ b/vllm/v1/core/block_pool.py -@@ -124,7 +124,7 @@ class BlockPool: - kv_cache_group_id: The id of the KV cache group. - hash_fn: The hash function to use for block hashes. - """ -- if num_cached_blocks == num_full_blocks: -+ if num_cached_blocks >= num_full_blocks: - return - new_full_blocks = blocks[num_cached_blocks:num_full_blocks] - assert len(block_hashes) >= num_cached_blocks -diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py -index 6937455e7..c36a25bc5 100644 ---- a/vllm/v1/core/kv_cache_manager.py -+++ b/vllm/v1/core/kv_cache_manager.py -@@ -3,7 +3,7 @@ - - from collections import defaultdict - from dataclasses import dataclass --from typing import Optional -+from typing import Optional, Union - - from vllm.distributed.kv_events import KVCacheEvent - from vllm.logger import init_logger -@@ -14,6 +14,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - from vllm.v1.kv_cache_interface import KVCacheConfig - from vllm.v1.metrics.stats import PrefixCacheStats - from vllm.v1.request import Request, RequestStatus -+from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse -+from ucm.sparse.base import INVALID_SLOT - - logger = init_logger(__name__) - -@@ -193,6 +195,7 @@ class KVCacheManager: - num_draft_tokens: int = 0, - num_lookahead_tokens: int = 0, - delay_cache_blocks: bool = False, -+ num_slots_sparsed: Union[None, int] = None - ) -> Optional[KVCacheBlocks]: - """Add slots for a request with new tokens to append. - -@@ -231,6 +234,12 @@ class KVCacheManager: - """ - if num_new_tokens == 0: - raise ValueError("num_new_tokens must be greater than 0") -+ if num_slots_sparsed != INVALID_SLOT: -+ return get_ucm_sparse().allocate_slots(request, -+ num_slots_sparsed, -+ self.coordinator, -+ self.block_pool, -+ self.kv_cache_config.kv_cache_groups) - - if new_computed_blocks is not None: - new_computed_block_list = new_computed_blocks.blocks -diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py -index d34f39327..141d750b3 100644 ---- a/vllm/v1/core/sched/output.py -+++ b/vllm/v1/core/sched/output.py -@@ -155,3 +155,6 @@ class SchedulerOutput: - - # KV Cache Connector metadata. - kv_connector_metadata: Optional[KVConnectorMetadata] = None -+ -+ # modified slots by sparse algorithm -+ req_sparsed_slots: dict[str, int] = None -diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index fe552db74..6a9d4b4b9 100644 ---- a/vllm/v1/core/sched/scheduler.py -+++ b/vllm/v1/core/sched/scheduler.py -@@ -34,6 +34,8 @@ from vllm.v1.outputs import ModelRunnerOutput - from vllm.v1.request import Request, RequestStatus - from vllm.v1.spec_decode.metrics import SpecDecodingStats - from vllm.v1.structured_output import StructuredOutputManager -+from ucm.sparse.state import ensure_ucm_sparse_initialized, get_ucm_sparse, has_ucm_sparse -+from ucm.sparse.base import UcmSparseBase, UcmSparseRole, INVALID_SLOT - - logger = init_logger(__name__) - -@@ -79,12 +81,18 @@ class Scheduler(SchedulerInterface): - # will have a corresponding KVConnector with Role=WORKER. - # KV Connector pushes/pull of remote KVs for P/D and offloading. - self.connector = None -+ self.ucm_sparse = None - if self.vllm_config.kv_transfer_config is not None: - assert len(self.kv_cache_config.kv_cache_groups) == 1, ( - "Multiple KV cache groups are not currently supported " - "with KV connectors") - self.connector = KVConnectorFactory.create_connector_v1( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER) -+ # Initialize UCM Sparse if available -+ if "ucm_sparse_config" in vllm_config.kv_transfer_config.kv_connector_extra_config: -+ ensure_ucm_sparse_initialized(vllm_config, role=UcmSparseRole.SCHEDULER) -+ self.ucm_sparse = get_ucm_sparse() -+ logger.info("UCM Sparse initialized successfully: {}".format(self.ucm_sparse)) - - self.kv_event_publisher = EventPublisherFactory.create( - self.kv_events_config, -@@ -201,8 +209,13 @@ class Scheduler(SchedulerInterface): - - # First, schedule the RUNNING requests. - req_index = 0 -+ req_sparsed_slots: dict[str, int] = {} - while req_index < len(self.running) and token_budget > 0: - request = self.running[req_index] -+ num_slots_sparsed = INVALID_SLOT -+ if self.ucm_sparse: -+ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request) -+ req_sparsed_slots.update({request.request_id: num_slots_sparsed}) - - num_new_tokens = (request.num_tokens_with_spec - - request.num_computed_tokens) -@@ -250,7 +263,8 @@ class Scheduler(SchedulerInterface): - request, - num_new_tokens, - num_draft_tokens=num_draft_tokens, -- num_lookahead_tokens=self.num_lookahead_tokens) -+ num_lookahead_tokens=self.num_lookahead_tokens, -+ num_slots_sparsed=num_slots_sparsed) - if new_blocks is None: - # The request cannot be scheduled. - # Preempt the lowest-priority request. -@@ -337,6 +351,10 @@ class Scheduler(SchedulerInterface): - break - - request = self.waiting.peek_request() -+ num_slots_sparsed = INVALID_SLOT -+ if self.ucm_sparse: -+ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request) -+ req_sparsed_slots.update({request.request_id: num_slots_sparsed}) - - # KVTransfer: skip request if still waiting for remote kvs. - if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: -@@ -446,6 +464,7 @@ class Scheduler(SchedulerInterface): - new_computed_blocks, - num_lookahead_tokens=self.num_lookahead_tokens, - delay_cache_blocks=load_kv_async, -+ num_slots_sparsed=num_slots_sparsed - ) - if new_blocks is None: - # The request cannot be scheduled. -@@ -559,6 +578,7 @@ class Scheduler(SchedulerInterface): - scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, - scheduled_encoder_inputs=scheduled_encoder_inputs, - num_common_prefix_blocks=num_common_prefix_blocks, -+ req_sparsed_slots=req_sparsed_slots, - # finished_req_ids is an existing state in the scheduler, - # instead of being newly scheduled in this step. - # It contains the request IDs that are finished in between -@@ -745,16 +765,31 @@ class Scheduler(SchedulerInterface): - num_scheduled_tokens = scheduler_output.num_scheduled_tokens - pooler_outputs = model_runner_output.pooler_output - num_nans_in_logits = model_runner_output.num_nans_in_logits -+ invalid_block_ids = model_runner_output.invalid_block_ids - - new_running: list[Request] = [] - outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) - spec_decoding_stats: Optional[SpecDecodingStats] = None - -+ recovered_req_ids = None -+ if invalid_block_ids: -+ # These blocks contain externally computed tokens that failed to -+ # load. Identify affected requests and adjust their computed token -+ # count to trigger recomputation of the invalid blocks. -+ recovered_req_ids = self._handle_invalid_blocks(invalid_block_ids) -+ - # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below - # loop can be a performance bottleneck. We should do our best to avoid - # expensive operations inside the loop. - for request in self.running: - req_id = request.request_id -+ # self.req_meta.stage == SequenceStage.PREFILL and self.req_meta.is_last_chunk -+ -+ -+ if recovered_req_ids and req_id in recovered_req_ids: -+ # Skip requests that were recovered from KV load failure -+ new_running.append(request) -+ continue - num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) - if num_tokens_scheduled == 0: - # The request was not scheduled in this step. -@@ -792,6 +827,13 @@ class Scheduler(SchedulerInterface): - new_token_ids = generated_token_ids - kv_transfer_params = None - -+ if model_runner_output.finished_dumping is not None: -+ request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) -+ is_prefill = request.num_output_tokens == 0 -+ is_last_chunk = (num_tokens_scheduled + request.num_computed_tokens >= request.num_prompt_tokens) -+ if is_prefill and is_last_chunk: -+ self.connector.connector.commit(request.succeed_dumped_blocks, True) -+ - # Append generated tokens and check for stop. Note that if - # a request is still being prefilled, we expect the model runner - # to return empty token ids for the request. -@@ -842,7 +884,6 @@ class Scheduler(SchedulerInterface): - spec_token_ids[req_index]) - else: - request.spec_token_ids = spec_token_ids[req_index] -- - # Get prompt logprobs for this request. - prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None \ -@@ -869,6 +910,7 @@ class Scheduler(SchedulerInterface): - - if not stopped: - new_running.append(request) -+ - self.running = new_running - - # KV Connector: update state for finished KV Transfers. -@@ -927,6 +969,8 @@ class Scheduler(SchedulerInterface): - def add_request(self, request: Request) -> None: - self.waiting.add_request(request) - self.requests[request.request_id] = request -+ if self.ucm_sparse: -+ self.ucm_sparse.request_begin(request.request_id, request.prompt_token_ids) - if self.log_stats: - request.record_event(EngineCoreEventType.QUEUED) - -@@ -976,6 +1020,8 @@ class Scheduler(SchedulerInterface): - - def _free_request(self, request: Request) -> Optional[dict[str, Any]]: - assert request.is_finished() -+ if self.ucm_sparse: -+ self.ucm_sparse.request_finished_in_scheduler(request.request_id) - - delay_free_blocks, kv_xfer_params = self._connector_finished(request) - self.encoder_cache_manager.free(request) -@@ -1113,3 +1159,117 @@ class Scheduler(SchedulerInterface): - for req_id in (model_runner_output.finished_sending or ()): - logger.debug("Finished sending KV transfer for request %s", req_id) - self._free_blocks(self.requests[req_id]) -+ -+ def _update_requests_with_invalid_blocks( -+ self, requests: Iterable[Request], -+ invalid_block_ids: set[int]) -> tuple[set[Request], int, set[int]]: -+ affected_requests: set[Request] = set() -+ num_tokens_to_reschedule = 0 -+ # If a block is invalid and shared by multiple requests in the batch, -+ # all requests must be rescheduled, but only the first will recompute -+ # it. This set tracks blocks already marked for recomputation. -+ marked_invalid_block_ids: set[int] = set() -+ for request in requests: -+ is_affected = False -+ marked_invalid_block = False -+ req_id = request.request_id -+ req_block_ids = self.kv_cache_manager.get_block_ids(req_id)[0] -+ # We iterate only over blocks that may contain externally computed -+ # tokens -+ if request.num_cached_tokens > 0: -+ req_num_computed_blocks = (request.num_cached_tokens + -+ self.block_size - -+ 1) // self.block_size -+ else: -+ req_num_computed_blocks = len(req_block_ids) -+ -+ for idx, block_id in zip(range(req_num_computed_blocks), -+ req_block_ids): -+ -+ if block_id not in invalid_block_ids: -+ continue -+ -+ is_affected = True -+ -+ if block_id in marked_invalid_block_ids: -+ # This invalid block is shared with a previous request -+ # and was already marked for recomputation. -+ # This means this request can still consider this block -+ # as computed when rescheduled. -+ continue -+ -+ marked_invalid_block_ids.add(block_id) -+ -+ if marked_invalid_block: -+ # This request has already marked an invalid block for -+ # recomputation and updated its num_computed_tokens. -+ continue -+ -+ marked_invalid_block = True -+ num_tokens_to_reschedule += request.num_computed_tokens -+ request.num_computed_tokens = idx * self.block_size -+ num_tokens_to_reschedule -= request.num_computed_tokens -+ -+ if is_affected: -+ if not marked_invalid_block: -+ # All invalid blocks of this request are shared with -+ # previous requests and will be recomputed by them. -+ # Revert to considering only cached tokens as computed. -+ num_tokens_to_reschedule += (request.num_computed_tokens - -+ request.num_cached_tokens) -+ request.num_computed_tokens = request.num_cached_tokens -+ -+ affected_requests.add(request) -+ -+ return (affected_requests, num_tokens_to_reschedule, -+ marked_invalid_block_ids) -+ -+ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: -+ total_requests_to_reschedule = 0 -+ total_tokens_to_reschedule = 0 -+ -+ # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) --- -+ async_load_reqs = ( -+ req for req in self.waiting -+ if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS) -+ (affected_requests, num_tokens_to_reschedule, -+ marked_invalid_block_ids) = ( -+ self._update_requests_with_invalid_blocks(async_load_reqs, -+ invalid_block_ids)) -+ -+ total_requests_to_reschedule += len(affected_requests) -+ total_tokens_to_reschedule += num_tokens_to_reschedule -+ -+ for request in affected_requests: -+ if request.num_computed_tokens: -+ # Cache any valid computed tokens. -+ self.kv_cache_manager.cache_blocks(request, -+ request.num_computed_tokens) -+ else: -+ # No valid computed tokens, release allocated blocks. -+ # There may be a local cache hit on retry. -+ self.kv_cache_manager.free(request) -+ -+ request.status = RequestStatus.WAITING -+ -+ # Remove async loaded invalid blocks already handled, -+ # as they cannot be shared with running requests. -+ invalid_block_ids.difference_update(marked_invalid_block_ids) -+ -+ # --- Handle sync KV loads (running requests) --- -+ affected_requests, num_tokens_to_reschedule, _ = ( -+ self._update_requests_with_invalid_blocks(self.running, -+ invalid_block_ids)) -+ -+ total_requests_to_reschedule += len(affected_requests) -+ total_tokens_to_reschedule += num_tokens_to_reschedule -+ -+ if total_requests_to_reschedule: -+ logger.info( -+ "Recovered from KV load failure: " -+ "%d request(s) rescheduled (%d tokens affected).", -+ total_requests_to_reschedule, total_tokens_to_reschedule) -+ -+ # Return the IDs of affected running requests to skip in -+ # update_from_output. -+ return {r.request_id for r in affected_requests} -diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py -index 5b4718038..28bd4618a 100644 ---- a/vllm/v1/core/single_type_kv_cache_manager.py -+++ b/vllm/v1/core/single_type_kv_cache_manager.py -@@ -142,6 +142,9 @@ class SingleTypeKVCacheManager(ABC): - num_cached_blocks = self.num_cached_block[request.request_id] - num_full_blocks = num_tokens // self.block_size - -+ if num_cached_blocks >= num_full_blocks: -+ return -+ - self.block_pool.cache_full_blocks( - request=request, - blocks=self.req_to_blocks[request.request_id], -diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py -index b06b7cc80..61cd7110f 100644 ---- a/vllm/v1/executor/multiproc_executor.py -+++ b/vllm/v1/executor/multiproc_executor.py -@@ -26,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel) - from vllm.distributed.device_communicators.shm_broadcast import (Handle, - MessageQueue) -+from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator - from vllm.executor.multiproc_worker_utils import ( - _add_prefix, set_multiprocessing_worker_envs) - from vllm.logger import init_logger -@@ -111,10 +112,14 @@ class MultiprocExecutor(Executor): - if self.max_concurrent_batches > 1: - # Note: must use only 1 IO thread to keep dequeue sequence - # from the response queue -+ # _async_aggregate_workers_output also assumes a single IO thread - self.io_thread_pool = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mp_exec_io") - - self.output_rank = self._get_output_rank() -+ self.has_connector = self.vllm_config.kv_transfer_config is not None -+ self.kv_output_aggregator = KVOutputAggregator( -+ self.parallel_config.world_size) - - def start_worker_monitor(self): - workers = self.workers -@@ -155,13 +160,30 @@ class MultiprocExecutor(Executor): - self, - scheduler_output, - ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: -- (output, ) = self.collective_rpc( -+ non_block = self.max_concurrent_batches > 1 -+ -+ if not self.has_connector or self.vllm_config.model_config.use_mla: -+ # get output only from a single worker (output_rank) -+ (output, ) = self.collective_rpc( -+ "execute_model", -+ args=(scheduler_output, ), -+ unique_reply_rank=self.output_rank, -+ non_block=non_block, -+ timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) -+ return output -+ -+ # get output from all workers -+ outputs = self.collective_rpc( - "execute_model", - args=(scheduler_output, ), -- unique_reply_rank=self.output_rank, -- non_block=self.max_concurrent_batches > 1, -+ non_block=non_block, - timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) -- return output -+ -+ # aggregate all workers output to a single output -+ if non_block: -+ return self.kv_output_aggregator.async_aggregate( -+ outputs, self.output_rank) -+ return self.kv_output_aggregator.aggregate(outputs, self.output_rank) - - def collective_rpc(self, - method: Union[str, Callable], -diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py -index f78623f57..c7b4100e3 100644 ---- a/vllm/v1/outputs.py -+++ b/vllm/v1/outputs.py -@@ -107,6 +107,11 @@ class ModelRunnerOutput: - # [req_ids] - finished_sending: Optional[set[str]] = None - finished_recving: Optional[set[str]] = None -+ finished_dumping: Optional[dict[str, list[str]]] = None -+ -+ # IDs of externally computed KV blocks that failed to load. -+ # Requests referencing these blocks should be rescheduled to recompute them. -+ invalid_block_ids: Optional[set[int]] = None - - # req_id -> num_nans_in_logits - num_nans_in_logits: Optional[dict[str, int]] = None -diff --git a/vllm/v1/request.py b/vllm/v1/request.py -index 9b96f4599..825b77bba 100644 ---- a/vllm/v1/request.py -+++ b/vllm/v1/request.py -@@ -102,7 +102,7 @@ class Request: - # State - # The number of tokens with prefix cache hits. - self.num_cached_tokens = -1 -- -+ self.succeed_dumped_blocks: list[str] = [] - # The number of NaNs in logits. A value greater than 0 - # indicates that the output is corrupted - self.num_nans_in_logits = 0 -diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py -index 8f4e8d64c..f45e39f5c 100644 ---- a/vllm/v1/worker/block_table.py -+++ b/vllm/v1/worker/block_table.py -@@ -61,6 +61,15 @@ class BlockTable: - self.num_blocks_per_row[row_idx] += num_blocks - self.block_table_np[row_idx, start:start + num_blocks] = block_ids - -+ def reset_row( -+ self, -+ row_idx: int, -+ ) -> None: -+ self.num_blocks_per_row[row_idx] = 0 -+ self.block_table[row_idx].fill_(0) -+ self.block_table_cpu[row_idx].fill_(0) -+ self.block_table_np[row_idx].fill(0) -+ - def add_row(self, block_ids: list[int], row_idx: int) -> None: - self.num_blocks_per_row[row_idx] = 0 - self.append_row(block_ids, row_idx) -@@ -117,6 +126,10 @@ class MultiGroupBlockTable: - for i, block_table in enumerate(self.block_tables): - block_table.append_row(block_ids[i], row_idx) - -+ def reset_row(self, row_idx: int) -> None: -+ for i, block_table in enumerate(self.block_tables): -+ block_table.reset_row(row_idx) -+ - def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: - for i, block_table in enumerate(self.block_tables): - block_table.add_row(block_ids[i], row_idx) -diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py -index 1a79d72be..0e65c98f5 100644 ---- a/vllm/v1/worker/gpu_input_batch.py -+++ b/vllm/v1/worker/gpu_input_batch.py -@@ -46,6 +46,11 @@ class CachedRequestState: - - def __post_init__(self): - self.num_prompt_tokens = len(self.prompt_token_ids) -+ # 'last_generator_offset' and 'last_gelen_last_output_token_ids' are -+ # used to allow safe rollback in case a sampled token turns out to be -+ # invalid (e.g., due to KV load errors). -+ self.last_generator_offset = 0 if self.generator else None -+ self.len_last_output_token_ids = len(self.output_token_ids) - - @property - def num_tokens(self) -> int: -@@ -201,6 +206,7 @@ class InputBatch: - # NOTE(woosuk): The indices of the requests that do not have their own - # generator should not be included in the dictionary. - self.generators: dict[int, torch.Generator] = {} -+ self.generators_last_offset: dict[int, int] = {} - - self.num_logprobs: dict[str, int] = {} - # NOTE(rob): num_prompt_logprobs only includes reqs -@@ -335,6 +341,9 @@ class InputBatch: - # do not have their own generator. - if request.generator is not None: - self.generators[req_index] = request.generator -+ assert (request.last_generator_offset is not None) -+ self.generators_last_offset[ -+ req_index] = request.last_generator_offset - - if sampling_params.logprobs is not None: - self.num_logprobs[req_id] = sampling_params.logprobs -diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index 5a26e88db..17b3d1c79 100644 ---- a/vllm/v1/worker/gpu_model_runner.py -+++ b/vllm/v1/worker/gpu_model_runner.py -@@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager - from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, - sanity_check_mm_encoder_outputs, scatter_mm_placeholders) - -+from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse -+from ucm.sparse.base import UcmSparseMetadata, INVALID_SLOT -+ - if TYPE_CHECKING: - import xgrammar as xgr - import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 -@@ -365,6 +368,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - """ - # Remove finished requests from the cached states. - for req_id in scheduler_output.finished_req_ids: -+ self.ucm_sparse_request_finished_in_worker(req_id) - self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) - # Remove the finished requests from the persistent batch. -@@ -468,13 +472,33 @@ class GPUModelRunner(LoRAModelRunnerMixin): - # Update the states of the running/resumed requests. - is_last_rank = get_pp_group().is_last_rank - req_data = scheduler_output.scheduled_cached_reqs -+ req_sparsed_slots = scheduler_output.req_sparsed_slots - for i, req_id in enumerate(req_data.req_ids): - req_state = self.requests[req_id] - num_computed_tokens = req_data.num_computed_tokens[i] - new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_data.resumed_from_preemption[i] -+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT - - # Update the cached states. -+ if (num_computed_tokens <= req_state.num_computed_tokens): -+ # The request was rescheduled after a KV load failure. Clear -+ # the last sampled tokens and rewind the generator state -+ len_output_token_ids = len(req_state.output_token_ids) -+ del req_state.output_token_ids[req_state. -+ len_last_output_token_ids:] -+ if req_state.generator: -+ req_state.generator.set_offset( -+ req_state.last_generator_offset) -+ req_index = self.input_batch.req_id_to_index.get(req_id) -+ if req_index is not None: -+ len_last_sampled = (len_output_token_ids - -+ req_state.len_last_output_token_ids) -+ end_idx = self.input_batch.num_tokens_no_spec[ -+ req_index] - len_last_sampled -+ self.input_batch.num_tokens[req_index] = end_idx -+ self.input_batch.num_tokens_no_spec[req_index] = end_idx -+ - req_state.num_computed_tokens = num_computed_tokens - - if not is_last_rank: -@@ -493,16 +517,22 @@ class GPUModelRunner(LoRAModelRunnerMixin): - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) - -+ req_state.len_last_output_token_ids = len( -+ req_state.output_token_ids) -+ if req_state.generator: -+ req_state.last_generator_offset = ( -+ req_state.generator.get_offset()) -+ - # Update the block IDs. -- if not resumed_from_preemption: -- # Append the new blocks to the existing block IDs. -- for block_ids, new_ids in zip(req_state.block_ids, -- new_block_ids): -- block_ids.extend(new_ids) -- else: -+ if resumed_from_preemption or is_sparsed_request: - # The request is resumed from preemption. - # Replace the existing block IDs with the new ones. - req_state.block_ids = new_block_ids -+ else: -+ # Append the new blocks to the existing block IDs. -+ for block_ids, new_ids in zip(req_state.block_ids, -+ new_block_ids): -+ block_ids.extend(new_ids) - - req_index = self.input_batch.req_id_to_index.get(req_id) - if req_index is None: -@@ -512,9 +542,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): - req_ids_to_add.append(req_id) - continue - -+ if req_state.generator: -+ assert (req_state.last_generator_offset is not None) -+ self.input_batch.generators_last_offset[ -+ req_index] = req_state.last_generator_offset -+ - # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) -+ if is_sparsed_request: -+ self.input_batch.block_table.reset_row(req_index) - self.input_batch.block_table.append_row(new_block_ids, req_index) - - # For the last rank, we don't need to update the token_ids_cpu -@@ -623,6 +660,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): - if self.uses_mrope: - self._calc_mrope_positions(scheduler_output) - -+ self.seq_lens_np[:num_reqs] = ( -+ self.input_batch.num_computed_tokens_cpu[:num_reqs] + -+ num_scheduled_tokens) -+ -+ # TODO: improve performance, no `positions_np.copy()` -+ sparsed_positions = positions_np.copy() -+ req_sparsed_slots = scheduler_output.req_sparsed_slots -+ for req_id in self.input_batch.req_id_to_index: -+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT -+ req_index = self.input_batch.req_id_to_index[req_id] -+ offset = 0 if req_index == 0 else cu_num_tokens[req_index - 1] # TODO: support MTP -+ if is_sparsed_request: -+ sparsed_positions[offset] = req_sparsed_slots[req_id] - 1 - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] -@@ -652,11 +702,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): - # block_size. - block_table_indices = ( - req_indices * block_table.max_num_blocks_per_req + -- positions_np // block_size) -+ sparsed_positions // block_size) - block_table_cpu = block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten( - )[block_table_indices].numpy() -- block_offsets = positions_np % block_size -+ block_offsets = sparsed_positions % block_size - np.add( - block_numbers * block_size, - block_offsets, -@@ -666,9 +716,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): - self.query_start_loc_np[0] = 0 - self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens - -- self.seq_lens_np[:num_reqs] = ( -- self.input_batch.num_computed_tokens_cpu[:num_reqs] + -- num_scheduled_tokens) -+ for req_id in self.input_batch.req_id_to_index: -+ req_index = self.input_batch.req_id_to_index[req_id] -+ is_sparsed_request = scheduler_output.req_sparsed_slots[req_id] != INVALID_SLOT -+ if is_sparsed_request: -+ self.seq_lens_np[req_index] = scheduler_output.req_sparsed_slots[req_id] - - # Copy the tensors to the GPU. - self.input_ids[:total_num_scheduled_tokens].copy_( -@@ -680,6 +732,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): - non_blocking=True) - else: - # Common case (1D positions) -+ self.positions_cpu[:total_num_scheduled_tokens] = torch.from_numpy( -+ positions_np[:total_num_scheduled_tokens]) - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], - non_blocking=True) -@@ -1370,6 +1424,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - skip_cuda_graphs=skip_cuda_graphs, - ): - self.maybe_setup_kv_connector(scheduler_output) -+ self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata) - - model_output = self.model( - input_ids=input_ids, -@@ -1378,9 +1433,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): - inputs_embeds=inputs_embeds, - ) - -- self.maybe_wait_for_kv_save() -+ finished_dumping = self.maybe_wait_for_kv_save() -+ self.maybe_execute_ucm_sparse_finished() -+ - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) -+ invalid_block_ids = self.get_block_ids_with_load_errors() - - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output -@@ -1474,7 +1532,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): - # This relies on cuda-specific torch-internal impl details - generator = self.input_batch.generators.get(i) - if generator is not None: -- generator.set_offset(generator.get_offset() - 4) -+ generator.set_offset( -+ self.input_batch.generators_last_offset.get(i)) - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) -@@ -1563,6 +1622,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): - finished_sending=finished_sending, - finished_recving=finished_recving, - num_nans_in_logits=num_nans_in_logits, -+ finished_dumping=finished_dumping, -+ invalid_block_ids = invalid_block_ids - ) - - def propose_draft_token_ids( -@@ -1693,13 +1754,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): - self.maybe_setup_kv_connector(scheduler_output) - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) -+ invalid_block_ids = self.get_block_ids_with_load_errors() -+ get_kv_transfer_group().clear_connector_metadata() - -- if not finished_sending and not finished_recving: -+ if not finished_sending and not finished_recving and not invalid_block_ids: - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.finished_sending = finished_sending - output.finished_recving = finished_recving -+ output.invalid_block_ids = invalid_block_ids - return output - - @staticmethod -@@ -1719,9 +1783,28 @@ class GPUModelRunner(LoRAModelRunnerMixin): - kv_connector.start_load_kv(get_forward_context()) - - @staticmethod -- def maybe_wait_for_kv_save() -> None: -+ def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]: - if has_kv_transfer_group(): -- get_kv_transfer_group().wait_for_save() -+ return get_kv_transfer_group().wait_for_save() -+ -+ def maybe_execute_ucm_sparse_begin(self, scheduler_output: "SchedulerOutput", attn_metadata: CommonAttentionMetadata): -+ if not has_ucm_sparse(): -+ return -+ ucm_sparse = get_ucm_sparse() -+ ucm_sparse.build_sparse_meta(scheduler_output, self.requests, self.input_batch, attn_metadata) -+ ucm_sparse.execute_begin(scheduler_output) -+ -+ def maybe_execute_ucm_sparse_finished(self): -+ if not has_ucm_sparse(): -+ return -+ ucm_sparse = get_ucm_sparse() -+ ucm_sparse.execute_finished() -+ -+ def ucm_sparse_request_finished_in_worker(self, request_id: str | int): -+ if not has_ucm_sparse(): -+ return -+ ucm_sparse = get_ucm_sparse() -+ ucm_sparse.request_finished_in_worker(request_id) - - @staticmethod - def get_finished_kv_transfers( -@@ -1732,6 +1815,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): - scheduler_output.finished_req_ids) - return None, None - -+ def get_block_ids_with_load_errors(self) -> Optional[set[int]]: -+ if has_kv_transfer_group(): -+ return get_kv_transfer_group().get_block_ids_with_load_errors() -+ return None -+ - def propose_ngram_draft_token_ids( - self, - sampled_token_ids: list[list[int]], -diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py -index 9e7e44d06..d52a49a2e 100644 ---- a/vllm/v1/worker/gpu_worker.py -+++ b/vllm/v1/worker/gpu_worker.py -@@ -1,6 +1,7 @@ - # SPDX-License-Identifier: Apache-2.0 - # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - """A GPU worker class.""" -+import copy - import gc - import os - from typing import TYPE_CHECKING, Optional -@@ -15,7 +16,8 @@ from vllm.device_allocator.cumem import CuMemAllocator - from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) --from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized -+from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, -+ has_kv_transfer_group) - from vllm.distributed.parallel_state import get_pp_group, get_tp_group - from vllm.logger import init_logger - from vllm.lora.request import LoRARequest -@@ -24,10 +26,11 @@ from vllm.platforms import current_platform - from vllm.sequence import IntermediateTensors - from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling - from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec --from vllm.v1.outputs import ModelRunnerOutput -+from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput - from vllm.v1.utils import report_usage_stats - from vllm.v1.worker.gpu_model_runner import GPUModelRunner - from vllm.v1.worker.worker_base import WorkerBase -+from ucm.sparse.state import ensure_ucm_sparse_initialized - - logger = init_logger(__name__) - -@@ -313,9 +316,22 @@ class Worker(WorkerBase): - assert isinstance(output, IntermediateTensors) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) -- return None -+ if not has_kv_transfer_group(): -+ return None -+ -+ # In case of PP with kv transfer, we need to pass through the -+ # finished_sending and finished_recving buffers. -+ new_output = EMPTY_MODEL_RUNNER_OUTPUT -+ if output.finished_sending or output.finished_recving or output.finished_dumping or output.invalid_block_ids: -+ new_output = copy.copy(new_output) -+ new_output.finished_sending = output.finished_sending -+ new_output.finished_recving = output.finished_recving -+ new_output.finished_dumping = output.finished_dumping -+ new_output.invalid_block_ids = output.invalid_block_ids -+ output = new_output -+ - assert isinstance(output, ModelRunnerOutput) -- return output if self.is_driver_worker else None -+ return output - - def profile(self, is_start: bool = True): - if self.profiler is None: -@@ -386,6 +402,7 @@ def init_worker_distributed_environment( - parallel_config.pipeline_parallel_size) - - ensure_kv_transfer_initialized(vllm_config) -+ ensure_ucm_sparse_initialized(vllm_config) - - - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): --- -2.50.1.windows.1 -diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index 6a9d4b4b9..ae06cf9eb 100644 ---- a/vllm/v1/core/sched/scheduler.py -+++ b/vllm/v1/core/sched/scheduler.py -@@ -830,9 +830,8 @@ class Scheduler(SchedulerInterface): - if model_runner_output.finished_dumping is not None: - request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) - is_prefill = request.num_output_tokens == 0 -- is_last_chunk = (num_tokens_scheduled + request.num_computed_tokens >= request.num_prompt_tokens) -- if is_prefill and is_last_chunk: -- self.connector.connector.commit(request.succeed_dumped_blocks, True) -+ if is_prefill: -+ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) - - # Append generated tokens and check for stop. Note that if - # a request is still being prefilled, we expect the model runner diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch b/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch index e15f7ab52..8c459aa7f 100644 --- a/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch +++ b/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch @@ -1,16 +1,17 @@ -From 67b10fc431e5aac0155ca5b77cd9a99e35656521 Mon Sep 17 00:00:00 2001 +From 73de421dd3a9d3877b8903b8ee419e692da62b29 Mon Sep 17 00:00:00 2001 From: wenxinwang -Date: Thu, 25 Sep 2025 05:31:48 -0700 -Subject: [PATCH] UCM adaptor +Date: Mon, 10 Nov 2025 20:44:02 +0800 +Subject: [PATCH] adapt to deepseek --- - vllm_ascend/attention/attention_v1.py | 75 ++++++++++++++++++++ + vllm_ascend/attention/attention_v1.py | 76 ++++++++++++++++++++ + vllm_ascend/attention/mla_v1.py | 14 +++- vllm_ascend/worker/model_runner_v1.py | 99 +++++++++++++++++++++++---- vllm_ascend/worker/worker_v1.py | 25 +++++-- - 3 files changed, 183 insertions(+), 16 deletions(-) + 4 files changed, 196 insertions(+), 18 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py -index 7d7f488..09c4345 100644 +index 7d7f488f..18039f42 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -24,6 +24,9 @@ import torch_npu @@ -26,10 +27,10 @@ index 7d7f488..09c4345 100644 @@ -33,6 +36,8 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) - + +from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse + - + class AscendAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -444,10 +449,14 @@ def unified_ascend_attention_with_output( @@ -42,19 +43,20 @@ index 7d7f488..09c4345 100644 attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] -+ -+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) ++ if not self.use_mla: ++ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) self.impl.forward(self, query, key, -@@ -456,8 +465,74 @@ def unified_ascend_attention_with_output( +@@ -456,8 +465,75 @@ def unified_ascend_attention_with_output( attn_metadata, output, trace_flag=False) -+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) ++ if not self.use_mla: ++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) return - + +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return @@ -119,11 +121,67 @@ index 7d7f488..09c4345 100644 + return + + ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context) - + def unified_attention_with_output_fake( query: torch.Tensor, +diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py +index f50fe56e..4a27c22f 100644 +--- a/vllm_ascend/attention/mla_v1.py ++++ b/vllm_ascend/attention/mla_v1.py +@@ -13,10 +13,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size + from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) + from vllm.utils import cdiv, round_down ++from vllm.forward_context import ForwardContext, get_forward_context ++from vllm.attention.layer import (maybe_execute_sparse_attention_begin, maybe_execute_sparse_attention_finished) + + from vllm_ascend.ascend_config import get_ascend_config + from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV +-from vllm_ascend.attention.attention_v1 import AscendAttentionState ++from vllm_ascend.attention.attention_v1 import AscendAttentionState, wait_for_kv_layer_from_connector, maybe_save_kv_layer_to_connector + from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig + from vllm_ascend.multistream.context import get_multistream_comm_context + from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn +@@ -1042,6 +1044,7 @@ class AscendMLAImpl(MLAAttentionImpl): + enable_multistream_mla: bool = False, + ckq: Optional[torch.Tensor] = None, + ) -> torch.Tensor: ++ forward_context: ForwardContext = get_forward_context() + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. +@@ -1192,6 +1195,8 @@ class AscendMLAImpl(MLAAttentionImpl): + # FIX: aicore move should be also placed on the comm stream in dbo, + # otherwise it may affect the accuracy + # TODO: use an elegant way to overlap ++ wait_for_kv_layer_from_connector(layer.layer_name) ++ maybe_execute_sparse_attention_begin(prefill_q, prefill_k_c_normed, prefill_k_pe, layer.layer_name, forward_context, "prefill") + output_prefill = self._forward_prefill(prefill_q, + prefill_k_c_normed, + prefill_k_pe, kv_cache, +@@ -1203,8 +1208,11 @@ class AscendMLAImpl(MLAAttentionImpl): + current_ms_metadata.after_comm_event.record() + else: + output[num_decode_tokens:] = output_prefill +- ++ maybe_execute_sparse_attention_finished(prefill_q, prefill_k_c_normed, prefill_k_pe, output[num_decode_tokens:], layer.layer_name, forward_context, "prefill") ++ maybe_save_kv_layer_to_connector(layer.layer_name, kv_cache) + if has_decode: ++ wait_for_kv_layer_from_connector(layer.layer_name) ++ maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, layer.layer_name, forward_context, "decode") + if self.running_in_graph: + return self._forward_decode(decode_ql_nope, decode_q_pe, + decode_k_nope, decode_k_pe, +@@ -1223,5 +1231,7 @@ class AscendMLAImpl(MLAAttentionImpl): + current_ms_metadata.after_comm_event.record() + else: + output[:num_decode_tokens] = output_decode ++ maybe_execute_sparse_attention_finished(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, output[:num_decode_tokens], layer.layer_name, forward_context, "decode") ++ maybe_save_kv_layer_to_connector(layer.layer_name, kv_cache) + + return output_padded diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py -index eabcdbc..e51f46e 100644 +index eabcdbcc..179dffde 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -39,7 +39,10 @@ from vllm.config import CompilationLevel, VllmConfig @@ -141,7 +199,7 @@ index eabcdbc..e51f46e 100644 @@ -88,6 +91,9 @@ from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch - + +from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse +from ucm.sparse.base import UcmSparseMetadata, INVALID_SLOT + @@ -157,7 +215,7 @@ index eabcdbc..e51f46e 100644 self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. @@ -453,12 +460,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): - + # Update the states of the running/resumed requests. req_data = scheduler_output.scheduled_cached_reqs + req_sparsed_slots = scheduler_output.req_sparsed_slots @@ -168,7 +226,7 @@ index eabcdbc..e51f46e 100644 new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] + is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT - + req_state.num_computed_tokens = num_computed_tokens if not is_last_rank: @@ -474,15 +483,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): @@ -189,18 +247,18 @@ index eabcdbc..e51f46e 100644 - # The request is resumed from preemption. - # Replace the existing block IDs with the new ones. - req_state.block_ids = new_block_ids - + req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: @@ -496,6 +505,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - + + if is_sparsed_request: + self.input_batch.block_table.reset_row(req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) - + if not is_last_rank: @@ -876,7 +888,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): intermediate_tensors: Optional[IntermediateTensors] = None, @@ -215,7 +273,7 @@ index eabcdbc..e51f46e 100644 @@ -955,12 +968,22 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens) seq_lens = self.seq_lens_cpu[:num_reqs] - + + # TODO: improve performance, no `positions_np.copy()` + sparsed_positions = positions_np.copy() + req_sparsed_slots = scheduler_output.req_sparsed_slots @@ -229,7 +287,7 @@ index eabcdbc..e51f46e 100644 block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) + sparsed_positions // self.block_size) - + block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size @@ -240,7 +298,7 @@ index eabcdbc..e51f46e 100644 @@ -985,10 +1008,16 @@ class NPUModelRunner(LoRAModelRunnerMixin): else: attn_state = AscendAttentionState.PrefillCacheHit - + + for req_id in self.input_batch.req_id_to_index: + is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT + req_index = self.input_batch.req_id_to_index[req_id] @@ -254,10 +312,10 @@ index eabcdbc..e51f46e 100644 + position=torch.tensor(sparsed_positions).npu(), attn_state=attn_state) self.attn_state = attn_state # type: ignore - + @@ -1100,6 +1129,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions = self.positions[:padded_batch_size] - + # Run forward pass + finished_dumping = None with set_forward_context(attn_metadata, @@ -269,7 +327,7 @@ index eabcdbc..e51f46e 100644 ACL_FORMAT_FRACTAL_ND) + self.maybe_setup_kv_connector(scheduler_output) + self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata) - + hidden_states = self.model( input_ids=input_ids, @@ -1133,6 +1165,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): @@ -278,16 +336,16 @@ index eabcdbc..e51f46e 100644 ) + finished_dumping = self.maybe_wait_for_kv_save() + self.maybe_execute_ucm_sparse_finished() - + use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -1163,7 +1197,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): - + return (attn_metadata, hidden_states, spec_decode_metadata, positions, total_num_scheduled_tokens, logits_indices, aux_hidden_states, - num_scheduled_tokens) + num_scheduled_tokens, finished_dumping) - + def _get_cumsum_and_arange( self, @@ -1400,7 +1434,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): @@ -297,7 +355,7 @@ index eabcdbc..e51f46e 100644 - num_scheduled_tokens_np) = (self._process_reqs( + num_scheduled_tokens_np, finished_dumping) = (self._process_reqs( scheduler_output, intermediate_tensors)) - + with ProfileExecuteDuration().capture_async("post process"): @@ -1561,6 +1595,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): logprobs=logprobs_lists, @@ -305,7 +363,7 @@ index eabcdbc..e51f46e 100644 pooler_output=[], + finished_dumping=finished_dumping ) - + durations = ProfileExecuteDuration().pop_captured_sync() @@ -2369,3 +2404,43 @@ class NPUModelRunner(LoRAModelRunnerMixin): if batch_size <= padded_batch_size < selected_batch_size: @@ -353,16 +411,16 @@ index eabcdbc..e51f46e 100644 + ucm_sparse.request_finished_in_worker(request_id) \ No newline at end of file diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py -index df03d50..a854923 100644 +index df03d508..5d5d9b5a 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -17,6 +17,7 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py # - + +import copy from typing import Optional - + import torch @@ -27,7 +28,8 @@ from vllm import envs from vllm.config import VllmConfig @@ -381,15 +439,15 @@ index df03d50..a854923 100644 -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerBase - + import vllm_ascend.envs as envs_ascend @@ -49,6 +51,7 @@ from vllm_ascend.utils import (check_kv_cache_bytes_cache_exist, read_kv_cache_bytes_from_file, sleep_mode_enabled, try_register_lib) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +from ucm.sparse.state import ensure_ucm_sparse_initialized - - + + class NPUWorker(WorkerBase): @@ -222,9 +225,22 @@ class NPUWorker(WorkerBase): assert isinstance(output, IntermediateTensors) @@ -413,7 +471,7 @@ index df03d50..a854923 100644 assert isinstance(output, ModelRunnerOutput) - return output if self.is_driver_worker else None + return output - + def load_model(self) -> None: if self.vllm_config.model_config.enable_sleep_mode: @@ -321,6 +337,7 @@ class NPUWorker(WorkerBase): @@ -421,8 +479,9 @@ index df03d50..a854923 100644 ) ensure_kv_transfer_initialized(self.vllm_config) + ensure_ucm_sparse_initialized(self.vllm_config) - + def _init_profiler(self): # Torch profiler. Enabled and configured through env vars: --- +-- 2.34.1 + diff --git a/ucm/integration/vllm/patch/__init__.py b/ucm/integration/vllm/patch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ucm/integration/vllm/patch/apply_patch.py b/ucm/integration/vllm/patch/apply_patch.py new file mode 100644 index 000000000..39f5ccbb0 --- /dev/null +++ b/ucm/integration/vllm/patch/apply_patch.py @@ -0,0 +1,175 @@ +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +""" +Monkey patching module for vLLM to apply UCM patches automatically. +This replaces the need for manual `git apply` commands. +""" + +import sys +from typing import Optional + +from ucm.logger import init_logger + +logger = init_logger(__name__) + +import os + +PLATFORM = os.getenv("PLATFORM") + + +def _patch_ascend() -> bool: + return PLATFORM == "ascend" + + +# Track if patches have been applied +_patches_applied = False +_import_hook_installed = False +_vllm_version: Optional[str] = None +_vllm_import_hook = None + + +def get_vllm_version() -> Optional[str]: + """Detect vLLM version.""" + global _vllm_version + if _vllm_version is not None: + return _vllm_version + + try: + # Try to get version from vllm module + import vllm as vllm_pkg + + vllm_version = vllm_pkg.__version__ + return vllm_version + except ImportError: + logger.warning("vLLM is not installed") + return None + except Exception as e: + logger.warning(f"Failed to detect vLLM version: {e}") + return None + + +def get_supported_versions() -> list[str]: + """Get list of supported vLLM versions.""" + return ["0.9.2"] + + +def apply_all_patches() -> None: + """Apply all vLLM patches based on detected version.""" + global _patches_applied + if _patches_applied: + return + + try: + version = get_vllm_version() + if version is None: + raise ValueError("Could not detect vLLM version") + + supported_versions = get_supported_versions() + if version not in supported_versions: + logger.warning( + f"vLLM version {version} is not explicitly supported to apply UCM patches. " + f"Supported versions: {', '.join(supported_versions)}. " + ) + + # Apply version-specific patches + match version: + case "0.9.2": + _apply_patches_v092() + case _: + logger.warning( + f"Unsupported vLLM version: {version} to apply UCM patches. " + f"Supported versions: {', '.join(supported_versions)}." + ) + + _patches_applied = True + logger.info(f"All vLLM patches applied successfully for version {version}") + except Exception as e: + logger.error(f"Failed to apply vLLM patches: {e}", exc_info=True) + raise + + +def _apply_patches_v092() -> None: + """Apply patches for vLLM 0.9.2.""" + from .patch_funcs.v092.vllm_patch import _apply_sparse_adapt + + _apply_sparse_adapt() # apply vllm-sparse-adapt.patch + if _patch_ascend(): + from .patch_funcs.v092.vllm_ascend_patch import _apply_ascend_patch + + _apply_ascend_patch() # apply vllm-ascend-adapt.patch + + +def install_import_hook() -> None: + """Install an import hook to automatically apply patches when vLLM is imported.""" + global _import_hook_installed, _vllm_import_hook + + if _import_hook_installed: + return + + try: + # Check if vLLM is already imported + if "vllm" in sys.modules: + # vLLM already imported, apply patches immediately + apply_all_patches() + _import_hook_installed = True + else: + # Install import hook by wrapping the builtin __import__ function + # This intercepts all imports and applies patches when vLLM is imported + import builtins + + original_import = builtins.__import__ + + def import_hook(name, globals=None, locals=None, fromlist=(), level=0): + # Call original import + module = original_import(name, globals, locals, fromlist, level) + + # If the main vLLM module is being imported, apply patches + # We only check for 'vllm' (not submodules) to avoid multiple patch attempts + if name == "vllm" and not _patches_applied: + try: + apply_all_patches() + except Exception as e: + logger.warning(f"Failed to apply patches during import: {e}") + + return module + + # Replace builtin __import__ + builtins.__import__ = import_hook + _vllm_import_hook = import_hook + _import_hook_installed = True + logger.debug("Import hook installed to intercept vLLM imports") + + except Exception as e: + logger.warning(f"Failed to install import hook: {e}") + + +def ensure_patches_applied() -> None: + """Ensure patches are applied, installing import hook if needed.""" + if not _patches_applied: + # Try to apply patches immediately + try: + apply_all_patches() + except Exception: + # If it fails (vLLM not imported yet), install hook + install_import_hook() diff --git a/ucm/integration/vllm/patch/patch_funcs/__init__.py b/ucm/integration/vllm/patch/patch_funcs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/__init__.py b/ucm/integration/vllm/patch/patch_funcs/v092/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py new file mode 100644 index 000000000..f3927ece9 --- /dev/null +++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py @@ -0,0 +1,1446 @@ +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +from __future__ import annotations + +import os + +from ucm.logger import init_logger + +logger = init_logger(__name__) + +ENABLE_SPARSE = os.getenv("ENABLE_SPARSE") + + +def _enable_sparse() -> bool: + return ENABLE_SPARSE is not None and ENABLE_SPARSE.lower() == "true" + + +def _apply_ascend_patch() -> None: + """Apply patch for vLLM-Ascend.""" + try: + from vllm_ascend.patch import platform, worker + + if _enable_sparse(): + _patch_attention_v1() + _patch_mla_v1() + _patch_model_runner_v1() + _patch_worker_v1() + logger.info("UCM sparse adapt patches applied successfully") + + except Exception as e: + logger.error(f"Could not apply sparse adapt patches: {e}") + raise e + + +# ========================= vllm_ascend/attention/attention_v1.py ========================= +def _patch_attention_v1() -> None: + """Patch attention_v1.py for vLLM-Ascend.""" + try: + from typing import List + + import torch + from vllm.forward_context import ForwardContext, get_forward_context + from vllm_ascend.attention import attention_v1 + + from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse + + def maybe_execute_sparse_attention_begin( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, + forward_context: ForwardContext, + ): + if not has_ucm_sparse(): + return + + ucm_sparse = get_ucm_sparse() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + ucm_sparse.attention_begin(query, key, value, layer_name, forward_context) + + attention_v1.maybe_execute_sparse_attention_begin = ( + maybe_execute_sparse_attention_begin + ) + + def maybe_execute_sparse_attention_finished( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_output: torch.Tensor, + layer_name: str, + forward_context: ForwardContext, + ): + if not has_ucm_sparse(): + return + ucm_sparse = get_ucm_sparse() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + ucm_sparse.attention_finished( + query, key, value, attn_output, layer_name, forward_context + ) + + attention_v1.maybe_execute_sparse_attention_finished = ( + maybe_execute_sparse_attention_finished + ) + + vllm_ops = torch.ops.vllm + orig_unified_ascend_attention_with_output = ( + vllm_ops.unified_ascend_attention_with_output + ) + + def _wrap_op_overload(orig, impl): + class _Wrapper: + def __init__(self, orig): + self._orig = orig + + def __call__(self, *args, **kwargs): + return impl(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self._orig, name) + + return _Wrapper(orig) + + def unified_ascend_attention_with_output_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, + ) -> None: + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + if not self.use_mla: + maybe_execute_sparse_attention_begin( + query, key, value, layer_name, forward_context + ) + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output, + trace_flag=False, + ) + if not self.use_mla: + maybe_execute_sparse_attention_finished( + query, key, value, output, layer_name, forward_context + ) + return + + vllm_ops.unified_ascend_attention_with_output = _wrap_op_overload( + orig_unified_ascend_attention_with_output, + unified_ascend_attention_with_output_impl, + ) + + attention_v1.unified_ascend_attention_with_output = ( + unified_ascend_attention_with_output_impl + ) + except ImportError as e: + logger.error(f"Failed to patch attention_v1.py: {e}", exc_info=True) + raise + + +# ========================= vllm_ascend/attention/mla_v1.py ========================= +def _patch_mla_v1() -> None: + """Patch mla_v1.py for vLLM-Ascend.""" + try: + from typing import Optional + + import torch + import torch_npu + from vllm.attention.backends.abstract import AttentionLayer + from vllm.attention.layer import ( + maybe_execute_sparse_attention_begin, + maybe_execute_sparse_attention_finished, + ) + from vllm.forward_context import ForwardContext, get_forward_context + from vllm_ascend.attention.attention_v1 import ( + AscendAttentionState, + ) + from vllm_ascend.attention.mla_v1 import AscendMLAImpl + from vllm_ascend.multistream.context import get_multistream_comm_context + from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor + + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, # query in unified attn + hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: M, + output: Optional[torch.Tensor] = None, + enable_multistream_mla: bool = False, + ckq: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + forward_context: ForwardContext = get_forward_context() + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. + return output + self.running_in_graph = ( + self.torchair_graph_enabled + and attn_metadata.attn_state + in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding] + ) + num_actual_toks = attn_metadata.num_actual_tokens + if k_pe is None and not self.running_in_graph: + if not self.torchair_graph_enabled: + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states_or_kv_c_normed)[ + 0 + ].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + else: + kv_c_normed = hidden_states_or_kv_c_normed + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + if not self.running_in_graph: + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + if not self.torchair_graph_enabled: + kv_c_normed = kv_c_normed[:num_actual_toks, ...] + prefill_k_c_normed = kv_c_normed[num_decode_tokens:] + if not self.running_in_graph: + hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] + prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] + if not self.torchair_graph_enabled: + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + k_pe = k_pe[:num_actual_toks, ...] + k_pe = k_pe.unsqueeze(1) + decode_k_pe = k_pe[:num_decode_tokens] + prefill_k_pe = k_pe[num_decode_tokens:] + else: + decode_hs_or_q_c = hidden_states_or_q_c + if has_decode: + decode_k_nope = None + assert attn_metadata.decode is not None + if self.running_in_graph: + seq_len = ( + self.rotary_emb.max_position_embeddings + * self.rotary_emb.scaling_factor + ) + cos = self.rotary_emb.cos_cached[:seq_len].to( + dtype=decode_hs_or_q_c.dtype + ) + sin = self.rotary_emb.sin_cached[:seq_len].to( + dtype=decode_hs_or_q_c.dtype + ) + cos = cos[attn_metadata.decode.input_positions] + sin = sin[attn_metadata.decode.input_positions] + cos = cos[:, None, None, :] + sin = sin[:, None, None, :] + with npu_stream_switch( + "mla_secondary", 0, enabled=enable_multistream_mla + ): + npu_wait_tensor( + hidden_states_or_kv_c_normed, + ckq, + enabled=enable_multistream_mla, + ) + decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( + hidden_states_or_kv_c_normed, + cos, + sin, + kv_cache, + attn_metadata.slot_mapping, + ) + # Without explicitly controlling the order, IndexByTensor operations + # would be placed after `matmul W_KV_T` hindering the overlapping of + # KvRmsNormRopeCache and SingleRope. + npu_wait_tensor( + decode_hs_or_q_c, cos, enabled=enable_multistream_mla + ) + npu_wait_tensor( + decode_hs_or_q_c, sin, enabled=enable_multistream_mla + ) + npu_wait_tensor( + decode_hs_or_q_c, decode_kv, enabled=enable_multistream_mla + ) + + decode_ql_nope, decode_q_pe = self._q_proj_and_k_up_proj( + decode_hs_or_q_c + ) + if self.running_in_graph: + with npu_stream_switch( + "mla_secondary", 0, enabled=enable_multistream_mla + ): + npu_wait_tensor( + decode_q_pe, decode_k_pe, enabled=enable_multistream_mla + ) + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + else: + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( + attn_metadata.decode.input_positions, + decode_q_pe.contiguous(), + decode_k_pe, + max_seq_len=attn_metadata.decode.max_seq_lens, + ) + if has_prefill: + assert attn_metadata.prefill is not None + prefill_q = self.q_proj(prefill_hs_or_q_c)[0].view( + -1, self.num_heads, self.qk_head_dim + ) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :] + prefill_q_nope = prefill_q[..., : self.qk_nope_head_dim] + if self.torchair_graph_enabled: + num_tokens = prefill_hs_or_q_c.shape[0] + seq_len = ( + self.rotary_emb.max_position_embeddings + * self.rotary_emb.scaling_factor + ) + cos = self.rotary_emb.cos_cached[:seq_len].to( + dtype=prefill_q_pe.dtype + ) + sin = self.rotary_emb.sin_cached[:seq_len].to( + dtype=prefill_q_pe.dtype + ) + cos = cos[attn_metadata.prefill.input_positions] + sin = sin[attn_metadata.prefill.input_positions] + cos = cos[:, None, None, :] + sin = sin[:, None, None, :] + + prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) + prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( + hidden_states_or_kv_c_normed, + cos, + sin, + kv_cache, + attn_metadata.slot_mapping, + ) + + kv_c_normed = prefill_k_nope[:num_actual_toks, ...] + prefill_k_c_normed = prefill_k_nope[num_decode_tokens:] + prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, -1) + prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1) + else: + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( + attn_metadata.prefill.input_positions, + prefill_q_pe.contiguous(), + prefill_k_pe, + max_seq_len=attn_metadata.prefill.max_seq_lens, + ) + if self.torchair_graph_enabled: + if ( + len(kv_cache) > 0 + and kv_cache[0].numel() > 0 + and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache + ): + slots = attn_metadata.slot_mapping + # NOTE: Separate the kv cache in advance to avoid OOM or other issues + torch_npu._npu_reshape_and_cache( + key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1), + value=prefill_k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=slots, + ) + elif kv_cache.numel() > 0: + key = torch.cat( + [kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe], + dim=2, + ) + torch_npu._npu_reshape_and_cache_siso( + key=key, + key_cache=kv_cache, + slot_indices=attn_metadata.slot_mapping.flatten(), + ) + if has_prefill: + # FIX: aicore move should be also placed on the comm stream in dbo, + # otherwise it may affect the accuracy + # TODO: use an elegant way to overlap + maybe_execute_sparse_attention_begin( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + layer.layer_name, + forward_context, + "prefill", + ) + output_prefill = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata + ) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + output[num_decode_tokens:] = output_prefill + current_ms_metadata.after_comm_event.record() + else: + output[num_decode_tokens:] = output_prefill + maybe_execute_sparse_attention_finished( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + output[num_decode_tokens:], + layer.layer_name, + forward_context, + "prefill", + ) + if has_decode: + maybe_execute_sparse_attention_begin( + torch.cat([decode_ql_nope, decode_q_pe], dim=-1), + decode_ql_nope, + decode_q_pe, + layer.layer_name, + forward_context, + "decode", + ) + if self.running_in_graph: + return self._forward_decode( + decode_ql_nope, + decode_q_pe, + decode_k_nope, + decode_k_pe, + kv_cache, + attn_metadata, + enable_multistream_mla, + ) + else: + output_decode = self._forward_decode( + decode_ql_nope, + decode_q_pe, + decode_k_nope, + decode_k_pe, + kv_cache, + attn_metadata, + ) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + output[:num_decode_tokens] = output_decode + current_ms_metadata.after_comm_event.record() + else: + output[:num_decode_tokens] = output_decode + maybe_execute_sparse_attention_finished( + torch.cat([decode_ql_nope, decode_q_pe], dim=-1), + decode_ql_nope, + decode_q_pe, + output[:num_decode_tokens], + layer.layer_name, + forward_context, + "decode", + ) + + return output_padded + + AscendMLAImpl.forward = forward + except ImportError as e: + logger.error(f"Failed to patch mla_v1.py: {e}", exc_info=True) + raise + + +# ========================= vllm_ascend/worker/model_runner_v1.py ========================= +def _patch_model_runner_v1() -> None: + """Patch model_runner_v1.py for vLLM-Ascend.""" + try: + from typing import TYPE_CHECKING, List, Optional, Union + + import numpy as np + import torch + from vllm.distributed.parallel_state import get_pp_group, get_tp_group + from vllm.forward_context import set_forward_context + from vllm.logger import logger + from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding + from vllm.sampling_params import SamplingType + from vllm.sequence import IntermediateTensors + from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput + from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + from vllm_ascend.ascend_config import get_ascend_config + from vllm_ascend.attention.attention_v1 import ( + AscendAttentionState, + AscendMetadata, + ) + from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata + from vllm_ascend.attention.mla_v1 import ( + AscendMLAMetadata, + CommonAttentionMetadata, + ) + from vllm_ascend.utils import ( + ACL_FORMAT_FRACTAL_ND, + ACL_FORMAT_FRACTAL_NZ, + ProfileExecuteDuration, + maybe_converting_weight_acl_format, + ) + from vllm_ascend.worker.npu_input_batch import CachedRequestState + + from ucm.sparse.base import INVALID_SLOT + + if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.distributed.kv_transfer import ( + get_kv_transfer_group, + has_kv_transfer_group, + ) + from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 + from vllm.forward_context import get_forward_context, set_forward_context + from vllm_ascend.worker.model_runner_v1 import NPUModelRunner + + from ucm.sparse.base import INVALID_SLOT + from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + """Update the cached states and the persistent batch with the scheduler + output. + + The SamplingMetadata is updated and copied to the NPU if there is a + new/resumed/paused/finished request in the batch. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.ucm_sparse_request_finished_in_worker(req_id) + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + removed_req_indices: List[int] = [] + for req_id in scheduler_output.finished_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + req_index = self.input_batch.remove_request(req_id) + assert req_index is not None + removed_req_indices.append(req_index) + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + if ( + sampling_params + and sampling_params.sampling_type == SamplingType.RANDOM_SEED + ): + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + pooling_params=new_req_data.pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_input in self.requests[req_id].mm_inputs: + if mm_input.get("image_grid_thw") is not None: + image_grid_thw.extend(mm_input["image_grid_thw"].tolist()) + if mm_input.get("video_grid_thw") is not None: + video_grid_thw.extend(mm_input["video_grid_thw"].tolist()) + if mm_input.get("second_per_grid_ts") is not None: + second_per_grid_ts.extend(mm_input["second_per_grid_ts"]) + if mm_input.get("audio_feature_lengths") is not None: + audio_feature_lengths.extend( + mm_input["audio_feature_lengths"] + ) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + hf_config = self.model_config.hf_config + + ( + self.requests[req_id].mrope_positions, + self.requests[req_id].mrope_position_delta, + ) = MRotaryEmbedding.get_input_positions_tensor( + self.requests[req_id].prompt_token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + req_ids_to_add.append(req_id) + + # Update the states of the running/resumed requests. + req_data = scheduler_output.scheduled_cached_reqs + req_sparsed_slots = scheduler_output.req_sparsed_slots + is_last_rank = get_pp_group().is_last_rank + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] + is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT + + req_state.num_computed_tokens = num_computed_tokens + if not is_last_rank: + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec decode tokens. + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:] + ) + # Update the block IDs. + if resumed_from_preemption or is_sparsed_request: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + else: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip( # type: ignore[call-overload] + req_state.block_ids, new_block_ids + ): + block_ids.extend(new_ids) + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + num_computed_tokens + ) + + if is_sparsed_request: + self.input_batch.block_table.reset_row(req_index) + + self.input_batch.block_table.append_row(new_block_ids, req_index) + + if not is_last_rank: + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_token_index:end_token_index + ] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, () + ) + if spec_token_ids: + start_index = end_token_index + end_token_index += len(spec_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index + ] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec decode tokens. + self.input_batch.num_tokens[req_index] = end_token_index + + # Check if the batch has changed. If not, we can skip copying the + # sampling metadata from CPU to GPU. + batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + removed_req_indices.sort(reverse=True) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + if removed_req_indices: + # Fill the empty index. + req_index = removed_req_indices.pop() + else: + # Append to the end. + req_index = None + self.input_batch.add_request(req_state, req_index) + + # Condense the batched states if there are empty indices. + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + if batch_changed: + self.input_batch.refresh_sampling_metadata() + + NPUModelRunner._update_states = _update_states + + def _process_reqs( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> tuple[ + Union[AscendMetadata, AscendMLAMetadata, AscendTorchairMetadata], + torch.Tensor, + SpecDecodeMetadata, + torch.Tensor, + int, + torch.Tensor, + torch.Tensor, + np.ndarray, + Optional[dict[str, list[str]]], + ]: + # Check input valid + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + if ( + self.use_aclgraph + and total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1] + ): + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + total_num_scheduled_tokens + ) + else: + # Eager mode. + num_input_tokens = total_num_scheduled_tokens + + modified_batch = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output + ) + if modified_batch: + self.input_batch.refresh_sampling_metadata() + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit(num_reqs) + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) + num_valid_tokens = np.empty(num_reqs, dtype=np.int32) + max_num_scheduled_tokens = 0 + for i, req_id in enumerate(self.input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens[i] = num_tokens + num_valid_tokens[i] = num_tokens - len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) + ) + max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + + # Prepare positions + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + cu_num_tokens = np.cumsum(num_scheduled_tokens) + cumsums_offsets = np.repeat( + cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens + ) + logits_indices = cu_num_tokens - 1 + logits_indices = torch.from_numpy(logits_indices).to( + self.device, non_blocking=True + ) + arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets + + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + if self.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True, + ) + + self.positions[total_num_scheduled_tokens:num_input_tokens].zero_() + self.positions[:total_num_scheduled_tokens].copy_( + self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True + ) + positions = self.positions[:num_input_tokens] + self.query_lens = torch.from_numpy(num_scheduled_tokens) + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens + ) + seq_lens = self.seq_lens_cpu[:num_reqs] + + # TODO: improve performance, no `positions_np.copy()` + sparsed_positions = positions_np.copy() + req_sparsed_slots = scheduler_output.req_sparsed_slots + for req_id in self.input_batch.req_id_to_index: + is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT + req_index = self.input_batch.req_id_to_index[req_id] + offset = ( + 0 if req_index == 0 else cu_num_tokens[req_index - 1] + ) # TODO: support MTP + if is_sparsed_request: + sparsed_positions[offset] = req_sparsed_slots[req_id] - 1 + + block_table_indices = ( + req_indices * self.max_num_blocks_per_req + + sparsed_positions // self.block_size + ) + + block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = sparsed_positions % self.block_size + np.add( + block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens], + ) + + ascend_config = get_ascend_config() + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): + attn_state = AscendAttentionState.PrefillNoCache + # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. + elif np.all(num_scheduled_tokens == 1): + attn_state = AscendAttentionState.DecodeOnly + # Speculative decoding. + elif np.all(num_valid_tokens == 1): + if self.use_eagle: + attn_state = AscendAttentionState.ChunkedPrefill + else: + attn_state = AscendAttentionState.SpecDecoding + # splitfuse + elif ( + not ascend_config.ascend_scheduler_config.enabled + or self.chunked_prefill_enabled + ): + attn_state = AscendAttentionState.ChunkedPrefill + else: + attn_state = AscendAttentionState.PrefillCacheHit + + for req_id in self.input_batch.req_id_to_index: + is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT + req_index = self.input_batch.req_id_to_index[req_id] + if is_sparsed_request: + seq_lens[req_index] = req_sparsed_slots[req_id] + + self.attn_mask = self._make_attention_mask( + seq_lens=seq_lens, + query_lens=num_scheduled_tokens, + position=torch.tensor(sparsed_positions).npu(), + attn_state=attn_state, + ) + self.attn_state = attn_state # type: ignore + + extra_builder_kwargs = {} + + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1 : num_reqs + 1] = cu_num_tokens + self.query_start_loc[: num_reqs + 1].copy_( + self.query_start_loc_cpu[: num_reqs + 1], non_blocking=True + ) + self.seq_lens[:num_reqs].copy_( + self.seq_lens_cpu[:num_reqs], non_blocking=True + ) + + # Fill unused with -1. Needed for reshape_and_cache + self.seq_lens[num_reqs:].fill_(0) + self.query_start_loc[num_reqs + 1 :].fill_(-1) + + with_prefill = attn_state not in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding, + ] + + if self.dp_size > 1: + max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( + total_num_scheduled_tokens, with_prefill + ) + extra_builder_kwargs["max_num_tokens_across_dp"] = max_num_tokens + extra_builder_kwargs["with_prefill_across_dp"] = with_prefill + + # Add graph_pad_size here + if self.torchair_graph_enabled and not with_prefill: + if self.dp_size > 1: + padded_batch_size = self.select_torchair_padded_batch_size( + max_num_tokens + ) + else: + padded_batch_size = self.select_torchair_padded_batch_size( + total_num_scheduled_tokens + ) + graph_pad_size = padded_batch_size - total_num_scheduled_tokens + + extra_builder_kwargs["graph_pad_size"] = graph_pad_size + + if self.vllm_config.model_config.use_mla: + query_start_loc = self.query_start_loc[: num_reqs + 1] + seq_lens = self.seq_lens[:num_reqs] + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, seq_lens=seq_lens + ) + attn_metadata = self.attn_metadata_builder.build( # type: ignore + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_attn_metadata=common_attn_metadata, + common_prefix_len=None, + **extra_builder_kwargs, + ) + else: + attn_metadata = self.attn_metadata_builder.build( # type: ignore + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=None, + **extra_builder_kwargs, + ) + attn_metadata.num_input_tokens = num_input_tokens + + # Prepare input_ids + token_indices = ( + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + ) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) + # Copy the tensors to the NPU. + self.input_ids[:total_num_scheduled_tokens].copy_( + self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True + ) + + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + else: + mm_embeds = [] + + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[:total_num_scheduled_tokens] + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings( + input_ids, mm_embeds + ) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:total_num_scheduled_tokens].copy_(inputs_embeds) + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the ACL graph. + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + + if self.torchair_graph_enabled and not with_prefill: + input_ids = self.input_ids[:padded_batch_size] + positions = self.positions[:padded_batch_size] + + # Run forward pass + with set_forward_context( + attn_metadata, self.vllm_config, num_tokens=num_input_tokens + ): + with ProfileExecuteDuration().capture_async("forward"): + model_kwargs = {} + if self.torchair_graph_enabled: + model_kwargs["kv_caches"] = self.kv_caches + model_kwargs["attn_metadata"] = attn_metadata + if self.torchair_graph_enabled and not with_prefill: + maybe_converting_weight_acl_format( + self.model, ACL_FORMAT_FRACTAL_NZ + ) + + compiled_model = self._get_torchair_lazy_compiled_model( + padded_batch_size + ) + hidden_states = compiled_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + else: + assert self.model is not None + maybe_converting_weight_acl_format( + self.model, ACL_FORMAT_FRACTAL_ND + ) + self.maybe_setup_kv_connector(scheduler_output) + self.maybe_execute_ucm_sparse_begin( + scheduler_output, attn_metadata + ) + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + self.maybe_wait_for_kv_save() + self.maybe_execute_ucm_sparse_finished() + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + spec_decode_metadata = None + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for ( + req_id, + draft_token_ids, + ) in scheduler_output.scheduled_spec_decode_tokens.items(): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens + ) + logits_indices = spec_decode_metadata.logits_indices + + aux_hidden_states = None + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = hidden_states + + return ( + attn_metadata, + hidden_states, + spec_decode_metadata, + positions, + total_num_scheduled_tokens, + logits_indices, + aux_hidden_states, + num_scheduled_tokens, + ) + + NPUModelRunner._process_reqs = _process_reqs + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, torch.Tensor]: + with ProfileExecuteDuration().capture_async("prepare input and forward"): + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + ( + attn_metadata, + hidden_states, + spec_decode_metadata, + positions, + num_scheduled_tokens, + logits_indices, + aux_hidden_states, + num_scheduled_tokens_np, + ) = self._process_reqs(scheduler_output, intermediate_tensors) + + with ProfileExecuteDuration().capture_async("post process"): + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + broadcast_pp_output = ( + self.parallel_config.distributed_executor_backend + == "external_launcher" + and len(get_pp_group().ranks) > 0 + ) + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + if not broadcast_pp_output: + return hidden_states + assert isinstance(hidden_states, IntermediateTensors) + get_pp_group().send_tensor_dict( + hidden_states.tensors, all_gather_group=get_tp_group() + ) + logits = None + else: + if self.input_batch.pooling_params: + return self._pool( + hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + ) + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if broadcast_pp_output: + model_output_broadcast_data = ( + { + "logits": logits.contiguous(), + } + if logits is not None + else {} + ) + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + logits = self.apply_grammar_bitmask(scheduler_output, logits) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + + discard_sampled_tokens_req_indices: list[int] = [] + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = ( + req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id] + ) + if seq_len < req_state.num_tokens: + # Ignore the sampled token. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + discard_sampled_tokens_req_indices.append(i) + + # NOTE: NPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = ( + logprobs_tensors.tolists() if logprobs_tensors is not None else None + ) + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output, + ) + + # Get the valid generated tokens. + sampled_token_ids = sampler_output.sampled_token_ids + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + # Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.model_config.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.model_config.max_model_len}" + ) + + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = ( + sampled_ids + ) + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + spec_token_ids = self._get_spec_token_ids( + valid_sampled_token_ids, + sampling_metadata, + scheduler_output, + spec_decode_metadata, + positions, + num_scheduled_tokens, + hidden_states, + attn_metadata, + aux_hidden_states, + ) + + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + ) + + durations = ProfileExecuteDuration().pop_captured_sync() + if durations: + dr_str = [ + f"[{tag}]:{duration:.2f}ms" for tag, duration in durations.items() + ] + captured_name = ( + "Decode" + if self.attn_state == AscendAttentionState.DecodeOnly + else "Prefill" + ) + logger.info( + "Profile execute duration [%s]:%s", captured_name, " ".join(dr_str) + ) + + return model_runner_output + + NPUModelRunner.execute_model = execute_model + + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata + ) + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + kv_connector.start_load_kv(get_forward_context()) + + NPUModelRunner.maybe_setup_kv_connector = maybe_setup_kv_connector + + @staticmethod + def maybe_wait_for_kv_save(): + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() + + NPUModelRunner.maybe_wait_for_kv_save = maybe_wait_for_kv_save + + def maybe_execute_ucm_sparse_begin( + self, + scheduler_output: "SchedulerOutput", + attn_metadata: CommonAttentionMetadata, + ): + if not has_ucm_sparse(): + return + ucm_sparse = get_ucm_sparse() + ucm_sparse.build_sparse_meta( + scheduler_output, self.requests, self.input_batch, attn_metadata + ) + ucm_sparse.execute_begin(scheduler_output) + + def maybe_execute_ucm_sparse_finished(self): + if not has_ucm_sparse(): + return + ucm_sparse = get_ucm_sparse() + ucm_sparse.execute_finished() + + def ucm_sparse_request_finished_in_worker(self, request_id: str | int): + if not has_ucm_sparse(): + return + ucm_sparse = get_ucm_sparse() + ucm_sparse.request_finished_in_worker(request_id) + + NPUModelRunner.maybe_execute_ucm_sparse_begin = maybe_execute_ucm_sparse_begin + NPUModelRunner.maybe_execute_ucm_sparse_finished = ( + maybe_execute_ucm_sparse_finished + ) + NPUModelRunner.ucm_sparse_request_finished_in_worker = ( + ucm_sparse_request_finished_in_worker + ) + except ImportError as e: + logger.error(f"Failed to patch model_runner_v1.py: {e}", exc_info=True) + raise + + +# ========================= vllm_ascend/worker/worker_v1.py ========================= +def _patch_worker_v1() -> None: + """Patch worker_v1.py for vLLM-Ascend.""" + try: + import copy + from typing import Optional + + from vllm.distributed.parallel_state import get_pp_group, get_tp_group + from vllm.logger import logger + from vllm.sequence import IntermediateTensors + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput + from vllm_ascend.worker.worker_v1 import NPUWorker + + from ucm.sparse.state import ensure_ucm_sparse_initialized + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> Optional[ModelRunnerOutput]: + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict(all_gather_group=get_tp_group()) + ) + + output = self.model_runner.execute_model( + scheduler_output, intermediate_tensors + ) + parallel_config = self.vllm_config.parallel_config + if ( + parallel_config.distributed_executor_backend != "external_launcher" + and not get_pp_group().is_last_rank + ): + assert isinstance(output, IntermediateTensors) + get_pp_group().send_tensor_dict( + output.tensors, all_gather_group=get_tp_group() + ) + + kv_connector_output = output.kv_connector_output + finished_sending = kv_connector_output.finished_sending + finished_recving = kv_connector_output.finished_recving + + if not finished_sending and not finished_recving: + return EMPTY_MODEL_RUNNER_OUTPUT + + new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + new_output.kv_connector_output = kv_connector_output + return new_output + + assert isinstance(output, ModelRunnerOutput) + return output + + NPUWorker.execute_model = execute_model + + original_init_worker_distributed_environment = ( + NPUWorker._init_worker_distributed_environment + ) + + def patched_init_worker_distributed_environment(self) -> None: + original_init_worker_distributed_environment(self) + ensure_ucm_sparse_initialized(self.vllm_config) + + NPUWorker._init_worker_distributed_environment = ( + patched_init_worker_distributed_environment + ) + except ImportError as e: + logger.error(f"Failed to patch worker_v1.py: {e}", exc_info=True) + raise diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py new file mode 100644 index 000000000..2a697efb0 --- /dev/null +++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_patch.py @@ -0,0 +1,1992 @@ +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +from __future__ import annotations + +import os + +from ucm.logger import init_logger + +logger = init_logger(__name__) + +ENABLE_SPARSE = os.getenv("ENABLE_SPARSE") + + +def _enable_sparse() -> bool: + return ENABLE_SPARSE is not None and ENABLE_SPARSE.lower() == "true" + + +def _apply_sparse_adapt() -> None: + """Apply sparse adapt patches.""" + try: + if _enable_sparse(): + _patch_block_table() + _patch_kv_cache_manager() + _patch_shared_storage_connector() + _patch_attention_layer() + _patch_mla_common() + _patch_gpu_model_runner() + _patch_gpu_worker() + _patch_scheduler_output() + _patch_scheduler() + logger.info("UCM sparse adapt patches applied successfully") + except Exception as e: + logger.error(f"Could not apply sparse adapt patches: {e}") + raise e + + +# ==================== vllm/v1/core/sched/output.py ==================== +def _patch_scheduler_output() -> None: + """Patch scheduler output to add UCM sparse support.""" + try: + from dataclasses import dataclass + from typing import TYPE_CHECKING, Optional + + if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, + ) + from vllm.v1.core.sched import output + from vllm.v1.core.sched.output import CachedRequestData, NewRequestData + + @dataclass + class SchedulerOutput: + + # list of the requests that are scheduled for the first time. + # We cache the request's data in each worker process, so that we don't + # need to re-send it every scheduling step. + scheduled_new_reqs: list[NewRequestData] + # list of the requests that have been scheduled before. + # Since the request's data is already cached in the worker processes, + # we only send the diff to minimize the communication cost. + scheduled_cached_reqs: CachedRequestData + + # req_id -> num_scheduled_tokens + # Number of tokens scheduled for each request. + num_scheduled_tokens: dict[str, int] + # Total number of tokens scheduled for all requests. + # Equal to sum(num_scheduled_tokens.values()) + total_num_scheduled_tokens: int + # req_id -> spec_token_ids + # If a request does not have any spec decode tokens, it will not be + # included in the dictionary. + scheduled_spec_decode_tokens: dict[str, list[int]] + # req_id -> encoder input indices that need processing. + # E.g., if a request has [0, 1], it could mean the vision encoder needs + # to process that the request's 0-th and 1-th images in the current step. + scheduled_encoder_inputs: dict[str, list[int]] + # Number of common prefix blocks for all requests in each KV cache group. + # This can be used for cascade attention. + num_common_prefix_blocks: list[int] + + # Request IDs that are finished in between the previous and the current + # steps. This is used to notify the workers about the finished requests + # so that they can free the cached states for those requests. + finished_req_ids: set[str] + # list of (req_id, encoder_input_index) tuples. + # Used to free the encoder cache. + free_encoder_input_ids: list[tuple[str, int]] + + # Dict of request ids to their index within the batch + # for filling the next token bitmask + structured_output_request_ids: dict[str, int] + # the bitmask for the whole batch + grammar_bitmask: Optional[npt.NDArray[np.int32]] + + # KV Cache Connector metadata. + kv_connector_metadata: Optional[KVConnectorMetadata] = None + + # modified slots by sparse algorithm + req_sparsed_slots: dict[str, int] = None + + # Set module and qualname to make the class pickleable + # This ensures pickle can find the class when serializing + SchedulerOutput.__module__ = output.__name__ + SchedulerOutput.__qualname__ = "SchedulerOutput" + + output.SchedulerOutput = SchedulerOutput + + except ImportError: + logger.warning("Could not patch scheduler output - module not found") + + +# ==================== vllm/attention/layer.py ==================== +def _patch_attention_layer() -> None: + """Patch attention layer & unified_attention_with_output C++ op.""" + try: + from typing import Optional + + import torch + from vllm.attention.layer import ( + maybe_save_kv_layer_to_connector, + wait_for_kv_layer_from_connector, + ) + from vllm.forward_context import ForwardContext, get_forward_context + + from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse + + def maybe_execute_sparse_attention_begin( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, + forward_context: ForwardContext, + phase: Optional[str] = None, + ): + if not has_ucm_sparse(): + return + + ucm_sparse = get_ucm_sparse() + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + + ucm_sparse.attention_begin( + query, key, value, layer_name, forward_context, phase + ) + + def maybe_execute_sparse_attention_finished( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_output: torch.Tensor, + layer_name: str, + forward_context: ForwardContext, + phase: Optional[str] = None, + ): + if not has_ucm_sparse(): + return + + ucm_sparse = get_ucm_sparse() + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + + ucm_sparse.attention_finished( + query, key, value, attn_output, layer_name, forward_context, phase + ) + + vllm_ops = torch.ops.vllm + orig_unified_attention_with_output = vllm_ops.unified_attention_with_output + orig_unified_attention = vllm_ops.unified_attention + + def _wrap_op_overload(orig, impl): + class _Wrapper: + def __init__(self, orig): + self._orig = orig + + def __call__(self, *args, **kwargs): + return impl(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self._orig, name) + + return _Wrapper(orig) + + def unified_attention_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, + ) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + maybe_execute_sparse_attention_begin( + query, key, value, layer_name, forward_context + ) + output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) + maybe_execute_sparse_attention_finished( + query, key, value, output, layer_name, forward_context + ) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output + + def unified_attention_with_output_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, + output_scale: Optional[torch.Tensor] = None, + ) -> None: + wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + if not self.use_mla: + maybe_execute_sparse_attention_begin( + query, key, value, layer_name, forward_context + ) + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + ) + if not self.use_mla: + maybe_execute_sparse_attention_finished( + query, key, value, output, layer_name, forward_context + ) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + + vllm_ops.unified_attention_with_output = _wrap_op_overload( + orig_unified_attention_with_output, unified_attention_with_output_impl + ) + vllm_ops.unified_attention = _wrap_op_overload( + orig_unified_attention, unified_attention_impl + ) + from vllm.attention import layer + + layer.maybe_execute_sparse_attention_begin = ( + maybe_execute_sparse_attention_begin + ) + layer.maybe_execute_sparse_attention_finished = ( + maybe_execute_sparse_attention_finished + ) + layer.unified_attention = unified_attention_impl + layer.unified_attention_with_output = unified_attention_with_output_impl + + except ImportError: + logger.warning( + "Could not patch unified attention with output - module not found" + ) + + +# ==================== v1/shared_storage_connector.py ==================== +def _patch_shared_storage_connector() -> None: + """Patch kv connector utils to add UCM sparse support.""" + try: + from dataclasses import dataclass, field + + from vllm.distributed.kv_transfer.kv_connector.v1 import ( + shared_storage_connector, + ) + from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, + ) + from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( + ReqMeta, + ) + + @dataclass + class SharedStorageConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] = field(default_factory=list) + + def add_request( + self, + token_ids: list[int], + block_ids: list[int], + block_size: int, + is_store: bool, + ) -> None: + self.requests.append( + ReqMeta.make_meta(token_ids, block_ids, block_size, is_store) + ) + + shared_storage_connector.SharedStorageConnectorMetadata = ( + SharedStorageConnectorMetadata + ) + except ImportError: + logger.warning("Could not patch shared storage connector - module not found") + + +# ==================== vllm/v1/attention/backends/mla/common.py ==================== +def _patch_mla_common() -> None: + """Patch mla common to add UCM sparse support.""" + try: + from typing import Optional, TypeVar + + import torch + from vllm import _custom_ops as ops + from vllm.attention.backends.abstract import AttentionLayer + from vllm.attention.layer import ( + maybe_execute_sparse_attention_begin, + maybe_execute_sparse_attention_finished, + ) + from vllm.forward_context import ForwardContext, get_forward_context + from vllm.v1.attention.backends.mla.common import ( + MLACommonImpl, + MLACommonMetadata, + ) + + M = TypeVar("M", bound=MLACommonMetadata) + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: M, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + forward_context: ForwardContext = get_forward_context() + assert output is not None, "Output tensor must be provided." + + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for MLACommonImpl" + ) + + if attn_metadata is None: + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + decode_q = q[:num_decode_tokens] + + prefill_q = q[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + if has_prefill: + maybe_execute_sparse_attention_begin( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + layer.layer_name, + forward_context, + "prefill", + ) + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata + ) + maybe_execute_sparse_attention_finished( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + output[num_decode_tokens:], + layer.layer_name, + forward_context, + "prefill", + ) + if has_decode: + assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + maybe_execute_sparse_attention_begin( + torch.cat([decode_ql_nope, decode_q_pe], dim=-1), + decode_ql_nope, + decode_q_pe, + layer.layer_name, + forward_context, + "decode", + ) + output[:num_decode_tokens] = self._forward_decode( + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata + ) + maybe_execute_sparse_attention_finished( + torch.cat([decode_ql_nope, decode_q_pe], dim=-1), + decode_ql_nope, + decode_q_pe, + output[:num_decode_tokens], + layer.layer_name, + forward_context, + "decode", + ) + return output_padded + + MLACommonImpl.forward = forward + except ImportError: + logger.warning("Could not patch mla common - module not found") + + +# ==================== v1/core/kv_cache_manager.py ==================== +def _patch_kv_cache_manager() -> None: + """Patch kv cache manager to add UCM sparse support.""" + try: + from typing import Optional, Union + + from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager + from vllm.v1.request import Request + + from ucm.sparse.base import INVALID_SLOT + from ucm.sparse.state import get_ucm_sparse + + original_allocate_slots = KVCacheManager.allocate_slots + + def patched_allocate_slots( + self, + request: Request, + num_new_tokens: int, + num_new_computed_tokens: int = 0, + new_computed_blocks: Optional[KVCacheBlocks] = None, + num_draft_tokens: int = 0, + num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, + num_slots_sparsed: Union[None, int] = None, + ) -> Optional[KVCacheBlocks]: + if num_new_tokens == 0: + raise ValueError("num_new_tokens must be greater than 0") + # Only route to UCM sparse path when caller explicitly provided + # a valid sparsified slot count. + if (num_slots_sparsed is not None) and (num_slots_sparsed != INVALID_SLOT): + return get_ucm_sparse().allocate_slots(self, request, num_slots_sparsed) + return original_allocate_slots( + self, + request, + num_new_tokens, + num_new_computed_tokens, + new_computed_blocks, + num_draft_tokens, + num_lookahead_tokens, + delay_cache_blocks, + ) + + KVCacheManager.allocate_slots = patched_allocate_slots + except ImportError: + logger.warning("Could not patch kv cache manager - module not found") + + +# ==================== vllm/v1/core/sched/scheduler.py ==================== +def _patch_scheduler() -> None: + """Patch Scheduler to add num_output_tokens field.""" + try: + import itertools + import time + from collections import defaultdict + from collections.abc import Iterable + from typing import Optional + + from vllm.distributed.kv_events import KVEventBatch + from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( + MultiConnector, + ) + from vllm.v1.core.sched.output import ( + CachedRequestData, + NewRequestData, + SchedulerOutput, + ) + from vllm.v1.core.sched.request_queue import ( + SchedulingPolicy, + create_request_queue, + ) + from vllm.v1.core.sched.scheduler import Scheduler + from vllm.v1.core.sched.utils import check_stop + from vllm.v1.engine import ( + EngineCoreEventType, + EngineCoreOutput, + EngineCoreOutputs, + ) + from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.request import Request, RequestStatus + from vllm.v1.spec_decode.metrics import SpecDecodingStats + + from ucm.sparse.base import INVALID_SLOT, UcmSparseRole + from ucm.sparse.state import ensure_ucm_sparse_initialized, get_ucm_sparse + + def init_ucm_sparse(self): + self.ucm_sparse = None + if self.vllm_config.kv_transfer_config is not None: + if ( + "ucm_sparse_config" + in self.vllm_config.kv_transfer_config.kv_connector_extra_config + ): + ensure_ucm_sparse_initialized( + self.vllm_config, role=UcmSparseRole.SCHEDULER + ) + self.ucm_sparse = get_ucm_sparse() + logger.info( + "UCM Sparse initialized successfully: {}".format( + self.ucm_sparse + ) + ) + + def patched_schedule(self) -> SchedulerOutput: + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. + # Each request just has the num_computed_tokens and + # num_tokens_with_spec. num_tokens_with_spec = + # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids). + # At each step, the scheduler tries to assign tokens to the requests + # so that each request's num_computed_tokens can catch up its + # num_tokens_with_spec. This is general enough to cover + # chunked prefills, prefix caching, speculative decoding, + # and the "jump decoding" optimization in the future. + + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + # NOTE: structured_output_request_ids maps + # a request's (request that uses structured output) + # request_id to the running request index. + # This will helps us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + structured_output_request_ids: dict[str, int] = {} + + req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_budget = self.max_num_encoder_input_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # For logging. + scheduled_timestamp = time.monotonic() + + # First, schedule the RUNNING requests. + req_index = 0 + req_sparsed_slots: dict[str, int] = {} + if not hasattr(self, "ucm_sparse"): + init_ucm_sparse(self) + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + num_slots_sparsed = INVALID_SLOT + if self.ucm_sparse: + num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed( + request + ) + req_sparsed_slots.update({request.request_id: num_slots_sparsed}) + + num_new_tokens = ( + request.num_tokens_with_spec - request.num_computed_tokens + ) + if ( + 0 + < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens + ): + num_new_tokens = self.scheduler_config.long_prefill_token_threshold + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min( + num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens + ) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget) = ( + self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_budget, + ) + ) + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when PP>1 and + # we have already scheduled all prompt tokens but they are + # not finished yet. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue + + num_draft_tokens = max( + num_new_tokens + request.num_computed_tokens - request.num_tokens, 0 + ) + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_draft_tokens=num_draft_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, + num_slots_sparsed=num_slots_sparsed, + ) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + if request.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids[request.request_id] = req_index + req_to_new_block_ids[request.request_id] = new_blocks.get_block_ids() + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = ( + num_new_tokens + + request.num_computed_tokens + - request.num_tokens + ) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids + ) + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule + ) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Record the LoRAs in scheduled_running_reqs + scheduled_loras: set[int] = set() + if self.lora_config: + scheduled_loras = set( + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) + assert len(scheduled_loras) <= self.lora_config.max_loras + + # Use a temporary RequestQueue to collect requests that need to be + # skipped and put back at the head of the waiting queue later + skipped_waiting_requests = create_request_queue(self.policy) + + # Next, schedule the WAITING requests. + if not preempted_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: + break + + request = self.waiting.peek_request() + num_slots_sparsed = INVALID_SLOT + if self.ucm_sparse: + num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed( + request + ) + req_sparsed_slots.update({request.request_id: num_slots_sparsed}) + + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id, + ) + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Skip request if the structured output request is still waiting + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Check that adding the request still respects the max_loras + # constraint. + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): + # Scheduling would exceed max_loras, skip. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_external_computed_tokens = 0 + load_kv_async = False + + # Get already-cached tokens. + if request.num_computed_tokens == 0: + # Get locally-cached tokens. + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request) + ) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens + ) + ) + + # Total computed tokens (local + external). + num_computed_tokens = ( + num_new_local_computed_tokens + num_external_computed_tokens + ) + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. + else: + new_computed_blocks = ( + self.kv_cache_manager.create_empty_block_list() + ) + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens + + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + + # KVTransfer: loading remote KV, do not allocate for new work. + if load_kv_async: + assert num_external_computed_tokens > 0 + num_new_tokens = 0 + # Number of tokens to be scheduled. + else: + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if ( + 0 + < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens + ): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold + ) + + # chunked prefill has to be enabled explicitly to allow + # pooling requests to be chunked + if ( + not self.scheduler_config.chunked_prefill_enabled + and num_new_tokens > token_budget + ): + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_budget, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_budget, + ) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens + num_external_computed_tokens, + num_new_local_computed_tokens, + new_computed_blocks, + num_lookahead_tokens=self.num_lookahead_tokens, + delay_cache_blocks=load_kv_async, + num_slots_sparsed=num_slots_sparsed, + ) + if new_blocks is None: + # The request cannot be scheduled. + break + + # KVTransfer: the connector uses this info to determine + # if a load is needed. Note that + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + new_computed_blocks + new_blocks, + num_external_computed_tokens, + ) + + # Request was already popped from self.waiting + # unless it was re-added above due to new_blocks being None. + request = self.waiting.pop_request() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.prepend_request(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + + if request.use_structured_output: + structured_output_request_ids[request.request_id] = req_index + req_index += 1 + self.running.append(request) + if self.log_stats: + request.record_event( + EngineCoreEventType.SCHEDULED, scheduled_timestamp + ) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError(f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_block_ids[request.request_id] = ( + self.kv_cache_manager.get_block_ids(request.request_id) + ) + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule + ) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.prepend_requests(skipped_waiting_requests) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs + ) <= len(self.running) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running) + ) + ) + + grammar_bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request(req, req_to_new_block_ids[req.request_id]) + for req in scheduled_new_reqs + ] + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_block_ids, + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=cached_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + req_sparsed_slots=req_sparsed_slots, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + structured_output_request_ids=structured_output_request_ids, + grammar_bitmask=grammar_bitmask, + ) + + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + + events = self.kv_cache_manager.take_events() + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + self._update_after_schedule(scheduler_output) + return scheduler_output + + Scheduler.schedule = patched_schedule + + def patched_add_request(self, request: Request) -> None: + if not hasattr(self, "ucm_sparse"): + init_ucm_sparse(self) + self.waiting.add_request(request) + self.requests[request.request_id] = request + if self.ucm_sparse: + self.ucm_sparse.request_begin( + request.request_id, request.prompt_token_ids + ) + if self.log_stats: + request.record_event(EngineCoreEventType.QUEUED) + + Scheduler.add_request = patched_add_request + + original_free_request = Scheduler._free_request + + def patched_free_request(self, request: Request): + assert request.is_finished() + if not hasattr(self, "ucm_sparse"): + init_ucm_sparse(self) + if self.ucm_sparse: + self.ucm_sparse.request_finished_in_scheduler(request.request_id) + original_free_request(self, request) + + Scheduler._free_request = patched_free_request + except ImportError: + logger.warning("Could not patch Scheduler - module not found") + + +# ==================== vllm/v1/worker/block_table.py ==================== +def _patch_block_table() -> None: + """Patch block table to add UCM sparse support.""" + try: + from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable + + def reset_row( + self, + row_idx: int, + ) -> None: + self.num_blocks_per_row[row_idx] = 0 + self.block_table[row_idx].fill_(0) + self.block_table_cpu[row_idx].fill_(0) + self.block_table_np[row_idx].fill(0) + + BlockTable.reset_row = reset_row + + def reset_row(self, row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.reset_row(row_idx) + + MultiGroupBlockTable.reset_row = reset_row + except ImportError: + logger.warning("Could not patch multigroup block table - module not found") + + +# ==================== vllm/v1/worker/gpu_model_runner.py ==================== +def _patch_gpu_model_runner() -> None: + """Patch gpu model runner to add UCM sparse support.""" + try: + import copy + from typing import TYPE_CHECKING, Any, Optional + + import numpy as np + import torch + import vllm.envs as envs + from vllm.distributed.kv_transfer import ( + get_kv_transfer_group, + has_kv_transfer_group, + ) + from vllm.distributed.parallel_state import get_pp_group + from vllm.forward_context import set_forward_context + from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding + from vllm.sampling_params import SamplingType + from vllm.sequence import IntermediateTensors + from vllm.utils import round_up + from vllm.v1.attention.backends.utils import CommonAttentionMetadata + from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput + from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + from vllm.v1.worker.block_table import BlockTable + from vllm.v1.worker.gpu_input_batch import CachedRequestState + + from ucm.sparse.base import INVALID_SLOT + from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse + + if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + + @staticmethod + def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().wait_for_save() + return None + + GPUModelRunner.maybe_wait_for_kv_save = maybe_wait_for_kv_save + + def maybe_execute_ucm_sparse_begin( + self, + scheduler_output: "SchedulerOutput", + attn_metadata: CommonAttentionMetadata, + ): + if not has_ucm_sparse(): + return + ucm_sparse = get_ucm_sparse() + ucm_sparse.build_sparse_meta( + scheduler_output, self.requests, self.input_batch, attn_metadata + ) + ucm_sparse.execute_begin(scheduler_output) + + def maybe_execute_ucm_sparse_finished(self): + if not has_ucm_sparse(): + return + ucm_sparse = get_ucm_sparse() + ucm_sparse.execute_finished() + + def ucm_sparse_request_finished_in_worker(self, request_id: str | int): + if not has_ucm_sparse(): + return + ucm_sparse = get_ucm_sparse() + ucm_sparse.request_finished_in_worker(request_id) + + GPUModelRunner.maybe_execute_ucm_sparse_begin = maybe_execute_ucm_sparse_begin + GPUModelRunner.maybe_execute_ucm_sparse_finished = ( + maybe_execute_ucm_sparse_finished + ) + GPUModelRunner.ucm_sparse_request_finished_in_worker = ( + ucm_sparse_request_finished_in_worker + ) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + The SamplingMetadata is updated and copied to the GPU if there is a + new/resumed/paused/finished request in the batch. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.ucm_sparse_request_finished_in_worker(req_id) + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + for req_id in scheduler_output.finished_req_ids: + self.input_batch.remove_request(req_id) + + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + self.input_batch.remove_request(req_id) + + req_ids_to_add: list[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + pooling_params = new_req_data.pooling_params + if ( + sampling_params + and sampling_params.sampling_type == SamplingType.RANDOM_SEED + ): + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + pooling_params=pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_input in self.requests[req_id].mm_inputs: + if mm_input.get("image_grid_thw") is not None: + image_grid_thw.extend(mm_input["image_grid_thw"].tolist()) + if mm_input.get("video_grid_thw") is not None: + video_grid_thw.extend(mm_input["video_grid_thw"].tolist()) + if mm_input.get("second_per_grid_ts") is not None: + second_per_grid_ts.extend(mm_input["second_per_grid_ts"]) + if mm_input.get("audio_feature_lengths") is not None: + audio_feature_lengths.extend( + mm_input["audio_feature_lengths"] + ) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + hf_config = self.model_config.hf_config + + ( + self.requests[req_id].mrope_positions, + self.requests[req_id].mrope_position_delta, + ) = MRotaryEmbedding.get_input_positions_tensor( + self.requests[req_id].prompt_token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + req_ids_to_add.append(req_id) + + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs + req_sparsed_slots = scheduler_output.req_sparsed_slots + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] + is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens + + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:] + ) + + # Update the block IDs. + if resumed_from_preemption or is_sparsed_request: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + else: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): + block_ids.extend(new_ids) + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + num_computed_tokens + ) + if is_sparsed_request: + self.input_batch.block_table.reset_row(req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) + + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not is_last_rank: + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_token_index:end_token_index + ] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, () + ) + if spec_token_ids: + start_index = end_token_index + end_token_index += len(spec_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index + ] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec tokens. + self.input_batch.num_tokens[req_index] = end_token_index + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + self.input_batch.add_request(req_state) + + # Condense the batched states if there are gaps left by removed requests + self.input_batch.condense() + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) + # Refresh batch metadata with any pending updates. + self.input_batch.refresh_metadata() + + GPUModelRunner._update_states = _update_states + + def _prepare_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> tuple[ + dict[str, Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], np.ndarray + ]: + """ + :return: tuple[ + attn_metadata: layer-to-attention_metadata mapping, + attention_cuda_graphs: whether attention can run in cudagraph + logits_indices, spec_decode_metadata + ] + """ + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = max(tokens) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) + + # Get positions. + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens + ) + + # TODO: improve performance, no `positions_np.copy()` + sparsed_positions = positions_np.copy() + req_sparsed_slots = scheduler_output.req_sparsed_slots + for req_id in self.input_batch.req_id_to_index: + is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT + req_index = self.input_batch.req_id_to_index[req_id] + offset = ( + 0 if req_index == 0 else cu_num_tokens[req_index - 1] + ) # TODO: support MTP + if is_sparsed_request: + sparsed_positions[offset] = req_sparsed_slots[req_id] - 1 + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = ( + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + ) + + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) + + # Calculate the slot mapping for each KV cache group. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups + ): + block_size = kv_cache_group_spec.kv_cache_spec.block_size + block_table: BlockTable = self.input_batch.block_table[ + kv_cache_group_id + ] + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + + sparsed_positions // block_size + ) + block_table_cpu = block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = sparsed_positions % block_size + np.add( + block_numbers * block_size, + block_offsets, + out=block_table.slot_mapping_np[:total_num_scheduled_tokens], + ) + + # Prepare the attention metadata. + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1 : num_reqs + 1] = cu_num_tokens + + for req_id in self.input_batch.req_id_to_index: + req_index = self.input_batch.req_id_to_index[req_id] + is_sparsed_request = ( + scheduler_output.req_sparsed_slots[req_id] != INVALID_SLOT + ) + if is_sparsed_request: + self.seq_lens_np[req_index] = scheduler_output.req_sparsed_slots[ + req_id + ] + + # Copy the tensors to the GPU. + self.input_ids[:total_num_scheduled_tokens].copy_( + self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True + ) + if self.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True, + ) + else: + # Common case (1D positions) + self.positions_cpu[:total_num_scheduled_tokens] = torch.from_numpy( + positions_np[:total_num_scheduled_tokens] + ) + self.positions[:total_num_scheduled_tokens].copy_( + self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True + ) + + self.query_start_loc[: num_reqs + 1].copy_( + self.query_start_loc_cpu[: num_reqs + 1], non_blocking=True + ) + self.seq_lens[:num_reqs].copy_( + self.seq_lens_cpu[:num_reqs], non_blocking=True + ) + + # Fill unused with -1. Needed for reshape_and_cache + self.seq_lens[num_reqs:].fill_(0) + # Note: pad query_start_loc to be non-decreasing, as kernels + # like FlashAttention requires that + self.query_start_loc[num_reqs + 1 :].fill_( + self.query_start_loc_cpu[num_reqs].item() + ) + + query_start_loc = self.query_start_loc[: num_reqs + 1] + seq_lens = self.seq_lens[:num_reqs] + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, + seq_lens=seq_lens, + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + ) + + attn_metadata: dict[str, Any] = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups + ): + + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + builder = self.attn_metadata_builders[kv_cache_group_id] + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks[kv_cache_group_id], + kv_cache_group_spec.kv_cache_spec, + builder, + ) + + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + ) + + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + attention_cuda_graphs = all( + b.can_run_in_cudagraph(common_attn_metadata) + for b in self.attn_metadata_builders + ) + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + spec_decode_metadata = None + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for ( + req_id, + draft_token_ids, + ) in scheduler_output.scheduled_spec_decode_tokens.items(): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens + ) + logits_indices = spec_decode_metadata.logits_indices + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + + return ( + attn_metadata, + attention_cuda_graphs, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens, + ) + + GPUModelRunner._prepare_inputs = _prepare_inputs + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + return self.kv_connector_no_forward(scheduler_output) + + # Prepare the decoder inputs. + ( + attn_metadata, + attention_cuda_graphs, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens_np, + ) = self._prepare_inputs(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if ( + self.use_cuda_graph + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1] + ): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_scheduled_tokens + ) + else: + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if ( + self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1 + ): + num_input_tokens = round_up(num_scheduled_tokens, tp_size) + else: + num_input_tokens = num_scheduled_tokens + + # Padding for DP + num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) + num_input_tokens += num_pad + + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + else: + mm_embeds = [] + + if self.is_multimodal_model and get_pp_group().is_first_rank: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[:num_scheduled_tokens] + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings( + input_ids, mm_embeds + ) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + else: + positions = self.positions[:num_input_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_input_tokens, intermediate_tensors, True + ) + + # Some attention backends only support CUDA Graphs in pure decode. + # If attention doesn't support CUDA Graphs for this batch, but we + # compiled with full CUDA graphs, we have to skip them entirely. + skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs + + # Run the model. + # Use persistent buffers for CUDA graphs. + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs, + ): + self.maybe_setup_kv_connector(scheduler_output) + self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata) + + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + self.maybe_wait_for_kv_save() + self.maybe_execute_ucm_sparse_finished() + + finished_sending, finished_recving = self.get_finished_kv_transfers( + scheduler_output + ) + + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + aux_hidden_states = None + + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + broadcast_pp_output = ( + self.parallel_config.distributed_executor_backend == "external_launcher" + and len(get_pp_group().ranks) > 0 + ) + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + if not broadcast_pp_output: + return hidden_states + assert isinstance(hidden_states, IntermediateTensors) + get_pp_group().send_tensor_dict( + hidden_states.tensors, all_gather_group=get_tp_group() + ) + logits = None + else: + if self.input_batch.pooling_params: + return self._pool( + hidden_states, + num_scheduled_tokens, + num_scheduled_tokens_np, + finished_sending, + finished_recving, + ) + + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if broadcast_pp_output: + model_output_broadcast_data = ( + { + "logits": logits.contiguous(), + } + if logits is not None + else {} + ) + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + self.apply_grammar_bitmask(scheduler_output, logits) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + + num_nans_in_logits = {} + if envs.VLLM_COMPUTE_NANS_IN_LOGITS: + num_nans_in_logits = self._get_nans_in_logits(logits) + + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = ( + req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id] + ) + if seq_len < req_state.num_tokens: + # Ignore the sampled token for partial prefills. + # Rewind the generator state as if the token was not sampled. + # This relies on cuda-specific torch-internal impl details + generator = self.input_batch.generators.get(i) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + # Record the index of the request that should not be sampled, + # so that we could clear the sampled tokens before returning. + discard_sampled_tokens_req_indices.append(i) + + # NOTE: GPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = ( + logprobs_tensors.tolists() if logprobs_tensors is not None else None + ) + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output, + ) + + # Get the valid generated tokens. + sampled_token_ids = sampler_output.sampled_token_ids + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.max_model_len}" + ) + + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + if not self.speculative_config: + # Speculative decoding is not enabled. + spec_token_ids = None + else: + spec_token_ids = self.propose_draft_token_ids( + scheduler_output, + valid_sampled_token_ids, + sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + attn_metadata, + ) + + # Clear KVConnector state after all KVs are generated. + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + + self.eplb_step() + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + finished_sending=finished_sending, + finished_recving=finished_recving, + num_nans_in_logits=num_nans_in_logits, + ) + + GPUModelRunner.execute_model = execute_model + + except ImportError: + logger.warning("Could not patch prepare inputs - module not found") + + +# ==================== vllm/v1/worker/gpu_worker.py ==================== +def _patch_gpu_worker() -> None: + """Patch gpu worker to add UCM sparse support.""" + try: + from typing import Optional + + from vllm.config import VllmConfig + from vllm.v1.worker import gpu_worker + + from ucm.sparse.state import ensure_ucm_sparse_initialized + + original_init_worker_distributed_environment = ( + gpu_worker.init_worker_distributed_environment + ) + + def patched_init_worker_distributed_environment( + vllm_config: VllmConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, + backend: str = "nccl", + ) -> None: + original_init_worker_distributed_environment( + vllm_config, rank, distributed_init_method, local_rank, backend + ) + ensure_ucm_sparse_initialized(vllm_config) + + gpu_worker.init_worker_distributed_environment = ( + patched_init_worker_distributed_environment + ) + except ImportError: + logger.warning("Could not patch gpu worker - module not found") diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index 421011ca2..c8317007b 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -25,9 +25,10 @@ # import hashlib import pickle +from collections import defaultdict from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union +from typing import TYPE_CHECKING, Any, List, Optional, Union import torch from vllm.config import VllmConfig @@ -43,6 +44,7 @@ from ucm.logger import init_logger from ucm.store.factory import UcmConnectorFactory from ucm.store.ucmstore import Task +from ucm.utils import Config if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -89,7 +91,7 @@ class UnifiedCacheConnectorV1(KVConnectorBase_V1): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self.block_size = vllm_config.cache_config.block_size - self.use_layerwise = True + self.use_layerwise = False self.kv_caches: dict[str, torch.Tensor] = {} self.total_tp_size = vllm_config.parallel_config.tensor_parallel_size self.rank = ( @@ -98,36 +100,25 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.request_block_infos: dict[str, RequestBlockInfo] = {} # dump tasks record request -> block -> list[task] self.dump_tasks: dict[str, dict[str, List[Task]]] = {} - self.layerwise_load_tasks: dict[str, dict[str, tuple[Task, Task]]] = {} + self.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict(dict) self.is_mla = self._vllm_config.model_config.is_deepseek_mla self.num_layers = vllm_config.model_config.get_num_layers( vllm_config.parallel_config ) self.element_size = vllm_config.model_config.dtype.itemsize self.kv_role = vllm_config.kv_transfer_config.kv_role - self._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {} + self._need_load_reqs: dict[str, Union[list[int], Task]] = {} self._load_failed_reqs: set[str] = set() self._load_req_to_blocks: dict[str, set[int]] = {} self.num_head = vllm_config.model_config.get_num_kv_heads( vllm_config.parallel_config ) self.head_size = vllm_config.model_config.get_head_size() - if ( - self._vllm_config.kv_transfer_config is not None - and "ucm_connector_name" - in self._vllm_config.kv_transfer_config.kv_connector_extra_config - ): - name = self._vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_name" - ] - config = {} - if ( - "ucm_connector_config" - in self._vllm_config.kv_transfer_config.kv_connector_extra_config - ): - config = self._vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_config" - ] + ucm_config = Config(vllm_config.kv_transfer_config) + launch_config = ucm_config.get_config() + if "ucm_connector_name" in launch_config: + name = launch_config.get("ucm_connector_name") + config = launch_config.get("ucm_connector_config") or {} config["device"] = self.rank config["role"] = ( "scheduler" if role == KVConnectorRole.SCHEDULER else "worker" @@ -187,9 +178,9 @@ def DataOffset(self, kv_layer, rank, layer_id, is_v): kv_layer[1][0].numel() if not self.is_mla else 0 ) * elem_size # When tp > 1 layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size - layer_size = ( - k_min_data_block_size + v_min_data_block_size - ) * self.total_tp_size + layer_size = (k_min_data_block_size + v_min_data_block_size) * ( + self.total_tp_size if not self.is_mla else 1 + ) if is_v: # Offset of v = Offset of k + k_min_data_block_size return int( @@ -258,65 +249,49 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: if len(self.kv_caches) == 0: self._init_kv_caches_from_forward_context(forward_context) + if len(list(self.kv_caches.values())[0]) == 2: + self.is_mla = False self.layerwise_load_tasks.clear() self.current_layer = 0 + need_load_tasks: dict[str, Task] = {} for request in metadata.requests: if not request.load_blocks: continue storage_block_ids = [block[0] for block in request.load_blocks] vllm_block_ids = [block[1] for block in request.load_blocks] - blocks_len = len(storage_block_ids) self._load_req_to_blocks.setdefault(request.request_id, set()).update( vllm_block_ids ) + is_load_async = request.load_async + total_offsets = [] + total_tensors = [] + storage_block_ids = storage_block_ids * (1 if self.is_mla else 2) for layer_name, kv_layer in self.kv_caches.items(): tensors, offsets = self.get_tensor_and_offset_layerwise( vllm_block_ids, kv_layer, layer_name ) - k_task_id = self.connector.load( - storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] - ) - v_task_id = None - if not self.is_mla: - v_task_id = self.connector.load( - storage_block_ids, - offsets[blocks_len:], - tensors[blocks_len:], - ) - if request.request_id not in self.layerwise_load_tasks: - self.layerwise_load_tasks[request.request_id] = {} - self.layerwise_load_tasks[request.request_id][layer_name] = ( - k_task_id, - v_task_id, + if self.use_layerwise and not is_load_async: + task_id = self.connector.load(storage_block_ids, offsets, tensors) + self.layerwise_load_tasks[request.request_id][layer_name] = task_id + continue + else: + total_offsets.extend(offsets) + total_tensors.extend(tensors) + if total_offsets and total_tensors: + storage_block_ids = storage_block_ids * self.num_layers + task_id = self.connector.load( + storage_block_ids, total_offsets, total_tensors ) - - if request.load_async and request.request_id in self.layerwise_load_tasks: - for _, (k_task, v_task) in self.layerwise_load_tasks[ - request.request_id - ].items(): - if request.request_id not in self._need_load_reqs: - self._need_load_reqs[request.request_id] = [] - self._need_load_reqs[request.request_id].append(k_task) - if not self.is_mla: - self._need_load_reqs[request.request_id].append(v_task) - self.layerwise_load_tasks.pop(request.request_id) - continue - - if ( - not self.use_layerwise - and request.request_id in self.layerwise_load_tasks - ): - for _, (k_task, v_task) in self.layerwise_load_tasks[ - request.request_id - ].items(): - if self.connector.wait(k_task) != 0: - self._load_failed_reqs.add(request.request_id) - break - if v_task and self.connector.wait(v_task) != 0: - self._load_failed_reqs.add(request.request_id) - break + if is_load_async: + self._need_load_reqs[request.request_id] = task_id + else: + need_load_tasks[request.request_id] = task_id + for req_id, task_id in need_load_tasks.items(): + if self.connector.wait(task_id) != 0: + self._load_failed_reqs.add(req_id) + logger.error(f"Failed to load blocks for req {req_id}") def wait_for_layer_load(self, layer_name: str) -> None: """ @@ -334,26 +309,19 @@ def wait_for_layer_load(self, layer_name: str) -> None: if self.layerwise_load_tasks: logger.debug(f"Waiting for layer {self.current_layer} to be loaded") - assert ( - self.current_layer < self.num_layers - ), "The current layer should be less than total layers!" + if self.current_layer >= self.num_layers: + return + for request_id, layer_to_task in self.layerwise_load_tasks.items(): if request_id in self._load_failed_reqs: continue - k_task, v_task = layer_to_task[layer_name] - if self.connector.wait(k_task) != 0: + task = layer_to_task[layer_name] + if self.connector.wait(task) != 0: self._load_failed_reqs.add(request_id) logger.error( f"Failed to load block for request {request_id} on layer {layer_name}" ) continue - if not self.is_mla: - if self.connector.wait(v_task) != 0: - self._load_failed_reqs.add(request_id) - logger.error( - f"Failed to load block for request {request_id} on layer {layer_name}" - ) - continue logger.debug(f"Load tasks for {request_id} on layer {layer_name} finished.") def save_kv_layer( @@ -384,6 +352,9 @@ def save_kv_layer( if not self.use_layerwise: return + if self.current_layer > self.num_layers: + return + metadata = self._get_connector_metadata() assert isinstance(metadata, UCConnectorV1Metadata) @@ -407,6 +378,8 @@ def save_kv_layer( torch.npu.current_stream().synchronize() elif kv_layer[0].device.type == "cuda": torch.cuda.current_stream().synchronize() + elif kv_layer[0].device.type == "musa": + torch.musa.current_stream().synchronize() for block_id, offset, tensor in zip( storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] @@ -434,6 +407,8 @@ def wait_for_save(self) -> Optional[dict[str, list[str]]]: """ if hasattr(self, "kv_role") and self.kv_role == "kv_consumer": return + if self.is_mla and self.rank != 0: + return # request id -> succeed dumped blocks success_dumped_blocks: dict[str, list[str]] = {} @@ -452,57 +427,57 @@ def wait_for_tasks(): self.dump_tasks.clear() return success_dumped_blocks if success_dumped_blocks else None + req_to_dump_blocks: dict[str, list[str]] = {} + need_dump_tasks: dict[str, Task] = {} for request in metadata.requests: if not request.dump_blocks: continue storage_block_ids = [block[0] for block in request.dump_blocks] vllm_block_ids = [block[1] for block in request.dump_blocks] - blocks_len = len(storage_block_ids) + req_to_dump_blocks[request.request_id] = storage_block_ids + total_offsets = [] + total_tensors = [] + total_block_ids = ( + storage_block_ids * (1 if self.is_mla else 2) * self.num_layers + ) for layer_name, kv_layer in self.kv_caches.items(): tensors, offsets = self.get_tensor_and_offset_layerwise( vllm_block_ids, kv_layer, layer_name ) - for block_id, offset, tensor in zip( - storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] - ): - task = self.connector.dump([block_id], [offset], [tensor]) - self.dump_tasks.setdefault(request.request_id, {}).setdefault( - block_id, [] - ).append(task) - if not self.is_mla: - for block_id, offset, tensor in zip( - storage_block_ids, - offsets[blocks_len:], - tensors[blocks_len:], - ): - task = self.connector.dump([block_id], [offset], [tensor]) - self.dump_tasks.setdefault(request.request_id, {}).setdefault( - block_id, [] - ).append(task) - wait_for_tasks() - self.dump_tasks.clear() + total_offsets.extend(offsets) + total_tensors.extend(tensors) + task_id = self.connector.dump(total_block_ids, total_offsets, total_tensors) + need_dump_tasks[request.request_id] = task_id + + for req_id, task_id in need_dump_tasks.items(): + if self.connector.wait(task_id) != 0: + logger.error(f"Failed to dump blocks for req {request.request_id}") + else: + success_dumped_blocks[req_id] = req_to_dump_blocks[req_id] return success_dumped_blocks if success_dumped_blocks else None def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" done_recving: set[str] = set() - for req_id, tasks in self._need_load_reqs.items(): + for req_id, task in self._need_load_reqs.items(): if req_id in self._load_failed_reqs: + done_recving.add(req_id) continue - unfinished_tasks = [] - for task in tasks: - ret = self.connector.check(task) - if ret == -1: - unfinished_tasks.append(task) - continue - elif ret == 0 and self.connector.wait(task) == 0: - continue + ret, finish = self.connector.check(task) + if ret != 0: + logger.error( + f"Task {task} failed, check return {ret} for request {req_id}" + ) self._load_failed_reqs.add(req_id) - break - if not unfinished_tasks: - done_recving.add(req_id) - self._need_load_reqs[req_id] = unfinished_tasks + elif not finish: + continue + elif (wret := self.connector.wait(task)) != 0: + logger.error( + f"Task {task} failed, wait return {wret} for request {req_id}" + ) + self._load_failed_reqs.add(req_id) + done_recving.add(req_id) # remove the finished requests for req_id in list(done_recving): @@ -596,9 +571,9 @@ def hash_request_tokens( # TODO we will fix hole match later break logger.info( - f"\nnum_total_blocks: {len(block_hashes)}\n" - f"\nnum_lookup_hits on hbm: {start_position}\n" - f"\nnum_lookup_hits on storage except hbm: {num_lookup_hits}\n" + f"num_total_blocks: {len(block_hashes)}, " + f"num_lookup_hits on hbm: {start_position}, " + f"num_lookup_hits on storage except hbm: {num_lookup_hits}" ) # Load async when Decode instance need to load @@ -776,7 +751,8 @@ def request_finished( if cancel_blocks: logger.debug(f"commit {cancel_blocks} to False.") self.connector.commit(cancel_blocks, False) - request.succeed_dumped_blocks.clear() + if hasattr(request, "succeed_dumped_blocks"): + request.succeed_dumped_blocks.clear() return False, None def _extract_blocks( diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py new file mode 100644 index 000000000..e843a3e7e --- /dev/null +++ b/ucm/integration/vllm/ucm_connector.py @@ -0,0 +1,868 @@ +import hashlib +import itertools +import os +import pickle +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Callable, List, Optional + +import torch +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.parallel_state import get_tp_group, get_world_group +from vllm.platforms import current_platform +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request + +from ucm.logger import init_logger +from ucm.shared.metrics import ucmmonitor +from ucm.shared.metrics.observability import UCMStatsLogger +from ucm.store.factory import UcmConnectorFactory +from ucm.store.ucmstore import Task, UcmKVStoreBase +from ucm.utils import Config + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + +logger = init_logger(__name__) + + +@dataclass +class RequestMeta: + ucm_block_ids: list[str] = field(default_factory=list) + hbm_hit_block_num: int = 0 + # local_computed_block + external_computed_block + total_hit_block_num: int = 0 + num_token_ids: int = 0 + vllm_block_ids: list[int] = field(default_factory=list) + token_processed: int = 0 + + +@dataclass +class RequestDispatchMeta: + load_block_ids: tuple[ + list[str], list[int] + ] # [0] mean ucm_block_ids, [1] means vllm_block_ids + dump_block_ids: tuple[list[str], list[int]] + + +@dataclass +class UCMConnectorMetadata(KVConnectorMetadata): + request_meta: dict[str, RequestDispatchMeta] = field(default_factory=dict) + + +class RequestHasher: + """hash(md5) request to generate ucm block id""" + + _SEED_HASH = None + + def __init__(self, vllm_config, rank_id): + meta = f"{vllm_config.model_config.model}:{vllm_config.parallel_config.world_size}:{vllm_config.model_config.dtype}:{rank_id}" + self.meta_bytes = meta.encode("utf-8") + + if RequestHasher._SEED_HASH is None: + RequestHasher._SEED_HASH = self("UCM_HASH_SEED") + + def __call__(self, input_data) -> int: + if isinstance(input_data, str): + input_bytes = input_data.encode("utf-8") + else: + input_bytes = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL) + + h = hashlib.md5(self.meta_bytes + input_bytes) + return int.from_bytes(h.digest(), byteorder="big") + + +class UCMDirectConnector(KVConnectorBase_V1): + """ + This connector means synchronize: + load -> forward -> save + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self.kv_caches: dict[str, torch.Tensor] = {} + self.local_rank = ( + -1 if role == KVConnectorRole.SCHEDULER else get_world_group().local_rank + ) + self.global_rank = self._vllm_config.parallel_config.rank + self.block_size = self._vllm_config.cache_config.block_size + self.is_mla = self._vllm_config.model_config.is_deepseek_mla + self.is_dsa = False + self.kv_cache_dtype: torch.dtype = None + + if current_platform.is_cuda_alike(): + logger.info("CUDA device is available.") + torch_dev = torch + dev_name = "cuda" + elif current_platform.device_type == "npu": + logger.info("NPU device is available.") + torch_dev = torch.npu + dev_name = "npu" + else: + raise RuntimeError("Unsupported device platform for UCMDirectConnector.") + + if self.local_rank >= 0: + self.device = torch_dev.device(f"{dev_name}:{self.local_rank}") + self._layer_offset_cache = {} + + self.store: UcmKVStoreBase + + if role == KVConnectorRole.SCHEDULER: + self.request_hasher = RequestHasher(vllm_config, 0) + else: + self.request_hasher = RequestHasher(vllm_config, self.global_rank) + + # save block info, avoid hash request twice, and track them until request finished + self.requests_meta: dict[str, RequestMeta] = {} + + ucm_config = Config(vllm_config.kv_transfer_config) + self.launch_config = ucm_config.get_config() + + self.load_only_first_rank: bool = ( + self.launch_config.get("load_only_first_rank", self.is_mla) and self.is_mla + ) + if self.load_only_first_rank: + if role == KVConnectorRole.WORKER: + self.group_coordinator = get_tp_group() + self.broadcast_fn = self.group_coordinator.broadcast + self.broadcast_stream = torch.cuda.Stream() + + logger.info(f"self.launch_config: {self.launch_config}") + connector_configs = self.launch_config.get("ucm_connectors", []) + assert len(connector_configs) > 0, "no storage connector name in config." + + name = connector_configs[0].get("ucm_connector_name") + config = connector_configs[0].get("ucm_connector_config") or {} + config["device"] = self.local_rank + config["role"] = "scheduler" if role == KVConnectorRole.SCHEDULER else "worker" + element_size = vllm_config.model_config.dtype.itemsize + single_head_dim = vllm_config.model_config.get_head_size() + num_head_per_tp = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + total_tp_size = vllm_config.parallel_config.tensor_parallel_size + num_layers = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config + ) + block_size_per_layer = self.block_size * element_size * single_head_dim + config["kv_block_size"] = ( + block_size_per_layer + * num_layers + * (1 if self.is_mla else num_head_per_tp * 2) + ) + config["io_size"] = block_size_per_layer * ( + 1 if self.is_mla else num_head_per_tp + ) + self.store = UcmConnectorFactory.create_connector(name, config) + self.block_data_size = config["kv_block_size"] + + logger.info("init UCConnectorImpl, connector: %s", name) + logger.info( + "single file size = %d MB, io_size = %d KB,", + config["kv_block_size"] / 1024 / 1024, + config["io_size"] / 1024, + ) + + self.metrics_config = self.launch_config.get("metrics_config_path", "") + if self.metrics_config: + self.stats_logger = UCMStatsLogger( + vllm_config.model_config.served_model_name, + self.global_rank, + self.metrics_config, + ) + self.monitor = ucmmonitor.StatsMonitor.get_instance() + self.synchronize = ( + torch.cuda.synchronize + if current_platform.is_cuda_alike() + else torch.npu.synchronize + ) + + def generate_hash(self, block_size: int, request: "Request") -> list[str]: + token_ids = request.all_token_ids + + ret = [] + parent_block_hash_value = RequestHasher._SEED_HASH + for start in range(0, len(token_ids), block_size): + end = start + block_size + block_token_ids = token_ids[start:end] + # Do not hash the block if it is not full. + if len(block_token_ids) < block_size: + break + + block_token_ids_tuple = tuple(block_token_ids) + hash_value = self.request_hasher( + (parent_block_hash_value, block_token_ids_tuple) + ) + parent_block_hash_value = hash_value + ret.append(str(hash_value)) + + return ret + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + assert num_computed_tokens % self.block_size == 0 + hbm_hit_block_num = num_computed_tokens // self.block_size + + ucm_block_ids = self.generate_hash(self.block_size, request) + + external_block_ids = ucm_block_ids[hbm_hit_block_num:] + if not external_block_ids: + return 0, False + + lookup_results = self.store.lookup(external_block_ids) + external_hit_blocks = 0 + for i, hit in enumerate(lookup_results): + if not hit: + break + external_hit_blocks += 1 + logger.info( + f"request_id: {request.request_id}, " + f"total_blocks_num: {len(ucm_block_ids)}, " + f"hit hbm: {hbm_hit_block_num}, " + f"hit external: {external_hit_blocks}" + ) + if self.metrics_config: + self.monitor.update_stats( + "ConnStats", + {"interval_lookup_hit_rates": external_hit_blocks / len(ucm_block_ids)}, + ) + + total_hit_block_num = hbm_hit_block_num + external_hit_blocks + + external_hit_tokens = external_hit_blocks * self.block_size + + # When all the tokens are cached in ssd or hbm, + # we need to recompute the last token. This if condition will be removed + # once vLLM scheduler provides a better solution in the future. + num_total_hit_tokens = total_hit_block_num * self.block_size + if num_total_hit_tokens == request.num_tokens: + external_hit_tokens -= 1 + + self.requests_meta[request.request_id] = RequestMeta( + ucm_block_ids=ucm_block_ids, + hbm_hit_block_num=hbm_hit_block_num, + total_hit_block_num=total_hit_block_num, + num_token_ids=len(request.all_token_ids), + token_processed=num_total_hit_tokens, + ) + + return external_hit_tokens, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + pass + + def _generate_dispatch_meta( + self, + req_meta: RequestMeta, + new_tokens: int, + vllm_block_ids: list[int], + need_load: bool = True, + ) -> RequestDispatchMeta: + """ + Request Blocks layout: + ---------------------------------------------------------------------------------------------------- + | local_computed_block(HBM hit) | external_computed_block(external hit) | new_block(need to dump) | + ---------------------------------------------------------------------------------------------------- + | hbm_hit_block_num | LOAD | new_blocks_num | + ---------------------------------------------------------------------------------------------------- + | total_hit_block_num | + ---------------------------------------------------------------------------------------------------- + | scheduled_block_num | + """ + + hbm_hit_block_num = req_meta.hbm_hit_block_num + total_hit_block_num = req_meta.total_hit_block_num + ucm_block_ids = req_meta.ucm_block_ids + req_meta.vllm_block_ids.extend(vllm_block_ids) + + load_ucm_block_ids, load_vllm_block_ids = [], [] + dump_ucm_block_ids, dump_vllm_block_ids = [], [] + if need_load: + load_ucm_block_ids = ucm_block_ids[hbm_hit_block_num:total_hit_block_num] + load_vllm_block_ids = vllm_block_ids[hbm_hit_block_num:total_hit_block_num] + + if req_meta.token_processed < req_meta.num_token_ids: + start_idx = req_meta.token_processed // self.block_size + end_idx = (req_meta.token_processed + new_tokens) // self.block_size + dump_ucm_block_ids = ucm_block_ids[start_idx:end_idx] + dump_vllm_block_ids = req_meta.vllm_block_ids[start_idx:end_idx] + req_meta.token_processed += new_tokens + + return RequestDispatchMeta( + (load_ucm_block_ids, load_vllm_block_ids), + (dump_ucm_block_ids, dump_vllm_block_ids), + ) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + requests_dispatch_meta = {} + # for new request, we need to load and dump + for request in scheduler_output.scheduled_new_reqs: + request_id, vllm_block_ids = request.req_id, request.block_ids[0] + req_meta = self.requests_meta.get(request_id) + if req_meta: + requests_dispatch_meta[request_id] = self._generate_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + vllm_block_ids, + ) + + # for cached request, there are 3 situation: + # 1. chunked prefill: we only need dump + # 2. resumed: we need to handle like new request + # 3. TODO decode stage: nothing happened + scheduled_cached_reqs = scheduler_output.scheduled_cached_reqs + if not isinstance(scheduled_cached_reqs, list): + # >= 0.9.2 + for i, request_id in enumerate(scheduled_cached_reqs.req_ids): + req_meta = self.requests_meta.get(request_id) + if req_meta: + new_block_ids = [] + if scheduled_cached_reqs.new_block_ids[i] != None: + new_block_ids = scheduled_cached_reqs.new_block_ids[i][0] + requests_dispatch_meta[request_id] = self._generate_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + new_block_ids, + scheduled_cached_reqs.resumed_from_preemption[i], + ) + else: + for request in scheduled_cached_reqs: + request_id = request.req_id + req_meta = self.requests_meta.get(request_id) + if req_meta: + requests_dispatch_meta[request_id] = self._generate_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + request.new_block_ids[0], + request.resumed_from_preemption, + ) + + # clear finished request + for request_id in scheduler_output.finished_req_ids: + self.requests_meta.pop(request_id, None) + + return UCMConnectorMetadata(requests_dispatch_meta) + + def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"): + if len(self.kv_caches) > 0: + return + for layer_name in forward_context.no_compile_layers: + attn_layer = forward_context.no_compile_layers[layer_name] + if not hasattr(attn_layer, "kv_cache"): + continue + + if layer_name not in self.kv_caches: + self.kv_caches[layer_name] = attn_layer.kv_cache[ + forward_context.virtual_engine + ] + # Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to + # (2, num_blocks, block_size, num_kv_heads, nope_dim/rope_dim). + # Currently, we treat it as GQA, and use is_dsa to mark it, + # which works but leads to space inefficiency. + # TODO: Optimize this to avoid unnecessary space usage. + sample_kv_layer = next(iter(self.kv_caches.values())) + if self.is_mla and len(sample_kv_layer) == 2: + self.is_mla = False + self.is_dsa = True + if self.kv_cache_dtype is None: + self.kv_cache_dtype = sample_kv_layer[0].dtype + + @staticmethod + def _extract_layer_index(layer_name: str) -> Optional[int]: + """ + Extract the layer index from the layer name. + """ + for chunk in layer_name.split("."): + if chunk.isdigit(): + return int(chunk) + return None + + def _precompute_layer_offsets(self): + if not self.kv_caches: + return + + sample_kv_layer = next(iter(self.kv_caches.values())) + elem_size = sample_kv_layer[0].element_size() + block_data_size = ( + sample_kv_layer[0].numel() if self.is_mla else sample_kv_layer[0][0].numel() + ) * elem_size + layer_data_size = block_data_size if self.is_mla else block_data_size * 2 + + # precompute all layers offset + for layer_name, _ in self.kv_caches.items(): + layer_id = self._extract_layer_index(layer_name) + assert layer_id is not None + k_offset = layer_data_size * layer_id + v_offset = k_offset + block_data_size if not self.is_mla else 0 + self._layer_offset_cache[layer_name] = (k_offset, v_offset) + + def _get_tensor_and_offset( + self, vllm_block_ids: list[int], kv_layer: torch.Tensor, layer_name: str + ) -> tuple[list[torch.Tensor], list[int]]: + """ + GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size) + MLA: one layer shape is (num_blocks, block_size, head_size) + """ + k_tensors, k_offsets = [], [] + v_tensors, v_offsets = [], [] + k_offset, v_offset = self._layer_offset_cache[layer_name] + + for vllm_block_id in vllm_block_ids: + k_tensors.append( + kv_layer[vllm_block_id] if self.is_mla else kv_layer[0][vllm_block_id] + ) + k_offsets.append(k_offset) + if not self.is_mla: + v_tensors.append(kv_layer[1][vllm_block_id]) + v_offsets.append(v_offset) + return k_tensors + v_tensors, k_offsets + v_offsets + + def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]): + if not self._layer_offset_cache: + self._precompute_layer_offsets() + + num_layers = len(self.kv_caches) + num_blocks_per_layer = len(vllm_block_ids) + num_tensors_per_layer = num_blocks_per_layer * (1 if self.is_mla else 2) + dst_tensor_addr = [None] * (num_layers * num_tensors_per_layer) + ucm_offsets = [0] * (num_layers * num_tensors_per_layer) + + idx = 0 + for layer_name, one_layer_kv_cache in self.kv_caches.items(): + tensors, offsets = self._get_tensor_and_offset( + vllm_block_ids, one_layer_kv_cache, layer_name + ) + dst_tensor_addr[idx : idx + len(tensors)] = tensors + ucm_offsets[idx : idx + len(offsets)] = offsets + idx += len(tensors) + + repeat_times = len(self.kv_caches) * (1 if self.is_mla else 2) + ucm_total_block_ids = ucm_block_ids * repeat_times + + assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr) + return ucm_total_block_ids, ucm_offsets, dst_tensor_addr + + def _broadcast(self, dst_tensor_addr: list[torch.Tensor]): + rec_tensor: torch.Tensor = None + with torch.cuda.stream(self.broadcast_stream): + # TODO support broadcast when PP + if self.global_rank == 0: + tensor_to_broadcast = torch.stack(dst_tensor_addr, dim=0) + self.broadcast_fn(tensor_to_broadcast, 0) + else: + shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape + # TODO create earlier + rec_tensor = torch.empty( + shape, dtype=self.kv_cache_dtype, device=self.device + ) + self.broadcast_fn(rec_tensor, 0) + self.broadcast_stream.synchronize() + if self.global_rank != 0 and rec_tensor is not None: + for i, tensor in enumerate(dst_tensor_addr): + tensor.copy_(rec_tensor[i]) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + metadata = self._get_connector_metadata() + assert isinstance(metadata, UCMConnectorMetadata) + + self._init_kv_caches_from_forward_context(forward_context) + + request_to_task: dict[str, Optional[Task]] = {} + req_broadcast_addr = {} + is_load = False + num_loaded_block = 0 + num_loaded_request = 0 + load_start_time = time.perf_counter() * 1000 + for request_id, request in metadata.request_meta.items(): + if len(request.load_block_ids[0]) == 0: + continue + is_load = True + num_loaded_block += len(request.load_block_ids[0]) + num_loaded_request += 1 + + ucm_block_ids, vllm_block_ids = request.load_block_ids + if self.global_rank != 0 and not self.is_mla and not self.is_dsa: + for i, ucm_block_id in enumerate(ucm_block_ids): + ucm_block_ids[i] = str(self.request_hasher(ucm_block_id)) + ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task( + vllm_block_ids, ucm_block_ids + ) + if self.global_rank == 0 or not self.load_only_first_rank: + request_to_task[request_id] = self.store.load( + ucm_total_block_ids, ucm_offsets, dst_tensor_addr + ) + else: + request_to_task[request_id] = None + req_broadcast_addr[request_id] = dst_tensor_addr + + for request_id, task in request_to_task.items(): + # TODO error handling + if self.global_rank == 0 or not self.load_only_first_rank: + if self.store.wait(task) != 0: + logger.error(f"request {request_id} load kv cache failed.") + if self.load_only_first_rank: + self._broadcast(req_broadcast_addr[request_id]) + load_end_time = time.perf_counter() * 1000 + load_speed = ( + num_loaded_block + * self.block_data_size + / (load_end_time - load_start_time) + / 1024 + / 1024 + ) # GB/s + if self.metrics_config and is_load: + self.monitor.update_stats( + "ConnStats", + { + "load_requests_num": num_loaded_request, + "load_blocks_num": num_loaded_block, + "load_duration": load_end_time - load_start_time, + "load_speed": load_speed, + }, + ) + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + pass + + def wait_for_save(self) -> None: + + # TODO support PP + if (self.is_mla or self.is_dsa) and self.global_rank != 0: + return + if self.metrics_config: + self.synchronize() + + metadata = self._get_connector_metadata() + assert isinstance(metadata, UCMConnectorMetadata) + + request_to_task: dict[str, Task] = {} + request_to_blocks: dict[str, list[str]] = {} + is_save = False + num_saved_block = 0 + num_saved_request = 0 + save_start_time = time.perf_counter() * 1000 + for request_id, request in metadata.request_meta.items(): + if len(request.dump_block_ids[0]) == 0: + continue + is_save = True + num_saved_block += len(request.dump_block_ids[0]) + num_saved_request += 1 + + ucm_block_ids, vllm_block_ids = request.dump_block_ids + if self.global_rank != 0: + for i, ucm_block_id in enumerate(ucm_block_ids): + ucm_block_ids[i] = str(self.request_hasher(ucm_block_id)) + rets = self.store.create(ucm_block_ids) + end = 0 + for i, ret in enumerate(rets): + if ret != 0: + logger.error( + f"create blocks for {request_id} failed, block index: {i}, ret code: {ret}" + ) + break + end += 1 + + if end == 0: + continue + ucm_block_ids = ucm_block_ids[:end] + vllm_block_ids = vllm_block_ids[:end] + ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task( + vllm_block_ids, ucm_block_ids + ) + request_to_task[request_id] = self.store.dump( + ucm_total_block_ids, ucm_offsets, dst_tensor_addr + ) + request_to_blocks[request_id] = ucm_block_ids + + for request_id, task in request_to_task.items(): + ucm_block_ids = request_to_blocks[request_id] + if self.store.wait(task) == 0: + self.store.commit(ucm_block_ids, True) + else: + logger.error(f"request {request_id} dump kv cache failed.") + self.store.commit(ucm_block_ids, False) + save_end_time = time.perf_counter() * 1000 + save_speed = ( + num_saved_block + * self.block_data_size + / (save_end_time - save_start_time) + / 1024 + / 1024 + ) # GB/s + if self.metrics_config and is_save: + self.monitor.update_stats( + "ConnStats", + { + "save_requests_num": num_saved_request, + "save_blocks_num": num_saved_block, + "save_duration": save_end_time - save_start_time, + "save_speed": save_speed, + }, + ) + + def clear_connector_metadata(self) -> None: + super().clear_connector_metadata() + + +class UCMLayerWiseConnector(UCMDirectConnector): + """ + This Connector means overlap: + load l0 -> forward l0 -> save l0 + load l1 -> forward l1 -> save l1 + load l2 -> forward l2 -> save l2 + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config, role) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + raise NotImplementedError + + def wait_for_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + raise NotImplementedError + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + raise NotImplementedError + + def wait_for_save(self) -> None: + raise NotImplementedError + + +class UCMPDConnector(UCMDirectConnector): + """ + This Connector means overlap (especially for Decode Instance): + step (req0,1,2) forward -> step (req0,1,2,3) forward + load req3 -> load req4 + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config, role) + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + raise NotImplementedError + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer + (requests that previously returned True from request_finished()), + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + raise NotImplementedError + + +class UCMMockConnector(UCMDirectConnector): + """ + This Connector can control hit ratio, for example: if your hit ratio is 100%, + you can set "hit_ratio" by config or env_vars, then get_num_new_matched_tokens() + will reduce hit_tokens under the hit_ratio you set. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config, role) + self._hit_ratio = float(self.launch_config["hit_ratio"]) + logger.info(f"hit_ratio: {self._hit_ratio}") + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + hit_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens) + expect_hit_tokens = int(self._hit_ratio * request.num_prompt_tokens) + if hit_tokens <= expect_hit_tokens: + return hit_tokens, False + expect_hit_block_num = expect_hit_tokens // self.block_size + request_meta = self.requests_meta[request.request_id] + request_meta.total_hit_block_num = expect_hit_block_num + request_meta.hbm_hit_block_num = min( + expect_hit_block_num, request_meta.hbm_hit_block_num + ) + + logger.info( + "Hijacked By MockConnector," + f"request_id: {request.request_id}, " + f"total_blocks_num: {len(request_meta.ucm_block_ids)}, " + f"hit hbm: {request_meta.hbm_hit_block_num}, " + f"hit external: {request_meta.total_hit_block_num - request_meta.hbm_hit_block_num}" + ) + + return expect_hit_block_num * self.block_size, False + + +class UCMConnector(KVConnectorBase_V1): + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self.connector: KVConnectorBase_V1 + # TODO new conn by config + if ( + self._vllm_config.kv_transfer_config is not None + and "hit_ratio" + in self._vllm_config.kv_transfer_config.kv_connector_extra_config + ): + self.connector = UCMMockConnector(vllm_config, role) + else: + self.connector = UCMDirectConnector(vllm_config, role) + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self.connector.get_num_new_matched_tokens(request, num_computed_tokens) + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + """ + Update KVConnector state after block allocation. + """ + self.connector.update_state_after_alloc(request, blocks, num_external_tokens) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + return self.connector.build_connector_meta(scheduler_output) + + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: + """Set the connector metadata from the scheduler. + + This function should be called by the model runner every time + before the model execution. The metadata will be used for runtime + KV cache loading and saving. + + Args: + connector_metadata (dict): the connector metadata. + """ + self.connector.bind_connector_metadata(connector_metadata) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + self.connector.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + self.connector.wait_for_layer_load(layer_name) + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + """ + Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + self.connector.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) + + def wait_for_save(self) -> None: + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + self.connector.wait_for_save() + + def clear_connector_metadata(self) -> None: + """Clear the connector metadata. + + This function should be called by the model runner every time + after the model execution. + """ + self.connector.clear_connector_metadata() diff --git a/ucm/pd/toy_proxy_server.py b/ucm/pd/toy_proxy_server.py index 9b7f15799..99dc75825 100644 --- a/ucm/pd/toy_proxy_server.py +++ b/ucm/pd/toy_proxy_server.py @@ -17,53 +17,78 @@ @asynccontextmanager async def lifespan(app: FastAPI): """ - Lifespan context manager to handle startup and shutdown events. + Lifespan context manager to initialize clients based on mode. """ - # Startup: Initialize client pools for prefiller and decoder services app.state.prefill_clients = [] app.state.decode_clients = [] - - # Create prefill clients - for i, (host, port) in enumerate(global_args.prefiller_instances): - prefiller_base_url = f"http://{host}:{port}/v1" - app.state.prefill_clients.append( - { - "client": httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), - "host": host, - "port": port, - "id": i, - } + app.state.worker_clients = [] # For PD-mixed workers + + if global_args.pd_disaggregation: + # === PD disaggregation === + for i, (host, port) in enumerate(global_args.prefiller_instances): + base_url = f"http://{host}:{port}/v1" + app.state.prefill_clients.append( + { + "client": httpx.AsyncClient(timeout=None, base_url=base_url), + "host": host, + "port": port, + "id": i, + } + ) + + for i, (host, port) in enumerate(global_args.decoder_instances): + base_url = f"http://{host}:{port}/v1" + app.state.decode_clients.append( + { + "client": httpx.AsyncClient(timeout=None, base_url=base_url), + "host": host, + "port": port, + "id": i, + } + ) + + app.state.prefill_iterator = itertools.cycle( + range(len(app.state.prefill_clients)) ) - - # Create decode clients - for i, (host, port) in enumerate(global_args.decoder_instances): - decoder_base_url = f"http://{host}:{port}/v1" - app.state.decode_clients.append( - { - "client": httpx.AsyncClient(timeout=None, base_url=decoder_base_url), - "host": host, - "port": port, - "id": i, - } + app.state.decode_iterator = itertools.cycle( + range(len(app.state.decode_clients)) ) - # Initialize round-robin iterators - app.state.prefill_iterator = itertools.cycle(range(len(app.state.prefill_clients))) - app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients))) + print( + f"[PD Mode] Initialized {len(app.state.prefill_clients)} prefillers " + f"and {len(app.state.decode_clients)} decoders." + ) - print( - f"Initialized {len(app.state.prefill_clients)} prefill clients " - f"and {len(app.state.decode_clients)} decode clients." - ) + else: + # === PD mix === + for i, (host, port) in enumerate(global_args.worker_instances): + base_url = f"http://{host}:{port}/v1" + app.state.worker_clients.append( + { + "client": httpx.AsyncClient(timeout=None, base_url=base_url), + "host": host, + "port": port, + "id": i, + } + ) + + app.state.worker_iterator = itertools.cycle( + range(len(app.state.worker_clients)) + ) + print( + f"[Mixed Mode] Initialized {len(app.state.worker_clients)} PD-mixed workers." + ) yield - # Shutdown: Close all clients - for client_info in app.state.prefill_clients: - await client_info["client"].aclose() - - for client_info in app.state.decode_clients: - await client_info["client"].aclose() + # Close all clients + for client_list in [ + app.state.prefill_clients, + app.state.decode_clients, + app.state.worker_clients, + ]: + for client_info in client_list: + await client_info["client"].aclose() # Update FastAPI app initialization to use lifespan @@ -75,6 +100,26 @@ def parse_args(): parser.add_argument("--port", type=int, default=8000) parser.add_argument("--host", type=str, default="localhost") + parser.add_argument( + "--pd-disaggregation", + action="store_true", + help="Enable PD disaggregation mode (prefill and decode separation)", + ) + # For PD mix instances + parser.add_argument( + "--worker-hosts", + "--work-host", + type=str, + nargs="+", + default=["localhost"], + ) + parser.add_argument( + "--worker-ports", + "--work-port", + type=int, + nargs="+", + default=[8100], + ) # For prefiller instances parser.add_argument( @@ -107,9 +152,15 @@ def parse_args(): if len(args.decoder_hosts) != len(args.decoder_ports): raise ValueError("Number of decoder hosts must match number of decoder ports") - # Create tuples of (host, port) for each service type + if len(args.worker_hosts) != len(args.worker_ports): + raise ValueError("Number of worker hosts must match number of worker ports") + + # Create instance tuples args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports)) args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) + args.worker_instances = list( + zip(args.worker_hosts, args.worker_ports) + ) # Mixed workers return args @@ -120,12 +171,15 @@ def get_next_client(app, service_type: str): Args: app: The FastAPI app instance - service_type: Either 'prefill' or 'decode' + service_type: 'worker' 、'prefill' 、'decode' Returns: The next client to use """ - if service_type == "prefill": + if service_type == "worker": + worker_idx = next(app.state.worker_iterator) + return app.state.worker_clients[worker_idx] + elif service_type == "prefill": client_idx = next(app.state.prefill_iterator) return app.state.prefill_clients[client_idx] elif service_type == "decode": @@ -183,37 +237,72 @@ async def _handle_completions(api: str, request: Request): req_data = await request.json() request_id = str(uuid.uuid4()) - # Get the next prefill client in round-robin fashion - prefill_client_info = get_next_client(request.app, "prefill") - - # Send request to prefill service - response = await send_request_to_service( - prefill_client_info, api, req_data, request_id - ) - - # Extract the needed fields - response_json = response.json() - - # Get the next decode client in round-robin fashion - decode_client_info = get_next_client(request.app, "decode") - - logger.debug("Using %s %s", prefill_client_info, decode_client_info) - - # Stream response from decode service - async def generate_stream(): - async for chunk in stream_service_response( - decode_client_info, api, req_data, request_id=request_id - ): - yield chunk - - return StreamingResponse(generate_stream(), media_type="application/json") + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + + if global_args.pd_disaggregation: + # === PD disaggregation logic === + + # Step 1: Send request to prefiller (to trigger computation and cache KV) + prefill_client_info = get_next_client(request.app, "prefill") + prefill_req_data = req_data.copy() + prefill_req_data["stream"] = False + prefill_req_data["max_tokens"] = 1 + if "stream_options" in prefill_req_data: + del prefill_req_data["stream_options"] + + response = await prefill_client_info["client"].post( + api, json=prefill_req_data, headers=headers + ) + response.raise_for_status() + + # Step 2: Stream full output from decoder + decode_client_info = get_next_client(request.app, "decode") + + logger.debug( + "PD-DISAGG: Prefill=%s:%d, Decode=%s:%d", + prefill_client_info["host"], + prefill_client_info["port"], + decode_client_info["host"], + decode_client_info["port"], + ) + + async def generate_stream(): + async for chunk in stream_service_response( + decode_client_info, api, req_data, request_id + ): + yield chunk + + return StreamingResponse(generate_stream(), media_type="application/json") + + else: + # === PD mixed mode: Directly forward the entire stream using round-robin === + worker_client_info = get_next_client(request.app, "worker") + + logger.debug( + "PD-MIXED: Forwarding to %s:%d", + worker_client_info["host"], + worker_client_info["port"], + ) + + async def generate_stream(): + async with worker_client_info["client"].stream( + "POST", api, json=req_data, headers=headers + ) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(): + yield chunk + + return StreamingResponse(generate_stream(), media_type="application/json") except Exception as e: import sys import traceback exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server" f" - {api} endpoint") + print(f"Error in proxy server - {api} endpoint") print(e) print("".join(traceback.format_exception(*exc_info))) raise @@ -231,12 +320,19 @@ async def handle_chat_completions(request: Request): @app.get("/healthcheck") async def healthcheck(): - """Simple endpoint to check if the server is running.""" - return { - "status": "ok", - "prefill_instances": len(app.state.prefill_clients), - "decode_instances": len(app.state.decode_clients), - } + if global_args.pd_disaggregation: + return { + "status": "ok", + "mode": "pd-disaggregation", + "prefill_instances": len(app.state.prefill_clients), + "decode_instances": len(app.state.decode_clients), + } + else: + return { + "status": "ok", + "mode": "pd-mixed", + "worker_instances": len(app.state.worker_clients), + } if __name__ == "__main__": diff --git a/ucm/shared/CMakeLists.txt b/ucm/shared/CMakeLists.txt new file mode 100644 index 000000000..1f73d1e8c --- /dev/null +++ b/ucm/shared/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(vendor) +add_subdirectory(infra) +add_subdirectory(trans) +add_subdirectory(metrics) +add_subdirectory(test) diff --git a/ucm/shared/__init__.py b/ucm/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ucm/shared/infra/CMakeLists.txt b/ucm/shared/infra/CMakeLists.txt new file mode 100644 index 000000000..ba4345dce --- /dev/null +++ b/ucm/shared/infra/CMakeLists.txt @@ -0,0 +1,22 @@ +file(GLOB_RECURSE UCMINFRA_STATUS_SOURCE_FILES "status/*.*") +add_library(infra_status OBJECT ${UCMINFRA_STATUS_SOURCE_FILES}) +target_include_directories(infra_status PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(infra_status PUBLIC fmt) + +file(GLOB UCMINFRA_LOGGER_SOURCE_FILES "logger/*.*") +file(GLOB_RECURSE UCMINFRA_LOGGER_DETAIL_SOURCE_FILES "logger/${LOGGER_BACKEND}/*.cc") +add_library(infra_logger OBJECT ${UCMINFRA_LOGGER_SOURCE_FILES} ${UCMINFRA_LOGGER_DETAIL_SOURCE_FILES}) +target_include_directories(infra_logger PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(infra_logger PUBLIC fmt spdlog) + +file(GLOB_RECURSE UCMINFRA_TEMPLATE_SOURCE_FILES "template/*.*") +add_library(infra_template OBJECT ${UCMINFRA_TEMPLATE_SOURCE_FILES}) +target_include_directories(infra_template PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +file(GLOB_RECURSE UCMINFRA_THREAD_SOURCE_FILES "thread/*.*") +add_library(infra_thread OBJECT ${UCMINFRA_THREAD_SOURCE_FILES}) +target_include_directories(infra_thread PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +file(GLOB_RECURSE UCMINFRA_TIME_SOURCE_FILES "time/*.*") +add_library(infra_time OBJECT ${UCMINFRA_TIME_SOURCE_FILES}) +target_include_directories(infra_time PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/ucm/store/infra/logger/flux/flux_logger.cc b/ucm/shared/infra/logger/flux/flux_logger.cc similarity index 100% rename from ucm/store/infra/logger/flux/flux_logger.cc rename to ucm/shared/infra/logger/flux/flux_logger.cc diff --git a/ucm/store/infra/logger/logger.h b/ucm/shared/infra/logger/logger.h similarity index 97% rename from ucm/store/infra/logger/logger.h rename to ucm/shared/infra/logger/logger.h index f27dd23df..516b9e663 100644 --- a/ucm/store/infra/logger/logger.h +++ b/ucm/shared/infra/logger/logger.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_LOGGER_H -#define UNIFIEDCACHE_LOGGER_H +#ifndef UNIFIEDCACHE_INFRA_LOGGER_H +#define UNIFIEDCACHE_INFRA_LOGGER_H #include #include diff --git a/ucm/store/infra/logger/spdlog/spdlog_logger.cc b/ucm/shared/infra/logger/spdlog/spdlog_logger.cc similarity index 100% rename from ucm/store/infra/logger/spdlog/spdlog_logger.cc rename to ucm/shared/infra/logger/spdlog/spdlog_logger.cc diff --git a/ucm/shared/infra/status/status.h b/ucm/shared/infra/status/status.h new file mode 100644 index 000000000..3711de842 --- /dev/null +++ b/ucm/shared/infra/status/status.h @@ -0,0 +1,90 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_INFRA_STATUS_H +#define UNIFIEDCACHE_INFRA_STATUS_H + +#include +#include +#include + +namespace UC { + +template +static inline constexpr int32_t __MakeStatusCode() +{ + return -50000 - i; +} + +class Status { + static constexpr int32_t OK_ = 0; + static constexpr int32_t ERROR_ = -1; + static constexpr int32_t EPARAM_ = __MakeStatusCode<0>(); + static constexpr int32_t EOOM_ = __MakeStatusCode<1>(); + static constexpr int32_t EOSERROR_ = __MakeStatusCode<2>(); + static constexpr int32_t EDUPLICATE_ = __MakeStatusCode<3>(); + static constexpr int32_t ERETRY_ = __MakeStatusCode<4>(); + static constexpr int32_t ENOOBJ_ = __MakeStatusCode<5>(); + static constexpr int32_t ESERIALIZE_ = __MakeStatusCode<6>(); + static constexpr int32_t EDESERIALIZE_ = __MakeStatusCode<7>(); + static constexpr int32_t EUNSUPPORTED_ = __MakeStatusCode<8>(); + static constexpr int32_t ENOSPACE_ = __MakeStatusCode<9>(); + int32_t code_; + std::string message_; + explicit Status(int32_t code) : code_(code) {} + +public: + bool operator==(const Status& other) const noexcept { return code_ == other.code_; } + bool operator!=(const Status& other) const noexcept { return !(*this == other); } + int32_t Underlying() const { return code_; } + std::string ToString() const + { + auto str = std::to_string(code_); + if (message_.empty()) { return str; } + return fmt::format("{}, {}", str, message_); + } + constexpr bool Success() const noexcept { return code_ == OK_; } + constexpr bool Failure() const noexcept { return !Success(); } + +public: + Status(int32_t code, std::string message) : code_{code}, message_{std::move(message)} {} + static Status OK() { return Status{OK_}; } + static Status Error(std::string message) { return {ERROR_, std::move(message)}; } + static Status Error() { return Status{ERROR_}; } + static Status InvalidParam() { return Status{EPARAM_}; } + static Status OutOfMemory() { return Status{EOOM_}; } + static Status OsApiError() { return Status{EOSERROR_}; } + static Status DuplicateKey() { return Status{EDUPLICATE_}; } + static Status Retry() { return Status{ERETRY_}; } + static Status NotFound() { return Status{ENOOBJ_}; } + static Status SerializeFailed() { return Status{ESERIALIZE_}; } + static Status DeserializeFailed() { return Status{EDESERIALIZE_}; } + static Status Unsupported() { return Status{EUNSUPPORTED_}; } + static Status NoSpace() { return Status{ENOSPACE_}; } +}; + +inline std::string format_as(const Status& status) { return status.ToString(); } + +} // namespace UC + +#endif diff --git a/ucm/store/infra/template/hashset.h b/ucm/shared/infra/template/hashset.h similarity index 98% rename from ucm/store/infra/template/hashset.h rename to ucm/shared/infra/template/hashset.h index b09692bc1..102f69b62 100644 --- a/ucm/store/infra/template/hashset.h +++ b/ucm/shared/infra/template/hashset.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_HASHSET_H -#define UNIFIEDCACHE_HASHSET_H +#ifndef UNIFIEDCACHE_INFRA_HASHSET_H +#define UNIFIEDCACHE_INFRA_HASHSET_H #include #include diff --git a/ucm/store/infra/template/singleton.h b/ucm/shared/infra/template/singleton.h similarity index 94% rename from ucm/store/infra/template/singleton.h rename to ucm/shared/infra/template/singleton.h index fda4957b5..f667288ee 100644 --- a/ucm/store/infra/template/singleton.h +++ b/ucm/shared/infra/template/singleton.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_SINGLETON_H -#define UNIFIEDCACHE_SINGLETON_H +#ifndef UNIFIEDCACHE_INFRA_SINGLETON_H +#define UNIFIEDCACHE_INFRA_SINGLETON_H namespace UC { diff --git a/ucm/shared/infra/template/timer.h b/ucm/shared/infra/template/timer.h new file mode 100644 index 000000000..0c9db149d --- /dev/null +++ b/ucm/shared/infra/template/timer.h @@ -0,0 +1,90 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_INFRA_TIMER_H +#define UNIFIEDCACHE_INFRA_TIMER_H + +#include +#include +#include +#include +#include + +namespace UC { + +template +class Timer { +public: + Timer(const std::chrono::seconds& interval, Callable&& callable) + : interval_(interval), callable_(callable), running_(false) + { + } + ~Timer() + { + { + std::lock_guard lg(this->mutex_); + this->running_ = false; + this->cv_.notify_one(); + } + if (this->thread_.joinable()) { this->thread_.join(); } + } + bool Start() + { + { + std::lock_guard lg(this->mutex_); + if (this->running_) { return true; } + } + try { + this->running_ = true; + this->thread_ = std::thread(&Timer::Runner, this); + return true; + } catch (...) { + return false; + } + } + +private: + void Runner() + { + while (this->running_) { + { + std::unique_lock lg(this->mutex_); + this->cv_.wait_for(lg, this->interval_, [this] { return !this->running_; }); + if (!this->running_) { break; } + } + this->callable_(); + } + } + +private: + std::chrono::seconds interval_; + Callable callable_; + std::thread thread_; + std::mutex mutex_; + std::condition_variable cv_; + std::atomic running_; +}; + +} // namespace UC + +#endif diff --git a/ucm/shared/infra/template/topn_heap.h b/ucm/shared/infra/template/topn_heap.h new file mode 100644 index 000000000..737d0b19a --- /dev/null +++ b/ucm/shared/infra/template/topn_heap.h @@ -0,0 +1,121 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ + +#ifndef UNIFIEDCACHE_INFRA_TOP_N_HEAP_H +#define UNIFIEDCACHE_INFRA_TOP_N_HEAP_H + +#include +#include +#include + +namespace UC { + +template > +class TopNHeap { +public: + using ValueType = T; + using SizeType = uint32_t; + using ConstRef = const T&; + +private: + using IndexType = uint32_t; + std::vector val_{}; + std::vector idx_{}; + SizeType capacity_{0}; + SizeType size_{0}; + Compare cmp_{}; + +public: + explicit TopNHeap(const SizeType capacity) noexcept( + std::is_nothrow_default_constructible_v) + : capacity_{capacity} { + val_.resize(capacity); + idx_.resize(capacity); + } + TopNHeap(const TopNHeap&) = delete; + TopNHeap(const TopNHeap&&) = delete; + TopNHeap& operator=(const TopNHeap&) = delete; + TopNHeap& operator=(const TopNHeap&&) = delete; + ~TopNHeap() { Clear(); } + + SizeType Size() const noexcept { return size_; } + SizeType Capacity() const noexcept { return capacity_; } + bool Empty() const noexcept { return size_ == 0; } + + void Push(ConstRef value) noexcept { + if (size_ < capacity_) { + val_[size_] = value; + idx_[size_] = size_; + SiftUp(size_); + size_++; + return; + } + if (cmp_(val_[idx_.front()], value)) { + val_[idx_.front()] = value; + SiftDown(0); + } + } + ConstRef Top() const noexcept { return val_[idx_.front()]; } + void Pop() noexcept { + idx_[0] = idx_[--size_]; + if (size_) { SiftDown(0); } + } + void Clear() noexcept { size_ = 0; } +private: + static IndexType Parent(IndexType i) noexcept { return (i - 1) / 2; } + static IndexType Left(IndexType i) noexcept { return 2 * i + 1; } + static IndexType Right(IndexType i) noexcept { return 2 * i + 2; } + void SiftUp(IndexType i) noexcept { + auto pos = i; + while (pos > 0) { + auto p = Parent(pos); + if (!cmp_(val_[idx_[pos]], val_[idx_[p]])) { break; } + std::swap(idx_[pos], idx_[p]); + pos = p; + } + } + void SiftDown(IndexType i) noexcept { + auto pos = i; + for (;;) { + auto l = Left(pos); + auto r = Right(pos); + auto best = pos; + if (l < size_ && cmp_(val_[idx_[l]], val_[idx_[best]])) { best = l; } + if (r < size_ && cmp_(val_[idx_[r]], val_[idx_[best]])) { best = r; } + if (best == pos) { break; } + std::swap(idx_[pos], idx_[best]); + pos = best; + } + } +}; + +template > +class TopNFixedHeap : public TopNHeap { +public: + TopNFixedHeap() : TopNHeap{N} {} +}; + +} // namespace UC + +#endif diff --git a/ucm/store/infra/thread/index_pool.h b/ucm/shared/infra/thread/index_pool.h similarity index 97% rename from ucm/store/infra/thread/index_pool.h rename to ucm/shared/infra/thread/index_pool.h index 225ee8842..4217b7a0e 100644 --- a/ucm/store/infra/thread/index_pool.h +++ b/ucm/shared/infra/thread/index_pool.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_INDEX_POOL_H -#define UNIFIEDCACHE_INDEX_POOL_H +#ifndef UNIFIEDCACHE_INFRA_INDEX_POOL_H +#define UNIFIEDCACHE_INFRA_INDEX_POOL_H #include #include diff --git a/ucm/store/infra/thread/latch.h b/ucm/shared/infra/thread/latch.h similarity index 68% rename from ucm/store/infra/thread/latch.h rename to ucm/shared/infra/thread/latch.h index 5d1ccf2bc..fb1dcf583 100644 --- a/ucm/store/infra/thread/latch.h +++ b/ucm/shared/infra/thread/latch.h @@ -21,11 +21,12 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_LATCH_H -#define UNIFIEDCACHE_LATCH_H +#ifndef UNIFIEDCACHE_INFRA_LATCH_H +#define UNIFIEDCACHE_INFRA_LATCH_H #include #include +#include #include namespace UC { @@ -34,8 +35,22 @@ class Latch { public: explicit Latch(const size_t expected = 0) : counter_{expected} {} void Up() { ++this->counter_; } - size_t Done() { return --this->counter_; } - void Notify() { this->cv_.notify_all(); } + void Done(std::function finish) noexcept + { + auto counter = this->counter_.load(std::memory_order_acquire); + while (counter > 0) { + auto desired = counter - 1; + if (this->counter_.compare_exchange_weak(counter, desired, std::memory_order_acq_rel)) { + if (desired == 0) { + if (finish) { finish(); } + std::lock_guard lg(this->mutex_); + this->cv_.notify_all(); + } + return; + } + counter = this->counter_.load(std::memory_order_acquire); + } + } void Wait() { std::unique_lock lk(this->mutex_); @@ -51,4 +66,4 @@ class Latch { } // namespace UC -#endif // UNIFIEDCACHE_LATCH_H +#endif // UNIFIEDCACHE_INFRA_LATCH_H diff --git a/ucm/store/infra/thread/thread_pool.h b/ucm/shared/infra/thread/thread_pool.h similarity index 70% rename from ucm/store/infra/thread/thread_pool.h rename to ucm/shared/infra/thread/thread_pool.h index b212f9e0f..baa514ed7 100644 --- a/ucm/store/infra/thread/thread_pool.h +++ b/ucm/shared/infra/thread/thread_pool.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_THREAD_POOL_H -#define UNIFIEDCACHE_THREAD_POOL_H +#ifndef UNIFIEDCACHE_INFRA_THREAD_POOL_H +#define UNIFIEDCACHE_INFRA_THREAD_POOL_H #include #include @@ -33,11 +33,11 @@ namespace UC { -template +template class ThreadPool { - using WorkerInitFn = std::function; - using WorkerFn = std::function; - using WorkerExitFn = std::function; + using WorkerInitFn = std::function; + using WorkerFn = std::function; + using WorkerExitFn = std::function; public: ThreadPool() = default; @@ -54,14 +54,31 @@ class ThreadPool { if (w.joinable()) { w.join(); } } } - bool Setup( - WorkerFn&& fn, WorkerInitFn&& initFn = [] { return true; }, WorkerExitFn&& exitFn = [] {}, - const size_t nWorker = 1) noexcept + ThreadPool& SetWorkerFn(WorkerFn&& fn) { - this->initFn_ = initFn; - this->fn_ = fn; - this->exitFn_ = exitFn; - std::list> start(nWorker); + this->fn_ = std::move(fn); + return *this; + } + ThreadPool& SetWorkerInitFn(WorkerInitFn&& fn) + { + this->initFn_ = std::move(fn); + return *this; + } + ThreadPool& SetWorkerExitFn(WorkerExitFn&& fn) + { + this->exitFn_ = std::move(fn); + return *this; + } + ThreadPool& SetNWorker(const size_t nWorker) + { + this->nWorker_ = nWorker; + return *this; + } + bool Run() + { + if (this->nWorker_ == 0) { return false; } + if (!this->fn_) { return false; } + std::list> start(this->nWorker_); std::list> fut; for (auto& s : start) { fut.push_back(s.get_future()); @@ -82,31 +99,33 @@ class ThreadPool { void Push(Task&& task) noexcept { std::unique_lock lk(this->mtx_); - this->taskQ_.push_back(task); + this->taskQ_.push_back(std::move(task)); this->cv_.notify_one(); } private: void Worker(std::promise& started) noexcept { - auto success = this->initFn_(); + WorkerArgs args = nullptr; + auto success = true; + if (this->initFn_) { success = this->initFn_(args); } started.set_value(success); while (success) { - Task task; std::unique_lock lk(this->mtx_); this->cv_.wait(lk, [this] { return this->stop_ || !this->taskQ_.empty(); }); if (this->stop_) { break; } if (this->taskQ_.empty()) { continue; } - task = std::move(this->taskQ_.front()); + auto task = std::make_shared(std::move(this->taskQ_.front())); this->taskQ_.pop_front(); lk.unlock(); - this->fn_(task); + this->fn_(*task, args); } - this->exitFn_(); + if (this->exitFn_) { this->exitFn_(args); } } private: bool stop_{false}; + size_t nWorker_{1}; std::list workers_; WorkerInitFn initFn_; WorkerFn fn_; diff --git a/ucm/store/infra/time/stopwatch.h b/ucm/shared/infra/time/stopwatch.h similarity index 95% rename from ucm/store/infra/time/stopwatch.h rename to ucm/shared/infra/time/stopwatch.h index 2386f394b..c2a5bb331 100644 --- a/ucm/store/infra/time/stopwatch.h +++ b/ucm/shared/infra/time/stopwatch.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_STOPWATCH_H -#define UNIFIEDCACHE_STOPWATCH_H +#ifndef UNIFIEDCACHE_INFRA_STOPWATCH_H +#define UNIFIEDCACHE_INFRA_STOPWATCH_H #include diff --git a/ucm/shared/metrics/CMakeLists.txt b/ucm/shared/metrics/CMakeLists.txt new file mode 100644 index 000000000..3933b9f0b --- /dev/null +++ b/ucm/shared/metrics/CMakeLists.txt @@ -0,0 +1,15 @@ +file(GLOB_RECURSE CORE_SRCS CONFIGURE_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/cc/stats/*.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/cc/*.cc") +add_library(monitor_static STATIC ${CORE_SRCS}) +set_property(TARGET monitor_static PROPERTY POSITION_INDEPENDENT_CODE ON) +target_include_directories(monitor_static PUBLIC + $ + $) +set_target_properties(monitor_static PROPERTIES OUTPUT_NAME monitor) + +file(GLOB_RECURSE BINDINGS_SRCS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/cpy/*.cc") +pybind11_add_module(ucmmonitor ${BINDINGS_SRCS}) +target_link_libraries(ucmmonitor PRIVATE -Wl,--whole-archive monitor_static -Wl,--no-whole-archive) +target_include_directories(ucmmonitor PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cc) +set_target_properties(ucmmonitor PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) \ No newline at end of file diff --git a/ucm/shared/metrics/cc/stats/conn_stats.cc b/ucm/shared/metrics/cc/stats/conn_stats.cc new file mode 100644 index 000000000..edf18ac2e --- /dev/null +++ b/ucm/shared/metrics/cc/stats/conn_stats.cc @@ -0,0 +1,83 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "conn_stats.h" + +namespace UC::Metrics { + +ConnStats::ConnStats() = default; + +std::string ConnStats::Name() const { return "ConnStats"; } + +void ConnStats::Reset() +{ + for (auto& v : data_) v.clear(); +} + +void ConnStats::Update(const std::unordered_map& params) +{ + for (const auto& [k, v] : params) { + Key id = KeyFromString(k); + if (id == Key::COUNT) continue; + EmplaceBack(id, v); + } +} + +std::unordered_map> ConnStats::Data() +{ + std::unordered_map> result; + result["save_requests_num"] = data_[static_cast(Key::save_requests_num)]; + result["save_blocks_num"] = data_[static_cast(Key::save_blocks_num)]; + result["save_duration"] = data_[static_cast(Key::save_duration)]; + result["save_speed"] = data_[static_cast(Key::save_speed)]; + result["load_requests_num"] = data_[static_cast(Key::load_requests_num)]; + result["load_blocks_num"] = data_[static_cast(Key::load_blocks_num)]; + result["load_duration"] = data_[static_cast(Key::load_duration)]; + result["load_speed"] = data_[static_cast(Key::load_speed)]; + result["interval_lookup_hit_rates"] = + data_[static_cast(Key::interval_lookup_hit_rates)]; + return result; +} + +Key ConnStats::KeyFromString(const std::string& k) +{ + if (k == "save_requests_num") return Key::save_requests_num; + if (k == "save_blocks_num") return Key::save_blocks_num; + if (k == "save_duration") return Key::save_duration; + if (k == "save_speed") return Key::save_speed; + if (k == "load_requests_num") return Key::load_requests_num; + if (k == "load_blocks_num") return Key::load_blocks_num; + if (k == "load_duration") return Key::load_duration; + if (k == "load_speed") return Key::load_speed; + if (k == "interval_lookup_hit_rates") return Key::interval_lookup_hit_rates; + return Key::COUNT; +} + +void ConnStats::EmplaceBack(Key id, double value) +{ + data_[static_cast(id)].push_back(value); +} + +static Registrar registrar; + +} // namespace UC::Metrics \ No newline at end of file diff --git a/ucm/shared/metrics/cc/stats/conn_stats.h b/ucm/shared/metrics/cc/stats/conn_stats.h new file mode 100644 index 000000000..e8cc94559 --- /dev/null +++ b/ucm/shared/metrics/cc/stats/conn_stats.h @@ -0,0 +1,78 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_CONNSTATS_H +#define UNIFIEDCACHE_CONNSTATS_H + +#include +#include +#include +#include +#include +#include "istats.h" +#include "stats_registry.h" + +namespace UC::Metrics { + +enum class Key : uint8_t { + interval_lookup_hit_rates = 0, + save_requests_num, + save_blocks_num, + save_duration, + save_speed, + load_requests_num, + load_blocks_num, + load_duration, + load_speed, + COUNT +}; + +class ConnStats : public IStats { +public: + ConnStats(); + ~ConnStats() = default; + + std::string Name() const override; + void Reset() override; + void Update(const std::unordered_map& params) override; + std::unordered_map> Data() override; + +private: + static constexpr std::size_t N = static_cast(Key::COUNT); + std::array, N> data_; + + static Key KeyFromString(const std::string& k); + void EmplaceBack(Key id, double value); +}; + +struct Registrar { + Registrar() + { + StatsRegistry::RegisterStats( + "ConnStats", []() -> std::unique_ptr { return std::make_unique(); }); + } +}; + +} // namespace UC::Metrics + +#endif // UNIFIEDCACHE_CONNSTATS_H \ No newline at end of file diff --git a/ucm/shared/metrics/cc/stats/istats.h b/ucm/shared/metrics/cc/stats/istats.h new file mode 100644 index 000000000..6e8de7b32 --- /dev/null +++ b/ucm/shared/metrics/cc/stats/istats.h @@ -0,0 +1,45 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_ISTATS_H +#define UNIFIEDCACHE_ISTATS_H + +#include +#include +#include +#include + +namespace UC::Metrics { + +class IStats { +public: + virtual ~IStats() = default; + virtual std::string Name() const = 0; + virtual void Update(const std::unordered_map& params) = 0; + virtual void Reset() = 0; + virtual std::unordered_map> Data() = 0; +}; + +} // namespace UC::Metrics + +#endif \ No newline at end of file diff --git a/ucm/shared/metrics/cc/stats_monitor.cc b/ucm/shared/metrics/cc/stats_monitor.cc new file mode 100644 index 000000000..2d3d80266 --- /dev/null +++ b/ucm/shared/metrics/cc/stats_monitor.cc @@ -0,0 +1,82 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "stats_monitor.h" +#include +#include +#include "stats/istats.h" +#include "stats_registry.h" + +namespace UC::Metrics { + +StatsMonitor::StatsMonitor() +{ + auto& registry = StatsRegistry::GetInstance(); + for (const auto& name : registry.GetRegisteredStatsNames()) { + stats_map_[name] = registry.CreateStats(name); + } +} + +void StatsMonitor::CreateStats(const std::string& name) +{ + std::lock_guard lock(mutex_); + auto& registry = StatsRegistry::GetInstance(); + stats_map_[name] = registry.CreateStats(name); +} + +std::unordered_map> StatsMonitor::GetStats(const std::string& name) +{ + std::lock_guard lock(mutex_); + return stats_map_[name]->Data(); +} + +void StatsMonitor::ResetStats(const std::string& name) +{ + std::lock_guard lock(mutex_); + stats_map_[name]->Reset(); +} + +std::unordered_map> +StatsMonitor::GetStatsAndClear(const std::string& name) +{ + std::lock_guard lock(mutex_); + auto result = stats_map_[name]->Data(); + stats_map_[name]->Reset(); + return result; +} + +void StatsMonitor::UpdateStats(const std::string& name, + const std::unordered_map& params) +{ + std::lock_guard lock(mutex_); + auto it = stats_map_.find(name); + if (it != stats_map_.end()) { it->second->Update(params); } +} + +void StatsMonitor::ResetAllStats() +{ + std::lock_guard lock(mutex_); + for (auto& [n, ptr] : stats_map_) { ptr->Reset(); } +} + +} // namespace UC::Metrics \ No newline at end of file diff --git a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_manager.h b/ucm/shared/metrics/cc/stats_monitor.h similarity index 55% rename from ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_manager.h rename to ucm/shared/metrics/cc/stats_monitor.h index d9d6a1976..1545d4b5c 100644 --- a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_manager.h +++ b/ucm/shared/metrics/cc/stats_monitor.h @@ -21,39 +21,50 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_TSF_TASK_MANAGER_H -#define UNIFIEDCACHE_TSF_TASK_MANAGER_H +#ifndef UNIFIEDCACHE_MONITOR_H +#define UNIFIEDCACHE_MONITOR_H #include +#include +#include #include #include -#include "tsf_task_queue.h" +#include "stats/istats.h" -namespace UC { +namespace UC::Metrics { -class TsfTaskManager { +class StatsMonitor { public: - Status Setup(const int32_t deviceId, const size_t streamNumber, const size_t bufferSize, - const size_t bufferNumber, const size_t timeoutMs, const SpaceLayout* layout); - Status Submit(std::list& tasks, const size_t size, const size_t number, - const std::string& brief, size_t& taskId); - Status Wait(const size_t taskId); - Status Check(const size_t taskId, bool& finish); + static StatsMonitor& GetInstance() + { + static StatsMonitor inst; + return inst; + } -private: - void Dispatch(std::list& tasks, std::vector>& targets, - const size_t taskId, std::shared_ptr waiter) const; + ~StatsMonitor() = default; + + void CreateStats(const std::string& name); + + std::unordered_map> GetStats(const std::string& name); + + void ResetStats(const std::string& name); + + std::unordered_map> GetStatsAndClear(const std::string& name); + + void UpdateStats(const std::string& name, + const std::unordered_map& params); + + void ResetAllStats(); private: - std::mutex _mutex; - TsfTaskSet _failureSet; - std::unordered_map> _waiters; - std::vector> _queues; - size_t _qIdx{0}; - size_t _taskIdSeed{0}; - size_t _timeoutMs{0}; + std::mutex mutex_; + std::unordered_map> stats_map_; + + StatsMonitor(); + StatsMonitor(const StatsMonitor&) = delete; + StatsMonitor& operator=(const StatsMonitor&) = delete; }; -} // namespace UC +} // namespace UC::Metrics -#endif +#endif // UNIFIEDCACHE_MONITOR_H \ No newline at end of file diff --git a/ucm/shared/metrics/cc/stats_registry.cc b/ucm/shared/metrics/cc/stats_registry.cc new file mode 100644 index 000000000..c2551d9ad --- /dev/null +++ b/ucm/shared/metrics/cc/stats_registry.cc @@ -0,0 +1,59 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "stats_registry.h" + +namespace UC::Metrics { + +StatsRegistry& StatsRegistry::GetInstance() +{ + static StatsRegistry inst; + return inst; +} + +void StatsRegistry::RegisterStats(std::string name, Creator creator) +{ + auto& reg = GetInstance(); + std::lock_guard lk(reg.mutex_); + reg.registry_[name] = creator; +} + +std::unique_ptr StatsRegistry::CreateStats(const std::string& name) +{ + auto& reg = GetInstance(); + std::lock_guard lk(reg.mutex_); + if (auto it = reg.registry_.find(name); it != reg.registry_.end()) return it->second(); + return nullptr; +} + +std::vector StatsRegistry::GetRegisteredStatsNames() +{ + auto& reg = GetInstance(); + std::lock_guard lk(reg.mutex_); + std::vector names; + names.reserve(reg.registry_.size()); + for (auto& [n, _] : reg.registry_) names.push_back(n); + return names; +} + +} // namespace UC::Metrics \ No newline at end of file diff --git a/ucm/shared/metrics/cc/stats_registry.h b/ucm/shared/metrics/cc/stats_registry.h new file mode 100644 index 000000000..c22b6617c --- /dev/null +++ b/ucm/shared/metrics/cc/stats_registry.h @@ -0,0 +1,58 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_REGISTRY_H +#define UNIFIEDCACHE_REGISTRY_H + +#include +#include +#include +#include "stats/istats.h" + +namespace UC::Metrics { + +using Creator = std::unique_ptr (*)(); + +class StatsRegistry { +public: + static StatsRegistry& GetInstance(); + + static void RegisterStats(std::string name, Creator creator); + + std::unique_ptr CreateStats(const std::string& name); + + std::vector GetRegisteredStatsNames(); + +private: + StatsRegistry() = default; + ~StatsRegistry() = default; + StatsRegistry(const StatsRegistry&) = delete; + StatsRegistry& operator=(const StatsRegistry&) = delete; + + std::mutex mutex_; + std::unordered_map registry_; +}; + +} // namespace UC::Metrics + +#endif // UNIFIEDCACHE_REGISTRY_H \ No newline at end of file diff --git a/ucm/shared/metrics/cpy/metrics.py.cc b/ucm/shared/metrics/cpy/metrics.py.cc new file mode 100644 index 000000000..10bfc2f97 --- /dev/null +++ b/ucm/shared/metrics/cpy/metrics.py.cc @@ -0,0 +1,50 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include +#include +#include "stats_monitor.h" + +namespace py = pybind11; +namespace UC::Metrics { + +void bind_monitor(py::module_& m) +{ + py::class_(m, "StatsMonitor") + .def_static("get_instance", &StatsMonitor::GetInstance, py::return_value_policy::reference) + .def("update_stats", &StatsMonitor::UpdateStats) + .def("reset_all", &StatsMonitor::ResetAllStats) + .def("get_stats", &StatsMonitor::GetStats) + .def("get_stats_and_clear", &StatsMonitor::GetStatsAndClear); +} + +} // namespace UC::Metrics + +PYBIND11_MODULE(ucmmonitor, module) +{ + module.attr("project") = UCM_PROJECT_NAME; + module.attr("version") = UCM_PROJECT_VERSION; + module.attr("commit_id") = UCM_COMMIT_ID; + module.attr("build_type") = UCM_BUILD_TYPE; + UC::Metrics::bind_monitor(module); +} \ No newline at end of file diff --git a/ucm/shared/metrics/observability.py b/ucm/shared/metrics/observability.py new file mode 100644 index 000000000..fb33400ce --- /dev/null +++ b/ucm/shared/metrics/observability.py @@ -0,0 +1,305 @@ +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + + +import os +import threading +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import prometheus_client +import yaml + +# Third Party +from prometheus_client import REGISTRY +from vllm.distributed.parallel_state import get_world_group + +from ucm.logger import init_logger +from ucm.shared.metrics import ucmmonitor + +logger = init_logger(__name__) + + +@dataclass +class UCMEngineMetadata: + """Metadata for UCM engine""" + + model_name: str + worker_id: str + + +class PrometheusLogger: + _gauge_cls = prometheus_client.Gauge + _counter_cls = prometheus_client.Counter + _histogram_cls = prometheus_client.Histogram + + def __init__(self, metadata: UCMEngineMetadata, config: Dict[str, Any]): + # Ensure PROMETHEUS_MULTIPROC_DIR is set before any metric registration + prometheus_config = config.get("prometheus", {}) + multiproc_dir = prometheus_config.get("multiproc_dir", "/vllm-workspace") + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + os.environ["PROMETHEUS_MULTIPROC_DIR"] = multiproc_dir + if not os.path.exists(multiproc_dir): + os.makedirs(multiproc_dir, exist_ok=True) + + self.metadata = metadata + self.config = config + self.labels = self._metadata_to_labels(metadata) + labelnames = list(self.labels.keys()) + + # Initialize metrics based on configuration + self._init_metrics_from_config(labelnames, prometheus_config) + + def _init_metrics_from_config( + self, labelnames: List[str], prometheus_config: Dict[str, Any] + ): + """Initialize metrics based on configuration""" + enabled = prometheus_config.get("enabled_metrics", {}) + + # Get metric name prefix from config (e.g., "ucm:") + # If not specified, use empty string + metric_prefix = prometheus_config.get("metric_prefix", "ucm:") + + # Store metric mapping: metric_name -> (metric_type, attribute_name, stats_field_name) + # This mapping will be used in log_prometheus to dynamically log metrics + self.metric_mappings: Dict[str, Dict[str, str]] = {} + + # Initialize counters + if enabled.get("counters", True): + counters = prometheus_config.get("counters", []) + for counter_cfg in counters: + name = counter_cfg.get("name") + doc = counter_cfg.get("documentation", "") + # Prometheus metric name with prefix + prometheus_name = f"{metric_prefix}{name}" if metric_prefix else name + # Internal attribute name for storing the metric object + attr_name = f"counter_{name}" + + if not hasattr(self, attr_name): + setattr( + self, + attr_name, + self._counter_cls( + name=prometheus_name, + documentation=doc, + labelnames=labelnames, + ), + ) + # Store mapping for dynamic logging + self.metric_mappings[name] = { + "type": "counter", + "attr": attr_name, + } + + # Initialize gauges + if enabled.get("gauges", True): + gauges = prometheus_config.get("gauges", []) + for gauge_cfg in gauges: + name = gauge_cfg.get("name") + doc = gauge_cfg.get("documentation", "") + multiprocess_mode = gauge_cfg.get("multiprocess_mode", "live") + # Prometheus metric name with prefix + prometheus_name = f"{metric_prefix}{name}" if metric_prefix else name + # Internal attribute name + attr_name = f"gauge_{name}" + + if not hasattr(self, attr_name): + setattr( + self, + attr_name, + self._gauge_cls( + name=prometheus_name, + documentation=doc, + labelnames=labelnames, + multiprocess_mode=multiprocess_mode, + ), + ) + # Store mapping for dynamic logging + self.metric_mappings[name] = { + "type": "gauge", + "attr": attr_name, + } + + # Initialize histograms + if enabled.get("histograms", True): + histograms = prometheus_config.get("histograms", []) + for hist_cfg in histograms: + name = hist_cfg.get("name") + doc = hist_cfg.get("documentation", "") + buckets = hist_cfg.get("buckets", []) + # Prometheus metric name with prefix + prometheus_name = f"{metric_prefix}{name}" if metric_prefix else name + # Internal attribute name + attr_name = f"histogram_{name}" + + if not hasattr(self, attr_name): + setattr( + self, + attr_name, + self._histogram_cls( + name=prometheus_name, + documentation=doc, + labelnames=labelnames, + buckets=buckets, + ), + ) + # Store mapping for dynamic logging + self.metric_mappings[name] = { + "type": "histogram", + "attr": attr_name, + } + + def _log_gauge(self, gauge, data: Union[int, float]) -> None: + # Convenience function for logging to gauge. + gauge.labels(**self.labels).set(data) + + def _log_counter(self, counter, data: Union[int, float]) -> None: + # Convenience function for logging to counter. + # Prevent ValueError from negative increment + if data < 0: + return + counter.labels(**self.labels).inc(data) + + def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None: + # Convenience function for logging to histogram. + for value in data: + histogram.labels(**self.labels).observe(value) + + def log_prometheus(self, stats: Any): + """Log metrics to Prometheus based on configuration file""" + # Dynamically log metrics based on what's configured in YAML + for stat_name, value in stats.items(): + try: + metric_mapped = self.metric_mappings[stat_name] + if metric_mapped is None: + logger.warning(f"Stat {stat_name} not initialized.") + continue + metric_obj = getattr(self, metric_mapped["attr"], None) + metric_type = metric_mapped["type"] + + # Log based on metric type + if metric_type == "counter": + self._log_counter(metric_obj, value) + elif metric_type == "gauge": + self._log_gauge(metric_obj, value) + elif metric_type == "histogram": + # Histograms expect a list + if not isinstance(value, list): + if value: + value = [value] + else: + value = [] + self._log_histogram(metric_obj, value) + except Exception as e: + logger.warning(f"Failed to log metric {stat_name}: {e}") + + @staticmethod + def _metadata_to_labels(metadata: UCMEngineMetadata): + return { + "model_name": metadata.model_name, + "worker_id": metadata.worker_id, + } + + _instance = None + + @staticmethod + def GetOrCreate( + metadata: UCMEngineMetadata, + config_path: str = "", + ) -> "PrometheusLogger": + if PrometheusLogger._instance is None: + PrometheusLogger._instance = PrometheusLogger(metadata, config_path) + # assert PrometheusLogger._instance.metadata == metadata, \ + # "PrometheusLogger instance already created with different metadata" + if PrometheusLogger._instance.metadata != metadata: + logger.error( + "PrometheusLogger instance already created with" + "different metadata. This should not happen except " + "in test" + ) + return PrometheusLogger._instance + + @staticmethod + def GetInstance() -> "PrometheusLogger": + assert ( + PrometheusLogger._instance is not None + ), "PrometheusLogger instance not created yet" + return PrometheusLogger._instance + + @staticmethod + def GetInstanceOrNone() -> Optional["PrometheusLogger"]: + """ + Returns the singleton instance of PrometheusLogger if it exists, + otherwise returns None. + """ + return PrometheusLogger._instance + + +class UCMStatsLogger: + def __init__(self, model_name: str, rank: int, config_path: str = ""): + # Create metadata + self.metadata = UCMEngineMetadata( + model_name=str(model_name), worker_id=str(rank) + ) + # Load configuration + config = self._load_config(config_path) + self.log_interval = config.get("log_interval", 10) + + self.monitor = ucmmonitor.StatsMonitor.get_instance() + self.prometheus_logger = PrometheusLogger.GetOrCreate(self.metadata, config) + self.is_running = True + + self.thread = threading.Thread(target=self.log_worker, daemon=True) + self.thread.start() + + def _load_config(self, config_path: str) -> Dict[str, Any]: + """Load configuration from YAML file""" + try: + with open(config_path, "r") as f: + config = yaml.safe_load(f) + if config is None: + logger.warning( + f"Config file {config_path} is empty, using defaults" + ) + return {} + return config + except FileNotFoundError: + logger.warning(f"Config file {config_path} not found, using defaults") + return {} + except yaml.YAMLError as e: + logger.error(f"Error parsing YAML config file {config_path}: {e}") + return {} + + def log_worker(self): + while self.is_running: + # Use UCMStatsMonitor.get_states_and_clear() from external import + stats = self.monitor.get_stats_and_clear("ConnStats") + self.prometheus_logger.log_prometheus(stats) + time.sleep(self.log_interval) + + def shutdown(self): + self.is_running = False + self.thread.join() diff --git a/ucm/shared/metrics/test/test.py b/ucm/shared/metrics/test/test.py new file mode 100644 index 000000000..246e6f880 --- /dev/null +++ b/ucm/shared/metrics/test/test.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + + +import os +import sys + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from ucm.shared.metrics import ucmmonitor + +# import monitor + +mon = ucmmonitor.StatsMonitor.get_instance() +mon.update_stats( + "ConnStats", + { + "save_duration": 1.2, + "save_speed": 300.5, + "load_duration": 0.8, + "load_speed": 450.0, + "interval_lookup_hit_rates": 0.95, + }, +) +mon.update_stats( + "ConnStats", + { + "save_duration": 1.2, + "save_speed": 300.5, + "load_duration": 0.8, + "load_speed": 450.0, + "interval_lookup_hit_rates": 0.95, + }, +) + +data = mon.get_stats("ConnStats") +print(data) diff --git a/ucm/shared/test/CMakeLists.txt b/ucm/shared/test/CMakeLists.txt new file mode 100644 index 000000000..07241d814 --- /dev/null +++ b/ucm/shared/test/CMakeLists.txt @@ -0,0 +1,11 @@ +if(BUILD_UNIT_TESTS) + include(GoogleTest) + file(GLOB_RECURSE UCMSHARED_TEST_SOURCE_FILES "./case/*.cc") + add_executable(ucmshared.test ${UCMSHARED_TEST_SOURCE_FILES}) + target_include_directories(ucmshared.test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/case) + target_link_libraries(ucmshared.test PRIVATE + trans + gtest_main gtest mockcpp + ) + gtest_discover_tests(ucmshared.test) +endif() diff --git a/ucm/store/test/case/infra/hashset_test.cc b/ucm/shared/test/case/infra/hashset_test.cc similarity index 100% rename from ucm/store/test/case/infra/hashset_test.cc rename to ucm/shared/test/case/infra/hashset_test.cc diff --git a/ucm/shared/test/case/trans/trans_test.cc b/ucm/shared/test/case/trans/trans_test.cc new file mode 100644 index 000000000..d8bef64b3 --- /dev/null +++ b/ucm/shared/test/case/trans/trans_test.cc @@ -0,0 +1,141 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include +#include "trans/device.h" + +class UCTransUnitTest : public ::testing::Test {}; + +TEST_F(UCTransUnitTest, CopyDataWithCE) +{ + const auto ok = UC::Status::OK(); + constexpr int32_t deviceId = 0; + constexpr size_t size = 36 * 1024; + constexpr size_t number = 64 * 61; + UC::Trans::Device device; + ASSERT_EQ(device.Setup(deviceId), ok); + auto buffer = device.MakeBuffer(); + auto stream = device.MakeStream(); + auto hPtr1 = buffer->MakeHostBuffer(size * number); + ASSERT_NE(hPtr1, nullptr); + ASSERT_EQ(buffer->MakeDeviceBuffers(size, number), ok); + std::vector> ptrHolder; + ptrHolder.reserve(number); + void* dPtrArr[number]; + for (size_t i = 0; i < number; i++) { + *(size_t*)(((char*)hPtr1.get()) + size * i) = i; + auto ptr = buffer->GetDeviceBuffer(size); + dPtrArr[i] = ptr.get(); + ptrHolder.emplace_back(ptr); + } + auto hPtr2 = buffer->MakeHostBuffer(size * number); + ASSERT_NE(hPtr2, nullptr); + ASSERT_EQ(stream->HostToDeviceAsync(hPtr1.get(), dPtrArr, size, number), ok); + ASSERT_EQ(stream->DeviceToHostAsync(dPtrArr, hPtr2.get(), size, number), ok); + ASSERT_EQ(stream->Synchronized(), ok); + for (size_t i = 0; i < number; i++) { + ASSERT_EQ(*(size_t*)(((char*)hPtr2.get()) + size * i), i); + } +} + +TEST_F(UCTransUnitTest, CopyDataWithSM) +{ + const auto ok = UC::Status::OK(); + constexpr int32_t deviceId = 0; + constexpr size_t size = 36 * 1024; + constexpr size_t number = 64 * 61; + UC::Trans::Device device; + ASSERT_EQ(device.Setup(deviceId), ok); + auto buffer = device.MakeBuffer(); + auto stream = device.MakeSMStream(); + if (!stream) { return; } + auto hPtr1 = buffer->MakeHostBuffer(size * number); + ASSERT_NE(hPtr1, nullptr); + ASSERT_EQ(buffer->MakeDeviceBuffers(size, number), ok); + std::vector> ptrHolder; + ptrHolder.reserve(number); + void* dPtrArr[number]; + for (size_t i = 0; i < number; i++) { + *(size_t*)(((char*)hPtr1.get()) + size * i) = i; + auto ptr = buffer->GetDeviceBuffer(size); + dPtrArr[i] = ptr.get(); + ptrHolder.emplace_back(ptr); + } + auto dPtrArrOnDev = buffer->MakeDeviceBuffer(sizeof(dPtrArr)); + ASSERT_EQ(stream->HostToDevice((void*)dPtrArr, dPtrArrOnDev.get(), sizeof(dPtrArr)), ok); + auto hPtr2 = buffer->MakeHostBuffer(size * number); + ASSERT_NE(hPtr2, nullptr); + ASSERT_EQ(stream->HostToDeviceAsync(hPtr1.get(), (void**)dPtrArrOnDev.get(), size, number), ok); + ASSERT_EQ(stream->DeviceToHostAsync((void**)dPtrArrOnDev.get(), hPtr2.get(), size, number), ok); + ASSERT_EQ(stream->Synchronized(), ok); + for (size_t i = 0; i < number; i++) { + ASSERT_EQ(*(size_t*)(((char*)hPtr2.get()) + size * i), i); + } +} + +TEST_F(UCTransUnitTest, CopyDataBatchWithSM) +{ + const auto ok = UC::Status::OK(); + constexpr int32_t deviceId = 0; + constexpr size_t size = 36 * 1024; + constexpr size_t number = 64 * 61; + UC::Trans::Device device; + ASSERT_EQ(device.Setup(deviceId), ok); + auto stream = device.MakeSMStream(); + if (!stream) { return; } + auto bDev = device.MakeBuffer(); + auto bHost1 = device.MakeBuffer(); + auto bHost2 = device.MakeBuffer(); + ASSERT_EQ(bDev->MakeDeviceBuffers(size, number), ok); + ASSERT_EQ(bHost1->MakeHostBuffers(size, number), ok); + ASSERT_EQ(bHost2->MakeHostBuffers(size, number), ok); + std::vector> devPtrHolder, host1PtrHolder, host2PtrHolder; + void *dPtrArr[number], *h1PtrArr[number], *h2PtrArr[number]; + for (size_t i = 0; i < number; i++) { + auto d = bDev->GetDeviceBuffer(size); + auto h1 = bHost1->GetHostBuffer(size); + auto h2 = bHost2->GetHostBuffer(size); + dPtrArr[i] = d.get(); + h1PtrArr[i] = h1.get(); + *(size_t*)h1PtrArr[i] = i; + h2PtrArr[i] = h2.get(); + devPtrHolder.emplace_back(d); + host1PtrHolder.emplace_back(h1); + host2PtrHolder.emplace_back(h2); + } + constexpr const auto arrSize = sizeof(void*) * number; + auto dPtrArrOnDev = bDev->MakeDeviceBuffer(arrSize); + auto h1PtrArrOnDev = bHost1->MakeDeviceBuffer(arrSize); + auto h2PtrArrOnDev = bHost2->MakeDeviceBuffer(arrSize); + ASSERT_EQ(stream->HostToDeviceAsync((void*)dPtrArr, dPtrArrOnDev.get(), arrSize), ok); + ASSERT_EQ(stream->HostToDeviceAsync((void*)h1PtrArr, h1PtrArrOnDev.get(), arrSize), ok); + ASSERT_EQ(stream->HostToDeviceAsync((void*)h2PtrArr, h2PtrArrOnDev.get(), arrSize), ok); + auto src = (void**)h1PtrArrOnDev.get(); + auto dst = (void**)dPtrArrOnDev.get(); + ASSERT_EQ(stream->HostToDeviceAsync(src, dst, size, number), ok); + src = (void**)dPtrArrOnDev.get(); + dst = (void**)h2PtrArrOnDev.get(); + ASSERT_EQ(stream->DeviceToHostAsync(src, dst, size, number), ok); + ASSERT_EQ(stream->Synchronized().Underlying(), ok.Underlying()); + for (size_t i = 0; i < number; i++) { ASSERT_EQ(*(size_t*)h2PtrArr[i], i); } +} diff --git a/ucm/shared/test/example/trans/trans_on_cuda_example.py b/ucm/shared/test/example/trans/trans_on_cuda_example.py new file mode 100644 index 000000000..01ffd864d --- /dev/null +++ b/ucm/shared/test/example/trans/trans_on_cuda_example.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +import time +from functools import wraps + +import cupy +import numpy as np + +from ucm.shared.trans import ucmtrans + + +def test_wrap(func): + @wraps(func) + def wrapper(*args, **kwargs): + print(f"========>> Running in {func.__name__}:") + result = func(*args, **kwargs) + print() + return result + + return wrapper + + +def make_host_memory(size, number, dtype, fill=False): + element_size = np.dtype(dtype).itemsize + num_elements = size // element_size + host = cupy.cuda.alloc_pinned_memory(size * number) + host_np = np.frombuffer(host, dtype=dtype, count=num_elements) + if fill: + fixed_len = min(1024, number) + host_np[:fixed_len] = np.arange(fixed_len, dtype=dtype) + print("make:", host_np.shape, host_np.itemsize, host_np) + return host + + +def make_batch_host_memory(size, number, dtype, fill=False): + element_size = np.dtype(dtype).itemsize + num_elements = size // element_size + host = [] + for i in range(number): + pinned_mem = cupy.cuda.alloc_pinned_memory(size) + np_array = np.frombuffer(pinned_mem, dtype=dtype, count=num_elements) + if fill: + value = np.uint64(1023 + i) + np_array[0] = value + np_array[-1] = value + host.append(pinned_mem) + if i == 0: + print("make:", np_array.shape, np_array.itemsize, np_array) + return host + + +def compare(host1, host2, size, dtype, show_detail=True): + element_size = np.dtype(dtype).itemsize + num_elements = size // element_size + host1_np = np.frombuffer(host1, dtype=dtype, count=num_elements) + host2_np = np.frombuffer(host2, dtype=dtype, count=num_elements) + if show_detail: + print("compare[1]:", host1_np.shape, host1_np.itemsize, host1_np) + print("compare[2]:", host2_np.shape, host2_np.itemsize, host2_np) + return np.array_equal(host1_np, host2_np) + + +@test_wrap +def trans_with_ce(d, size, number, dtype): + s = d.MakeStream() + host1 = make_host_memory(size, number, dtype, True) + device = [cupy.empty(size, dtype=np.uint8) for _ in range(number)] + device_ptr = np.array([d.data.ptr for d in device], dtype=np.uint64) + host2 = make_host_memory(size, number, dtype) + tp = time.perf_counter() + s.HostToDeviceScatter(host1.ptr, device_ptr, size, number) + s.DeviceToHostGather(device_ptr, host2.ptr, size, number) + cost = time.perf_counter() - tp + print(f"cost: {cost}s") + print(f"bandwidth: {size * number / cost / 1e9}GB/s") + assert compare(host1, host2, size, dtype) + + +@test_wrap +def trans_with_sm(d, size, number, dtype): + s = d.MakeSMStream() + host1 = make_host_memory(size, number, dtype, True) + device = [cupy.empty(size, dtype=np.uint8) for _ in range(number)] + device_ptr = np.array([d.data.ptr for d in device], dtype=np.uint64) + device_ptr_cupy = cupy.empty(number, dtype=np.uint64) + device_ptr_cupy.set(device_ptr) + host2 = make_host_memory(size, number, dtype) + tp = time.perf_counter() + s.HostToDeviceScatter(host1.ptr, device_ptr_cupy.data.ptr, size, number) + s.DeviceToHostGather(device_ptr_cupy.data.ptr, host2.ptr, size, number) + cost = time.perf_counter() - tp + print(f"cost: {cost}s") + print(f"bandwidth: {size * number / cost / 1e9}GB/s") + assert compare(host1, host2, size, dtype) + + +@test_wrap +def trans_with_ce_async(d, size, number, dtype): + s = d.MakeStream() + host1 = make_host_memory(size, number, dtype, True) + device = [cupy.empty(size, dtype=np.uint8) for _ in range(number)] + device_ptr = np.array([d.data.ptr for d in device], dtype=np.uint64) + host2 = make_host_memory(size, number, dtype) + tp = time.perf_counter() + s.HostToDeviceScatterAsync(host1.ptr, device_ptr, size, number) + s.DeviceToHostGatherAsync(device_ptr, host2.ptr, size, number) + s.Synchronized() + cost = time.perf_counter() - tp + print(f"cost: {cost}s") + print(f"bandwidth: {size * number / cost / 1e9}GB/s") + assert compare(host1, host2, size, dtype) + + +@test_wrap +def trans_with_sm_async(d, size, number, dtype): + s = d.MakeSMStream() + host1 = make_host_memory(size, number, dtype, True) + device = [cupy.empty(size, dtype=np.uint8) for _ in range(number)] + device_ptr = np.array([d.data.ptr for d in device], dtype=np.uint64) + device_ptr_cupy = cupy.empty(number, dtype=np.uint64) + device_ptr_cupy.set(device_ptr) + host2 = make_host_memory(size, number, dtype) + tp = time.perf_counter() + s.HostToDeviceScatterAsync(host1.ptr, device_ptr_cupy.data.ptr, size, number) + s.DeviceToHostGatherAsync(device_ptr_cupy.data.ptr, host2.ptr, size, number) + s.Synchronized() + cost = time.perf_counter() - tp + print(f"cost: {cost}s") + print(f"bandwidth: {size * number / cost / 1e9}GB/s") + assert compare(host1, host2, size, dtype) + + +@test_wrap +def trans_batch_with_ce(d, size, number, dtype): + s = d.MakeStream() + host1 = make_batch_host_memory(size, number, dtype, True) + host1_ptr = np.array([h.ptr for h in host1], dtype=np.uint64) + device = [cupy.empty(size, dtype=np.uint8) for _ in range(number)] + device_ptr = np.array([d.data.ptr for d in device], dtype=np.uint64) + host2 = make_batch_host_memory(size, number, dtype) + host2_ptr = np.array([h.ptr for h in host2], dtype=np.uint64) + tp = time.perf_counter() + s.HostToDeviceBatch(host1_ptr, device_ptr, size, number) + s.DeviceToHostBatch(device_ptr, host2_ptr, size, number) + cost = time.perf_counter() - tp + print(f"cost: {cost}s") + print(f"bandwidth: {size * number / cost / 1e9}GB/s") + for h1, h2 in zip(host1, host2): + assert compare(h1, h2, size, dtype, False) + + +@test_wrap +def trans_batch_with_sm(dev, size, number, dtype): + s = dev.MakeSMStream() + h1 = make_batch_host_memory(size, number, dtype, True) + h1_ptr = np.array([h.ptr for h in h1], dtype=np.uint64) + h1_ptr_cupy = cupy.empty(number, dtype=np.uint64) + h1_ptr_cupy.set(h1_ptr) + d = [cupy.empty(size, dtype=np.uint8) for _ in range(number)] + d_ptr = np.array([d.data.ptr for d in d], dtype=np.uint64) + d_ptr_cupy = cupy.empty(number, dtype=np.uint64) + d_ptr_cupy.set(d_ptr) + h2 = make_batch_host_memory(size, number, dtype) + h2_ptr = np.array([h.ptr for h in h2], dtype=np.uint64) + h2_ptr_cupy = cupy.empty(number, dtype=np.uint64) + h2_ptr_cupy.set(h2_ptr) + tp = time.perf_counter() + s.HostToDeviceBatch(h1_ptr_cupy.data.ptr, d_ptr_cupy.data.ptr, size, number) + s.DeviceToHostBatch(d_ptr_cupy.data.ptr, h2_ptr_cupy.data.ptr, size, number) + cost = time.perf_counter() - tp + print(f"cost: {cost}s") + print(f"bandwidth: {size * number / cost / 1e9}GB/s") + for x, y in zip(h1, h2): + assert compare(x, y, size, dtype, False) + + +@test_wrap +def trans_batch_with_ce_async(d, size, number, dtype): + s = d.MakeStream() + host1 = make_batch_host_memory(size, number, dtype, True) + host1_ptr = np.array([h.ptr for h in host1], dtype=np.uint64) + device = [cupy.empty(size, dtype=np.uint8) for _ in range(number)] + device_ptr = np.array([d.data.ptr for d in device], dtype=np.uint64) + host2 = make_batch_host_memory(size, number, dtype) + host2_ptr = np.array([h.ptr for h in host2], dtype=np.uint64) + tp = time.perf_counter() + s.HostToDeviceBatchAsync(host1_ptr, device_ptr, size, number) + s.DeviceToHostBatchAsync(device_ptr, host2_ptr, size, number) + s.Synchronized() + cost = time.perf_counter() - tp + print(f"cost: {cost}s") + print(f"bandwidth: {size * number / cost / 1e9}GB/s") + for h1, h2 in zip(host1, host2): + assert compare(h1, h2, size, dtype, False) + + +@test_wrap +def trans_batch_with_sm_async(dev, size, number, dtype): + s = dev.MakeSMStream() + h1 = make_batch_host_memory(size, number, dtype, True) + h1_ptr = np.array([h.ptr for h in h1], dtype=np.uint64) + h1_ptr_cupy = cupy.empty(number, dtype=np.uint64) + h1_ptr_cupy.set(h1_ptr) + d = [cupy.empty(size, dtype=np.uint8) for _ in range(number)] + d_ptr = np.array([d.data.ptr for d in d], dtype=np.uint64) + d_ptr_cupy = cupy.empty(number, dtype=np.uint64) + d_ptr_cupy.set(d_ptr) + h2 = make_batch_host_memory(size, number, dtype) + h2_ptr = np.array([h.ptr for h in h2], dtype=np.uint64) + h2_ptr_cupy = cupy.empty(number, dtype=np.uint64) + h2_ptr_cupy.set(h2_ptr) + tp = time.perf_counter() + s.HostToDeviceBatchAsync(h1_ptr_cupy.data.ptr, d_ptr_cupy.data.ptr, size, number) + s.DeviceToHostBatchAsync(d_ptr_cupy.data.ptr, h2_ptr_cupy.data.ptr, size, number) + s.Synchronized() + cost = time.perf_counter() - tp + print(f"cost: {cost}s") + print(f"bandwidth: {size * number / cost / 1e9}GB/s") + for x, y in zip(h1, h2): + assert compare(x, y, size, dtype, False) + + +def main(): + device_id = 0 + size = 36 * 1024 + number = 61 * 64 + dtype = np.float16 + print(f"ucmtrans: {ucmtrans.commit_id}-{ucmtrans.build_type}") + cupy.cuda.Device(device_id).use() + d = ucmtrans.Device() + d.Setup(device_id) + trans_with_ce(d, size, number, dtype) + trans_with_sm(d, size, number, dtype) + trans_with_ce_async(d, size, number, dtype) + trans_with_sm_async(d, size, number, dtype) + trans_batch_with_ce(d, size, number, dtype) + trans_batch_with_sm(d, size, number, dtype) + trans_batch_with_ce_async(d, size, number, dtype) + trans_batch_with_sm_async(d, size, number, dtype) + + +if __name__ == "__main__": + main() diff --git a/ucm/shared/trans/CMakeLists.txt b/ucm/shared/trans/CMakeLists.txt new file mode 100644 index 000000000..57a1bd0aa --- /dev/null +++ b/ucm/shared/trans/CMakeLists.txt @@ -0,0 +1,16 @@ +if(RUNTIME_ENVIRONMENT STREQUAL "ascend") + add_subdirectory(ascend) +endif() +if(RUNTIME_ENVIRONMENT STREQUAL "cuda") + add_subdirectory(cuda) +endif() +if(RUNTIME_ENVIRONMENT STREQUAL "simu") + add_subdirectory(simu) +endif() +target_include_directories(trans PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..) +target_link_libraries(trans PUBLIC infra_status) + +file(GLOB_RECURSE UCMTRANS_CPY_SOURCE_FILES "./cpy/*.cc") +pybind11_add_module(ucmtrans ${UCMTRANS_CPY_SOURCE_FILES}) +target_link_libraries(ucmtrans PRIVATE trans) +set_target_properties(ucmtrans PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/ucm/shared/trans/__init__.py b/ucm/shared/trans/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ucm/shared/trans/ascend/CMakeLists.txt b/ucm/shared/trans/ascend/CMakeLists.txt new file mode 100644 index 000000000..5a1c3670c --- /dev/null +++ b/ucm/shared/trans/ascend/CMakeLists.txt @@ -0,0 +1,15 @@ +set(ASCEND_ROOT "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Path to Ascend root directory") +add_library(Ascend::ascendcl UNKNOWN IMPORTED) +set_target_properties(Ascend::ascendcl PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${ASCEND_ROOT}/include" + IMPORTED_LOCATION "${ASCEND_ROOT}/lib64/libascendcl.so" +) +add_library(trans STATIC + ascend_device.cc + ascend_buffer.cc + ascend_stream.cc +) +target_link_libraries(trans PUBLIC + fmt + Ascend::ascendcl +) diff --git a/ucm/store/test/case/nfsstore/tsf_task_waiter_test.cc b/ucm/shared/trans/ascend/ascend_buffer.cc similarity index 53% rename from ucm/store/test/case/nfsstore/tsf_task_waiter_test.cc rename to ucm/shared/trans/ascend/ascend_buffer.cc index bcf11a44a..cb748bda2 100644 --- a/ucm/store/test/case/nfsstore/tsf_task_waiter_test.cc +++ b/ucm/shared/trans/ascend/ascend_buffer.cc @@ -21,42 +21,36 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#include -#include -#include -#include "tsf_task/tsf_task_waiter.h" +#include "ascend_buffer.h" +#include -class UCTsfTaskWaiterTest : public ::testing::Test {}; +namespace UC::Trans { -TEST_F(UCTsfTaskWaiterTest, TaskTimeout) +std::shared_ptr Trans::AscendBuffer::MakeDeviceBuffer(size_t size) { - UC::TsfTaskWaiter waiter{1, 1024, 1, "xxx"}; - auto fut = std::async([&] { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - waiter.Done(); - }); - ASSERT_FALSE(waiter.Wait(1)); - fut.get(); - ASSERT_TRUE(waiter.Finish()); + void* device = nullptr; + auto ret = aclrtMalloc(&device, size, ACL_MEM_TYPE_HIGH_BAND_WIDTH); + if (ret == ACL_SUCCESS) { return std::shared_ptr(device, aclrtFree); } + return nullptr; } -TEST_F(UCTsfTaskWaiterTest, TaskSuccess) +std::shared_ptr Trans::AscendBuffer::MakeHostBuffer(size_t size) { - UC::TsfTaskWaiter waiter{1, 1024, 1, "xxx"}; - auto fut = std::async([&] { waiter.Done(); }); - ASSERT_TRUE(waiter.Wait(1000)); - ASSERT_TRUE(waiter.Finish()); - fut.get(); + void* host = nullptr; + auto ret = aclrtMallocHost(&host, size); + if (ret == ACL_SUCCESS) { return std::shared_ptr(host, aclrtFreeHost); } + return nullptr; } -TEST_F(UCTsfTaskWaiterTest, TaskTimeoutButSuccess) +Status Buffer::RegisterHostBuffer(void* host, size_t size, void** pDevice) { - UC::TsfTaskWaiter waiter{1, 1024, 1, "xxx"}; - auto fut = std::async([&] { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - waiter.Done(); - }); - fut.get(); - ASSERT_TRUE(waiter.Finish()); - ASSERT_TRUE(waiter.Wait(1)); + void* device = nullptr; + auto ret = aclrtHostRegister(host, size, ACL_HOST_REGISTER_MAPPED, &device); + if (ret != ACL_SUCCESS) [[unlikely]] { return Status{ret, std::to_string(ret)}; } + if (pDevice) { *pDevice = device; } + return Status::OK(); } + +void Buffer::UnregisterHostBuffer(void* host) { aclrtHostUnregister(host); } + +} // namespace UC::Trans diff --git a/ucm/shared/trans/ascend/ascend_buffer.h b/ucm/shared/trans/ascend/ascend_buffer.h new file mode 100644 index 000000000..3f64ce237 --- /dev/null +++ b/ucm/shared/trans/ascend/ascend_buffer.h @@ -0,0 +1,39 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_ASCEND_BUFFER_H +#define UNIFIEDCACHE_TRANS_ASCEND_BUFFER_H + +#include "trans/detail/reserved_buffer.h" + +namespace UC::Trans { + +class AscendBuffer : public ReservedBuffer { +public: + std::shared_ptr MakeDeviceBuffer(size_t size) override; + std::shared_ptr MakeHostBuffer(size_t size) override; +}; + +} // namespace UC::Trans + +#endif diff --git a/ucm/shared/trans/ascend/ascend_device.cc b/ucm/shared/trans/ascend/ascend_device.cc new file mode 100644 index 000000000..bfc92f987 --- /dev/null +++ b/ucm/shared/trans/ascend/ascend_device.cc @@ -0,0 +1,62 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include +#include "ascend_buffer.h" +#include "ascend_stream.h" +#include "trans/device.h" + +namespace UC::Trans { + +Status Device::Setup(int32_t deviceId) +{ + if (deviceId < 0) { return Status::Error(fmt::format("invalid device id({})", deviceId)); } + auto ret = aclrtSetDevice(deviceId); + if (ret == ACL_SUCCESS) { return Status::OK(); } + return Status{ret, std::to_string(ret)}; +} + +std::unique_ptr Device::MakeStream() +{ + std::unique_ptr stream = nullptr; + try { + stream = std::make_unique(); + } catch (...) { + return nullptr; + } + if (stream->Setup().Success()) { return stream; } + return nullptr; +} + +std::unique_ptr Device::MakeSMStream() { return nullptr; } + +std::unique_ptr Device::MakeBuffer() +{ + try { + return std::make_unique(); + } catch (...) { + return nullptr; + } +} + +} // namespace UC::Trans diff --git a/ucm/shared/trans/ascend/ascend_stream.cc b/ucm/shared/trans/ascend/ascend_stream.cc new file mode 100644 index 000000000..104612716 --- /dev/null +++ b/ucm/shared/trans/ascend/ascend_stream.cc @@ -0,0 +1,177 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "ascend_stream.h" + +namespace UC::Trans { + +AscendStream::~AscendStream() +{ + if (cbThread_.joinable()) { + auto tid = cbThread_.native_handle(); + (void)aclrtUnSubscribeReport(tid, stream_); + stop_ = true; + cbThread_.join(); + } + if (stream_) { + (void)aclrtDestroyStream(stream_); + stream_ = nullptr; + } +} + +Status AscendStream::Setup() +{ + auto ret = aclrtCreateStream(&stream_); + if (ret != ACL_SUCCESS) [[unlikely]] { return Status{ret, std::to_string(ret)}; } + cbThread_ = std::thread([this] { + while (!this->stop_) { (void)aclrtProcessReport(10); } + }); + auto tid = cbThread_.native_handle(); + ret = aclrtSubscribeReport(tid, stream_); + if (ret != ACL_SUCCESS) [[unlikely]] { return Status{ret, std::to_string(ret)}; } + return Status::OK(); +} + +Status AscendStream::DeviceToHost(void* device, void* host, size_t size) +{ + auto ret = aclrtMemcpy(host, size, device, size, ACL_MEMCPY_DEVICE_TO_HOST); + if (ret == ACL_SUCCESS) { return Status::OK(); } + return Status{ret, std::to_string(ret)}; +} + +Status AscendStream::DeviceToHost(void* device[], void* host[], size_t size, size_t number) +{ + auto s = DeviceToHostAsync(device, host, size, number); + if (s.Failure()) [[unlikely]] { return s; } + return Synchronized(); +} + +Status AscendStream::DeviceToHost(void* device[], void* host, size_t size, size_t number) +{ + auto s = DeviceToHostAsync(device, host, size, number); + if (s.Failure()) [[unlikely]] { return s; } + return Synchronized(); +} + +Status AscendStream::DeviceToHostAsync(void* device, void* host, size_t size) +{ + auto ret = aclrtMemcpyAsync(host, size, device, size, ACL_MEMCPY_DEVICE_TO_HOST, stream_); + if (ret == ACL_SUCCESS) { return Status::OK(); } + return Status{ret, std::to_string(ret)}; +} + +Status AscendStream::DeviceToHostAsync(void* device[], void* host[], size_t size, size_t number) +{ + for (size_t i = 0; i < number; i++) { + auto s = DeviceToHostAsync(device[i], host[i], size); + if (s.Failure()) [[unlikely]] { return s; } + } + return Status::OK(); +} + +Status AscendStream::DeviceToHostAsync(void* device[], void* host, size_t size, size_t number) +{ + for (size_t i = 0; i < number; i++) { + auto pHost = (void*)(((int8_t*)host) + size * i); + auto s = DeviceToHostAsync(device[i], pHost, size); + if (s.Failure()) [[unlikely]] { return s; } + } + return Status::OK(); +} + +Status AscendStream::HostToDevice(void* host, void* device, size_t size) +{ + auto ret = aclrtMemcpy(device, size, host, size, ACL_MEMCPY_HOST_TO_DEVICE); + if (ret == ACL_SUCCESS) { return Status::OK(); } + return Status{ret, std::to_string(ret)}; +} + +Status AscendStream::HostToDevice(void* host[], void* device[], size_t size, size_t number) +{ + auto s = HostToDeviceAsync(host, device, size, number); + if (s.Failure()) [[unlikely]] { return s; } + return Synchronized(); +} + +Status AscendStream::HostToDevice(void* host, void* device[], size_t size, size_t number) +{ + auto s = HostToDeviceAsync(host, device, size, number); + if (s.Failure()) [[unlikely]] { return s; } + return Synchronized(); +} + +Status AscendStream::HostToDeviceAsync(void* host, void* device, size_t size) +{ + auto ret = aclrtMemcpyAsync(device, size, host, size, ACL_MEMCPY_HOST_TO_DEVICE, stream_); + if (ret == ACL_SUCCESS) { return Status::OK(); } + return Status{ret, std::to_string(ret)}; +} + +Status AscendStream::HostToDeviceAsync(void* host[], void* device[], size_t size, size_t number) +{ + for (size_t i = 0; i < number; i++) { + auto s = HostToDeviceAsync(host[i], device[i], size); + if (s.Failure()) [[unlikely]] { return s; } + } + return Status::OK(); +} + +Status AscendStream::HostToDeviceAsync(void* host, void* device[], size_t size, size_t number) +{ + for (size_t i = 0; i < number; i++) { + auto pHost = (void*)(((int8_t*)host) + size * i); + auto s = HostToDeviceAsync(pHost, device[i], size); + if (s.Failure()) [[unlikely]] { return s; } + } + return Status::OK(); +} + +using Closure = std::function; + +static void Trampoline(void* data) +{ + auto c = static_cast(data); + (*c)(true); + delete c; +} + +Status Trans::AscendStream::AppendCallback(std::function cb) +{ + auto c = new (std::nothrow) Closure{std::move(cb)}; + if (!c) [[unlikely]] { return Status::Error("out of memory for appending callback"); } + auto ret = aclrtLaunchCallback(Trampoline, (void*)c, ACL_CALLBACK_NO_BLOCK, stream_); + if (ret != ACL_SUCCESS) [[unlikely]] { + delete c; + return Status{ret, std::to_string(ret)}; + } + return Status::OK(); +} + +Status AscendStream::Synchronized() +{ + auto ret = aclrtSynchronizeStream(stream_); + if (ret == ACL_SUCCESS) { return Status::OK(); } + return Status{ret, std::to_string(ret)}; +} + +} // namespace UC::Trans diff --git a/ucm/shared/trans/ascend/ascend_stream.h b/ucm/shared/trans/ascend/ascend_stream.h new file mode 100644 index 000000000..8ae53d595 --- /dev/null +++ b/ucm/shared/trans/ascend/ascend_stream.h @@ -0,0 +1,64 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_ASCEND_STREAM_H +#define UNIFIEDCACHE_TRANS_ASCEND_STREAM_H + +#include +#include +#include +#include "trans/stream.h" + +namespace UC::Trans { + +class AscendStream : public Stream { +protected: + aclrtStream stream_{nullptr}; + std::atomic_bool stop_{false}; + std::thread cbThread_; + +public: + ~AscendStream() override; + Status Setup() override; + + Status DeviceToHost(void* device, void* host, size_t size) override; + Status DeviceToHost(void* device[], void* host[], size_t size, size_t number) override; + Status DeviceToHost(void* device[], void* host, size_t size, size_t number) override; + Status DeviceToHostAsync(void* device, void* host, size_t size) override; + Status DeviceToHostAsync(void* device[], void* host[], size_t size, size_t number) override; + Status DeviceToHostAsync(void* device[], void* host, size_t size, size_t number) override; + + Status HostToDevice(void* host, void* device, size_t size) override; + Status HostToDevice(void* host[], void* device[], size_t size, size_t number) override; + Status HostToDevice(void* host, void* device[], size_t size, size_t number) override; + Status HostToDeviceAsync(void* host, void* device, size_t size) override; + Status HostToDeviceAsync(void* host[], void* device[], size_t size, size_t number) override; + Status HostToDeviceAsync(void* host, void* device[], size_t size, size_t number) override; + + Status AppendCallback(std::function cb) override; + Status Synchronized() override; +}; + +} // namespace UC::Trans + +#endif diff --git a/ucm/shared/trans/buffer.h b/ucm/shared/trans/buffer.h new file mode 100644 index 000000000..a73752513 --- /dev/null +++ b/ucm/shared/trans/buffer.h @@ -0,0 +1,50 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_BUFFER_H +#define UNIFIEDCACHE_TRANS_BUFFER_H + +#include +#include "status/status.h" + +namespace UC::Trans { + +class Buffer { +public: + virtual ~Buffer() = default; + + virtual std::shared_ptr MakeDeviceBuffer(size_t size) = 0; + virtual Status MakeDeviceBuffers(size_t size, size_t number) = 0; + virtual std::shared_ptr GetDeviceBuffer(size_t size) = 0; + + virtual std::shared_ptr MakeHostBuffer(size_t size) = 0; + virtual Status MakeHostBuffers(size_t size, size_t number) = 0; + virtual std::shared_ptr GetHostBuffer(size_t size) = 0; + + static Status RegisterHostBuffer(void* host, size_t size, void** pDevice = nullptr); + static void UnregisterHostBuffer(void* host); +}; + +} // namespace UC::Trans + +#endif diff --git a/ucm/shared/trans/cpy/trans.py.cc b/ucm/shared/trans/cpy/trans.py.cc new file mode 100644 index 000000000..952fc8e69 --- /dev/null +++ b/ucm/shared/trans/cpy/trans.py.cc @@ -0,0 +1,192 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include +#include +#include "trans/device.h" + +namespace py = pybind11; + +namespace UC::Trans { + +using Ptr = uintptr_t; +using PtrArray = py::array_t; + +inline void ThrowIfFailed(const Status& s) +{ + if (s.Failure()) [[unlikely]] { throw std::runtime_error{s.ToString()}; } +} + +inline void DeviceToHost(Stream& self, Ptr src, Ptr dst, size_t size) +{ + ThrowIfFailed(self.DeviceToHost((void*)src, (void*)dst, size)); +} + +inline void DeviceToHostBatch(Stream& self, py::object src, py::object dst, size_t size, + size_t number) +{ + if (py::isinstance(src)) { + auto device = static_cast(src.cast().request().ptr); + auto host = static_cast(dst.cast().request().ptr); + ThrowIfFailed(self.DeviceToHost(device, host, size, number)); + } else { + auto device = static_cast((void*)src.cast()); + auto host = static_cast((void*)dst.cast()); + ThrowIfFailed(self.DeviceToHost(device, host, size, number)); + } +} + +inline void DeviceToHostGather(Stream& self, py::object src, Ptr dst, size_t size, size_t number) +{ + if (py::isinstance(src)) { + auto device = static_cast(src.cast().request().ptr); + ThrowIfFailed(self.DeviceToHost(device, (void*)dst, size, number)); + } else { + auto device = static_cast((void*)src.cast()); + ThrowIfFailed(self.DeviceToHost(device, (void*)dst, size, number)); + } +} + +inline void DeviceToHostAsync(Stream& self, Ptr src, Ptr dst, size_t size) +{ + ThrowIfFailed(self.DeviceToHostAsync((void*)src, (void*)dst, size)); +} + +inline void DeviceToHostBatchAsync(Stream& self, py::object src, py::object dst, size_t size, + size_t number) +{ + if (py::isinstance(src)) { + auto device = static_cast(src.cast().request().ptr); + auto host = static_cast(dst.cast().request().ptr); + ThrowIfFailed(self.DeviceToHostAsync(device, host, size, number)); + } else { + auto device = static_cast((void*)src.cast()); + auto host = static_cast((void*)dst.cast()); + ThrowIfFailed(self.DeviceToHostAsync(device, host, size, number)); + } +} + +inline void DeviceToHostGatherAsync(Stream& self, py::object src, Ptr dst, size_t size, + size_t number) +{ + if (py::isinstance(src)) { + auto device = static_cast(src.cast().request().ptr); + ThrowIfFailed(self.DeviceToHostAsync(device, (void*)dst, size, number)); + } else { + auto device = static_cast((void*)src.cast()); + ThrowIfFailed(self.DeviceToHostAsync(device, (void*)dst, size, number)); + } +} + +inline void HostToDevice(Stream& self, Ptr src, Ptr dst, size_t size) +{ + ThrowIfFailed(self.HostToDevice((void*)src, (void*)dst, size)); +} + +inline void HostToDeviceBatch(Stream& self, py::object src, py::object dst, size_t size, + size_t number) +{ + if (py::isinstance(src)) { + auto host = static_cast(src.cast().request().ptr); + auto device = static_cast(dst.cast().request().ptr); + ThrowIfFailed(self.HostToDevice(host, device, size, number)); + } else { + auto host = static_cast((void*)src.cast()); + auto device = static_cast((void*)dst.cast()); + ThrowIfFailed(self.HostToDevice(host, device, size, number)); + } +} + +inline void HostToDeviceScatter(Stream& self, Ptr src, py::object dst, size_t size, size_t number) +{ + if (py::isinstance(dst)) { + auto device = static_cast(dst.cast().request().ptr); + ThrowIfFailed(self.HostToDevice((void*)src, device, size, number)); + } else { + auto device = static_cast((void*)dst.cast()); + ThrowIfFailed(self.HostToDevice((void*)src, device, size, number)); + } +} + +inline void HostToDeviceAsync(Stream& self, Ptr src, Ptr dst, size_t size) +{ + ThrowIfFailed(self.HostToDeviceAsync((void*)src, (void*)dst, size)); +} + +inline void HostToDeviceBatchAsync(Stream& self, py::object src, py::object dst, size_t size, + size_t number) +{ + if (py::isinstance(src)) { + auto host = static_cast(src.cast().request().ptr); + auto device = static_cast(dst.cast().request().ptr); + ThrowIfFailed(self.HostToDeviceAsync(host, device, size, number)); + } else { + auto host = static_cast((void*)src.cast()); + auto device = static_cast((void*)dst.cast()); + ThrowIfFailed(self.HostToDeviceAsync(host, device, size, number)); + } +} + +inline void HostToDeviceScatterAsync(Stream& self, Ptr src, py::object dst, size_t size, + size_t number) +{ + if (py::isinstance(dst)) { + auto device = static_cast(dst.cast().request().ptr); + ThrowIfFailed(self.HostToDeviceAsync((void*)src, device, size, number)); + } else { + auto device = static_cast((void*)dst.cast()); + ThrowIfFailed(self.HostToDeviceAsync((void*)src, device, size, number)); + } +} + +} // namespace UC::Trans + +PYBIND11_MODULE(ucmtrans, m) +{ + using namespace UC::Trans; + m.attr("project") = UCM_PROJECT_NAME; + m.attr("version") = UCM_PROJECT_VERSION; + m.attr("commit_id") = UCM_COMMIT_ID; + m.attr("build_type") = UCM_BUILD_TYPE; + + auto s = py::class_>(m, "Stream"); + s.def("DeviceToHost", &DeviceToHost); + s.def("DeviceToHostBatch", &DeviceToHostBatch); + s.def("DeviceToHostGather", &DeviceToHostGather); + s.def("DeviceToHostAsync", &DeviceToHostAsync); + s.def("DeviceToHostBatchAsync", &DeviceToHostBatchAsync); + s.def("DeviceToHostGatherAsync", &DeviceToHostGatherAsync); + s.def("HostToDevice", &HostToDevice); + s.def("HostToDeviceBatch", &HostToDeviceBatch); + s.def("HostToDeviceScatter", &HostToDeviceScatter); + s.def("HostToDeviceAsync", &HostToDeviceAsync); + s.def("HostToDeviceBatchAsync", &HostToDeviceBatchAsync); + s.def("HostToDeviceScatterAsync", &HostToDeviceScatterAsync); + s.def("Synchronized", [](Stream& self) { ThrowIfFailed(self.Synchronized()); }); + + auto d = py::class_(m, "Device"); + d.def(py::init<>()); + d.def("Setup", [](Device& self, int32_t deviceId) { ThrowIfFailed(self.Setup(deviceId)); }); + d.def("MakeStream", &Device::MakeStream); + d.def("MakeSMStream", &Device::MakeSMStream); +} diff --git a/ucm/shared/trans/cuda/CMakeLists.txt b/ucm/shared/trans/cuda/CMakeLists.txt new file mode 100644 index 000000000..b98de9985 --- /dev/null +++ b/ucm/shared/trans/cuda/CMakeLists.txt @@ -0,0 +1,23 @@ +set(CUDA_ROOT "/usr/local/cuda/" CACHE PATH "Path to CUDA root directory") +set(CMAKE_CUDA_COMPILER ${CUDA_ROOT}/bin/nvcc) +set(CMAKE_CUDA_ARCHITECTURES 75 80 86 89 90) +enable_language(CUDA) +add_library(kernel OBJECT cuda_sm_kernel.cu) +target_compile_options(kernel PRIVATE + --diag-suppress=128 --diag-suppress=2417 --diag-suppress=2597 + -Wall -fPIC +) +add_library(trans STATIC + cuda_device.cc + cuda_buffer.cc + cuda_stream.cc + cuda_sm_stream.cc +) +target_include_directories(trans PUBLIC ${CUDA_ROOT}/include) +target_link_directories(trans PUBLIC ${CUDA_ROOT}/lib64) +target_link_libraries(trans PUBLIC + fmt + cudart + nvidia-ml + kernel +) diff --git a/ucm/shared/trans/cuda/cuda_buffer.cc b/ucm/shared/trans/cuda/cuda_buffer.cc new file mode 100644 index 000000000..11f25a7e2 --- /dev/null +++ b/ucm/shared/trans/cuda/cuda_buffer.cc @@ -0,0 +1,58 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "cuda_buffer.h" +#include + +namespace UC::Trans { + +std::shared_ptr CudaBuffer::MakeDeviceBuffer(size_t size) +{ + void* device = nullptr; + auto ret = cudaMalloc(&device, size); + if (ret == cudaSuccess) { return std::shared_ptr(device, cudaFree); } + return nullptr; +} + +std::shared_ptr CudaBuffer::MakeHostBuffer(size_t size) +{ + void* host = nullptr; + auto ret = cudaMallocHost(&host, size); + if (ret == cudaSuccess) { return std::shared_ptr(host, cudaFreeHost); } + return nullptr; +} + +Status Buffer::RegisterHostBuffer(void* host, size_t size, void** pDevice) +{ + auto ret = cudaHostRegister(host, size, cudaHostRegisterDefault); + if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; } + if (pDevice) { + ret = cudaHostGetDevicePointer(pDevice, host, 0); + if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; } + } + return Status::OK(); +} + +void Buffer::UnregisterHostBuffer(void* host) { cudaHostUnregister(host); } + +} // namespace UC::Trans diff --git a/ucm/shared/trans/cuda/cuda_buffer.h b/ucm/shared/trans/cuda/cuda_buffer.h new file mode 100644 index 000000000..fb3b136b4 --- /dev/null +++ b/ucm/shared/trans/cuda/cuda_buffer.h @@ -0,0 +1,39 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_CUDA_BUFFER_H +#define UNIFIEDCACHE_TRANS_CUDA_BUFFER_H + +#include "trans/detail/reserved_buffer.h" + +namespace UC::Trans { + +class CudaBuffer : public ReservedBuffer { +public: + std::shared_ptr MakeDeviceBuffer(size_t size) override; + std::shared_ptr MakeHostBuffer(size_t size) override; +}; + +} // namespace UC::Trans + +#endif diff --git a/ucm/shared/trans/cuda/cuda_device.cc b/ucm/shared/trans/cuda/cuda_device.cc new file mode 100644 index 000000000..86132e998 --- /dev/null +++ b/ucm/shared/trans/cuda/cuda_device.cc @@ -0,0 +1,82 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include +#include +#include "cuda_buffer.h" +#include "cuda_sm_stream.h" +#include "cuda_stream.h" +#include "trans/device.h" + +namespace UC::Trans { + +static void SetCpuAffinity(int32_t deviceId) +{ + nvmlDevice_t device; + auto ret = nvmlDeviceGetHandleByIndex(deviceId, &device); + if (ret != NVML_SUCCESS) { return; } + nvmlDeviceSetCpuAffinity(device); +} + +Status Device::Setup(int32_t deviceId) +{ + auto ret = cudaSetDevice(deviceId); + if (ret != cudaSuccess) { return Status{ret, cudaGetErrorString(ret)}; } + SetCpuAffinity(deviceId); + return Status::OK(); +} + +std::unique_ptr Device::MakeStream() +{ + std::unique_ptr stream = nullptr; + try { + stream = std::make_unique(); + } catch (...) { + return nullptr; + } + if (stream->Setup().Success()) { return stream; } + return nullptr; +} + +std::unique_ptr Device::MakeSMStream() +{ + std::unique_ptr stream = nullptr; + try { + stream = std::make_unique(); + } catch (...) { + return nullptr; + } + if (stream->Setup().Success()) { return stream; } + return nullptr; +} + +std::unique_ptr Device::MakeBuffer() +{ + try { + return std::make_unique(); + } catch (...) { + return nullptr; + } +} + +} // namespace UC::Trans diff --git a/ucm/shared/trans/cuda/cuda_sm_kernel.cu b/ucm/shared/trans/cuda/cuda_sm_kernel.cu new file mode 100644 index 000000000..595092525 --- /dev/null +++ b/ucm/shared/trans/cuda/cuda_sm_kernel.cu @@ -0,0 +1,116 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include +#include "cuda_sm_kernel.h" + +namespace UC::Trans { + +#define CUDA_TRANS_UNIT_SIZE (sizeof(uint4) * 2) +#define CUDA_TRANS_BLOCK_NUMBER (32) +#define CUDA_TRANS_BLOCK_SIZE (256) +#define CUDA_TRANS_THREAD_NUMBER (CUDA_TRANS_BLOCK_NUMBER * CUDA_TRANS_BLOCK_SIZE) + +inline __device__ void CudaCopyUnit(const uint8_t* __restrict__ src, + volatile uint8_t* __restrict__ dst) +{ + uint4 lo, hi; + asm volatile("ld.global.cs.v4.b32 {%0,%1,%2,%3}, [%4];" + : "=r"(lo.x), "=r"(lo.y), "=r"(lo.z), "=r"(lo.w) + : "l"(src)); + asm volatile("ld.global.cs.v4.b32 {%0,%1,%2,%3}, [%4+16];" + : "=r"(hi.x), "=r"(hi.y), "=r"(hi.z), "=r"(hi.w) + : "l"(src)); + asm volatile("st.volatile.global.v4.b32 [%0], {%1,%2,%3,%4};" + : + : "l"(dst), "r"(lo.x), "r"(lo.y), "r"(lo.z), "r"(lo.w)); + asm volatile("st.volatile.global.v4.b32 [%0+16], {%1,%2,%3,%4};" + : + : "l"(dst), "r"(hi.x), "r"(hi.y), "r"(hi.z), "r"(hi.w)); +} + +__global__ void CudaCopyKernel(const void** src, void** dst, size_t size, size_t num) +{ + auto length = size * num; + auto offset = (blockIdx.x * blockDim.x + threadIdx.x) * CUDA_TRANS_UNIT_SIZE; + while (offset + CUDA_TRANS_UNIT_SIZE <= length) { + auto idx = offset / size; + auto off = offset % size; + auto host = ((const uint8_t*)src[idx]) + off; + auto device = ((uint8_t*)dst[idx]) + off; + CudaCopyUnit(host, device); + offset += CUDA_TRANS_THREAD_NUMBER * CUDA_TRANS_UNIT_SIZE; + } +} + +__global__ void CudaCopyKernel(const void** src, void* dst, size_t size, size_t num) +{ + auto length = size * num; + auto offset = (blockIdx.x * blockDim.x + threadIdx.x) * CUDA_TRANS_UNIT_SIZE; + while (offset + CUDA_TRANS_UNIT_SIZE <= length) { + auto idx = offset / size; + auto off = offset % size; + auto host = ((const uint8_t*)src[idx]) + off; + auto device = ((uint8_t*)dst) + offset; + CudaCopyUnit(host, device); + offset += CUDA_TRANS_THREAD_NUMBER * CUDA_TRANS_UNIT_SIZE; + } +} + +__global__ void CudaCopyKernel(const void* src, void** dst, size_t size, size_t num) +{ + auto length = size * num; + auto offset = (blockIdx.x * blockDim.x + threadIdx.x) * CUDA_TRANS_UNIT_SIZE; + while (offset + CUDA_TRANS_UNIT_SIZE <= length) { + auto idx = offset / size; + auto off = offset % size; + auto host = ((const uint8_t*)src) + offset; + auto device = ((uint8_t*)dst[idx]) + off; + CudaCopyUnit(host, device); + offset += CUDA_TRANS_THREAD_NUMBER * CUDA_TRANS_UNIT_SIZE; + } +} + +cudaError_t CudaSMCopyAsync(void* src[], void* dst[], size_t size, size_t number, + cudaStream_t stream) +{ + CudaCopyKernel<<>>( + (const void**)src, dst, size, number); + return cudaGetLastError(); +} + +cudaError_t CudaSMCopyAsync(void* src[], void* dst, size_t size, size_t number, cudaStream_t stream) +{ + CudaCopyKernel<<>>( + (const void**)src, dst, size, number); + return cudaGetLastError(); +} + +cudaError_t CudaSMCopyAsync(void* src, void* dst[], size_t size, size_t number, cudaStream_t stream) +{ + CudaCopyKernel<<>>( + (const void*)src, dst, size, number); + return cudaGetLastError(); +} + +} // namespace UC::Trans diff --git a/ucm/shared/trans/cuda/cuda_sm_kernel.h b/ucm/shared/trans/cuda/cuda_sm_kernel.h new file mode 100644 index 000000000..a161c82e4 --- /dev/null +++ b/ucm/shared/trans/cuda/cuda_sm_kernel.h @@ -0,0 +1,41 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_CUDA_SM_KERNEL_H +#define UNIFIEDCACHE_TRANS_CUDA_SM_KERNEL_H + +#include +#include + +namespace UC::Trans { + +cudaError_t CudaSMCopyAsync(void* src[], void* dst[], size_t size, size_t number, + cudaStream_t stream); +cudaError_t CudaSMCopyAsync(void* src[], void* dst, size_t size, size_t number, + cudaStream_t stream); +cudaError_t CudaSMCopyAsync(void* src, void* dst[], size_t size, size_t number, + cudaStream_t stream); + +} // namespace UC::Trans + +#endif diff --git a/ucm/shared/trans/cuda/cuda_sm_stream.cc b/ucm/shared/trans/cuda/cuda_sm_stream.cc new file mode 100644 index 000000000..f2f341722 --- /dev/null +++ b/ucm/shared/trans/cuda/cuda_sm_stream.cc @@ -0,0 +1,57 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "cuda_sm_stream.h" +#include "cuda_sm_kernel.h" + +namespace UC::Trans { + +Status CudaSmStream::DeviceToHostAsync(void* device[], void* host[], size_t size, size_t number) +{ + auto ret = CudaSMCopyAsync(device, host, size, number, stream_); + if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; } + return Status::OK(); +} + +Status CudaSmStream::DeviceToHostAsync(void* device[], void* host, size_t size, size_t number) +{ + auto ret = CudaSMCopyAsync(device, host, size, number, stream_); + if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; } + return Status::OK(); +} + +Status CudaSmStream::HostToDeviceAsync(void* host[], void* device[], size_t size, size_t number) +{ + auto ret = CudaSMCopyAsync(host, device, size, number, stream_); + if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; } + return Status::OK(); +} + +Status CudaSmStream::HostToDeviceAsync(void* host, void* device[], size_t size, size_t number) +{ + auto ret = CudaSMCopyAsync(host, device, size, number, stream_); + if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; } + return Status::OK(); +} + +} // namespace UC::Trans diff --git a/ucm/shared/trans/cuda/cuda_sm_stream.h b/ucm/shared/trans/cuda/cuda_sm_stream.h new file mode 100644 index 000000000..ab9817dc7 --- /dev/null +++ b/ucm/shared/trans/cuda/cuda_sm_stream.h @@ -0,0 +1,41 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_CUDA_SM_STREAM_H +#define UNIFIEDCACHE_TRANS_CUDA_SM_STREAM_H + +#include "cuda_stream.h" + +namespace UC::Trans { + +class CudaSmStream : public CudaStream { +public: + Status DeviceToHostAsync(void* device[], void* host[], size_t size, size_t number) override; + Status DeviceToHostAsync(void* device[], void* host, size_t size, size_t number) override; + Status HostToDeviceAsync(void* host[], void* device[], size_t size, size_t number) override; + Status HostToDeviceAsync(void* host, void* device[], size_t size, size_t number) override; +}; + +} // namespace UC::Trans + +#endif diff --git a/ucm/shared/trans/cuda/cuda_stream.cc b/ucm/shared/trans/cuda/cuda_stream.cc new file mode 100644 index 000000000..103dee53c --- /dev/null +++ b/ucm/shared/trans/cuda/cuda_stream.cc @@ -0,0 +1,158 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "cuda_stream.h" + +namespace UC::Trans { + +Status CudaStream::Setup() +{ + auto ret = cudaStreamCreate(&stream_); + if (ret != cudaSuccess) { return Status{ret, cudaGetErrorString(ret)}; } + return Status::OK(); +} + +Status CudaStream::DeviceToHost(void* device, void* host, size_t size) +{ + auto ret = cudaMemcpy(host, device, size, cudaMemcpyDeviceToHost); + if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; } + return Status::OK(); +} + +Status CudaStream::DeviceToHost(void* device[], void* host[], size_t size, size_t number) +{ + auto s = DeviceToHostAsync(device, host, size, number); + if (s.Failure()) [[unlikely]] { return s; } + return Synchronized(); +} + +Status CudaStream::DeviceToHost(void* device[], void* host, size_t size, size_t number) +{ + auto s = DeviceToHostAsync(device, host, size, number); + if (s.Failure()) [[unlikely]] { return s; } + return Synchronized(); +} + +Status CudaStream::DeviceToHostAsync(void* device, void* host, size_t size) +{ + auto ret = cudaMemcpyAsync(host, device, size, cudaMemcpyDeviceToHost, stream_); + if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; } + return Status::OK(); +} + +Status CudaStream::DeviceToHostAsync(void* device[], void* host[], size_t size, size_t number) +{ + for (size_t i = 0; i < number; i++) { + auto s = DeviceToHostAsync(device[i], host[i], size); + if (s.Failure()) [[unlikely]] { return s; } + } + return Status::OK(); +} + +Status CudaStream::DeviceToHostAsync(void* device[], void* host, size_t size, size_t number) +{ + for (size_t i = 0; i < number; i++) { + auto pHost = (void*)(((int8_t*)host) + size * i); + auto s = DeviceToHostAsync(device[i], pHost, size); + if (s.Failure()) [[unlikely]] { return s; } + } + return Status::OK(); +} + +Status CudaStream::HostToDevice(void* host, void* device, size_t size) +{ + auto ret = cudaMemcpy(device, host, size, cudaMemcpyHostToDevice); + if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; } + return Status::OK(); +} + +Status CudaStream::HostToDevice(void* host[], void* device[], size_t size, size_t number) +{ + auto s = HostToDeviceAsync(host, device, size, number); + if (s.Failure()) [[unlikely]] { return s; } + return Synchronized(); +} + +Status CudaStream::HostToDevice(void* host, void* device[], size_t size, size_t number) +{ + auto s = HostToDeviceAsync(host, device, size, number); + if (s.Failure()) [[unlikely]] { return s; } + return Synchronized(); +} + +Status CudaStream::HostToDeviceAsync(void* host, void* device, size_t size) +{ + auto ret = cudaMemcpyAsync(device, host, size, cudaMemcpyHostToDevice, stream_); + if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; } + return Status::OK(); +} + +Status CudaStream::HostToDeviceAsync(void* host[], void* device[], size_t size, size_t number) +{ + for (size_t i = 0; i < number; i++) { + auto s = HostToDeviceAsync(host[i], device[i], size); + if (s.Failure()) [[unlikely]] { return s; } + } + return Status::OK(); +} + +Status Trans::CudaStream::HostToDeviceAsync(void* host, void* device[], size_t size, size_t number) +{ + for (size_t i = 0; i < number; i++) { + auto pHost = (void*)(((int8_t*)host) + size * i); + auto s = HostToDeviceAsync(pHost, device[i], size); + if (s.Failure()) [[unlikely]] { return s; } + } + return Status::OK(); +} + +using Closure = std::function; + +static void Trampoline(cudaStream_t stream, cudaError_t err, void* data) +{ + (void)stream; + auto c = static_cast(data); + (*c)(err == cudaSuccess); + delete c; +} + +Status Trans::CudaStream::AppendCallback(std::function cb) +{ + auto c = new (std::nothrow) Closure{std::move(cb)}; + if (!c) [[unlikely]] { return Status::Error("out of memory for appending callback"); } + auto ret = cudaStreamAddCallback(stream_, Trampoline, c, 0); + if (ret != cudaSuccess) [[unlikely]] { + delete c; + return Status{ret, cudaGetErrorString(ret)}; + } + return Status::OK(); +} + +Status Trans::CudaStream::Synchronized() +{ + auto ret = cudaStreamSynchronize(stream_); + if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; } + return Status::OK(); +} + +} // namespace UC::Trans diff --git a/ucm/shared/trans/cuda/cuda_stream.h b/ucm/shared/trans/cuda/cuda_stream.h new file mode 100644 index 000000000..d327285d0 --- /dev/null +++ b/ucm/shared/trans/cuda/cuda_stream.h @@ -0,0 +1,59 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_CUDA_STREAM_H +#define UNIFIEDCACHE_TRANS_CUDA_STREAM_H + +#include +#include "trans/stream.h" + +namespace UC::Trans { + +class CudaStream : public Stream { +protected: + cudaStream_t stream_; + +public: + Status Setup() override; + + Status DeviceToHost(void* device, void* host, size_t size) override; + Status DeviceToHost(void* device[], void* host[], size_t size, size_t number) override; + Status DeviceToHost(void* device[], void* host, size_t size, size_t number) override; + Status DeviceToHostAsync(void* device, void* host, size_t size) override; + Status DeviceToHostAsync(void* device[], void* host[], size_t size, size_t number) override; + Status DeviceToHostAsync(void* device[], void* host, size_t size, size_t number) override; + + Status HostToDevice(void* host, void* device, size_t size) override; + Status HostToDevice(void* host[], void* device[], size_t size, size_t number) override; + Status HostToDevice(void* host, void* device[], size_t size, size_t number) override; + Status HostToDeviceAsync(void* host, void* device, size_t size) override; + Status HostToDeviceAsync(void* host[], void* device[], size_t size, size_t number) override; + Status HostToDeviceAsync(void* host, void* device[], size_t size, size_t number) override; + + Status AppendCallback(std::function cb) override; + Status Synchronized() override; +}; + +} // namespace UC::Trans + +#endif diff --git a/ucm/shared/trans/detail/indexer.h b/ucm/shared/trans/detail/indexer.h new file mode 100644 index 000000000..dc6833df4 --- /dev/null +++ b/ucm/shared/trans/detail/indexer.h @@ -0,0 +1,98 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_INDEXER_H +#define UNIFIEDCACHE_TRANS_INDEXER_H + +#include +#include +#include +#include + +namespace UC::Trans { + +class Indexer { +public: + using Index = uint32_t; + static constexpr Index npos = std::numeric_limits::max(); + +private: + struct Node { + Index idx; + Index next; + }; + struct Pointer { + Index slot; + uint32_t ver; + }; + static_assert(sizeof(Pointer) == 8, "Pointer must be 64-bit"); + +public: + void Setup(const Index capacity) noexcept + { + this->capacity_ = capacity; + this->nodes_.resize(capacity + 1); + for (Index slot = 1; slot <= capacity; slot++) { + this->nodes_[slot].idx = slot - 1; + this->nodes_[slot].next = slot + 1; + } + this->nodes_[capacity].next = 0; + this->pointer_.store({1, 0}); + } + Index Acquire() noexcept + { + for (;;) { + auto ptr = this->pointer_.load(std::memory_order_acquire); + if (ptr.slot == 0) { return npos; } + auto next = this->nodes_[ptr.slot].next; + Pointer desired{next, ptr.ver + 1}; + if (this->pointer_.compare_exchange_weak(ptr, desired, std::memory_order_release, + std::memory_order_relaxed)) { + return this->nodes_[ptr.slot].idx; + } + } + } + void Release(const Index idx) noexcept + { + if (idx >= this->capacity_) { return; } + auto slot = idx + 1; + for (;;) { + auto ptr = this->pointer_.load(std::memory_order_acquire); + this->nodes_[slot].next = ptr.slot; + Pointer desired{slot, ptr.ver + 1}; + if (this->pointer_.compare_exchange_weak(ptr, desired, std::memory_order_release, + std::memory_order_relaxed)) { + return; + } + } + } + +private: + Index capacity_; + std::vector nodes_; + alignas(64) std::atomic pointer_; +}; + +} // namespace UC::Trans + +#endif diff --git a/ucm/shared/trans/detail/reserved_buffer.h b/ucm/shared/trans/detail/reserved_buffer.h new file mode 100644 index 000000000..98eba3201 --- /dev/null +++ b/ucm/shared/trans/detail/reserved_buffer.h @@ -0,0 +1,99 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_RESERVED_BUFFER_H +#define UNIFIEDCACHE_TRANS_RESERVED_BUFFER_H + +#include +#include "indexer.h" +#include "trans/buffer.h" + +namespace UC::Trans { + +class ReservedBuffer : public Buffer { + struct { + Indexer indexer; + std::shared_ptr buffers; + size_t size; + } hostBuffers_, deviceBuffers_; + + template + static std::shared_ptr GetBufferFrom(Buffers& buffers) + { + auto pos = buffers.indexer.Acquire(); + if (pos != buffers.indexer.npos) { + auto addr = static_cast(buffers.buffers.get()); + auto ptr = static_cast(addr + buffers.size * pos); + return std::shared_ptr(ptr, + [&buffers, pos](void*) { buffers.indexer.Release(pos); }); + } + return nullptr; + } + +public: + Status MakeDeviceBuffers(size_t size, size_t number) override + { + auto totalSize = size * number; + auto buffers = this->MakeDeviceBuffer(totalSize); + if (!buffers) { + return Status::Error(fmt::format("out of memory({}) on device", totalSize)); + } + this->deviceBuffers_.size = size; + this->deviceBuffers_.buffers = buffers; + this->deviceBuffers_.indexer.Setup(number); + return Status::OK(); + } + + std::shared_ptr GetDeviceBuffer(size_t size) override + { + if (size <= this->deviceBuffers_.size) { + auto buffer = GetBufferFrom(this->deviceBuffers_); + if (buffer) { return buffer; } + } + return this->MakeDeviceBuffer(size); + } + + Status MakeHostBuffers(size_t size, size_t number) override + { + auto totalSize = size * number; + auto buffers = this->MakeHostBuffer(totalSize); + if (!buffers) { return Status::Error(fmt::format("out of memory({}) on host", totalSize)); } + this->hostBuffers_.size = size; + this->hostBuffers_.buffers = buffers; + this->hostBuffers_.indexer.Setup(number); + return Status::OK(); + } + + std::shared_ptr GetHostBuffer(size_t size) override + { + if (size <= this->hostBuffers_.size) { + auto buffer = GetBufferFrom(this->hostBuffers_); + if (buffer) { return buffer; } + } + return this->MakeDeviceBuffer(size); + } +}; + +} // namespace UC::Trans + +#endif diff --git a/ucm/shared/trans/device.h b/ucm/shared/trans/device.h new file mode 100644 index 000000000..a6801c8a0 --- /dev/null +++ b/ucm/shared/trans/device.h @@ -0,0 +1,42 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_DEVICE_H +#define UNIFIEDCACHE_TRANS_DEVICE_H + +#include "buffer.h" +#include "stream.h" + +namespace UC::Trans { + +class Device { +public: + Status Setup(int32_t deviceId); + std::unique_ptr MakeStream(); + std::unique_ptr MakeSMStream(); + std::unique_ptr MakeBuffer(); +}; + +} // namespace UC::Trans + +#endif diff --git a/ucm/shared/trans/simu/CMakeLists.txt b/ucm/shared/trans/simu/CMakeLists.txt new file mode 100644 index 000000000..9404eead1 --- /dev/null +++ b/ucm/shared/trans/simu/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(trans STATIC + simu_device.cc + simu_buffer.cc + simu_stream.cc +) +target_link_libraries(trans PUBLIC + fmt +) diff --git a/ucm/shared/trans/simu/simu_buffer.cc b/ucm/shared/trans/simu/simu_buffer.cc new file mode 100644 index 000000000..4af607d8f --- /dev/null +++ b/ucm/shared/trans/simu/simu_buffer.cc @@ -0,0 +1,78 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "simu_buffer.h" +#include +#include +#include +#include "trans/buffer.h" + +namespace UC::Trans { + +static void* AllocMemory(size_t size, int8_t initVal) +{ + auto ptr = malloc(size); + if (!ptr) { return nullptr; } + std::memset(ptr, initVal, size); + return ptr; +} + +static void FreeMemory(void* ptr) { free(ptr); } + +template +static std::shared_ptr GetBuffer(Buffers& buffers) +{ + auto pos = buffers.indexer.Acquire(); + if (pos != buffers.indexer.npos) { + auto addr = static_cast(buffers.buffers.get()); + auto ptr = static_cast(addr + buffers.size * pos); + return std::shared_ptr(ptr, [&buffers, pos](void*) { buffers.indexer.Release(pos); }); + } + return nullptr; +} + +std::shared_ptr SimuBuffer::MakeDeviceBuffer(size_t size) +{ + constexpr int8_t deviceInitVal = 0xd; + auto device = AllocMemory(size, deviceInitVal); + if (!device) { return nullptr; } + return std::shared_ptr(device, FreeMemory); +} + +std::shared_ptr SimuBuffer::MakeHostBuffer(size_t size) +{ + constexpr int8_t hostInitVal = 0xa; + auto device = AllocMemory(size, hostInitVal); + if (!device) { return nullptr; } + return std::shared_ptr(device, FreeMemory); +} + +Status Buffer::RegisterHostBuffer(void* host, size_t size, void** pDevice) +{ + if (pDevice) { *pDevice = host; } + return Status::OK(); +} + +void Buffer::UnregisterHostBuffer(void* host) {} + +} // namespace UC::Trans diff --git a/ucm/shared/trans/simu/simu_buffer.h b/ucm/shared/trans/simu/simu_buffer.h new file mode 100644 index 000000000..269e4d5aa --- /dev/null +++ b/ucm/shared/trans/simu/simu_buffer.h @@ -0,0 +1,39 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_SIMU_BUFFER_H +#define UNIFIEDCACHE_TRANS_SIMU_BUFFER_H + +#include "trans/detail/reserved_buffer.h" + +namespace UC::Trans { + +class SimuBuffer : public ReservedBuffer { +public: + std::shared_ptr MakeDeviceBuffer(size_t size) override; + std::shared_ptr MakeHostBuffer(size_t size) override; +}; + +} // namespace UC::Trans + +#endif diff --git a/ucm/shared/trans/simu/simu_device.cc b/ucm/shared/trans/simu/simu_device.cc new file mode 100644 index 000000000..351be42e6 --- /dev/null +++ b/ucm/shared/trans/simu/simu_device.cc @@ -0,0 +1,60 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include +#include "simu_buffer.h" +#include "simu_stream.h" +#include "trans/device.h" + +namespace UC::Trans { + +Status Device::Setup(int32_t deviceId) +{ + if (deviceId < 0) { return Status::Error(fmt::format("invalid device id({})", deviceId)); } + return Status::OK(); +} + +std::unique_ptr Device::MakeStream() +{ + std::unique_ptr stream = nullptr; + try { + stream = std::make_unique(); + } catch (...) { + return nullptr; + } + if (stream->Setup().Success()) { return stream; } + return nullptr; +} + +std::unique_ptr Device::MakeSMStream() { return MakeStream(); } + +std::unique_ptr Device::MakeBuffer() +{ + try { + return std::make_unique(); + } catch (...) { + return nullptr; + } +} + +} // namespace UC::Trans diff --git a/ucm/shared/trans/simu/simu_stream.cc b/ucm/shared/trans/simu/simu_stream.cc new file mode 100644 index 000000000..0d6efaa52 --- /dev/null +++ b/ucm/shared/trans/simu/simu_stream.cc @@ -0,0 +1,175 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "simu_stream.h" +#include + +namespace UC::Trans { + +void SimuStream::AsyncWorker() +{ + for (;;) { + std::unique_lock lock{this->mutex_}; + this->condition_.wait(lock, [this] { return this->stop_ || !this->tasks_.empty(); }); + if (this->stop_) { return; } + if (this->tasks_.empty()) { continue; } + auto task = std::move(this->tasks_.front()); + this->tasks_.pop_front(); + lock.unlock(); + task(); + } +} + +void SimuStream::EnqueueTask(std::function task) +{ + std::lock_guard lock{this->mutex_}; + this->tasks_.emplace_back(std::move(task)); + this->condition_.notify_one(); +} + +SimuStream::~SimuStream() +{ + { + std::lock_guard lock{this->mutex_}; + this->stop_ = true; + this->condition_.notify_all(); + } + if (this->thread_.joinable()) { this->thread_.join(); } +} + +Status SimuStream::Setup() +{ + this->thread_ = std::thread{&SimuStream::AsyncWorker, this}; + return Status::OK(); +} + +Status SimuStream::DeviceToHost(void* device, void* host, size_t size) +{ + std::memcpy(host, device, size); + return Status::OK(); +} + +Status SimuStream::DeviceToHost(void* device[], void* host[], size_t size, size_t number) +{ + for (size_t i = 0; i < number; i++) { + auto s = this->DeviceToHost(device[i], host[i], size); + if (s.Failure()) { return s; } + } + return Status::OK(); +} + +Status SimuStream::DeviceToHost(void* device[], void* host, size_t size, size_t number) +{ + for (size_t i = 0; i < number; i++) { + auto pDevice = device[i]; + auto pHost = (void*)(((int8_t*)host) + size * i); + auto s = this->DeviceToHost(pDevice, pHost, size); + if (s.Failure()) { return s; } + } + return Status::OK(); +} + +Status SimuStream::DeviceToHostAsync(void* device, void* host, size_t size) +{ + this->EnqueueTask([=] { this->DeviceToHost(device, host, size); }); + return Status::OK(); +} + +Status SimuStream::DeviceToHostAsync(void* device[], void* host[], size_t size, size_t number) +{ + this->EnqueueTask([=] { this->DeviceToHost(device, host, size, number); }); + return Status::OK(); +} + +Status SimuStream::DeviceToHostAsync(void* device[], void* host, size_t size, size_t number) +{ + this->EnqueueTask([=] { this->DeviceToHost(device, host, size, number); }); + return Status::OK(); +} + +Status SimuStream::HostToDevice(void* host, void* device, size_t size) +{ + std::memcpy(device, host, size); + return Status::OK(); +} + +Status SimuStream::HostToDevice(void* host[], void* device[], size_t size, size_t number) +{ + for (size_t i = 0; i < number; i++) { + auto s = this->HostToDevice(host[i], device[i], size); + if (s.Failure()) { return s; } + } + return Status::OK(); +} + +Status SimuStream::HostToDevice(void* host, void* device[], size_t size, size_t number) +{ + for (size_t i = 0; i < number; i++) { + auto pHost = (void*)(((int8_t*)host) + size * i); + auto pDevice = device[i]; + auto s = this->HostToDevice(pHost, pDevice, size); + if (s.Failure()) { return s; } + } + return Status::OK(); +} + +Status SimuStream::HostToDeviceAsync(void* host, void* device, size_t size) +{ + this->EnqueueTask([=] { this->HostToDevice(host, device, size); }); + return Status::OK(); +} + +Status SimuStream::HostToDeviceAsync(void* host[], void* device[], size_t size, size_t number) +{ + this->EnqueueTask([=] { this->HostToDevice(host, device, size, number); }); + return Status::OK(); +} + +Status SimuStream::HostToDeviceAsync(void* host, void* device[], size_t size, size_t number) +{ + this->EnqueueTask([=] { this->HostToDevice(host, device, size, number); }); + return Status::OK(); +} + +Status SimuStream::AppendCallback(std::function cb) +{ + this->EnqueueTask([=] { cb(true); }); + return Status::OK(); +} + +Status SimuStream::Synchronized() +{ + std::mutex mutex; + std::condition_variable cv; + bool finish = false; + this->EnqueueTask([&] { + std::lock_guard lock{mutex}; + finish = true; + cv.notify_one(); + }); + std::unique_lock lock{mutex}; + cv.wait(lock, [&] { return finish; }); + return Status::OK(); +} + +} // namespace UC::Trans diff --git a/ucm/shared/trans/simu/simu_stream.h b/ucm/shared/trans/simu/simu_stream.h new file mode 100644 index 000000000..57028d978 --- /dev/null +++ b/ucm/shared/trans/simu/simu_stream.h @@ -0,0 +1,70 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_SIMU_STREAM_H +#define UNIFIEDCACHE_TRANS_SIMU_STREAM_H + +#include +#include +#include +#include +#include +#include "trans/stream.h" + +namespace UC::Trans { + +class SimuStream : public Stream { + std::thread thread_; + std::list> tasks_; + std::mutex mutex_; + std::condition_variable condition_; + bool stop_{false}; + + void AsyncWorker(); + void EnqueueTask(std::function task); + +public: + ~SimuStream() override; + Status Setup() override; + + Status DeviceToHost(void* device, void* host, size_t size) override; + Status DeviceToHost(void* device[], void* host[], size_t size, size_t number) override; + Status DeviceToHost(void* device[], void* host, size_t size, size_t number) override; + Status DeviceToHostAsync(void* device, void* host, size_t size) override; + Status DeviceToHostAsync(void* device[], void* host[], size_t size, size_t number) override; + Status DeviceToHostAsync(void* device[], void* host, size_t size, size_t number) override; + + Status HostToDevice(void* host, void* device, size_t size) override; + Status HostToDevice(void* host[], void* device[], size_t size, size_t number) override; + Status HostToDevice(void* host, void* device[], size_t size, size_t number) override; + Status HostToDeviceAsync(void* host, void* device, size_t size) override; + Status HostToDeviceAsync(void* host[], void* device[], size_t size, size_t number) override; + Status HostToDeviceAsync(void* host, void* device[], size_t size, size_t number) override; + + Status AppendCallback(std::function cb) override; + Status Synchronized() override; +}; + +} // namespace UC::Trans + +#endif diff --git a/ucm/shared/trans/stream.h b/ucm/shared/trans/stream.h new file mode 100644 index 000000000..425617968 --- /dev/null +++ b/ucm/shared/trans/stream.h @@ -0,0 +1,57 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_STREAM_H +#define UNIFIEDCACHE_TRANS_STREAM_H + +#include +#include "status/status.h" + +namespace UC::Trans { + +class Stream { +public: + virtual ~Stream() = default; + virtual Status Setup() = 0; + + virtual Status DeviceToHost(void* device, void* host, size_t size) = 0; + virtual Status DeviceToHost(void* device[], void* host[], size_t size, size_t number) = 0; + virtual Status DeviceToHost(void* device[], void* host, size_t size, size_t number) = 0; + virtual Status DeviceToHostAsync(void* device, void* host, size_t size) = 0; + virtual Status DeviceToHostAsync(void* device[], void* host[], size_t size, size_t number) = 0; + virtual Status DeviceToHostAsync(void* device[], void* host, size_t size, size_t number) = 0; + + virtual Status HostToDevice(void* host, void* device, size_t size) = 0; + virtual Status HostToDevice(void* host[], void* device[], size_t size, size_t number) = 0; + virtual Status HostToDevice(void* host, void* device[], size_t size, size_t number) = 0; + virtual Status HostToDeviceAsync(void* host, void* device, size_t size) = 0; + virtual Status HostToDeviceAsync(void* host[], void* device[], size_t size, size_t number) = 0; + virtual Status HostToDeviceAsync(void* host, void* device[], size_t size, size_t number) = 0; + + virtual Status AppendCallback(std::function cb) = 0; + virtual Status Synchronized() = 0; +}; + +} // namespace UC::Trans + +#endif diff --git a/ucm/sparse/vendor/CMakeLists.txt b/ucm/shared/vendor/CMakeLists.txt similarity index 85% rename from ucm/sparse/vendor/CMakeLists.txt rename to ucm/shared/vendor/CMakeLists.txt index 67b7f4935..10d813cc6 100644 --- a/ucm/sparse/vendor/CMakeLists.txt +++ b/ucm/shared/vendor/CMakeLists.txt @@ -8,6 +8,6 @@ function(EnableDept name url tag) endfunction() include(FetchContent) -EnableDept(pybind11 https://github.com/pybind/pybind11.git v2.13.6) EnableDept(fmt https://github.com/fmtlib/fmt.git 11.2.0) EnableDept(spdlog https://github.com/gabime/spdlog.git v1.15.3) +EnableDept(pybind11 https://github.com/pybind/pybind11.git v3.0.1) diff --git a/ucm/sparse/CMakeLists.txt b/ucm/sparse/CMakeLists.txt index a0033323b..8e39e3589 100644 --- a/ucm/sparse/CMakeLists.txt +++ b/ucm/sparse/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(vendor) add_subdirectory(esa) add_subdirectory(gsa) add_subdirectory(kvcomp) diff --git a/ucm/sparse/base.py b/ucm/sparse/base.py index 918c7e71a..ed62ab30c 100644 --- a/ucm/sparse/base.py +++ b/ucm/sparse/base.py @@ -130,6 +130,7 @@ def attention_begin( value: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: """ This is called at the beginning of "unified_attention". @@ -146,6 +147,7 @@ def attention_finished( attn_output: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: """ This is called at the end of "unified_attention". @@ -196,7 +198,5 @@ def build_sparse_meta( """ pass - def allocate_slots( - self, request, num_slots_sparsed, coordinator, block_pool, kv_cache_groups - ): + def allocate_slots(self, kv_cache_manager, request, num_slots_sparsed): pass diff --git a/ucm/sparse/esa/CMakeLists.txt b/ucm/sparse/esa/CMakeLists.txt index a7bdc9456..9d2afcdb4 100644 --- a/ucm/sparse/esa/CMakeLists.txt +++ b/ucm/sparse/esa/CMakeLists.txt @@ -1 +1,50 @@ +if(BUILD_NUMA) + message(STATUS "Building numactl library...") + + set(NUMA_INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/numa_install) + FetchContent_Declare( + numactl + URL https://github.com/numactl/numactl/releases/download/v2.0.16/numactl-2.0.16.tar.gz + TLS_VERIFY OFF + ) + FetchContent_MakeAvailable(numactl) + if(NOT EXISTS "${NUMA_INSTALL_DIR}/lib/libnuma.so") + message(STATUS "Configuring numactl...") + execute_process( + COMMAND ./configure --prefix=${NUMA_INSTALL_DIR} + WORKING_DIRECTORY ${numactl_SOURCE_DIR} + RESULT_VARIABLE numa_configure_result + OUTPUT_VARIABLE numa_configure_output + ERROR_VARIABLE numa_configure_error + ) + if(NOT numa_configure_result EQUAL 0) + message(FATAL_ERROR "Failed to configure numactl. \n" + "Result: ${numa_configure_result}\n" + "STDOUT: ${numa_configure_output}\n" + "STDERR: ${numa_configure_error}\n") + endif() + + message(STATUS "Building and installing numactl...") + execute_process( + COMMAND make install -j8 + WORKING_DIRECTORY ${numactl_SOURCE_DIR} + RESULT_VARIABLE numa_install_result + OUTPUT_VARIABLE numa_install_output + ERROR_VARIABLE numa_install_error + ) + if(NOT numa_install_result EQUAL 0) + message(FATAL_ERROR "Failed to build and install numactl. \n" + "Result: ${numa_install_result}\n" + "STDOUT: ${numa_install_output}\n" + "STDERR: ${numa_install_error}\n") + endif() + else() + message(STATUS "Found already built libnuma. Skipping build.") + endif() + + add_definitions(-DNUMA_ENABLED) +else() + message(STATUS "Skipping numactl build...") +endif() + add_subdirectory(retrieval) diff --git a/ucm/sparse/esa/esa.py b/ucm/sparse/esa/esa.py index 52756a220..ac36e54c5 100644 --- a/ucm/sparse/esa/esa.py +++ b/ucm/sparse/esa/esa.py @@ -1,9 +1,10 @@ import hashlib import math import pickle +from collections import defaultdict from dataclasses import dataclass from functools import cache -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import torch @@ -13,8 +14,9 @@ from vllm.forward_context import ForwardContext from vllm.sequence import SequenceStage from vllm.v1.core.kv_cache_manager import KVCacheBlocks -from vllm.v1.request import Request +from vllm.v1.request import Request, RequestStatus +from ucm.integration.vllm.ucm_connector import RequestHasher from ucm.sparse.base import ( INVALID_SLOT, UcmSparseBase, @@ -23,7 +25,9 @@ ) from ucm.sparse.esa.retrieval import retrieval_backend from ucm.sparse.esa.retrieval.retrieval_worker import RetrievalWorker +from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank from ucm.store.ucmstore import Task, UcmKVStoreBase +from ucm.utils import Config ReqType = Union[str, int] HashType = Union[str, int] @@ -58,10 +62,8 @@ class ReqMeta: query_start_loc: int prompt_token_ids: list[int] output_token_ids: list[int] - - @property - def step(self) -> int: - return self.num_output_tokens + is_preempt: bool + ucm_block_hashes: list[str] @property def num_prompt_tokens(self) -> int: @@ -72,19 +74,13 @@ def num_output_tokens(self) -> int: return len(self.output_token_ids) @property - def stage(self) -> SequenceStage: - return ( - SequenceStage.DECODE - if self.num_output_tokens > 0 - else SequenceStage.PREFILL - ) + def num_tokens(self) -> int: + return self.num_prompt_tokens + self.num_output_tokens @property def is_last_chunk(self) -> bool: - return ( - self.num_computed_tokens + self.num_scheduled_tokens - >= self.num_prompt_tokens - ) + # NOTE: both decode and last chunk-prefill meet `self.num_computed_tokens + self.num_scheduled_tokens >= self.num_tokens` + return self.num_computed_tokens + self.num_scheduled_tokens >= self.num_tokens @dataclass @@ -106,6 +102,8 @@ def add_request( query_start_loc: int, prompt_token_ids: list[int], output_token_ids: list[int], + is_preempt: bool, + ucm_block_hashes: list[str], ) -> None: meta = ReqMeta( @@ -117,6 +115,8 @@ def add_request( query_start_loc=query_start_loc, prompt_token_ids=prompt_token_ids, output_token_ids=output_token_ids, + is_preempt=is_preempt, + ucm_block_hashes=ucm_block_hashes, ) self.requests.append(meta) @@ -126,7 +126,9 @@ def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> block_size, num_key_heads_per_tp, head_size = block_shape k_min_data_block_size = block_size * num_key_heads_per_tp * head_size * precision v_min_data_block_size = k_min_data_block_size if not is_mla else 0 - layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size + layer_size = (k_min_data_block_size + v_min_data_block_size) * ( + tp_size if not is_mla else 1 + ) if is_mla: k_offset = layer_size * layer_id else: @@ -136,30 +138,64 @@ def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> @cache -def md5(input) -> int: - input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - md5_bytes = hashlib.md5(input_bytes).digest() - return int.from_bytes(md5_bytes, byteorder="big") +def get_sparse_range(init_window_sz, local_window_sz, prompt_len, block_size): + num_blocks_upper_bound = math.ceil(prompt_len / block_size) + sparse_range = num_blocks_upper_bound - init_window_sz - local_window_sz + return sparse_range @cache -def block_hash_func(parent_block_hash, curr_block_token_ids): - if not parent_block_hash: - parent_block_hash = md5("UCMHASHSEED") - curr_block_token_ids_tuple = tuple(curr_block_token_ids) - return md5((parent_block_hash, curr_block_token_ids_tuple)) +def compute_parent_block_hash(model_name, world_size, dtype, seed_rank=0) -> int: + meta = f"{model_name}:{world_size}:{dtype}:{seed_rank}" + meta_bytes = meta.encode("utf-8") + h_seed = hashlib.md5(meta_bytes + b"UCM_HASH_SEED").digest() + return int.from_bytes(h_seed, byteorder="big") + + +@cache +def compute_layer_offset( + block_data_size: int, + layer_id: int, + is_v: bool, + is_mla: bool, +) -> int: + layer_data_size = block_data_size if is_mla else block_data_size * 2 + + k_offset = layer_data_size * layer_id + + if is_mla: + return k_offset + + v_offset = k_offset + block_data_size + return v_offset if is_v else k_offset def task_hash_func(block_ids, store_type, tensor_type): return hash((tuple(block_ids), store_type, tensor_type)) +def diff_two_map(map1: dict, map2: dict): + keys2 = map2.keys() + values2 = map2.values() + keys2_set = set(keys2) + values2_set = set(values2) + diff_map = {} + updated_map = {} + for k1, v1 in map1.items(): + if k1 in keys2 and v1 in values2: + updated_map[k1] = v1 + keys2_set.remove(k1) + values2_set.remove(v1) + for k2, v2 in zip(keys2_set, values2_set): + diff_map[k2] = v2 + updated_map[k2] = v2 + return updated_map, diff_map + + class ReqStatePerLayer: # handle single request per layer - def __init__( self, - req_meta: ReqMeta, layer_name: str, rank: int, tp_size: int, @@ -167,6 +203,7 @@ def __init__( vllm_config: VllmConfig, retrieval_worker: Optional[RetrievalWorker] = None, repre_pool: Optional[ReprePool] = None, + esa_cfg: Optional[Dict[str, Any]] = None, ): self.layer_name = layer_name self.layer_id = int(layer_name.split(".")[2]) @@ -176,7 +213,7 @@ def __init__( self.store_instance = store_instance self.retrieval_worker: Optional[RetrievalWorker] = retrieval_worker self.retrieval_task = None - self.req_meta = req_meta + self.req_meta = None self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.k_cache = None @@ -184,37 +221,21 @@ def __init__( self.rank = rank self.tp_size = tp_size self.tasks: Dict[str, Task] = {} - self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_config" - ]["ESA"] + self.esa_cfg = esa_cfg self.indexes: Optional[NDArray[np.int64]] = None self.block_hashes = None self.pre_topk_block_hashes: Dict[int, str] = {} self.sparse_range: int = 0 self.init_static_flag = False + self.init_window = None + self.local_window = None self.num_key_heads = vllm_config.model_config.get_num_kv_heads( vllm_config.parallel_config ) self.head_size = vllm_config.model_config.get_head_size() - self.sparse_range = self.get_sparse_prefill_range() - - def set_block_hashes(self, token_ids): - if self.block_hashes is not None: - return - self.block_hashes = [] - parent_block_hash_value = None - for start in range(0, len(token_ids), self.block_size): - end = start + self.block_size - block_token_ids = token_ids[start:end] - if len(block_token_ids) < self.block_size: - break - curr_block_token_ids_tuple = tuple(block_token_ids) - block_hash = block_hash_func( - parent_block_hash_value, curr_block_token_ids_tuple - ) - self.block_hashes.append(str(block_hash)) - parent_block_hash_value = block_hash + self.is_mla = self.vllm_config.model_config.is_deepseek_mla + self.step = 0 def update_meta(self, req_meta: ReqMeta): self.req_meta = req_meta @@ -222,45 +243,34 @@ def update_meta(self, req_meta: ReqMeta): def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids): fn = getattr(self.store_instance, transfer_type) length = len(block_hashes) - block_shape = (self.block_size, self.num_key_heads, self.head_size) - precision = self.k_cache.storage().element_size() - # TODO: consider is_mla here - is_mla = False - - block_shape = tuple(block_shape) - offsets_k = [ - get_offset( - block_shape, - self.rank, - self.tp_size, - precision, - self.layer_id, - is_v=False, - is_mla=is_mla, - ) - ] * length - offsets_v = [ - get_offset( - block_shape, - self.rank, - self.tp_size, - precision, - self.layer_id, - is_v=True, - is_mla=is_mla, - ) - ] * length + precision = self.vllm_config.model_config.dtype.itemsize + block_data_size = self.k_cache[0].numel() * precision + + offset_k = compute_layer_offset( + block_data_size, + self.layer_id, + is_v=False, + is_mla=self.is_mla, + ) + offsets_k = [offset_k] * length key_src_tensors = [self.k_cache[id_] for id_ in vllm_block_ids] - value_src_tensors = [self.v_cache[id_] for id_ in vllm_block_ids] - task_k = fn(block_hashes, offsets_k, key_src_tensors) - task_v = fn(block_hashes, offsets_v, value_src_tensors) - task_k_hash = task_hash_func(block_hashes, transfer_type, "key") self.tasks[task_k_hash] = task_k - task_v_hash = task_hash_func(block_hashes, transfer_type, "value") - self.tasks[task_v_hash] = task_v + + if not self.is_mla: + offset_v = compute_layer_offset( + block_data_size, + self.layer_id, + is_v=True, + is_mla=self.is_mla, + ) + offsets_v = [offset_v] * length + value_src_tensors = [self.v_cache[id_] for id_ in vllm_block_ids] + task_v = fn(block_hashes, offsets_v, value_src_tensors) + task_v_hash = task_hash_func(block_hashes, transfer_type, "value") + self.tasks[task_v_hash] = task_v def extract_block_repre(self, vllm_block_ids): return self.k_cache[vllm_block_ids].mean(1) @@ -271,13 +281,16 @@ def maybe_register_static_data(self, forward_context: ForwardContext): attn = forward_context.no_compile_layers[self.layer_name] kv_cache = attn.kv_cache[forward_context.virtual_engine] # TODO not mla - self.k_cache = kv_cache[0] - self.v_cache = kv_cache[1] - self.set_block_hashes(self.req_meta.prompt_token_ids) + if self.is_mla: + self.k_cache = kv_cache + else: + self.k_cache = kv_cache[0] + self.v_cache = kv_cache[1] + self.block_hashes = self.req_meta.ucm_block_hashes self.init_static_flag = True def wait_transfer_task_done(self): - assert len(self.tasks) > 0 + # assert len(self.tasks) > 0 for task_hash, task in self.tasks.items(): # TODO: handle exceptions ret = self.store_instance.wait(task) @@ -289,7 +302,12 @@ def start_retrieval(self, batch_query, forward_context): query = batch_query[query_start_loc : query_start_loc + query_len] ntokens, num_q_heads, _ = query.shape if num_q_heads > self.num_key_heads: - query = query.view(ntokens, self.num_key_heads, -1, self.head_size) + query = query.view( + ntokens, + self.num_key_heads, + num_q_heads // self.num_key_heads, + self.head_size, + ) query = query.mean(2) elif num_q_heads < self.num_key_heads: query = torch.repeat_interleave(query, self.num_key_heads // num_q_heads, 1) @@ -307,52 +325,40 @@ def wait_retrieval_and_start_load(self): rel_block_ids = [self.slots_to_relative_indexes[int(e)] for e in choosed_slots] block_hashes = [self.block_hashes[id_] for id_ in rel_block_ids] top_k = int(self.sparse_range * self.esa_cfg["sparse_ratio"]) - sparse_vllm_block_ids = self.req_meta.vllm_block_ids[:top_k] - - # load delta - diff_vllm_block_ids = set(sparse_vllm_block_ids) - diff_block_hashes = set(block_hashes) - if len(self.pre_topk_block_hashes) == 0: - self.pre_topk_block_hashes = { - blk_id: blk_hash - for (blk_id, blk_hash) in zip(sparse_vllm_block_ids, block_hashes) - } - else: - matched = {} - for k in sparse_vllm_block_ids: - if ( - k in self.pre_topk_block_hashes - and self.pre_topk_block_hashes[k] in diff_block_hashes - ): - matched[k] = self.pre_topk_block_hashes[k] - diff_vllm_block_ids.remove(k) - diff_block_hashes.remove(matched[k]) - self.pre_topk_block_hashes = matched - for diff_blk_id, diff_blk_hash in zip( - diff_vllm_block_ids, diff_block_hashes - ): - self.pre_topk_block_hashes[diff_blk_id] = diff_blk_hash - - self.launch_transfer_task( - "load", list(diff_block_hashes), list(diff_vllm_block_ids) + vllm_block_ids = self.req_meta.vllm_block_ids[ + self.esa_cfg["init_window_sz"] : self.esa_cfg["init_window_sz"] + top_k + ] + ## 1. load delta + target_map = { + b_id: b_hash for b_id, b_hash in zip(vllm_block_ids, block_hashes) + } + self.pre_topk_block_hashes, diff_blocks = diff_two_map( + self.pre_topk_block_hashes, target_map ) - self.retrieval_task = None - - def get_sparse_prefill_range(self): - if (self.req_meta.num_prompt_tokens % self.block_size) == 0: - sparse_range = ( - self.req_meta.num_prompt_tokens // self.block_size - - self.esa_cfg["local_window_sz"] + if diff_blocks: + self.launch_transfer_task( + "load", list(diff_blocks.values()), list(diff_blocks.keys()) ) - else: - sparse_range = math.floor( - self.req_meta.num_prompt_tokens / self.block_size - ) - (self.esa_cfg["local_window_sz"] - 1) - return sparse_range + + ## 2. load all + # self.launch_transfer_task( + # "load", block_hashes, vllm_block_ids + # ) + + self.retrieval_task = None def block_repre_data(self): + self.sparse_range = get_sparse_range( + self.esa_cfg["init_window_sz"], + self.esa_cfg["local_window_sz"], + self.req_meta.num_prompt_tokens, + self.block_size, + ) vllm_block_ids = self.req_meta.vllm_block_ids - vllm_block_ids_dump = vllm_block_ids[: self.sparse_range] + vllm_block_ids_dump = vllm_block_ids[ + self.esa_cfg["init_window_sz"] : self.esa_cfg["init_window_sz"] + + self.sparse_range + ] repre = self.extract_block_repre(vllm_block_ids_dump) repre_flat = repre.reshape(repre.shape[0], -1) new_slots = self.repre_pool.allocate(self.sparse_range) @@ -360,8 +366,27 @@ def block_repre_data(self): for i, slot in enumerate(new_slots): self.slots_to_relative_indexes[slot] = og_len + i self.slots.extend(new_slots) - vals = repre_flat.to("cpu", non_blocking=True, dtype=torch.float32) + vals = repre_flat.to("cpu", dtype=torch.float32) data[self.layer_id][new_slots] = vals + # NOTE: in Preemption, local_window_start != -self.esa_cfg['local_window_sz'] + local_window_start = self.esa_cfg["init_window_sz"] + self.sparse_range + + if not self.is_mla: + self.init_window = ( + self.k_cache[vllm_block_ids[: self.esa_cfg["init_window_sz"]]].clone(), + self.v_cache[vllm_block_ids[: self.esa_cfg["init_window_sz"]]].clone(), + ) + self.local_window = ( + self.k_cache[vllm_block_ids[local_window_start:]].clone(), + self.v_cache[vllm_block_ids[local_window_start:]].clone(), + ) + else: + self.init_window = self.k_cache[ + vllm_block_ids[: self.esa_cfg["init_window_sz"]] + ].clone() + self.local_window = self.k_cache[ + vllm_block_ids[local_window_start:] + ].clone() def attention_begin( self, @@ -371,11 +396,34 @@ def attention_begin( forward_context: ForwardContext, ) -> None: self.maybe_register_static_data(forward_context) - if self.req_meta.step % self.esa_cfg["retrieval_stride"] == 1: - if self.req_meta.step == 1: + if self.step % self.esa_cfg["retrieval_stride"] == 1: + if self.step == 1: + vllm_block_ids = self.req_meta.vllm_block_ids + # NOTE: in Preemption, local_window_start != -self.esa_cfg['local_window_sz'] + if not self.is_mla: + local_window_sz = self.local_window[0].shape[0] + self.k_cache[vllm_block_ids[: self.esa_cfg["init_window_sz"]]] = ( + self.init_window[0] + ) + self.v_cache[vllm_block_ids[: self.esa_cfg["init_window_sz"]]] = ( + self.init_window[1] + ) + self.k_cache[vllm_block_ids[-local_window_sz:]] = self.local_window[ + 0 + ] + self.v_cache[vllm_block_ids[-local_window_sz:]] = self.local_window[ + 1 + ] + else: + local_window_sz = self.local_window.shape[0] + self.k_cache[vllm_block_ids[: self.esa_cfg["init_window_sz"]]] = ( + self.init_window + ) + self.k_cache[vllm_block_ids[-local_window_sz:]] = self.local_window self.start_retrieval(query, forward_context) self.wait_retrieval_and_start_load() - self.wait_transfer_task_done() + if len(self.tasks) > 0: + self.wait_transfer_task_done() def attention_finished( self, @@ -385,18 +433,16 @@ def attention_finished( attn_output: torch.Tensor, forward_context: ForwardContext, ) -> None: - should_save = ( - self.req_meta.stage == SequenceStage.PREFILL and self.req_meta.is_last_chunk - ) - if should_save: - self.block_repre_data() + if self.step == 0: + if self.req_meta.is_last_chunk: + self.block_repre_data() + self.step += 1 else: - if self.req_meta.step == 0: - return - if self.req_meta.step % self.esa_cfg["retrieval_stride"] == 2: + if self.step % self.esa_cfg["retrieval_stride"] == 2: self.start_retrieval(query, forward_context) - if self.req_meta.step % self.esa_cfg["retrieval_stride"] == 0: + if self.step % self.esa_cfg["retrieval_stride"] == 0: self.wait_retrieval_and_start_load() + self.step += 1 class ESA(UcmSparseBase): @@ -407,16 +453,25 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self.rank = vllm_config.parallel_config.rank self.tp_size = vllm_config.parallel_config.tensor_parallel_size if role == UcmSparseRole.WORKER: - self.connector = get_kv_transfer_group().connector + self.connector = get_kv_transfer_group().connector.store else: self.connector = None - self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_config" - ]["ESA"] + self.esa_cfg = ( + Config(vllm_config.kv_transfer_config) + .get_config() + .get("ucm_sparse_config") + .get("ESA") + ) self.total_num_hidden_layers = ( vllm_config.model_config.hf_config.num_hidden_layers ) - + self.is_mla = vllm_config.model_config.is_deepseek_mla + self._sparse_metadata_prefill: ESASparseMetaData = ESASparseMetaData() + self._sparse_metadata_decode: ESASparseMetaData = ESASparseMetaData() + self._sparse_metadata: ESASparseMetaData = ESASparseMetaData() + self.request_hasher = RequestHasher(vllm_config, 0) + self.block_size = vllm_config.cache_config.block_size + self.block_hashes: dict[int, dict[int, list[str]]] = {} global data if data is None: @@ -438,14 +493,35 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): ReprePool(num_slots) for _ in range(self.total_num_hidden_layers) ] + self.local_tp_rank = vllm_config.parallel_config.rank + self.total_tp_size = vllm_config.parallel_config.tensor_parallel_size + ratio = 0.75 + + bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank( + self.total_tp_size, self.local_tp_rank, ratio=ratio + ) + + bind_info_dict = defaultdict(list) + for item in bind_info_list: + bind_info_dict[item[1]].append(item[0]) + bind_info_dict = dict(bind_info_dict) + self.retrieval_workers: List[RetrievalWorker] = [] for i in range(self.total_num_hidden_layers): backend_src = data[i] - backend = retrieval_backend.RetrievalWorkerBackend(backend_src) + backend = retrieval_backend.RetrievalWorkerBackend( + backend_src, bind_info_dict + ) self.retrieval_workers.append(RetrievalWorker(backend)) - def create_layerwise_req_state(self, req_meta, layer_name): + self.preempt_req_output_tokens: Dict[ReqType, int] = {} + + def get_or_create_layerwise_req_state(self, req_meta, layer_name): layer_id = int(layer_name.split(".")[2]) + if req_meta.is_preempt: + layer_state = self.req_states[req_meta.request_id][layer_id] + layer_state.repre_pool.free(layer_state.slots) + self.req_states[req_meta.request_id][layer_id] = None if req_meta.request_id not in self.req_states: if self.req_states.get(req_meta.request_id) is None: self.req_states[req_meta.request_id] = [ @@ -453,7 +529,6 @@ def create_layerwise_req_state(self, req_meta, layer_name): ] * self.total_num_hidden_layers if self.req_states[req_meta.request_id][layer_id] is None: self.req_states[req_meta.request_id][layer_id] = ReqStatePerLayer( - req_meta, layer_name, self.rank, self.tp_size, @@ -461,9 +536,17 @@ def create_layerwise_req_state(self, req_meta, layer_name): self._vllm_config, self.retrieval_workers[layer_id], self.layer_pools[layer_id], + self.esa_cfg, ) return self.req_states[req_meta.request_id][layer_id] + def create_req_state_attention_begin( + self, req_meta, layer_name, query, key, value, forward_context + ): + req_state = self.get_or_create_layerwise_req_state(req_meta, layer_name) + req_state.update_meta(req_meta) + req_state.attention_begin(query, key, value, forward_context) + def attention_begin( self, query: torch.Tensor, @@ -471,11 +554,32 @@ def attention_begin( value: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: - for req_meta in self._sparse_metadata.requests: - req_state = self.create_layerwise_req_state(req_meta, layer_name) - req_state.update_meta(req_meta) - req_state.attention_begin(query, key, value, forward_context) + if not self.is_mla: + for req_meta in self._sparse_metadata.requests: + self.create_req_state_attention_begin( + req_meta, layer_name, query, key, value, forward_context + ) + else: + if phase == "prefill": + for req_meta in self._sparse_metadata_prefill.requests: + self.create_req_state_attention_begin( + req_meta, layer_name, query, key, value, forward_context + ) + if phase == "decode": + for req_meta in self._sparse_metadata_decode.requests: + self.create_req_state_attention_begin( + req_meta, layer_name, query, key, value, forward_context + ) + + def update_req_state_attention_end( + self, req_meta, layer_name, query, key, value, attn_output, forward_context + ): + layer_id = int(layer_name.split(".")[2]) + req_state = self.req_states[req_meta.request_id][layer_id] + req_state.update_meta(req_meta) + req_state.attention_finished(query, key, value, attn_output, forward_context) def attention_finished( self, @@ -485,13 +589,42 @@ def attention_finished( attn_output: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: - for req_meta in self._sparse_metadata.requests: - req_state = self.create_layerwise_req_state(req_meta, layer_name) - req_state.update_meta(req_meta) - req_state.attention_finished( - query, key, value, attn_output, forward_context - ) + if not self.is_mla: + for req_meta in self._sparse_metadata.requests: + self.update_req_state_attention_end( + req_meta, + layer_name, + query, + key, + value, + attn_output, + forward_context, + ) + else: + if phase == "prefill": + for req_meta in self._sparse_metadata_prefill.requests: + self.update_req_state_attention_end( + req_meta, + layer_name, + query, + key, + value, + attn_output, + forward_context, + ) + if phase == "decode": + for req_meta in self._sparse_metadata_decode.requests: + self.update_req_state_attention_end( + req_meta, + layer_name, + query, + key, + value, + attn_output, + forward_context, + ) def is_sparsed_request(self, req): return ( @@ -499,10 +632,66 @@ def is_sparsed_request(self, req): >= self._vllm_config.cache_config.block_size * self.esa_cfg["min_blocks"] ) + def set_block_hashes(self, req_id, token_ids): + if req_id not in self.block_hashes: + self.block_hashes[req_id] = {} + + if self.rank in self.block_hashes[req_id]: + return + + self.block_hashes[req_id][self.rank] = [] + + parent_block_hash_value = compute_parent_block_hash( + self._vllm_config.model_config.model, + self._vllm_config.parallel_config.world_size, + self._vllm_config.model_config.dtype, + seed_rank=0, + ) + + num_total_blocks = math.ceil(len(token_ids) / self.block_size) + for start in range(0, len(token_ids), self.block_size): + end = start + self.block_size + block_idx = start // self.block_size + if block_idx >= num_total_blocks - self.esa_cfg["local_window_sz"]: + continue + block_token_ids = token_ids[start:end] + if len(block_token_ids) < self.block_size: + break + curr_block_token_ids_tuple = tuple(block_token_ids) + hash_value = self.request_hasher( + (parent_block_hash_value, curr_block_token_ids_tuple) + ) + if block_idx >= self.esa_cfg["init_window_sz"]: + self.block_hashes[req_id][self.rank].append(str(hash_value)) + + parent_block_hash_value = hash_value + + if self.rank != 0 and not self.is_mla: + self.newqrequest_hasher = RequestHasher(self._vllm_config, self.rank) + for i, ucm_block_id in enumerate(self.block_hashes[req_id][self.rank]): + self.block_hashes[req_id][self.rank][i] = str( + self.newqrequest_hasher(ucm_block_id) + ) + def build_sparse_meta( self, scheduler_output, requests, input_batch, attn_metadata ) -> UcmSparseMetadata: - sparse_meta = ESASparseMetaData() + self._sparse_metadata_prefill = ESASparseMetaData() + self._sparse_metadata_decode = ESASparseMetaData() + self._sparse_metadata = ESASparseMetaData() + + num_sched = scheduler_output.num_scheduled_tokens + req_ids = list(getattr(input_batch, "req_ids", [])) + decode_ids = [rid for rid in req_ids if num_sched.get(rid, 0) == 1] + decode_set = set(decode_ids) + cached_reqs = scheduler_output.scheduled_cached_reqs + preempt_reqs = set() + if cached_reqs: + for req, is_preempt in zip( + cached_reqs.req_ids, cached_reqs.resumed_from_preemption + ): + if is_preempt: + preempt_reqs.add(req) for ( req_id, num_scheduled_tokens, @@ -510,19 +699,63 @@ def build_sparse_meta( req = requests[req_id] if not self.is_sparsed_request(req): continue + self.set_block_hashes(int(req_id), req.prompt_token_ids) if isinstance(attn_metadata, dict): attn_metadata = next(iter(attn_metadata.values())) - sparse_meta.add_request( - req_id, - input_batch.req_id_to_index[req_id], - num_scheduled_tokens, - req.num_computed_tokens, - req.block_ids[0], - attn_metadata.query_start_loc[input_batch.req_id_to_index[req_id]], - req.prompt_token_ids, - req.output_token_ids, - ) - self._sparse_metadata = sparse_meta + + if not self.is_mla: + self._sparse_metadata.add_request( + req_id, + input_batch.req_id_to_index[req_id], + num_scheduled_tokens, + req.num_computed_tokens, + req.block_ids[0], + attn_metadata.query_start_loc[input_batch.req_id_to_index[req_id]], + req.prompt_token_ids, + req.output_token_ids, + req_id in preempt_reqs, + self.block_hashes[int(req_id)][self.rank], + ) + + else: + attn_metadata_prefill = getattr(attn_metadata, "prefill", None) + attn_metadata_decode = getattr(attn_metadata, "decode", None) + + # 区分该req是在decode阶段还是prefill + if req_id in decode_set: + if attn_metadata_decode: + req_id_to_index_decode = input_batch.req_id_to_index[req_id] + self._sparse_metadata_decode.add_request( + req_id, + req_id_to_index_decode, + num_scheduled_tokens, + req.num_computed_tokens, + req.block_ids[0], + attn_metadata.query_start_loc[req_id_to_index_decode], + req.prompt_token_ids, + req.output_token_ids, + req_id in preempt_reqs, + self.block_hashes[int(req_id)][self.rank], + ) + + else: + req_id_to_index_prefill = ( + input_batch.req_id_to_index[req_id] - attn_metadata.num_decodes + ) + self._sparse_metadata_prefill.add_request( + req_id, + req_id_to_index_prefill, + num_scheduled_tokens, + req.num_computed_tokens, + req.block_ids[0], + attn_metadata_prefill.query_start_loc[req_id_to_index_prefill], + req.prompt_token_ids, + req.output_token_ids, + req_id in preempt_reqs, + self.block_hashes[int(req_id)][self.rank], + ) + + # self._sparse_metadata = sparse_meta def request_begin(self, request_id: ReqType, prompt_token_ids: List[int]): pass @@ -534,46 +767,71 @@ def request_finished_in_worker(self, request_id: ReqType): layer_state.repre_pool.free(layer_state.slots) del self.req_states[request_id] + def request_finished_in_scheduler(self, request_id: Union[int, str]): + """ + This is called inside "Scheduler->finish_requests" function. + Generate the metadata required by UcmSparse instance at worker-side. + """ + pass + def estimate_num_slots_sparsed(self, request: Request) -> int: - if request.num_output_tokens == 0 or not self.is_sparsed_request(request): + if request.status == RequestStatus.PREEMPTED: + self.preempt_req_output_tokens[request.request_id] = ( + request.num_output_tokens + ) + + if request.request_id in self.preempt_req_output_tokens: + num_output_tokens = ( + request.num_output_tokens + - self.preempt_req_output_tokens[request.request_id] + ) + else: + num_output_tokens = request.num_output_tokens + + if ( + request.num_computed_tokens == 0 + or num_output_tokens == 0 + or not self.is_sparsed_request(request) + ): return INVALID_SLOT prompt_len = request.num_prompt_tokens output_len = request.num_output_tokens block_size = self._vllm_config.cache_config.block_size + sparse_range = get_sparse_range( + self.esa_cfg["init_window_sz"], + self.esa_cfg["local_window_sz"], + prompt_len, + block_size, + ) if (flaw := prompt_len % block_size) == 0: - sparse_range = prompt_len // block_size - self.esa_cfg["local_window_sz"] - local_window = block_size * self.esa_cfg["local_window_sz"] + output_len + local_window_tokens = block_size * self.esa_cfg["local_window_sz"] else: - sparse_range = math.floor(prompt_len / block_size) - ( + local_window_tokens = flaw + block_size * ( self.esa_cfg["local_window_sz"] - 1 ) - local_window = ( - flaw + block_size * (self.esa_cfg["local_window_sz"] - 1) + output_len - ) - return ( - int(sparse_range * self.esa_cfg["sparse_ratio"]) * block_size + local_window + compressed_prompt_len = ( + self.esa_cfg["init_window_sz"] * block_size + + int(sparse_range * self.esa_cfg["sparse_ratio"]) * block_size + + local_window_tokens ) + return compressed_prompt_len + output_len + + def allocate_slots(self, kv_cache_manager, request, num_slots_sparsed): + coordinator = kv_cache_manager.coordinator + block_pool = kv_cache_manager.block_pool + kv_cache_groups = kv_cache_manager.kv_cache_config.kv_cache_groups + + if request.request_id in self.preempt_req_output_tokens: + # handle preempt: get the TRUE output_len + num_output_tokens = ( + request.num_output_tokens + - self.preempt_req_output_tokens[request.request_id] + ) + else: + num_output_tokens = request.num_output_tokens - def allocate_slots( - self, request, num_slots_sparsed, coordinator, block_pool, kv_cache_groups - ): - block_size = self._vllm_config.cache_config.block_size - num_blocks_need = math.ceil(num_slots_sparsed / block_size) - allocated_blocks = coordinator.get_blocks(request.request_id)[0] - returned_blocks = [] - kept_blocks = [] - num_blocks_original = len(allocated_blocks) - for i, block in enumerate(allocated_blocks): - if i >= num_blocks_original - num_blocks_need: - kept_blocks.append(block) - else: - returned_blocks.append(block) - block_pool._maybe_evict_cached_block(block) - block_pool.free_blocks(returned_blocks) - - coordinator.single_type_managers[0].req_to_blocks[ - request.request_id - ] = kept_blocks + if num_output_tokens == 1: + kv_cache_manager.free(request) new_computed_block_list = tuple([] for _ in range(len(kv_cache_groups))) num_blocks_to_allocate = coordinator.get_num_blocks_to_allocate( @@ -581,7 +839,10 @@ def allocate_slots( num_tokens=num_slots_sparsed, new_computed_blocks=new_computed_block_list, ) - if num_blocks_to_allocate > block_pool.get_num_free_blocks(): + manual_preempt = False + # manual_preempt = (request.num_output_tokens % 10) == 0 + if manual_preempt or num_blocks_to_allocate > block_pool.get_num_free_blocks(): return None coordinator.allocate_new_blocks(request.request_id, num_slots_sparsed) - return KVCacheBlocks(tuple([kept_blocks])) + blocks = coordinator.single_type_managers[0].req_to_blocks[request.request_id] + return KVCacheBlocks(tuple([blocks])) diff --git a/ucm/sparse/esa/retrieval/CMakeLists.txt b/ucm/sparse/esa/retrieval/CMakeLists.txt index 3ef08d07f..aabb02e58 100644 --- a/ucm/sparse/esa/retrieval/CMakeLists.txt +++ b/ucm/sparse/esa/retrieval/CMakeLists.txt @@ -1,2 +1,17 @@ +# 添加编译目标 pybind11_add_module(retrieval_backend cpy/retrieval_backend.cpp) + +# 设置输出库的目录 set_target_properties(retrieval_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + +# 设置头文件目录,以确保 numaf.h 能找到 +target_include_directories(retrieval_backend PUBLIC + ${NUMA_INSTALL_DIR}/include + ${Torch_INCLUDE_DIRS} +) + +# 链接所需的库 +target_link_libraries(retrieval_backend PUBLIC + $<$:${NUMA_INSTALL_DIR}/lib/libnuma.so> + ${Torch_LIBRARIES} +) diff --git a/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp b/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp index 6a81af73e..17d1e92af 100644 --- a/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp +++ b/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp @@ -1,22 +1,26 @@ // retrieval_backend.cpp -#include -#include +#include #include #include #include +#include +#include #include +#include #include #include #include -#include -#include +#ifdef NUMA_ENABLED +#include +#endif +#include namespace py = pybind11; class RetrievalWorkerBackend { public: - RetrievalWorkerBackend(py::array_t data) + RetrievalWorkerBackend(py::array_t data, py::dict cpu_idx_tbl) : data_array_(data), stop_workers_(false), next_req_id_(0) { py::buffer_info info = data_array_.request(); @@ -25,26 +29,55 @@ class RetrievalWorkerBackend { data_ = static_cast(info.ptr); // Start worker threads - int n_workers = std::thread::hardware_concurrency(); - for (int i = 0; i < n_workers; ++i) { - worker_threads_.emplace_back(&RetrievalWorkerBackend::worker_loop, this); + for (auto cpu_idx : cpu_idx_tbl) { + py::list core_ids = cpu_idx.second.cast(); + + for (size_t i = 0; i < core_ids.size(); ++i) { + int core_id = core_ids[i].cast(); + worker_threads_.emplace_back(&RetrievalWorkerBackend::worker_loop, this); + + // 核心绑定代码 + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(core_id, &cpuset); // 绑定每个线程到指定的核心 + + pthread_t thread = worker_threads_.back().native_handle(); + + // 设置 CPU 亲和性 + int rc = pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset); + if (rc != 0) { + std::cerr << "Error binding thread " << i << " to CPU core " << core_id + << std::endl; + } + +#ifdef NUMA_ENABLED + int numaId = cpu_idx.first.cast(); + // 设置内存亲和性 + unsigned long nodeMask = 1UL << numaId; + rc = set_mempolicy(MPOL_BIND, &nodeMask, sizeof(nodeMask) * 8); + if (rc != 0) { + std::cerr << "Error binding memory to NUMA node " << numaId << std::endl; + } +#endif + } } } - ~RetrievalWorkerBackend() { + ~RetrievalWorkerBackend() + { { std::lock_guard lock(mutex_); stop_workers_ = true; cond_.notify_all(); } - for (auto& t: worker_threads_) t.join(); + for (auto& t : worker_threads_) t.join(); } - int submit(py::array_t query, int topk, py::array_t indexes) { + int submit(py::array_t query, int topk, py::array_t indexes) + { py::buffer_info qinfo = query.request(); py::buffer_info iinfo = indexes.request(); - if (qinfo.shape[1] != dim_) - throw std::runtime_error("Query dim mismatch"); + if (qinfo.shape[1] != dim_) throw std::runtime_error("Query dim mismatch"); if ((size_t)iinfo.shape[0] != (size_t)qinfo.shape[0]) throw std::runtime_error("Query and indexes batch mismatch"); @@ -73,12 +106,14 @@ class RetrievalWorkerBackend { return req_id; } - bool poll(int req_id) { + bool poll(int req_id) + { std::lock_guard lock(mutex_); return results_.find(req_id) != results_.end(); } - void wait(int req_id) { + void wait(int req_id) + { std::shared_ptr s; { std::lock_guard lock(mutex_); @@ -90,7 +125,8 @@ class RetrievalWorkerBackend { s->cv.wait(lk2, [&] { return s->done; }); } - py::dict get_result(int req_id) { + py::dict get_result(int req_id) + { std::lock_guard lock(mutex_); auto it = results_.find(req_id); if (it == results_.end()) throw std::runtime_error("Result not ready"); @@ -132,12 +168,13 @@ class RetrievalWorkerBackend { bool done = false; }; - void worker_loop() { + void worker_loop() + { while (true) { Request req; { std::unique_lock lock(mutex_); - cond_.wait(lock, [&]{ return stop_workers_ || !requests_.empty(); }); + cond_.wait(lock, [&] { return stop_workers_ || !requests_.empty(); }); if (stop_workers_ && requests_.empty()) return; req = std::move(requests_.front()); requests_.pop(); @@ -181,7 +218,7 @@ class RetrievalWorkerBackend { } int curr_topk = std::min((int)heap.size(), req.topk); std::partial_sort(heap.begin(), heap.begin() + curr_topk, heap.end(), - [](const auto& a, const auto& b){ return a.first > b.first; }); + [](const auto& a, const auto& b) { return a.first > b.first; }); for (int k = 0; k < curr_topk; ++k) { res.scores[b].push_back(heap[k].first); @@ -215,9 +252,10 @@ class RetrievalWorkerBackend { std::atomic next_req_id_; }; -PYBIND11_MODULE(retrieval_backend, m) { +PYBIND11_MODULE(retrieval_backend, m) +{ py::class_(m, "RetrievalWorkerBackend") - .def(py::init>()) + .def(py::init, py::dict>()) .def("submit", &RetrievalWorkerBackend::submit) .def("poll", &RetrievalWorkerBackend::poll) .def("get_result", &RetrievalWorkerBackend::get_result) diff --git a/ucm/sparse/esa/retrieval/retrieval_worker.py b/ucm/sparse/esa/retrieval/retrieval_worker.py index ebed1ed1c..7209d604e 100644 --- a/ucm/sparse/esa/retrieval/retrieval_worker.py +++ b/ucm/sparse/esa/retrieval/retrieval_worker.py @@ -1,10 +1,11 @@ import time +from collections import defaultdict import numpy as np import torch -# import retrieval_backend from ucm.sparse.esa.retrieval import retrieval_backend +from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank class RetrievalWorker: @@ -42,7 +43,19 @@ def wait(self, req_id): data = torch.rand(kv_cache_blocks, dim).to(torch.float32) print("data created", data.shape) - backend = retrieval_backend.RetrievalWorkerBackend(data) + ratio = 0.75 + total_tp_size = 4 + local_tp_rank = 0 + bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank( + total_tp_size, local_tp_rank, ratio=ratio + ) + + bind_info_dict = defaultdict(list) + for item in bind_info_list: + bind_info_dict[item[1]].append(item[0]) + bind_info_dict = dict(bind_info_dict) + + backend = retrieval_backend.RetrievalWorkerBackend(data, bind_info_dict) worker = RetrievalWorker(backend) topk = 3000 search_blocks_range = 8000 diff --git a/ucm/sparse/factory.py b/ucm/sparse/factory.py index cb1b43ae7..d5b49cf37 100644 --- a/ucm/sparse/factory.py +++ b/ucm/sparse/factory.py @@ -5,6 +5,7 @@ from ucm.logger import init_logger from ucm.sparse.base import UcmSparseBase, UcmSparseRole +from ucm.utils import Config logger = init_logger(__name__) @@ -30,9 +31,9 @@ def loader() -> type[UcmSparseBase]: def create_sparse_method( cls, config: "VllmConfig", role: UcmSparseRole ) -> UcmSparseBase: - ucm_cfg = config.kv_transfer_config.kv_connector_extra_config.get( - "ucm_sparse_config" - ) + ucm_config = Config(config.kv_transfer_config) + ucm_cfg = ucm_config.get_config().get("ucm_sparse_config") + sparse_method_name, _ = next(iter(ucm_cfg.items())) if sparse_method_name in cls._registry: sparse_method_cls = cls._registry[sparse_method_name]() diff --git a/ucm/sparse/gsa/gsa.py b/ucm/sparse/gsa/gsa.py index 3d4f4978a..b1bf1e5c1 100644 --- a/ucm/sparse/gsa/gsa.py +++ b/ucm/sparse/gsa/gsa.py @@ -1,23 +1,28 @@ import copy +import hashlib import math +import pickle import time from dataclasses import dataclass -from functools import wraps +from functools import cache, wraps from itertools import accumulate -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import torch from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.forward_context import ( ForwardContext, get_forward_context, - set_forward_context, ) from vllm.sequence import SequenceStage from vllm.utils import make_tensor_with_pad, sha256 +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import NONE_HASH from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import Request +from ucm.integration.vllm.ucm_connector import RequestHasher from ucm.sparse.base import ( INVALID_SLOT, UcmSparseBase, @@ -28,46 +33,17 @@ from ucm.sparse.gsa.prefetch.prefetch_engine import GSAPrefetchBase from ucm.sparse.utils import ( CUDA_TOPK, - LOCAL_WINDOW_SZ, MAX_BS, - MAX_TOPK_LEN, PTOPK_PREFETCH_ENABLE, SEG_PREFILL_THRESHOLD, - compute_topk_len, + gsa_config, ) -from ucm.store.factory import UcmConnectorFactory - - -def stat(func): - @wraps(func) - def wrapper(*args, **kwargs): - wrapper.call_count += 1 - start = time.perf_counter_ns() - result = func(*args, **kwargs) - end = time.perf_counter_ns() - cost = end - start - wrapper.time_costs.append(cost) - return result - - wrapper.call_count = 0 - wrapper.time_costs = [] - return wrapper - ReqType = Union[str, int] -HashType = Union[str, int] - -# TODO: add ESA specific config in kv_transfer_config -> extra_config -INIT_WINDOW_SZ = 1 -SPARSE_RATIO = 0.3 -RETRIEVAL_STRIDE = 4 class GSAReqStat: - def __init__( - self, - req_id, - ) -> None: + def __init__(self, req_id, vllm_config: VllmConfig) -> None: self.req_id = req_id self.repre_slot_mapping = [] self.calc_block_table = [] @@ -87,6 +63,15 @@ def __init__( self.init_window_kv = None self.local_window_kv = [] self.sparse_len = 0 + self.block_size = vllm_config.cache_config.block_size + self.block_hashes = None + self.num_prompt_blocks = 0 + self.reamin_map = None + self.prefetch_map = None + self._vllm_config = vllm_config + self.rank = vllm_config.parallel_config.rank + self.use_mla = vllm_config.model_config.use_mla + self.request_hasher = RequestHasher(vllm_config, 0) def step(self) -> int: return self.num_output_tokens @@ -107,12 +92,40 @@ def is_gsa(self) -> bool: def is_last_chunk(self) -> bool: return ( self.num_computed_tokens + self.num_scheduled_tokens - >= self.num_prompt_tokens + == self.num_prompt_tokens ) def get_seq_len(self) -> int: return self.num_computed_tokens + self.num_scheduled_tokens + def set_block_hashes(self, token_ids): + if self.block_hashes is not None: + return + self.block_hashes = [] + + parent_block_hash_value = compute_parent_block_hash( + self._vllm_config.model_config.model, + self._vllm_config.parallel_config.world_size, + self._vllm_config.model_config.dtype, + seed_rank=0, + ) + + for start in range(0, len(token_ids), self.block_size): + end = start + self.block_size + block_token_ids = token_ids[start:end] + if len(block_token_ids) < self.block_size: + break + curr_block_token_ids_tuple = tuple(block_token_ids) + hash_value = self.request_hasher( + (parent_block_hash_value, curr_block_token_ids_tuple) + ) + parent_block_hash_value = hash_value + + if self.rank != 0 and not self.use_mla: + self.newqrequest_hasher = RequestHasher(self._vllm_config, self.rank) + for i, ucm_block_id in enumerate(self.block_hashes): + self.block_hashes[i] = str(self.newqrequest_hasher(ucm_block_id)) + def add_req_new( self, num_scheduled_tokens, add_req_state, index_in_batch, offset ) -> None: @@ -122,12 +135,14 @@ def add_req_new( self.num_scheduled_tokens = num_scheduled_tokens self.num_prompt_tokens = len(add_req_state.prompt_token_ids) self.num_output_tokens = len(add_req_state.output_token_ids) + self.num_prompt_blocks = math.ceil(self.num_prompt_tokens / self.block_size) self.is_use_gsa = ( True if self.num_prompt_tokens > SEG_PREFILL_THRESHOLD else False ) self._init_slot(offset) if len(self.repre_slot_mapping) > len(self.blocks): self.repre_slot_mapping = self.repre_slot_mapping[: len(self.blocks)] + self.set_block_hashes(add_req_state.prompt_token_ids) def updata_req_state( self, num_scheduled_tokens, add_req_state, index_in_batch @@ -155,35 +170,36 @@ def updata_req_state( self.repre_slot_mapping = self.repre_slot_mapping[: len(self.blocks)] def _get_sparse_and_free_block(self): - if self.num_prompt_tokens == self.num_computed_tokens: - blocks_len = len(self.blocks) - if self.num_prompt_tokens > SEG_PREFILL_THRESHOLD and PTOPK_PREFETCH_ENABLE: - remain_len = compute_topk_len(blocks_len) - if remain_len > MAX_TOPK_LEN: - prefetch_len = 0 - remain_blocks_idx = list(range(remain_len)) - else: - prefetch_len = MAX_TOPK_LEN - remain_len + 1 - remain_blocks_idx = list(range(MAX_TOPK_LEN + 1)) - self.remain_idx = [] - self.prefetch_idx = [] - assert LOCAL_WINDOW_SZ < remain_len + if self.num_prompt_tokens != self.num_computed_tokens: + self.remain_idx = None + self.prefetch_idx = None + return + + blocks_len = len(self.blocks) + if self.num_prompt_tokens > SEG_PREFILL_THRESHOLD and PTOPK_PREFETCH_ENABLE: + remain_len = gsa_config.compute_topk_len(blocks_len) + if remain_len < blocks_len: + prefetch_len = min( + gsa_config.num_prefetch_blocks, blocks_len - remain_len + ) + req_idx_list = list(range(blocks_len)) + init_windows_size = gsa_config.init_windows_size self.remain_idx = ( - remain_blocks_idx[: remain_len - LOCAL_WINDOW_SZ] - + remain_blocks_idx[-LOCAL_WINDOW_SZ:] + req_idx_list[:init_windows_size] + + req_idx_list[init_windows_size - remain_len :] ) - self.prefetch_idx = remain_blocks_idx[ - remain_len - LOCAL_WINDOW_SZ : -LOCAL_WINDOW_SZ + self.prefetch_idx = req_idx_list[ + init_windows_size + - remain_len + - prefetch_len : init_windows_size + - remain_len ] self.sparse_len = remain_len + prefetch_len - else: - self.remain_idx = list(range(blocks_len)) - self.prefetch_idx = [] - self.sparse_len = blocks_len - return - else: - self.remain_idx = None - self.prefetch_idx = None + return + + self.remain_idx = list(range(blocks_len)) + self.prefetch_idx = [] + self.sparse_len = blocks_len def _init_slot(self, offset: int) -> None: self.repre_slot_mapping = list(range(len(self.blocks))) @@ -240,14 +256,12 @@ def _update_slot( class GSAMetaData(UcmSparseMetadata): - def __init__( - self, - block_size, - device, - ): + def __init__(self, vllm_config: VllmConfig): self.gsa_stats = {} - self.block_size = block_size - self.device = device + self.block_size = vllm_config.cache_config.block_size + self.device = vllm_config.device_config.device_type + self.use_mla = vllm_config.model_config.use_mla + self._vllm_config = vllm_config def get_model_input( self, @@ -256,18 +270,32 @@ def get_model_input( max_block_len, requests, input_batch, + prefetch_engine, ) -> None: - for req_id in scheduler_output.scheduled_cached_reqs.req_ids: + for index, req_id in enumerate(scheduler_output.scheduled_cached_reqs.req_ids): assert req_id in self.gsa_stats - self.gsa_stats[req_id].updata_req_state( - scheduler_output.num_scheduled_tokens[req_id], - requests[req_id], - input_batch.req_id_to_index[req_id], - ) + if scheduler_output.scheduled_cached_reqs.resumed_from_preemption[index]: + del self.gsa_stats[req_id] + prefetch_engine.del_finish_meta(req_id, False) + self.gsa_stats[req_id] = GSAReqStat(req_id, self._vllm_config) + self.gsa_stats[req_id].add_req_new( + scheduler_output.num_scheduled_tokens[req_id], + requests[req_id], + input_batch.req_id_to_index[req_id], + max_block_len * topk_kpre_map[req_id], + ) + else: + self.gsa_stats[req_id].updata_req_state( + scheduler_output.num_scheduled_tokens[req_id], + requests[req_id], + input_batch.req_id_to_index[req_id], + ) for new_req in scheduler_output.scheduled_new_reqs: if new_req.req_id in self.gsa_stats: del self.gsa_stats[new_req.req_id] - self.gsa_stats[new_req.req_id] = GSAReqStat(new_req.req_id) + self.gsa_stats[new_req.req_id] = GSAReqStat( + new_req.req_id, self._vllm_config + ) self.gsa_stats[new_req.req_id].add_req_new( scheduler_output.num_scheduled_tokens[new_req.req_id], requests[new_req.req_id], @@ -282,41 +310,29 @@ def trans_input_tensor(self, scheduler_output: SchedulerOutput): calc_repre_slot_mappings = [] batch_size = len(scheduler_output.num_scheduled_tokens.items()) query_locals = [0] * (batch_size + 1) - if CUDA_TOPK: - repre_slot_mapping = [0] * batch_size - include_mask = [0] * batch_size - exclude_mask = [0] * batch_size - for req_id, _ in scheduler_output.num_scheduled_tokens.items(): + if self.use_mla: + query_locals_prefill = [0] * (batch_size + 1) + for req_id, num_tokens in scheduler_output.num_scheduled_tokens.items(): req_in_batch = self.gsa_stats[req_id].index_in_batch calc_block_table += self.gsa_stats[req_id].calc_block_table calc_repre_slot_mappings += self.gsa_stats[req_id].calc_repre_slot_mapping - if CUDA_TOPK: - repre_slot_mapping[req_in_batch] = self.gsa_stats[ - req_id - ].repre_slot_mapping - include_mask[req_in_batch] = self.gsa_stats[req_id].include_mask - exclude_mask[req_in_batch] = self.gsa_stats[req_id].exclude_mask query_locals[req_in_batch + 1] = scheduler_output.num_scheduled_tokens[ req_id ] + if self.use_mla and self.gsa_stats[req_id].stage() == SequenceStage.PREFILL: + query_locals_prefill[req_in_batch + 1] = num_tokens query_locals = list(accumulate(query_locals)) + if self.use_mla: + query_locals_prefill = list(accumulate(query_locals_prefill)) model_input["calc_block_table"] = torch.tensor( calc_block_table, dtype=torch.int32, device="cpu" ) model_input["calc_repre_slot_mapping"] = torch.tensor( calc_repre_slot_mappings, dtype=torch.int32, device="cpu" ) - if CUDA_TOPK: - model_input["repre_slot_mapping"] = make_tensor_with_pad( - repre_slot_mapping, pad=0, dtype=torch.int32, device="cpu" - ) - model_input["include_mask"] = make_tensor_with_pad( - include_mask, pad=False, dtype=torch.uint8, device=self.device - ) - model_input["exclude_mask"] = make_tensor_with_pad( - exclude_mask, pad=False, dtype=torch.uint8, device=self.device - ) model_input["query_locals"] = query_locals + if self.use_mla: + model_input["query_locals_prefill"] = query_locals_prefill return model_input @@ -352,26 +368,29 @@ def is_exist(self, req_id: ReqType) -> bool: return False -def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> int: - block_size, num_key_heads_per_tp, head_size = block_shape - k_min_data_block_size = block_size * num_key_heads_per_tp * head_size * precision - v_min_data_block_size = k_min_data_block_size if not is_mla else 0 - layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size - if is_mla: - k_offset = layer_size * layer_id - else: - k_offset = layer_size * layer_id + layer_size // tp_size * rank - v_offset = k_offset + k_min_data_block_size - return v_offset if is_v else k_offset +@cache +def md5(input) -> int: + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + md5_bytes = hashlib.md5(input_bytes).digest() + return int.from_bytes(md5_bytes, byteorder="big") + + +@cache +def block_hash_func(parent_block_hash, curr_block_token_ids): + if not parent_block_hash: + parent_block_hash = md5("UCMHASHSEED") + curr_block_token_ids_tuple = tuple(curr_block_token_ids) + return md5((parent_block_hash, curr_block_token_ids_tuple)) class TopkCal: - def __init__(self, att_num_heads, kv_num_heads, head_size, kpre_caches): + def __init__(self, att_num_heads, kv_num_heads, head_size, kpre_caches, use_mla): self.att_num_heads = att_num_heads self.kv_num_heads = kv_num_heads self.head_size = head_size self.kpre_caches = kpre_caches self.topk_ratio = 0.3 + self.use_mla = use_mla def set_topk_param(self, repre_slot_mapping, include_mask, exclude_mask): self.repre_slot_mapping = repre_slot_mapping @@ -385,10 +404,9 @@ def set_topk_caches(self, cal_topk_id, topk_caches, topk_len_list): def cal_topk(self, intermediate_q, current_layer_id): bs = len(self.cal_topk_id) - scale = self.head_size**-0.5 head_group_num = self.att_num_heads // self.kv_num_heads q_decode = intermediate_q[self.cal_topk_id] - kpre_index = self.repre_slot_mapping[self.cal_topk_id].flatten() + kpre_index = self.repre_slot_mapping.flatten() kpre_need = self.kpre_caches[current_layer_id][kpre_index] max_norm_num = kpre_need.shape[1] kpre_out = kpre_need.unsqueeze(2).expand(-1, -1, head_group_num, -1, -1) @@ -399,12 +417,8 @@ def cal_topk(self, intermediate_q, current_layer_id): qk.reshape(bs, self.att_num_heads, blk_num, max_norm_num), dim=-1 ) dot_product_weights = attention_weights_without_norm.mean(1) - dot_product_weights.masked_fill_( - self.include_mask[self.cal_topk_id] == 1, float("inf") - ) - dot_product_weights.masked_fill_( - self.exclude_mask[self.cal_topk_id] == 1, float("-inf") - ) + dot_product_weights.masked_fill_(self.include_mask == 1, float("inf")) + dot_product_weights.masked_fill_(self.exclude_mask == 1, float("-inf")) selected_block_nums = self.topk_len_list[0] _, top_indices = torch.topk( dot_product_weights, selected_block_nums, dim=-1, sorted=False @@ -412,8 +426,53 @@ def cal_topk(self, intermediate_q, current_layer_id): self.topk_caches[current_layer_id][self.cal_topk_id] = top_indices +@cache +def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> int: + block_size, num_key_heads_per_tp, head_size = block_shape + k_min_data_block_size = block_size * num_key_heads_per_tp * head_size * precision + v_min_data_block_size = k_min_data_block_size if not is_mla else 0 + layer_size = (k_min_data_block_size + v_min_data_block_size) * ( + tp_size if not is_mla else 1 + ) + if is_mla: + k_offset = layer_size * layer_id + else: + k_offset = layer_size * layer_id + layer_size // tp_size * rank + v_offset = k_offset + k_min_data_block_size + return v_offset if is_v else k_offset + + +@cache +def compute_parent_block_hash(model_name, world_size, dtype, seed_rank=0) -> int: + meta = f"{model_name}:{world_size}:{dtype}:{seed_rank}" + meta_bytes = meta.encode("utf-8") + h_seed = hashlib.md5(meta_bytes + b"UCM_HASH_SEED").digest() + return int.from_bytes(h_seed, byteorder="big") + + +@cache +def compute_layer_offset( + block_data_size: int, + layer_id: int, + is_v: bool, + is_mla: bool, +) -> int: + layer_data_size = block_data_size if is_mla else block_data_size * 2 + + k_offset = layer_data_size * layer_id + + if is_mla: + return k_offset + + v_offset = k_offset + block_data_size + return v_offset if is_v else k_offset + + +def task_hash_func(block_ids, store_type, tensor_type): + return hash((tuple(block_ids), store_type, tensor_type)) + + class GSA(UcmSparseBase): - # handle batch def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): super().__init__(vllm_config, role) self.rank = vllm_config.parallel_config.rank @@ -433,53 +492,33 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self.layer_num = vllm_config.model_config.get_num_layers( vllm_config.parallel_config ) + self.att_num_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + self.dtype = vllm_config.model_config.dtype if PTOPK_PREFETCH_ENABLE: - config_base = self.block_size * self.element_size * self.head_size - kv_block_size = ( - config_base - * self.layer_num - * (1 if self.use_mla else self.num_head * self.total_tp_size * 2) - ) - io_size = config_base * (1 if self.use_mla else self.num_head) - nfs_config = { - "storage_backends": "./ucm/data/" + str(self.rank), - "kv_block_size": kv_block_size, - "device": self.rank, - "role": "worker", - "io_size": io_size, - } - self.connector = UcmConnectorFactory.create_connector( - "UcmNfsStore", nfs_config - ) + if role == UcmSparseRole.WORKER: + self.connector = get_kv_transfer_group().connector.store + else: + self.connector = None + self.is_python_load = not torch.cuda.is_available() if CUDA_TOPK: self.prefetch_engine = GSAPrefetchBase( - vllm_config, 16, True, False, False, 1 + vllm_config, 16, True, False, False, 1, self.is_python_load ) else: self.prefetch_engine = GSAPrefetchBase( - vllm_config, 16, True, True, False, 1 + vllm_config, 16, True, True, False, 1, self.is_python_load ) self.topk_kpre_manger = TopKAndKpreManger(MAX_BS) - self.k_cache = {} - self.v_cache = {} - self.tasks_dump = {} - self.tasks_load = {} self.gsa_metadata = None self.model_input = None self.gsa_stats = {} self.init_topk_cal(vllm_config, self.prefetch_engine) - - @classmethod - def req_state_hash(cls, req_id, layer_name): - return hash((req_id, layer_name)) - - @classmethod - def block_hash(cls, request_id, block_id): - return sha256(f"req_{request_id}_blk_{block_id}") - - @classmethod - def task_hash(cls, block_ids, store_type, tensor_type): - return hash((tuple(block_ids), store_type, tensor_type)) + self.decode_index = [] + self.copy_k_flag = [False] * self.layer_num + gsa_config.set_config(self.block_size) + self.task_load = {} def init_topk_cal( self, @@ -493,7 +532,6 @@ def init_topk_cal( ) kv_num_heads = vllm_config.model_config.get_num_kv_heads(parallel_config) head_size = vllm_config.model_config.get_head_size() - max_model_len = vllm_config.model_config.max_model_len self.gsa_offload_ops = gsa_offload_ops.CalKpreAndTopk( self.layer_num, block_size, MAX_BS, att_num_heads, head_size ) @@ -512,120 +550,39 @@ def init_topk_cal( ) if CUDA_TOPK: self.gsa_cuda_topk = TopkCal( - att_num_heads, kv_num_heads, head_size, prefetch_engine.kpre_caches - ) - - def launch_transfer_task( - self, transfer_type, block_hashes, vllm_block_ids, layer_id - ): - fn = getattr(self.connector, transfer_type) - length = len(block_hashes) - block_shape = (self.block_size, self.num_key_heads, self.head_size) - precision = self.k_cache[layer_id].untyped_storage().element_size() - # TODO: consider is_mla here - is_mla = self.use_mla - offsets_k = [ - get_offset( - block_shape, - self.rank, - self.tp_size, - precision, - layer_id, - is_v=False, - is_mla=is_mla, - ) - ] * length - offsets_v = [ - get_offset( - block_shape, - self.rank, - self.tp_size, - precision, - layer_id, - is_v=True, - is_mla=is_mla, + att_num_heads, + kv_num_heads, + head_size, + prefetch_engine.kpre_caches, + self.use_mla, ) - ] * length - key_src_tensors = [self.k_cache[layer_id][id_] for id_ in vllm_block_ids] - value_src_tensors = [self.v_cache[layer_id][id_] for id_ in vllm_block_ids] - task_k = fn(block_hashes, offsets_k, key_src_tensors) - task_v = fn(block_hashes, offsets_v, value_src_tensors) - task_k_hash = self.task_hash(block_hashes, transfer_type, "key") - task_v_hash = self.task_hash(block_hashes, transfer_type, "value") - if transfer_type == "dump": - self.tasks_dump[task_k_hash] = task_k - self.tasks_dump[task_v_hash] = task_v - if transfer_type == "load": - self.tasks_load[task_k_hash] = task_k - self.tasks_load[task_v_hash] = task_v - - def launch_transfer_task_all(self, transfer_type, block_hashes, vllm_block_ids): - fn = getattr(self.connector, transfer_type) - - block_shape = (self.block_size, self.num_key_heads, self.head_size) - precision = self.k_cache[0].untyped_storage().element_size() - # TODO: consider is_mla here - is_mla = self.use_mla - offsets_k = [] - offsets_v = [] - block_hashes_all = [] - key_src_tensors = [] - value_src_tensors = [] - for layer_id in range(self.layer_num): - length = len(block_hashes[layer_id]) - offsets_k += [ - get_offset( - block_shape, - self.rank, - self.tp_size, - precision, - layer_id, - is_v=False, - is_mla=is_mla, - ) - ] * length - offsets_v += [ - get_offset( - block_shape, - self.rank, - self.tp_size, - precision, - layer_id, - is_v=True, - is_mla=is_mla, - ) - ] * length - key_src_tensors += [ - self.k_cache[layer_id][id_] for id_ in vllm_block_ids[layer_id] - ] - value_src_tensors += [ - self.v_cache[layer_id][id_] for id_ in vllm_block_ids[layer_id] - ] - block_hashes_all += block_hashes[layer_id] - task_k = fn(block_hashes_all, offsets_k, key_src_tensors) - task_v = fn(block_hashes_all, offsets_v, value_src_tensors) - task_k_hash = self.task_hash(block_hashes_all, transfer_type, "key") - task_v_hash = self.task_hash(block_hashes_all, transfer_type, "value") - if transfer_type == "dump": - self.tasks_dump[task_k_hash] = task_k - self.tasks_dump[task_v_hash] = task_v - if transfer_type == "load": - self.tasks_load[task_k_hash] = task_k - self.tasks_load[task_v_hash] = task_v def copy_q(self, query: torch.Tensor, current_layer_id: int) -> None: ids = [-1] * len(self.prefetch_engine.req_ids_bs) for req_id in self.prefetch_engine.req_ids_bs: req_meta = self.gsa_metadata.gsa_stats[req_id] - if req_meta.is_gsa(): - index_in_batch = req_meta.index_in_batch - ids[index_in_batch] = ( - self.model_input["query_locals"][index_in_batch + 1] - 1 - ) + if not self.use_mla: + if req_meta.is_gsa(): + index_in_batch = req_meta.index_in_batch + ids[index_in_batch] = ( + self.model_input["query_locals"][index_in_batch + 1] - 1 + ) + else: + if req_meta.is_gsa(): + index_in_batch = req_meta.index_in_batch + ids[index_in_batch] = 1 if CUDA_TOPK: - self.gsa_cuda_topk.cal_topk(query[ids], current_layer_id) + if not self.use_mla: + self.gsa_cuda_topk.cal_topk( + query[ids], current_layer_id + ) ##### todo 计算的ids + else: + self.gsa_cuda_topk.cal_topk(query, current_layer_id) else: - self.gsa_q_cache[current_layer_id][: len(ids)].copy_(query[ids]) + if not self.use_mla: + self.gsa_q_cache[current_layer_id][: len(ids)].copy_(query[ids]) + else: + self.gsa_q_cache[current_layer_id][self.decode_index].copy_(query) is_cal_kpre = len(self.model_input["calc_block_table"]) > 0 self.gsa_offload_ops.add_copy_req( is_cal_kpre, current_layer_id, ids, self.gsa_q_cache[current_layer_id] @@ -637,11 +594,20 @@ def copy_k(self, layer_name: str, forward_context: ForwardContext) -> None: calc_repre_slot_mappings = self.model_input["calc_repre_slot_mapping"] if len(block_ids) > 0: attn = forward_context.no_compile_layers - key_cache_mean_out = ( - attn[layer_name] - .kv_cache[forward_context.virtual_engine][0][block_ids] - .mean(dim=1, keepdim=True) - ) + if not self.use_mla: + key_cache_mean_out = ( + attn[layer_name] + .kv_cache[forward_context.virtual_engine][0][block_ids] + .mean(dim=1, keepdim=True) + ) + else: + key_cache_mean_out = ( + attn[layer_name] + .kv_cache[forward_context.virtual_engine][block_ids] + .mean(dim=1, keepdim=True) + ) + if torch.cuda.is_available(): + key_cache_mean_out = torch.unsqueeze(key_cache_mean_out, 1) if CUDA_TOPK: self.prefetch_engine.kpre_caches[current_layer_id][ calc_repre_slot_mappings @@ -650,8 +616,13 @@ def copy_k(self, layer_name: str, forward_context: ForwardContext) -> None: self.prefetch_engine.kpre_caches[current_layer_id][ calc_repre_slot_mappings ] = key_cache_mean_out.to(dtype=torch.float32, device="cpu") - k_needed = attn[layer_name].kv_cache[forward_context.virtual_engine][0] - self.gsa_offload_ops.add_copy_req(True, current_layer_id, [], k_needed) + if not self.use_mla: + k_needed = attn[layer_name].kv_cache[forward_context.virtual_engine][0] + else: + k_needed = attn[layer_name].kv_cache[forward_context.virtual_engine] + self.gsa_offload_ops.add_copy_req( + True, current_layer_id, [], k_needed + ) ##### todo 适配kcache形状 def attention_begin( self, @@ -660,30 +631,57 @@ def attention_begin( value: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: current_layer_id = int(layer_name.split(".")[2]) if self.prefetch_engine.atb_gsa_enable and self.prefetch_engine.is_topk_cal: - self.copy_q(query, current_layer_id) - + if not self.use_mla: + self.copy_q(query, current_layer_id) + else: + if phase == "decode": + self.copy_q(query, current_layer_id) if isinstance(forward_context.attn_metadata, dict): attn_metadata = forward_context.attn_metadata[layer_name] else: attn_metadata = forward_context.attn_metadata if self.prefetch_engine.atb_gsa_enable: - if torch.cuda.is_available(): - attn_metadata.block_table = self.model_input["block_tables_mp"][ - current_layer_id - ] - attn_metadata.seq_lens = self.model_input["gsa_seq_len"][ - current_layer_id - ] + if not self.use_mla: + if torch.cuda.is_available(): + attn_metadata.block_table = self.model_input["block_tables_mp"][ + current_layer_id + ] + attn_metadata.seq_lens = self.model_input["gsa_seq_len"][ + current_layer_id + ] + else: + attn_metadata.block_tables[ + : len(self.prefetch_engine.req_ids_bs) + ].copy_(self.model_input["block_tables_mp"][current_layer_id]) + attn_metadata.seq_lens.copy_( + self.model_input["gsa_seq_len"][current_layer_id] + ) else: - attn_metadata.block_tables[ - : len(self.prefetch_engine.req_ids_bs) - ].copy_(self.model_input["block_tables_mp"][current_layer_id]) - attn_metadata.seq_lens.copy_( - self.model_input["gsa_seq_len"][current_layer_id] - ) + if phase == "decode": + if torch.cuda.is_available(): + attn_metadata.decode.block_table = self.model_input[ + "block_tables_mp" + ][current_layer_id][self.decode_index] + attn_metadata.decode.seq_lens = self.model_input["gsa_seq_len"][ + current_layer_id + ][self.decode_index] + else: + attn_metadata.decode.block_table[ + : len(self.prefetch_engine.req_ids_bs) + ].copy_( + self.model_input["block_tables_mp"][current_layer_id][ + self.decode_index + ] + ) + attn_metadata.decode.seq_lens.copy_( + self.model_input["gsa_seq_len"][current_layer_id][ + self.decode_index + ] + ) def attention_finished( self, @@ -693,89 +691,171 @@ def attention_finished( attn_output: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: - self.maybe_register_kv_cache(forward_context, layer_name) current_layer_id = int(layer_name.split(".")[2]) - self.copy_k(layer_name, forward_context) - """ - block_ids = self.model_input["calc_block_table"] - if len(block_ids) > 0: - attn = forward_context.no_compile_layers - k_needed = attn[layer_name].kv_cache[forward_context.virtual_engine][0][block_ids].cpu() - temp_k_cache = k_needed.to(torch.float32).permute(0, 2, 1, 3) - self.gsa_offload_ops.k_cache[current_layer_id][:len(block_ids)] = temp_k_cache - result = self.gsa_offload_ops.set_kpre_data_ready(current_layer_id) - if not result: - self.is_cal_kpre[current_layer_id] = True - elif self.is_cal_kpre[current_layer_id]: - result = self.gsa_offload_ops.set_kpre_data_ready(current_layer_id) - if result: - self.is_cal_kpre[current_layer_id] = False - """ - if not PTOPK_PREFETCH_ENABLE: + if not self.copy_k_flag[current_layer_id]: + self.copy_k(layer_name, forward_context) + self.copy_k_flag[current_layer_id] = True + if self.use_mla and torch.cuda.is_available(): return - block_hashes = [] - block_ids = [] for req_id in self.prefetch_engine.req_ids_bs: - offset = ( - self.prefetch_engine.max_block_len - * self.topk_kpre_manger.cache_map[req_id] - ) - block_hashes += [ - f"{self.block_hash(req_id, id_ - offset)}" - for id_ in self.gsa_metadata.gsa_stats[req_id].calc_repre_slot_mapping - ] - block_ids += self.gsa_metadata.gsa_stats[req_id].calc_block_table - if len(block_hashes) > 0: - if torch.cuda.is_available(): - torch.cuda.current_stream().synchronize() - else: - torch.npu.current_stream().synchronize() - if current_layer_id == 0: - ret = self.connector.create(block_hashes) - self.launch_transfer_task("dump", block_hashes, block_ids, current_layer_id) - self.wait_all_task_done("dump") - if current_layer_id == self.layer_num - 1: - self.connector.commit(block_hashes, True) - - def wait_all_task_done(self, transfer_type): - if transfer_type == "dump": - for _, task in self.tasks_dump.items(): - ret = self.connector.wait(task) - self.tasks_dump.clear() + assert req_id in self.gsa_metadata.gsa_stats + req_meta = self.gsa_metadata.gsa_stats[req_id] + if ( + req_meta.is_last_chunk() + and req_meta.num_prompt_tokens > SEG_PREFILL_THRESHOLD + and PTOPK_PREFETCH_ENABLE + ): + blocks_len = len(self.gsa_metadata.gsa_stats[req_id].blocks) + remain_len = gsa_config.compute_topk_len(blocks_len) + prefetch_len = min( + gsa_config.num_prefetch_blocks, blocks_len - remain_len + ) + topk_value = self.last_chunk_topk_cal( + req_meta, query, current_layer_id, remain_len + prefetch_len + ) + + if self.gsa_metadata.gsa_stats[req_id].reamin_map == None: + self.gsa_metadata.gsa_stats[req_id].reamin_map = [ + None + ] * self.layer_num + self.gsa_metadata.gsa_stats[req_id].prefetch_map = [ + None + ] * self.layer_num + + self.kvcache_init_last_chunk( + forward_context, layer_name, topk_value, req_id + ) + + if self.gsa_metadata.gsa_stats[req_id].topk_buf_tmp == None: + self.gsa_metadata.gsa_stats[req_id].topk_buf_tmp = torch.zeros( + (self.layer_num, len(topk_value)), + dtype=torch.int32, + device="cpu", + ) + self.gsa_metadata.gsa_stats[req_id].topk_buf_tmp[ + current_layer_id + ] = topk_value + + def last_chunk_topk_cal(self, req_meta, query, current_layer_id, first_topk_len): + index_in_batch = req_meta.index_in_batch + bs = 1 + if not self.use_mla: + cal_topk_id = [self.model_input["query_locals"][index_in_batch + 1] - 1] else: - for _, task in self.tasks_load.items(): - ret = self.connector.wait(task) - self.tasks_load.clear() - - def check_all_task_is_done(self, transfer_type): - if transfer_type == "dump": - for _, task in self.tasks_dump.items(): - ret = self.connector.check(task) - if ret == -1: - return False - self.tasks_dump.clear() - return True + cal_topk_id = [ + self.model_input["query_locals_prefill"][index_in_batch + 1] - 1 + ] + head_group_num = self.att_num_heads // self.num_key_heads + q_decode = query[cal_topk_id] + + include_mask = torch.tensor( + req_meta.include_mask, dtype=torch.uint8, device=self.device + ) + exclude_mask = torch.tensor( + req_meta.exclude_mask, dtype=torch.uint8, device=self.device + ) + if CUDA_TOPK: + kpre_index = torch.tensor( + req_meta.repre_slot_mapping, dtype=torch.int32, device=self.device + ) + kpre_need = self.prefetch_engine.kpre_caches[current_layer_id][kpre_index] else: - for _, task in self.tasks_load.items(): - ret = self.connector.check(task) - if ret == -1: - return False - self.tasks_load.clear() - return True + kpre_index = torch.tensor( + req_meta.repre_slot_mapping, dtype=torch.int32, device="cpu" + ) + kpre_need = self.prefetch_engine.kpre_caches[current_layer_id][ + kpre_index + ].to(device=self.device, dtype=self.dtype) - def maybe_register_kv_cache( - self, forward_context: ForwardContext, layer_name - ) -> None: + max_norm_num = kpre_need.shape[1] + kpre_out = kpre_need.unsqueeze(2).expand(-1, -1, head_group_num, -1, -1) + kpre_out = kpre_out.reshape(bs, -1, self.att_num_heads, self.head_size) + blk_num = kpre_out.shape[1] // max_norm_num + qk = torch.einsum("bij,bmij->bim", q_decode, kpre_out) + attention_weights_without_norm, _ = torch.max( + qk.reshape(bs, self.att_num_heads, blk_num, max_norm_num), dim=-1 + ) + dot_product_weights = attention_weights_without_norm.mean(1) + dot_product_weights.masked_fill_(include_mask == 1, float("inf")) + dot_product_weights.masked_fill_(exclude_mask == 1, float("-inf")) + _, top_indices = torch.topk(dot_product_weights, first_topk_len, dim=-1) + return top_indices[0].cpu() + + def kvcache_init_last_chunk( + self, forward_context: ForwardContext, layer_name, topk_value, req_id + ): current_layer_id = int(layer_name.split(".")[2]) - attn = forward_context.no_compile_layers[layer_name] - kv_cache = attn.kv_cache[forward_context.virtual_engine] - # TODO: consider is_mla here - self.k_cache[current_layer_id] = kv_cache[0] - self.v_cache[current_layer_id] = kv_cache[1] - self.block_size = self.k_cache[current_layer_id].shape[1] - self.num_key_heads = self.k_cache[current_layer_id].shape[2] - self.head_size = self.k_cache[current_layer_id].shape[3] + blocks_len = len(self.gsa_metadata.gsa_stats[req_id].blocks) + remain_len = gsa_config.compute_topk_len(blocks_len) + prefetch_len = min(gsa_config.num_prefetch_blocks, blocks_len - remain_len) + req_idx_list = list(range(blocks_len)) + init_windows_size = gsa_config.init_windows_size + remain_idx = ( + req_idx_list[:init_windows_size] + + req_idx_list[init_windows_size - remain_len - prefetch_len :] + ) + assert len(remain_idx) == len(topk_value) + mv_map, reamin_map, prefetch_map = self.get_mv_map( + self.gsa_metadata.gsa_stats[req_id].blocks, + remain_idx, + topk_value.tolist(), + remain_len, + ) + self.gsa_metadata.gsa_stats[req_id].reamin_map[current_layer_id] = reamin_map + self.gsa_metadata.gsa_stats[req_id].prefetch_map[ + current_layer_id + ] = prefetch_map + if not self.use_mla: + layer_k_cache = forward_context.no_compile_layers[layer_name].kv_cache[ + forward_context.virtual_engine + ][0] + layer_v_cache = forward_context.no_compile_layers[layer_name].kv_cache[ + forward_context.virtual_engine + ][1] + else: + layer_k_cache = forward_context.no_compile_layers[layer_name].kv_cache[ + forward_context.virtual_engine + ] + for block_id in mv_map: + layer_k_cache[mv_map[block_id]].copy_(layer_k_cache[block_id]) + if not self.use_mla: + layer_v_cache[mv_map[block_id]].copy_(layer_v_cache[block_id]) + + def get_mv_map(self, blocks, remain_idxs, topk_values, remain_len): + mv_map = {} + free_block = [] + hit_block = [] + miss_block = [] + remain_map = {} + prefetch_map = {} + new_block = [None] * len(topk_values) + for index, idx in enumerate(topk_values): + if idx in remain_idxs: + new_block[index] = blocks[idx] + hit_block.append(idx) + else: + miss_block.append(idx) + + for idx in remain_idxs: + if idx not in hit_block: + free_block.append(idx) + + for index in range(len(new_block)): + if new_block[index] == None: + one_free_idx = free_block.pop(0) + new_block[index] = blocks[one_free_idx] + idx = topk_values[index] + mv_map[blocks[idx]] = blocks[one_free_idx] + + for index in range(len(new_block)): + idx = topk_values[index] + if index < remain_len: + remain_map[idx] = new_block[index] + else: + prefetch_map[idx] = new_block[index] + return mv_map, remain_map, prefetch_map def build_gsa_metadata( self, scheduler_output: SchedulerOutput, requests, input_batch @@ -784,7 +864,7 @@ def build_gsa_metadata( if not self.topk_kpre_manger.is_exist(req_id): index = self.topk_kpre_manger.alloc(req_id) assert index != None - gsa_meta = GSAMetaData(self.block_size, self.device) + gsa_meta = GSAMetaData(self._vllm_config) gsa_meta.gsa_stats = self.gsa_stats self.model_input = gsa_meta.get_model_input( scheduler_output, @@ -792,11 +872,13 @@ def build_gsa_metadata( self.prefetch_engine.max_block_len, requests, input_batch, + self.prefetch_engine, ) self.gsa_stats = gsa_meta.gsa_stats return gsa_meta def execute_begin(self, scheduler_output: SchedulerOutput): + self.copy_k_flag = [False] * self.layer_num batch_size = len(scheduler_output.num_scheduled_tokens.items()) req_ids = [0] * batch_size block_table_ori = [0] * batch_size @@ -806,8 +888,9 @@ def execute_begin(self, scheduler_output: SchedulerOutput): req_ids[req_in_batch] = req_id block_table_ori[req_in_batch] = self.gsa_metadata.gsa_stats[req_id].blocks topk_kpre_maps[req_in_batch] = self.topk_kpre_manger.cache_map[req_id] + is_topk_done = self.gsa_offload_ops.is_calculate_finish() - self.prefetch_engine.model_input_del( + self.prefetch_engine.model_input_deal( req_ids, block_table_ori, topk_kpre_maps, @@ -819,51 +902,106 @@ def execute_begin(self, scheduler_output: SchedulerOutput): self._start_topk_cal() def execute_finished(self): - self.prefetch_engine.deal_async_prefetch(self.rank, self.gsa_metadata) - if not PTOPK_PREFETCH_ENABLE: - return + kv_caches = [None] * self.layer_num forward_context = get_forward_context() attn = forward_context.no_compile_layers - is_load_done = self.check_all_task_is_done("load") - self.gsa_stats = self.gsa_metadata.gsa_stats - self._gsa_sparse_local_kv() - if ( - is_load_done - and self.prefetch_engine.is_prefetch_flag - and self.prefetch_engine.prefetch_engine_c.get_prefetch_status() - ): - self.prefetch_engine.is_prefetch_flag = False - all_need_load_block = ( - self.prefetch_engine.prefetch_engine_c.obtain_load_blocks() + for layer_name in attn.keys(): + if self.use_mla and "mlp.experts" in layer_name: + continue + kv_cache = attn[layer_name].kv_cache[forward_context.virtual_engine] + layer_id = int(layer_name.split(".")[2]) + kv_caches[layer_id] = kv_cache + if PTOPK_PREFETCH_ENABLE: + if self.is_python_load: + is_prefetch_done = self.check_transfer_task_done() + else: + is_prefetch_done = ( + self.prefetch_engine.prefetch_engine_c.get_prefetch_status() + ) + all_free_block_ids, all_miss_ids = self.prefetch_engine.deal_async_prefetch( + is_prefetch_done, + self.gsa_metadata, + kv_caches, + self.connector.cc_store(), ) - all_miss_idx = self.prefetch_engine.prefetch_engine_c.obtain_miss_idxs() - block_hashes_load_all = {} - block_ids_load_all = {} - num_load_blocks = 0 - for layer_name in attn.keys(): - layer_id = int(layer_name.split(".")[2]) - self.k_cache[layer_id] = attn[layer_name].kv_cache[ - forward_context.virtual_engine - ][0] - self.v_cache[layer_id] = attn[layer_name].kv_cache[ - forward_context.virtual_engine - ][1] - block_hashes_load = [] - block_ids_load = [] - for index, req_id in enumerate(self.prefetch_engine.req_ids_bs): - load_len = len(all_need_load_block[index][layer_id]) - block_hashes_load += [ - f"{self.block_hash(req_id, id_)}" - for id_ in all_miss_idx[index][layer_id][:load_len] - ] - block_ids_load += all_need_load_block[index][layer_id] - num_load_blocks += len(block_hashes_load) - block_hashes_load_all[layer_id] = block_hashes_load - block_ids_load_all[layer_id] = block_ids_load - if num_load_blocks > 0: - self.launch_transfer_task_all( - "load", block_hashes_load_all, block_ids_load_all + if self.is_python_load: + self.launch_transfer_task(all_free_block_ids, all_miss_ids, kv_caches) + else: + self.prefetch_engine.deal_async_prefetch( + False, self.gsa_metadata, kv_caches, None + ) + + def launch_transfer_task(self, all_free_block_ids, all_miss_ids, kv_caches): + if all_free_block_ids == None: + return + fn = getattr(self.connector, "load") + precision = self.element_size + if self.use_mla: + block_data_size = kv_caches[0].numel() * precision + else: + block_data_size = kv_caches[0][0].numel() * precision + + offsets_k = [] + key_src_tensors = [] + block_hashes = [] + + for req_id in all_free_block_ids.keys(): + req_block_hash = self.gsa_metadata.gsa_stats[req_id].block_hashes + for layer_id in range(self.layer_num): + length = len(all_free_block_ids[req_id][layer_id]) + if length == 0: + continue + + offset_k = compute_layer_offset( + block_data_size, + layer_id, + is_v=False, + is_mla=self.use_mla, ) + offsets_k += [offset_k] * length + block_hashes += [ + req_block_hash[i] for i in all_miss_ids[req_id][layer_id] + ] + + if not self.use_mla: + key_src_tensors += [ + kv_caches[layer_id][0][_id] + for _id in all_free_block_ids[req_id][layer_id] + ] + offset_v = compute_layer_offset( + block_data_size, + layer_id, + is_v=True, + is_mla=self.use_mla, + ) + offsets_k += [offset_v] * length + block_hashes += [ + req_block_hash[i] for i in all_miss_ids[req_id][layer_id] + ] + key_src_tensors += [ + kv_caches[layer_id][1][_id] + for _id in all_free_block_ids[req_id][layer_id] + ] + else: + key_src_tensors += [ + kv_caches[layer_id][_id] + for _id in all_free_block_ids[req_id][layer_id] + ] + + task_all = fn(block_hashes, offsets_k, key_src_tensors) + task_all_hash = task_hash_func(block_hashes, "load", "value") + self.task_load[task_all_hash] = task_all + + def check_transfer_task_done(self) -> bool: + if len(self.task_load) == 0: + return True + + for task_hash, task in self.task_load.items(): + ret = self.connector.check(task) + if not ret: + return False + self.task_load.clear() + return True def build_sparse_meta( self, scheduler_output: SchedulerOutput, requests, input_batch, attn_metadata @@ -871,8 +1009,13 @@ def build_sparse_meta( self.gsa_metadata = self.build_gsa_metadata( scheduler_output, requests, input_batch ) - if PTOPK_PREFETCH_ENABLE: - self._init_sparse_local_kv(scheduler_output, requests) + num_sched = scheduler_output.num_scheduled_tokens + req_ids = list(getattr(input_batch, "req_ids", [])) + self.decode_index = [ + input_batch.req_id_to_index[rid] + for rid in req_ids + if num_sched.get(rid, 0) == 1 + ] def request_begin(self, request_id: ReqType, prompt_token_ids: List[int]): pass @@ -881,7 +1024,8 @@ def request_finished_in_scheduler(self, request_id: ReqType): pass def request_finished_in_worker(self, request_id: ReqType): - self.topk_kpre_manger.free(request_id) + if self.topk_kpre_manger.is_exist(request_id): + self.topk_kpre_manger.free(request_id) if request_id in self.gsa_stats: del self.gsa_stats[request_id] self.prefetch_engine.del_finish_meta(request_id) @@ -892,18 +1036,18 @@ def update_state_after_alloc(self, request: Request, num_blocks: int): def estimate_num_slots_sparsed(self, request: Request) -> int: if not PTOPK_PREFETCH_ENABLE: return INVALID_SLOT - if request.num_output_tokens == 0: + if ( + request.num_output_tokens == 0 + or request.num_prompt_tokens < self.block_size + ): return INVALID_SLOT if request.num_prompt_tokens <= SEG_PREFILL_THRESHOLD: return INVALID_SLOT block_size = self._vllm_config.cache_config.block_size num_prompt_blocks = math.ceil(request.num_prompt_tokens / block_size) num_all_blocks = math.ceil(request.num_tokens / block_size) - topk_len = compute_topk_len(num_prompt_blocks) - if topk_len > MAX_TOPK_LEN: - prefetch_len = 0 - else: - prefetch_len = MAX_TOPK_LEN - topk_len + 1 + topk_len = gsa_config.compute_topk_len(num_prompt_blocks) + prefetch_len = min(gsa_config.num_prefetch_blocks, num_prompt_blocks - topk_len) num_sparse_blocks = num_all_blocks - num_prompt_blocks + topk_len + prefetch_len flaw = request.num_tokens % block_size if flaw: @@ -912,120 +1056,103 @@ def estimate_num_slots_sparsed(self, request: Request) -> int: return num_tokens_sparsed def _start_topk_cal(self) -> None: - cal_topk_id = [] - is_decode = [] - topk_len_list = [] - repre_slot_mappings = [] - calc_block_tables = [] - calc_repre_slot_mappings = [] - for req_id in self.prefetch_engine.req_ids_bs: - req_meta = self.gsa_metadata.gsa_stats[req_id] - if req_meta.is_gsa(): - cal_topk_id.append(req_meta.index_in_batch) - is_decode.append(True) - one_topk_len = compute_topk_len(len(req_meta.blocks)) - topk_len_list.append(one_topk_len) - else: - is_decode.append(False) - repre_slot_mappings.append(req_meta.repre_slot_mapping) - calc_block_tables = self.model_input["calc_block_table"] - calc_repre_slot_mappings += req_meta.calc_repre_slot_mapping - if CUDA_TOPK and len(topk_len_list) != 0: - topk_len_list = [max(topk_len_list)] * len(topk_len_list) - self.gsa_offload_ops.set_common_param(cal_topk_id, is_decode) - if len(calc_block_tables) != 0: - self.gsa_offload_ops.set_kpre_param( - calc_block_tables, calc_repre_slot_mappings - ) if self.prefetch_engine.atb_gsa_enable and self.prefetch_engine.is_topk_cal: + cal_topk_id = [] + is_decode = [] + topk_len_list = [] + repre_slot_mappings = [] + repre_slot_mappings_all = [] + include_masks = [] + exclude_masks = [] + for req_id in self.prefetch_engine.req_ids_bs: + req_meta = self.gsa_metadata.gsa_stats[req_id] + if req_meta.is_gsa(): + cal_topk_id.append(req_meta.index_in_batch) + is_decode.append(True) + one_topk_len = ( + gsa_config.compute_topk_len(len(req_meta.blocks)) + + gsa_config.num_prefetch_blocks + ) + topk_len_list.append(one_topk_len) + if CUDA_TOPK: + include_masks.append(req_meta.include_mask) + exclude_masks.append(req_meta.exclude_mask) + repre_slot_mappings.append(req_meta.repre_slot_mapping) + else: + is_decode.append(False) + repre_slot_mappings_all.append(req_meta.repre_slot_mapping) + + if CUDA_TOPK and len(topk_len_list) != 0: + topk_len_list = [max(topk_len_list)] * len(topk_len_list) + repre_slot_mappings = make_tensor_with_pad( + repre_slot_mappings, pad=0, dtype=torch.int32, device=self.device + ) + include_masks = make_tensor_with_pad( + include_masks, pad=False, dtype=torch.uint8, device=self.device + ) + exclude_masks = make_tensor_with_pad( + exclude_masks, pad=True, dtype=torch.uint8, device=self.device + ) + self.gsa_offload_ops.set_common_param(cal_topk_id, is_decode) + if len(self.model_input["calc_block_table"]) != 0: + self.gsa_offload_ops.set_kpre_param( + self.model_input["calc_block_table"], [] + ) + if CUDA_TOPK: self.gsa_cuda_topk.set_topk_param( - self.model_input["repre_slot_mapping"], - self.model_input["include_mask"], - self.model_input["exclude_mask"], + repre_slot_mappings, + include_masks, + exclude_masks, ) self.gsa_cuda_topk.set_topk_caches( cal_topk_id, self.model_input["topk_caches"], topk_len_list ) else: - self.gsa_offload_ops.set_topk_param(repre_slot_mappings) + self.gsa_offload_ops.set_topk_param(repre_slot_mappings_all) self.gsa_offload_ops.set_topk_cache( self.model_input["topk_caches"], topk_len_list ) - def _init_sparse_local_kv( - self, scheduler_output: SchedulerOutput, requests - ) -> None: - forward_context = get_forward_context() - attn = forward_context.no_compile_layers - for req_id, _ in scheduler_output.num_scheduled_tokens.items(): - if ( - self.gsa_metadata.gsa_stats[req_id].num_prompt_tokens - <= SEG_PREFILL_THRESHOLD - ): - return - if ( - req_id in self.gsa_metadata.gsa_stats - and self.gsa_metadata.gsa_stats[req_id].num_computed_tokens - == self.gsa_metadata.gsa_stats[req_id].num_prompt_tokens - ): - assert self.gsa_metadata.gsa_stats[req_id].remain_idx != None - local_window = self.gsa_metadata.gsa_stats[req_id].remain_idx[ - LOCAL_WINDOW_SZ * -1 : - ] - req_blocks = requests[req_id].block_ids[0] - local_blocks = [req_blocks[x] for x in local_window] - for layer_name in attn.keys(): - for index, block in enumerate(local_blocks): - attn[layer_name].kv_cache[forward_context.virtual_engine][0][ - block - ].copy_( - self.gsa_metadata.gsa_stats[req_id].local_window_kv[0][ - layer_name - ][index] - ) - attn[layer_name].kv_cache[forward_context.virtual_engine][1][ - block - ].copy_( - self.gsa_metadata.gsa_stats[req_id].local_window_kv[1][ - layer_name - ][index] - ) - - def _gsa_sparse_local_kv( - self, - ) -> None: - forward_context = get_forward_context() - attn = forward_context.no_compile_layers - for req_id in self.prefetch_engine.req_ids_bs: - assert req_id in self.gsa_metadata.gsa_stats + def allocate_slots(self, kv_cache_manager, request, num_slots_sparsed): + coordinator = kv_cache_manager.coordinator + block_pool = kv_cache_manager.block_pool + kv_cache_groups = kv_cache_manager.kv_cache_config.kv_cache_groups + if ( + request.num_prompt_tokens + 1 == request.num_tokens + and request.num_tokens % self.block_size == 1 + ): + num_blocks_need = math.ceil(num_slots_sparsed / self.block_size) - 1 + else: + num_blocks_need = math.ceil(num_slots_sparsed / self.block_size) + allocated_blocks = coordinator.get_blocks(request.request_id)[0] + returned_blocks = [] + kept_blocks = [] + num_blocks_original = len(allocated_blocks) + init_windows_size = gsa_config.init_windows_size + for i, block in enumerate(allocated_blocks): if ( - self.gsa_metadata.gsa_stats[req_id].stage() == SequenceStage.PREFILL - and self.gsa_metadata.gsa_stats[req_id].is_last_chunk() + i >= num_blocks_original - num_blocks_need + init_windows_size + or i < init_windows_size ): - if ( - self.gsa_metadata.gsa_stats[req_id].num_prompt_tokens - <= SEG_PREFILL_THRESHOLD - ): - return - local_blocks = self.gsa_metadata.gsa_stats[req_id].blocks[ - LOCAL_WINDOW_SZ * -1 : - ] - k_cache = {} - v_cache = {} - for layer_name in attn.keys(): - k_cache[layer_name] = [] - v_cache[layer_name] = [] - for block in local_blocks: - k_cache[layer_name].append( - attn[layer_name] - .kv_cache[forward_context.virtual_engine][0][block] - .clone() - ) - v_cache[layer_name].append( - attn[layer_name] - .kv_cache[forward_context.virtual_engine][1][block] - .clone() - ) - self.gsa_metadata.gsa_stats[req_id].local_window_kv.append(k_cache) - self.gsa_metadata.gsa_stats[req_id].local_window_kv.append(v_cache) + kept_blocks.append(block) + else: + returned_blocks.append(block) + block.ref_cnt = 1 + block_pool._maybe_evict_cached_block(block) + block_pool.free_blocks(returned_blocks) + + coordinator.single_type_managers[0].req_to_blocks[ + request.request_id + ] = kept_blocks + + new_computed_block_list = tuple([] for _ in range(len(kv_cache_groups))) + num_blocks_to_allocate = coordinator.get_num_blocks_to_allocate( + request_id=request.request_id, + num_tokens=num_slots_sparsed, + new_computed_blocks=new_computed_block_list, + ) + if num_blocks_to_allocate > block_pool.get_num_free_blocks(): + return None + coordinator.allocate_new_blocks(request.request_id, num_slots_sparsed) + return KVCacheBlocks(tuple([kept_blocks])) diff --git a/ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp b/ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp index e5e87ae21..658520192 100644 --- a/ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp +++ b/ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp @@ -1,34 +1,31 @@ +#include "select_topk_block.h" #include -#include +#include #include #include -#include -#include "select_topk_block.h" +#include namespace SelectTopkBlock { #define OMP_THREAD_NUM 16u -bool TopkBlockSelector::ValidateParameters(float* q, const float* kRepre, - uint32_t numBlock, uint32_t kHead, uint32_t qHead, - uint32_t numKrepre, uint32_t headSize) +bool TopkBlockSelector::ValidateParameters(float* q, const float* kRepre, uint32_t numBlock, + uint32_t kHead, uint32_t qHead, uint32_t numKrepre, + uint32_t headSize) { - return (q != nullptr) && (kRepre != nullptr) && - (numBlock > 0) && (kHead > 0) && (qHead > 0) && + return (q != nullptr) && (kRepre != nullptr) && (numBlock > 0) && (kHead > 0) && (qHead > 0) && (numKrepre > 0) && (headSize > 0); } -void TopkBlockSelector::TopKImpl(const float* scores, uint32_t numScores, uint32_t k, int32_t* topkIndices) +void TopkBlockSelector::TopKImpl(const float* scores, uint32_t numScores, uint32_t k, + int32_t* topkIndices) { if (startWindow_ + endWindow_ >= numScores || k >= numScores || k == 0) { - for (uint32_t i = 0; i < numScores; ++i) { - topkIndices[i] = i; - } + for (uint32_t i = 0; i < numScores; ++i) { topkIndices[i] = i; } return; } uint32_t idx = 0; - for (uint32_t i = 0; i < startWindow_; ++i) { - topkIndices[idx++] = i; - } + for (uint32_t i = 0; i < startWindow_; ++i) { topkIndices[idx++] = i; } + for (uint32_t i = 0; i < endWindow_; ++i) { topkIndices[idx++] = numScores - endWindow_ + i; } int32_t midCount = k - startWindow_ - endWindow_; if (midCount > 0) { std::vector middleIndices; @@ -36,23 +33,16 @@ void TopkBlockSelector::TopKImpl(const float* scores, uint32_t numScores, uint32 for (uint32_t i = startWindow_; i < numScores - endWindow_; ++i) { middleIndices.push_back(i); } - std::stable_sort(middleIndices.begin(), middleIndices.end(), - [scores](uint32_t lhs, uint32_t rhs) { - return scores[lhs] > scores[rhs]; - }); - for (int32_t i = 0; i < midCount; ++i) { - topkIndices[idx++] = middleIndices[i]; - } - } - for (uint32_t i = 0; i < endWindow_; ++i) { - topkIndices[idx++] = numScores - endWindow_ + i; + std::stable_sort( + middleIndices.begin(), middleIndices.end(), + [scores](uint32_t lhs, uint32_t rhs) { return scores[lhs] > scores[rhs]; }); + for (int32_t i = 0; i < midCount; ++i) { topkIndices[idx++] = middleIndices[i]; } } - std::sort(topkIndices, topkIndices + k); } -float TopkBlockSelector::ComputeBlockScore(float* qMean, const float* blockBase, - uint32_t kHead, uint32_t numKrepre, - uint32_t headSize, const VecProductClass& vecProduct) +float TopkBlockSelector::ComputeBlockScore(float* qMean, const float* blockBase, uint32_t kHead, + uint32_t numKrepre, uint32_t headSize, + const VecProductClass& vecProduct) { const size_t headOffset = headSize; const size_t normOffset = headSize; @@ -81,8 +71,10 @@ const VecProductClass& TopkBlockSelector::ThreadLocalVecProduct::GetInstance() return instance; } -std::vector TopkBlockSelector::ComputeKQDotScores(const float* __restrict qMean, const float* __restrict kRepre, - uint32_t numBlock, uint32_t kHead, uint32_t numKrepre, uint32_t headSize) +std::vector TopkBlockSelector::ComputeKQDotScores(const float* __restrict qMean, + const float* __restrict kRepre, + uint32_t numBlock, uint32_t kHead, + uint32_t numKrepre, uint32_t headSize) { std::vector blockScores(numBlock, 0.0f); const size_t blockOffset = static_cast(kHead * numKrepre * headSize); @@ -93,16 +85,16 @@ std::vector TopkBlockSelector::ComputeKQDotScores(const float* __restrict if (idxBlock + 1 < numBlock) { __builtin_prefetch(kRepre + (idxBlock + 1) * blockOffset, 0, 1); } - blockScores[idxBlock] = ComputeBlockScore(const_cast(qMean), blockBase, kHead, numKrepre, headSize, vecProduct); + blockScores[idxBlock] = ComputeBlockScore(const_cast(qMean), blockBase, kHead, + numKrepre, headSize, vecProduct); } return blockScores; } -void TopkBlockSelector::ComputeQHeadMean(float* __restrict q, uint32_t kHead, uint32_t qHead, uint32_t headSize) +void TopkBlockSelector::ComputeQHeadMean(float* __restrict q, uint32_t kHead, uint32_t qHead, + uint32_t headSize) { - if (kHead == qHead) { - return; - } + if (kHead == qHead) { return; } const VecProductClass& vecProduct = ThreadLocalVecProduct::GetInstance(); const uint32_t groupSize = qHead / kHead; for (uint32_t kIdx = 0; kIdx < kHead; ++kIdx) { @@ -113,28 +105,25 @@ void TopkBlockSelector::ComputeQHeadMean(float* __restrict q, uint32_t kHead, ui } } -void TopkBlockSelector::SelectTopK(float* q, const float* kRepre, - uint32_t numBlock, uint32_t kHead, uint32_t qHead, - uint32_t numKrepre, uint32_t headSize, +void TopkBlockSelector::SelectTopK(float* q, const float* kRepre, uint32_t numBlock, uint32_t kHead, + uint32_t qHead, uint32_t numKrepre, uint32_t headSize, uint32_t topkLength, int32_t* topkResult) { if (!ValidateParameters(q, kRepre, numBlock, kHead, qHead, numKrepre, headSize) || topkResult == nullptr || topkLength == 0) { - return; + return; } ComputeQHeadMean(q, kHead, qHead, headSize); - const std::vector scores = ComputeKQDotScores(q, kRepre, numBlock, - kHead, numKrepre, headSize); + const std::vector scores = + ComputeKQDotScores(q, kRepre, numBlock, kHead, numKrepre, headSize); TopKImpl(scores.data(), numBlock, topkLength, topkResult); } void TopkBlockSelector::SelectTopKBS(const std::vector& qCacheVec, const std::vector& kfCacheVec, - const std::vector& topkCacheVec, - uint32_t numBatch, - const std::vector& numBlockVec, - uint32_t kHead, uint32_t qHead, - uint32_t numKrepre, uint32_t headSize, + const std::vector& topkCacheVec, uint32_t numBatch, + const std::vector& numBlockVec, uint32_t kHead, + uint32_t qHead, uint32_t numKrepre, uint32_t headSize, const std::vector& topkLengthVec) { for (uint32_t bs = 0; bs < numBatch; ++bs) { @@ -143,9 +132,8 @@ void TopkBlockSelector::SelectTopKBS(const std::vector& qCacheVec, float* q = qCacheVec[bs]; const float* kRepre = kfCacheVec[bs]; int32_t* topkResult = topkCacheVec[bs]; - SelectTopK(q, kRepre, numBlock, kHead, qHead, - numKrepre, headSize, topkLength, topkResult); + SelectTopK(q, kRepre, numBlock, kHead, qHead, numKrepre, headSize, topkLength, topkResult); } } -} \ No newline at end of file +} // namespace SelectTopkBlock \ No newline at end of file diff --git a/ucm/sparse/gsa/prefetch/CMakeLists.txt b/ucm/sparse/gsa/prefetch/CMakeLists.txt index 4c50f843c..73a56a129 100644 --- a/ucm/sparse/gsa/prefetch/CMakeLists.txt +++ b/ucm/sparse/gsa/prefetch/CMakeLists.txt @@ -24,6 +24,7 @@ set(INCLUDE_DIRS ${PYTORCH_PATH}/include/torch/csrc/api/include ${PYTORCH_PATH}/include ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_SOURCE_DIR}/ucm/store ) set(LIBRARY_DIRS @@ -38,6 +39,7 @@ set(LIBRARIES torch_python gomp pthread + storetask ) # NPU特殊配置 diff --git a/ucm/sparse/gsa/prefetch/include/kvcache_log.h b/ucm/sparse/gsa/prefetch/include/kvcache_log.h index 38caee9df..7d446ca39 100644 --- a/ucm/sparse/gsa/prefetch/include/kvcache_log.h +++ b/ucm/sparse/gsa/prefetch/include/kvcache_log.h @@ -1,22 +1,17 @@ #ifndef ATB_KV_LOG_H #define ATB_KV_LOG_H -#include -#include -#include #include -#include -#include +#include #include +#include +#include #include -enum class LogLevel { - DEBUG, - INFO, - WARNING, - ERROR -}; +#include +#include +#include +enum class LogLevel { DEBUG, INFO, WARNING, ERROR }; -class Logger -{ +class Logger { private: std::ofstream mLogFile; LogLevel mMinLevel; @@ -25,13 +20,12 @@ class Logger static std::string LevelToString(LogLevel level) { - switch (level) - { - case LogLevel::DEBUG: return "DEBUG"; - case LogLevel::INFO: return "INFO"; - case LogLevel::WARNING: return "WARNING"; - case LogLevel::ERROR: return "ERROR"; - default: return "UNKNOWN"; + switch (level) { + case LogLevel::DEBUG: return "DEBUG"; + case LogLevel::INFO: return "INFO"; + case LogLevel::WARNING: return "WARNING"; + case LogLevel::ERROR: return "ERROR"; + default: return "UNKNOWN"; } } @@ -39,8 +33,8 @@ class Logger { auto now = std::chrono::system_clock::now(); auto nowC = std::chrono::system_clock::to_time_t(now); - auto ms = std::chrono::duration_cast( - now.time_since_epoch()) % 1000; + auto ms = + std::chrono::duration_cast(now.time_since_epoch()) % 1000; std::stringstream oss; oss << std::put_time(std::localtime(&nowC), "%Y-%m-%d %H:%M:%S"); oss << '.' << std::setfill('0') << std::setw(3) << ms.count(); @@ -48,8 +42,8 @@ class Logger } public: - Logger(const std::string &fileName, LogLevel level = LogLevel::INFO, bool enable = true) - :mMinLevel(level), mEnable(enable) + Logger(const std::string& fileName, LogLevel level = LogLevel::INFO, bool enable = true) + : mMinLevel(level), mEnable(enable) { if (enable) { mLogFile.open(fileName, std::ios::app); @@ -59,43 +53,37 @@ class Logger } } - Logger(){} + Logger() {} ~Logger() { - if (mLogFile.is_open()) { - mLogFile.close(); - } + if (mLogFile.is_open()) { mLogFile.close(); } } - void SetLevel(LogLevel level) - { - mMinLevel = level; - } + void SetLevel(LogLevel level) { mMinLevel = level; } void log(LogLevel level, const char* format, ...) { - if (level < mMinLevel || !mLogFile.is_open() || !mEnable) { - return; - } + if (level < mMinLevel || !mLogFile.is_open() || !mEnable) { return; } std::lock_guard lock(mMutex); auto now = std::chrono::system_clock::now(); auto nowC = std::chrono::system_clock::to_time_t(now); auto duration = now.time_since_epoch(); - auto millis = std::chrono::duration_cast(duration).count() % 1000; - auto micros = std::chrono::duration_cast(duration).count() % 1000; + auto millis = + std::chrono::duration_cast(duration).count() % 1000; + auto micros = + std::chrono::duration_cast(duration).count() % 1000; std::tm localTime = *std::localtime(&nowC); char timeBuffer[26]; std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", &localTime); - const char *levelStr = ""; - switch (level) - { - case LogLevel::DEBUG: levelStr = "DEBUG"; break; - case LogLevel::INFO: levelStr = "INFO"; break; - case LogLevel::WARNING: levelStr = "WARNING"; break; - case LogLevel::ERROR: levelStr = "ERROR"; break; - default: levelStr = "UNKNOWN"; break; + const char* levelStr = ""; + switch (level) { + case LogLevel::DEBUG: levelStr = "DEBUG"; break; + case LogLevel::INFO: levelStr = "INFO"; break; + case LogLevel::WARNING: levelStr = "WARNING"; break; + case LogLevel::ERROR: levelStr = "ERROR"; break; + default: levelStr = "UNKNOWN"; break; } char messageBuffer[4096]; va_list args; @@ -103,18 +91,14 @@ class Logger vsnprintf(messageBuffer, sizeof(messageBuffer), format, args); va_end(args); - mLogFile << timeBuffer << "." - << std::setfill('0') << std::setw(3) << millis << std::setw(3) - << micros << " " << "[" << levelStr << "]" - << messageBuffer; + mLogFile << timeBuffer << "." << std::setfill('0') << std::setw(3) << millis << std::setw(3) + << micros << " " << "[" << levelStr << "]" << messageBuffer; mLogFile.flush(); } void LogWOPrefix(LogLevel level, const char* format, ...) { - if (level < mMinLevel || !mLogFile.is_open() || !mEnable) { - return; - } + if (level < mMinLevel || !mLogFile.is_open() || !mEnable) { return; } std::lock_guard lock(mMutex); char messageBuffer[2048]; va_list args; diff --git a/ucm/sparse/gsa/prefetch/include/kvcache_pre.h b/ucm/sparse/gsa/prefetch/include/kvcache_pre.h index e76218e32..1ee9952fe 100644 --- a/ucm/sparse/gsa/prefetch/include/kvcache_pre.h +++ b/ucm/sparse/gsa/prefetch/include/kvcache_pre.h @@ -1,156 +1,179 @@ #ifndef ATB_KV_CACHE_PRE_H #define ATB_KV_CACHE_PRE_H -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include #include -#include -#include -#include +#include #include -#include -#include +#include +#include +#include #include -#include #include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include "../../../../store/ucmstore.h" namespace py = pybind11; -namespace ucmprefetch -{ - typedef struct { - int topkLen; - int reqID; - int layerID; - int topkIndex; - int bsIndex; - } PrefetchReqInfo; - - class ThreadPool - { - public: - static ThreadPool *GetInst() - { - static ThreadPool pool(1); - return &pool; - } - - ~ThreadPool(); - - template - auto enqueue(F&& f, Args&&... args) -> std::future::type>; - - size_t GetActiveThreads() const; - - private: - explicit ThreadPool(size_t threadCount); - std::vector workers; - std::queue> tasks; - mutable std::mutex queueMutex; - bool stop; - std::condition_variable condition; - std::atomic activeThreads{0}; - size_t maxThreads; - }; - - void MutliBSThreadFun(void *args); - - class __attribute__((visibility("hidden"))) GSAPrefetchEngineC +namespace ucmprefetch { +typedef struct { + int topkLen; + std::string reqID; + int layerID; + int topkIndex; + int bsIndex; +} PrefetchReqInfo; + +class ThreadPool { +public: + static ThreadPool* GetInst() { - private: - std::map>> mDocsTables; - std::map>> mBlocksMap; - torch::Tensor mLoadSuccessBlocks; - torch::Tensor mFreeBlock; - torch::Tensor mFreeBlockLen; - torch::Tensor mSuccessTableLen; - torch::Tensor mUseTopkIdxs; - int mLayerNum; - int mRank = -1; - uint32_t mMaxBs = 30; - int *mReqIdList = NULL; - int *mTopkLenList = NULL; - int *mBsIndexList = NULL; - uint32_t runBsLen = 0; - bool mIsLog = false; - bool mIsPrefetchDone = true; - Logger mLogger; - ThreadPool *mThreadPool; - uint32_t mDecodeStep = 0; - uint32_t mMaxTopkLen = 0; - uint32_t mMaxBlocksLen = 0; - std::unordered_set mDelSeqIds; - std::vector>> allNeedLoadBlock; - std::vector>> allMissIdxs; - - private: - void LoadKVToHBM(std::vector loadNPUBlockIDs, - std::vector missIdxs, int layerID, int reqID); - - void GetHitAndMissBlock(PrefetchReqInfo oneBsInfo, - std::unordered_set &hitBlocks, - std::map &hitBlocksIdx, - std::vector &missIdxs); - - void RunPrefetchH2D(PrefetchReqInfo oneBsInfo, - std::unordered_set &hitBlocks, - std::map &hitBlocksIdx, - std::vector &missIdxs); - - void RunOneBsPrefetch(int reqID, int topkLen, - int bsIndex, int topkIndex); - - public: - ~GSAPrefetchEngineC(); - - GSAPrefetchEngineC(torch::Tensor &freeBlock, - torch::Tensor &loadSuccessBlocks, - torch::Tensor & freeBlockLen, - torch::Tensor &successTableLen, - bool isLog); - - void SetBlocksMap(int reqID, std::vector &blockTableList, - std::vector &selectIndex); - - void CheckInputIndex(uint32_t maxLen, uint32_t index); - - void AddBlocksMap(int reqID, int idx, int blockID); - - void DelBlocksMap(int reqID); - - void SetBlockTableInfo(torch::Tensor &blockTables, - torch::Tensor &blockLengths, - torch::Tensor &inputTopkBuf, int step); - - void RunAsyncPrefetchBs(std::vector &reqIDsInput, - std::vector &topkLensInput, - std::vector &bsIndexInput, int rank); - - int CallPrefetchProcessFun(); - - void PrintMap(int reqID, int i); - - bool GetPrefetchStatus(); - - void SetPrefetchStatus(bool flag); - - std::vector>> ObtainLoadBlocks(); - - std::vector>> ObtainMissIdxs(); - - std::map>> ObtainBlocksMap(); - }; - -} // namespace uc + static ThreadPool pool(1); + return &pool; + } + + ~ThreadPool(); + + template + auto Enqueue(F&& f, Args&&... args) -> std::future::type>; + + size_t GetActiveThreads() const; + +private: + explicit ThreadPool(size_t threadCount); + std::vector workers; + std::queue> tasks; + mutable std::mutex queueMutex; + bool stop; + std::condition_variable condition; + std::atomic activeThreads{0}; + size_t maxThreads; +}; + +void MutliBSThreadFun(void* args); + +class __attribute__((visibility("hidden"))) GSAPrefetchEngineC { +private: + std::map>> mDocsTables; + std::map>> mBlocksMap; + torch::Tensor mLoadSuccessBlocks; + torch::Tensor mFreeBlock; + torch::Tensor mFreeBlockLen; + torch::Tensor mSuccessTableLen; + torch::Tensor mUseTopkIdxs; + int mLayerNum; + int mRank = -1; + uint32_t mMaxBs = 30; + std::vector mReqIdList; + int* mTopkLenList = NULL; + int* mBsIndexList = NULL; + uint32_t runBsLen = 0; + bool mIsLog = false; + bool mIsPrefetchDone = true; + bool mUseMla = false; + Logger mLogger; + ThreadPool* mThreadPool; + uint32_t mDecodeStep = 0; + uint32_t mMaxTopkLen = 0; + uint32_t mMaxBlocksLen = 0; + std::unordered_set mDelSeqIds; + std::map>> allNeedLoadBlock; + std::map>> allMissIdxs; + std::map mPromptLen; + UC::CCStore<>* mStore = nullptr; + std::vector mKvCaches; + uint32_t mBlockSize = 128; + uint32_t mTensorElemSize = 2; // fp16 + uint32_t mHeadNum = 40; + uint32_t mHeadSzie = 128; + uint32_t mTPSize = 2; + std::map> mAllBlcoksHash; + uint32_t mKVSzieBytes = 0; + uint32_t mExtraTopkLen = 16; + bool mIsPythonLoad = false; + +public: + std::mutex mMutex; + bool mStopPrefetch = false; + +private: + void LoadKVToHBM(std::vector loadNPUBlockIDs, std::vector missIdxs, int layerID, + std::string reqID); + + void GetHitAndMissBlock(PrefetchReqInfo oneBsInfo, std::unordered_set& hitBlocks, + std::map& hitBlocksIdx, std::vector& missIdxs); + + void RunPrefetchH2D(PrefetchReqInfo oneBsInfo, std::unordered_set& hitBlocks, + std::map& hitBlocksIdx, std::vector& missIdxs); + + void RunOneBsPrefetch(std::string reqID, int topkLen, int bsIndex, int topkIndex); + +public: + ~GSAPrefetchEngineC(); + + GSAPrefetchEngineC(torch::Tensor& freeBlock, torch::Tensor& loadSuccessBlocks, + torch::Tensor& freeBlockLen, torch::Tensor& successTableLen, + std::vector& kvShape, bool useMla, bool isLog, int tpSize, + int rank, int extraTopkLen, bool isPythonLoad); + + void SetBlocksMap(std::string reqID, std::vector& blockTableList, + std::vector& selectIndex, std::vector& blocksHash, + int maxIdx); + + void SetBlocksMapMultiLayer(std::string reqID, std::vector>& remainMap, + std::vector>& prefetchMap, + std::vector& blocksHash, int maxIdx); + + void CheckInputIndex(uint32_t maxLen, uint32_t index); + + void AddBlocksMap(std::string reqID, int idx, int blockID); + + void DelBlocksMap(std::string reqID); + + void DelReqIDRun(); + + void SetBlockTableInfo(torch::Tensor& blockTables, torch::Tensor& blockLengths, + torch::Tensor& inputTopkBuf, int step); + + void RunAsyncPrefetchBs(std::vector& reqIDsInput, std::vector& topkLensInput, + std::vector& bsIndexInput, std::vector& kvCaches, + void* storePtr); + + int CallPrefetchProcessFun(); + + void PrintMap(std::string reqID, int i); + + bool GetPrefetchStatus(); + + void SetPrefetchStatus(bool flag); + + void SetModelRunningStatus(bool flag); + + size_t GetOffset(uint32_t layerID, bool isV); + + size_t GetOffsetNew(uint32_t layerID, bool isV); + + std::map>> ObtainLoadBlocks(); + + std::map>> ObtainMissIdxs(); + + std::map>> ObtainBlocksMap(); + + std::map>> ObtainDocsMap(); +}; + +} // namespace ucmprefetch #endif diff --git a/ucm/sparse/gsa/prefetch/prefetch_engine.py b/ucm/sparse/gsa/prefetch/prefetch_engine.py index 38e2569ba..c38324606 100644 --- a/ucm/sparse/gsa/prefetch/prefetch_engine.py +++ b/ucm/sparse/gsa/prefetch/prefetch_engine.py @@ -11,14 +11,11 @@ from ucm.sparse.gsa.prefetch import gsa_prefetch from ucm.sparse.utils import ( - LOCAL_WINDOW_SZ, MAX_BS, - MAX_TOPK_LEN, PTOPK_PREFETCH_ENABLE, - SEG_PREFILL_THRESHOLD, VLLM_CUDA_MEM_ALIGN_KV_CACHE, align_to_256bytes, - compute_topk_len, + gsa_config, ) @@ -31,10 +28,12 @@ def __init__( is_cpu_topk: bool = False, is_max_norm: bool = False, max_norm_num: int = 1, + is_python_load: bool = False, is_prefetch: Optional[bool] = True, head_num: Optional[int] = None, is_mutli_head: Optional[bool] = None, ) -> None: + self.rank = vllm_config.parallel_config.rank self.is_cpu_topk = is_cpu_topk self.is_max_norm = is_max_norm self.async_thread = async_thread @@ -85,18 +84,24 @@ def __init__( self.device_config.device, self.dtype, torch.int64 ) self._init_tensor() + kv_shape = [self.block_size, self.num_kv_heads, self.head_size] + self.is_python_load = is_python_load self.prefetch_engine_c = gsa_prefetch.GSAPrefetchEngineC( self.prefetch_blocks, self.m_load_success_list, self.prefetch_block_len, self.block_table_len, + kv_shape, + self.use_mla, self.is_log, + self.tp_size, + self.rank, + gsa_config.num_prefetch_blocks, + self.is_python_load, ) - self.prefetch_space = 0 - self.num_token = 0 + self.topk_space = 0 self.step_time = 0 - self.is_prefetch_flag = False self.is_topk_cal = False self.select_bs_index = None self.open_gsa = True @@ -112,12 +117,12 @@ def __init__( self.atten_score = [] self.is_gsa_req_id = {} - self.min_gsa_len = math.ceil(SEG_PREFILL_THRESHOLD / self.block_size) self.topk_buf_tmp = None self.topk_bs = [] + self.is_topk_update = False - def model_input_del( + def model_input_deal( self, req_ids, block_table_ori, @@ -135,12 +140,16 @@ def model_input_del( if self.atb_gsa_enable: block_table_index = torch.tensor(self.select_bs_index, device="cpu") - self.topk_len = compute_topk_len(self._get_max_block_len()) - topk_buf_tmp = self.use_topk_caches[:, block_table_index.cpu(), :] + self.topk_len = ( + gsa_config.compute_topk_len(self._get_max_block_len(gsa_metadata)) + + gsa_config.num_prefetch_blocks + ) + topk_buf_tmp = self.use_topk_caches[:, block_table_index, :] topk_buf_tmp = topk_buf_tmp[:, :, : self.topk_len] - self.is_topk_cal = is_topk_done and self.num_token % 3 == 0 + self.is_topk_cal = is_topk_done and self.topk_space % 3 == 0 if self.is_topk_cal: self._topk_tmp_deal(gsa_metadata, topk_buf_tmp) + self.is_topk_update = True self._topk_insert_last_idx(gsa_metadata) if self.ptopk_prefetch_enable: @@ -151,9 +160,12 @@ def model_input_del( block_table_tmp = self.use_block_table[:, block_table_index, :].to( self.device_config.device ) - gen_len_tmp = self.gsa_seq_len[:, self.select_bs_index].to( - self.device_config.device - ) + if torch.cuda.is_available(): + gen_len_tmp = self.gsa_seq_len[:, self.select_bs_index].to( + self.device_config.device + ) + else: + gen_len_tmp = self.gsa_seq_len[:, self.select_bs_index] list_topk_buf = list(topk_buf_tmp.unbind(dim=0)) list_block_table = list(block_table_tmp.unbind(dim=0)) @@ -167,20 +179,21 @@ def model_input_del( def _topk_tmp_deal(self, gsa_metadata, topk_buf_tmp): for index, topk_info in enumerate(self.topk_bs): - if topk_info[1]: - if topk_info[0] in gsa_metadata.gsa_stats: - if not self.is_cpu_topk: - gsa_metadata.gsa_stats[topk_info[0]].topk_buf_tmp = ( - self.topk_buf_tmp[:, index, : topk_info[2]].cpu() - ) - else: - gsa_metadata.gsa_stats[topk_info[0]].topk_buf_tmp = ( - self.topk_buf_tmp[:, index, : topk_info[2]].clone() - ) + if topk_info[1] and topk_info[0] in gsa_metadata.gsa_stats: + if not self.is_cpu_topk: + gsa_metadata.gsa_stats[topk_info[0]].topk_buf_tmp = ( + self.topk_buf_tmp[:, index, : topk_info[2]].cpu() + ) + else: + gsa_metadata.gsa_stats[topk_info[0]].topk_buf_tmp = ( + self.topk_buf_tmp[:, index, : topk_info[2]].clone() + ) self.topk_bs = [] for index, req_id in enumerate(self.req_ids_bs): - one_block_len = len(gsa_metadata.gsa_stats[req_id].blocks) - one_topk_len = compute_topk_len(one_block_len) + one_topk_len = ( + gsa_config.compute_topk_len(len(gsa_metadata.gsa_stats[req_id].blocks)) + + gsa_config.num_prefetch_blocks + ) self.topk_bs.append( [ req_id, @@ -190,62 +203,60 @@ def _topk_tmp_deal(self, gsa_metadata, topk_buf_tmp): ) self.topk_buf_tmp = topk_buf_tmp - def deal_async_prefetch( - self, - rank, - gsa_metadata, - ) -> None: - if self.atb_gsa_enable: - if self.ptopk_prefetch_enable: - if self.prefetch_space >= 5: - tmp = self.use_block_table - self.use_block_table = self.m_load_success_list - self.m_load_success_list = tmp - - tmp = self.use_block_table_len - self.use_block_table_len = self.block_table_len - self.block_table_len = tmp - - self._swap_block_table_tensor(self.select_bs_index, gsa_metadata) - self.prefetch_engine_c.set_blocks_table_info( - self.m_load_success_list, - self.block_table_len, - self.prefetch_topk_buf[:, : len(self.select_bs_index), :], - self.step_time, - ) - - topk_len_list = [] - req_id_list = [] - for req_id in self.req_ids_bs: - req_id_list.append(int(req_id)) - if not self.is_gsa_req_id[req_id]: - topk_len_list.append(0) - continue - else: - if gsa_metadata.gsa_stats[req_id].topk_buf_tmp != None: - topk_len_list.append( - len(gsa_metadata.gsa_stats[req_id].topk_buf_tmp[0]) - ) - else: - topk_len_list.append(0) - self.prefetch_engine_c.run_async_prefetch_bs( - req_id_list, topk_len_list, self.select_bs_index, rank - ) - self.is_prefetch_flag = True - self.prefetch_space = 0 + def deal_async_prefetch(self, is_prefetch_done, gsa_metadata, kvcache, store_ptr): + self.topk_space += 1 + all_free_block_ids = None + all_miss_ids = None + if not self.atb_gsa_enable: + return all_free_block_ids, all_miss_ids + if is_prefetch_done and self.ptopk_prefetch_enable and self.is_topk_update: + tmp = self.use_block_table + self.use_block_table = self.m_load_success_list + self.m_load_success_list = tmp + + tmp = self.use_block_table_len + self.use_block_table_len = self.block_table_len + self.block_table_len = tmp + + self._swap_block_table_tensor(self.select_bs_index, gsa_metadata) + self.prefetch_engine_c.set_blocks_table_info( + self.m_load_success_list, + self.block_table_len, + self.prefetch_topk_buf[:, : len(self.select_bs_index), :], + self.step_time, + ) + topk_len_list = [] + req_id_list = [] + for req_id in self.req_ids_bs: + req_id_list.append(req_id) + if not self.is_gsa_req_id[req_id]: + topk_len_list.append(0) + continue else: - self.prefetch_space += 1 - self.num_token += 1 + if gsa_metadata.gsa_stats[req_id].topk_buf_tmp != None: + topk_len_list.append( + len(gsa_metadata.gsa_stats[req_id].topk_buf_tmp[0]) + ) + else: + topk_len_list.append(0) + self.prefetch_engine_c.run_async_prefetch_bs( + req_id_list, topk_len_list, self.select_bs_index, kvcache, store_ptr + ) + self.is_topk_update = False + if self.is_python_load: + all_free_block_ids = self.prefetch_engine_c.obtain_load_blocks() + all_miss_ids = self.prefetch_engine_c.obtain_miss_idxs() + return all_free_block_ids, all_miss_ids - def del_finish_meta(self, del_req) -> None: + def del_finish_meta(self, del_req, flag: bool = True) -> None: if del_req in self.block_map_flag: del self.block_map_flag[del_req] if del_req in self.block_table_flag: del self.block_table_flag[del_req] if del_req in self.is_gsa_req_id: del self.is_gsa_req_id[del_req] - if self.ptopk_prefetch_enable: - self.prefetch_engine_c.del_blocks_map(int(del_req)) + if PTOPK_PREFETCH_ENABLE and flag: + self.prefetch_engine_c.del_blocks_map(del_req) def _init_tensor(self): device = "cpu" @@ -326,56 +337,76 @@ def _init_kpre_and_topk_cache( def _first_topk_deal(self, gsa_metadata) -> None: for index, req_id in enumerate(self.req_ids_bs): - if gsa_metadata.gsa_stats[req_id].remain_idx != None: - bs_index = self.select_bs_index[index] + if gsa_metadata.gsa_stats[req_id].remain_idx == None: + continue + + bs_index = self.select_bs_index[index] + if gsa_metadata.gsa_stats[req_id].reamin_map != None: + topk_block_list_all = [] + prefetch_blocks_list_all = [] + for layer_id in range(self.num_attention_layers): + topk_block_list = sorted( + list( + gsa_metadata.gsa_stats[req_id].reamin_map[layer_id].values() + ) + ) + prefetch_blocks_list = list( + gsa_metadata.gsa_stats[req_id].prefetch_map[layer_id].values() + ) + topk_block_list_all.append(topk_block_list) + prefetch_blocks_list_all.append(prefetch_blocks_list) + topk_block_tensor = torch.tensor( + topk_block_list_all, dtype=torch.int32, device="cpu" + ) + prefetch_block_tensor = torch.tensor( + prefetch_blocks_list_all, dtype=torch.int32 + ) + else: real_length = len(gsa_metadata.gsa_stats[req_id].blocks) block_table_list = self.block_table_list_bs[index][:real_length] remain_index = gsa_metadata.gsa_stats[req_id].remain_idx prefetch_idx = gsa_metadata.gsa_stats[req_id].prefetch_idx assert len(remain_index) < self.sp_max_len - self.prefetch_block_len[:, bs_index] = len(prefetch_idx) - self.block_table_len[:, bs_index] = len(remain_index) - self.use_block_table_len[:, bs_index] = len(remain_index) - prefetch_blocks_list = [block_table_list[x] for x in prefetch_idx] - self.prefetch_blocks[:, bs_index, : len(prefetch_blocks_list)] = ( - torch.tensor(prefetch_blocks_list, dtype=torch.int32) - ) topk_block_list = [block_table_list[x] for x in remain_index] topk_block_tensor = torch.tensor( topk_block_list, dtype=torch.int32, device="cpu" ) + prefetch_block_tensor = torch.tensor( + prefetch_blocks_list, dtype=torch.int32 + ) - if ( - gsa_metadata.gsa_stats[req_id].num_prompt_tokens - <= SEG_PREFILL_THRESHOLD - ): - block_table_list_input = block_table_list + self.prefetch_block_len[:, bs_index] = len(prefetch_blocks_list) + self.block_table_len[:, bs_index] = len(topk_block_list) + self.use_block_table_len[:, bs_index] = len(topk_block_list) + + self.prefetch_blocks[:, bs_index, : len(prefetch_blocks_list)] = ( + prefetch_block_tensor + ) + self.use_block_table[:, bs_index, : len(topk_block_list)] = ( + topk_block_tensor + ) + self.m_load_success_list[:, bs_index, : len(topk_block_list)] = ( + topk_block_tensor + ) + max_idx = len(gsa_metadata.gsa_stats[req_id].block_hashes) + if self.is_gsa_req_id[req_id]: + if gsa_metadata.gsa_stats[req_id].reamin_map != None: + self.prefetch_engine_c.set_blocks_map_multilayer( + req_id, + gsa_metadata.gsa_stats[req_id].reamin_map, + gsa_metadata.gsa_stats[req_id].prefetch_map, + gsa_metadata.gsa_stats[req_id].block_hashes, + max_idx, + ) else: - block_table_list_input = [x for x in block_table_list] - remain_all_len = len(prefetch_idx) + len(remain_index) - block_table_list_input[-1 * LOCAL_WINDOW_SZ :] = block_table_list[ - remain_all_len - LOCAL_WINDOW_SZ : remain_all_len - ] - - block_table_list_input[ - remain_all_len - LOCAL_WINDOW_SZ : real_length - LOCAL_WINDOW_SZ - ] = block_table_list[remain_all_len - real_length :] - - remain_index[-1 * LOCAL_WINDOW_SZ :] = list(range(real_length))[ - -1 * LOCAL_WINDOW_SZ : - ] - input_idxs = prefetch_idx + remain_index - self.use_block_table[:, bs_index, : len(topk_block_list)] = ( - topk_block_tensor - ) - self.m_load_success_list[:, bs_index, : len(topk_block_list)] = ( - topk_block_tensor - ) - if self.is_gsa_req_id[req_id]: self.prefetch_engine_c.set_blocks_map( - int(req_id), block_table_list_input, input_idxs + req_id, + block_table_list, + prefetch_idx + remain_index, + gsa_metadata.gsa_stats[req_id].block_hashes, + max_idx, ) def _gsa_block_len_pre( @@ -427,16 +458,19 @@ def _gsa_block_len_pre( def _topk_insert_last_idx(self, gsa_metadata) -> None: for index in range(len(self.req_ids_bs)): req_id = self.req_ids_bs[index] - if gsa_metadata.gsa_stats[req_id].topk_buf_tmp != None: - last_idx = len(gsa_metadata.gsa_stats[req_id].blocks) - 1 - if last_idx not in gsa_metadata.gsa_stats[req_id].topk_buf_tmp: - gsa_metadata.gsa_stats[req_id].topk_buf_tmp = ( - torch.nn.functional.pad( - gsa_metadata.gsa_stats[req_id].topk_buf_tmp, - (0, 1), - value=last_idx, - ) - ) + if gsa_metadata.gsa_stats[req_id].topk_buf_tmp == None: + continue + + last_idx = len(gsa_metadata.gsa_stats[req_id].blocks) - 1 + + if last_idx in gsa_metadata.gsa_stats[req_id].topk_buf_tmp: + continue + + gsa_metadata.gsa_stats[req_id].topk_buf_tmp = torch.nn.functional.pad( + gsa_metadata.gsa_stats[req_id].topk_buf_tmp, + (0, 1), + value=last_idx, + ) def _swap_block_table_tensor( self, @@ -448,7 +482,7 @@ def _swap_block_table_tensor( if req_id in self.block_map_flag: for block_mp_add in self.block_map_flag[req_id]: self.prefetch_engine_c.add_blocks_map( - int(req_id), block_mp_add[0], block_mp_add[1] + req_id, block_mp_add[0], block_mp_add[1] ) self.block_map_flag[req_id].clear() @@ -462,6 +496,7 @@ def _swap_block_table_tensor( ] = block_table_add self.use_block_table_len[layer_id][bs_index].add_(1) self.block_table_flag[req_id].clear() + if gsa_metadata.gsa_stats[req_id].topk_buf_tmp != None: self.prefetch_topk_buf[ :, index, : len(gsa_metadata.gsa_stats[req_id].topk_buf_tmp[0]) @@ -498,10 +533,10 @@ def _set_req_stat( else: self.is_gsa_req_id[req_id] = False - def _get_max_block_len(self) -> int: + def _get_max_block_len(self, gsa_metadata) -> int: max_len = 0 - for blocks in self.block_table_list_bs: - max_len = max(max_len, len(blocks)) + for req_id in self.req_ids_bs: + max_len = max(max_len, len(gsa_metadata.gsa_stats[req_id].blocks)) return max_len def _no_gsa_input_deal( @@ -523,6 +558,7 @@ def _no_gsa_input_deal( self.gsa_seq_len[:, bs_index] = gsa_metadata.gsa_stats[ req_id ].get_seq_len() + self.use_block_table[:, bs_index, :].fill_(0) self.use_block_table[ :, bs_index, : len(gsa_metadata.gsa_stats[req_id].blocks) ] = one_block_table diff --git a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp index a98e69486..4a9ac7100 100644 --- a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp +++ b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp @@ -1,410 +1,602 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ #include "kvcache_pre.h" #include -#include #include +#include -namespace ucmprefetch +namespace ucmprefetch { +ThreadPool::ThreadPool(size_t threadCount) : stop(false), maxThreads(threadCount) { - ThreadPool::ThreadPool(size_t threadCount) - :stop(false), maxThreads(threadCount) - { - for (size_t i = 0; i < maxThreads; i++) { - workers.emplace_back([this] { - while(true) { - std::function task; - { - std::unique_lock lock(this->queueMutex); - this->condition.wait(lock, [this] { - return this->stop || !this->tasks.empty(); - }); - - if (this->stop && this->tasks.empty()) { - return; - } - - task = std::move(this->tasks.front()); - this->tasks.pop(); - ++activeThreads; - } - - task(); - { - std::unique_lock lock(this->queueMutex); - --activeThreads; - condition.notify_all(); - } + for (size_t i = 0; i < maxThreads; i++) { + workers.emplace_back([this] { + while (true) { + std::function task; + { + std::unique_lock lock(this->queueMutex); + this->condition.wait(lock, + [this] { return this->stop || !this->tasks.empty(); }); + + if (this->stop && this->tasks.empty()) { return; } + + task = std::move(this->tasks.front()); + this->tasks.pop(); + ++activeThreads; } - }); - } + + task(); + { + std::unique_lock lock(this->queueMutex); + --activeThreads; + condition.notify_all(); + } + } + }); } - ThreadPool::~ThreadPool() +} +ThreadPool::~ThreadPool() +{ { - { - std::unique_lock lock(queueMutex); - stop = true; - } - condition.notify_all(); - for (std::thread &worker : workers) { - worker.join(); - } + std::unique_lock lock(queueMutex); + stop = true; } + condition.notify_all(); + for (std::thread& worker : workers) { worker.join(); } +} - template - auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> - { - using return_type = typename std::result_of::type; +template +auto ThreadPool::Enqueue(F&& f, Args&&... args) + -> std::future::type> +{ + using return_type = typename std::result_of::type; - auto task = std::make_shared>( - std::bind(std::forward(f), std::forward(args)...) - ); + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); - std::future res = task->get_future(); - { - std::unique_lock lock(queueMutex); + std::future res = task->get_future(); + { + std::unique_lock lock(queueMutex); - condition.wait(lock, [this] { - if (!(activeThreads < maxThreads || tasks.size() < maxThreads * 2)) { - std::cout << "Need wait: " << activeThreads << " " << tasks.size() << std::endl; - } - return (activeThreads < maxThreads || tasks.size() < maxThreads * 2); - }); - // don't allow enqueueing after stopping the pool - if(stop) { - throw std::runtime_error("enqueue on stopped ThreadPool"); + condition.wait(lock, [this] { + if (!(activeThreads < maxThreads || tasks.size() < maxThreads * 2)) { + std::cout << "Need wait: " << activeThreads << " " << tasks.size() << std::endl; } + return (activeThreads < maxThreads || tasks.size() < maxThreads * 2); + }); + // don't allow enqueueing after stopping the pool + if (stop) { throw std::runtime_error("enqueue on stopped ThreadPool"); } - tasks.emplace([task](){ (*task)(); }); - } - condition.notify_one(); - return res; + tasks.emplace([task]() { (*task)(); }); } + condition.notify_one(); + return res; +} - size_t ThreadPool::GetActiveThreads() const - { - return activeThreads; - } +size_t ThreadPool::GetActiveThreads() const { return activeThreads; } - void MutliBSThreadFun(void *args) - { - GSAPrefetchEngineC *engine = static_cast(args); - int ret = engine->CallPrefetchProcessFun(); - if (ret == 0) { - engine->SetPrefetchStatus(true); - } +void MutliBSThreadFun(void* args) +{ + GSAPrefetchEngineC* engine = static_cast(args); + int ret = engine->CallPrefetchProcessFun(); + engine->mMutex.lock(); + engine->DelReqIDRun(); + engine->mMutex.unlock(); + if (ret == 0) { engine->SetPrefetchStatus(true); } +} + +GSAPrefetchEngineC::GSAPrefetchEngineC(torch::Tensor& freeBlock, torch::Tensor& loadSuccessBlocks, + torch::Tensor& freeBlockLen, torch::Tensor& successTableLen, + std::vector& kvShape, bool useMla, bool isLog, + int tpSize, int rank, int extraTopkLen, bool isPythonLoad) + : mLogger("./log/kvcache_pre_log.txt", LogLevel::INFO, isLog) +{ + mLoadSuccessBlocks = loadSuccessBlocks; + mLayerNum = mLoadSuccessBlocks.sizes()[0]; + mMaxBs = mLoadSuccessBlocks.sizes()[1]; + mMaxTopkLen = mLoadSuccessBlocks.sizes()[2]; + mFreeBlock = freeBlock; + mFreeBlockLen = freeBlockLen; + mSuccessTableLen = successTableLen; + mIsLog = isLog; + mBsIndexList = (int*)malloc(sizeof(int) * mMaxBs); + mTopkLenList = (int*)malloc(sizeof(int) * mMaxBs); + mIsPrefetchDone = true; + mThreadPool = ThreadPool::GetInst(); + mUseMla = useMla; + mHeadSzie = kvShape[2]; + mHeadNum = kvShape[1]; + mBlockSize = kvShape[0]; + mTPSize = tpSize; + mRank = rank; + mIsPythonLoad = isPythonLoad; + if (mRank != 0) { + mLogger.SetLevel(LogLevel::WARNING); + mIsLog = false; + } + mExtraTopkLen = extraTopkLen; + mLogger.log(LogLevel::INFO, + "GSAPrefetchEngineC Init mLayerNum %d mMaxBs %u, mUseMla %d, mHeadSzie %u, mTPSize " + "%u mBlockSize %u mHeadNum %u\n", + mLayerNum, mMaxBs, mUseMla, mHeadSzie, mTPSize, mBlockSize, mHeadNum); +} + +size_t GSAPrefetchEngineC::GetOffset(uint32_t layerID, bool isV) +{ + size_t kMinDataBlockSize = + static_cast(mBlockSize) * mHeadNum * mHeadSzie * mTensorElemSize; + size_t vMinDataBlockSize = kMinDataBlockSize; + size_t layerSize = (kMinDataBlockSize + vMinDataBlockSize) * mTPSize; + if (mUseMla) { + vMinDataBlockSize = 0; + layerSize = kMinDataBlockSize; + } + size_t kOffset = 0; + if (mUseMla) { + kOffset = layerSize * layerID; + } else { + kOffset = layerSize * layerID + layerSize / mTPSize * mRank; } + size_t vOffset = kOffset + kMinDataBlockSize; + if (isV) { + return vOffset; + } else { + return kOffset; + } +} - GSAPrefetchEngineC::GSAPrefetchEngineC(torch::Tensor &freeBlock, - torch::Tensor &loadSuccessBlocks, - torch::Tensor &freeBlockLen, - torch::Tensor &successTableLen, - bool isLog) - :mLogger("./log/kvcache_pre_log.txt", LogLevel::INFO, isLog) - { - mLoadSuccessBlocks = loadSuccessBlocks; - mLayerNum = mLoadSuccessBlocks.sizes()[0]; - mMaxBs = mLoadSuccessBlocks.sizes()[1]; - mMaxTopkLen = mLoadSuccessBlocks.sizes()[2]; - mFreeBlock = freeBlock; - mFreeBlockLen = freeBlockLen; - mSuccessTableLen = successTableLen; - mIsLog = isLog; - mReqIdList = (int *)malloc(sizeof(int) * mMaxBs); - mBsIndexList = (int *)malloc(sizeof(int) * mMaxBs); - mTopkLenList = (int *)malloc(sizeof(int) * mMaxBs); - mIsPrefetchDone = true; - mThreadPool = ThreadPool::GetInst(); +size_t GSAPrefetchEngineC::GetOffsetNew(uint32_t layerID, bool isV) +{ + size_t kMinDataBlockSize = + static_cast(mBlockSize) * mHeadNum * mHeadSzie * mTensorElemSize; + size_t layerSize = kMinDataBlockSize * 2; + size_t kOffset = layerSize * layerID; + if (mUseMla) { + layerSize = kMinDataBlockSize; + kOffset = layerSize * layerID; + return kOffset; } + size_t vOffset = kOffset + kMinDataBlockSize; - void GSAPrefetchEngineC::CheckInputIndex(uint32_t maxLen, uint32_t index) - { - if (index >= maxLen) { - mLogger.log(LogLevel::ERROR, - "Decode step: %u, |KVCache Prefetch| index error! index: %u, maxLen: %u\n", - mDecodeStep, index, maxLen); - std::abort(); - } + if (isV) { + return vOffset; + } else { + return kOffset; } +} - GSAPrefetchEngineC::~GSAPrefetchEngineC() - { - free(mReqIdList); - free(mBsIndexList); - free(mTopkLenList); +void GSAPrefetchEngineC::CheckInputIndex(uint32_t maxLen, uint32_t index) +{ + if (index >= maxLen) { + mLogger.log(LogLevel::ERROR, + "Decode step: %u, |KVCache Prefetch| index error! index: %u, maxLen: %u\n", + mDecodeStep, index, maxLen); + std::abort(); } +} - void GSAPrefetchEngineC::SetBlocksMap(int reqID, std::vector &blockTableList, - std::vector &selectIndex) - { - if (mBlocksMap.find(reqID) != mBlocksMap.end()) { - mBlocksMap[reqID].clear(); - mDocsTables[reqID].clear(); +GSAPrefetchEngineC::~GSAPrefetchEngineC() +{ + free(mBsIndexList); + free(mTopkLenList); +} + +void GSAPrefetchEngineC::SetBlocksMap(std::string reqID, std::vector& blockTableList, + std::vector& selectIndex, + std::vector& blocksHash, int maxIdx) +{ + if (mBlocksMap.find(reqID) != mBlocksMap.end()) { + mBlocksMap[reqID].clear(); + mDocsTables[reqID].clear(); + mAllBlcoksHash[reqID].clear(); + } + mAllBlcoksHash[reqID] = blocksHash; + for (int i = 0; i < mLayerNum; i++) { + std::map oneDocTable; + std::map oneBlockMap; + for (auto idx : selectIndex) { + oneDocTable[idx] = blockTableList[idx]; + oneBlockMap[blockTableList[idx]] = idx; } - for (int i = 0; i < mLayerNum; i++) { + mDocsTables[reqID].push_back(oneDocTable); + mBlocksMap[reqID].push_back(oneBlockMap); + } + mPromptLen[reqID] = maxIdx; + PrintMap(reqID, 0); +} + +void GSAPrefetchEngineC::SetBlocksMapMultiLayer(std::string reqID, + std::vector>& remainMap, + std::vector>& prefetchMap, + std::vector& blocksHash, int maxIdx) +{ + if (mBlocksMap.find(reqID) != mBlocksMap.end()) { + mBlocksMap[reqID].clear(); + mDocsTables[reqID].clear(); + mAllBlcoksHash[reqID].clear(); + } + mAllBlcoksHash[reqID] = blocksHash; + for (int i = 0; i < mLayerNum; i++) { + std::map oneDocTable; + std::map oneBlockMap; + for (auto it = remainMap[i].begin(); it != remainMap[i].end(); it++) { + oneDocTable[it->first] = it->second; + oneBlockMap[it->second] = it->first; + } + for (auto it = prefetchMap[i].begin(); it != prefetchMap[i].end(); it++) { + oneDocTable[it->first] = it->second; + oneBlockMap[it->second] = it->first; + } + mDocsTables[reqID].push_back(oneDocTable); + mBlocksMap[reqID].push_back(oneBlockMap); + } + mPromptLen[reqID] = maxIdx; +} + +void GSAPrefetchEngineC::AddBlocksMap(std::string reqID, int idx, int blockID) +{ + if (mBlocksMap.find(reqID) == mBlocksMap.end()) { + for (int i = 0; i < mLayerNum; ++i) { std::map oneDocTable; std::map oneBlockMap; - for (auto idx:selectIndex) { - oneDocTable[idx] = blockTableList[idx]; - oneBlockMap[blockTableList[idx]] = idx; - } + oneDocTable[idx] = blockID; + oneBlockMap[blockID] = idx; mDocsTables[reqID].push_back(oneDocTable); mBlocksMap[reqID].push_back(oneBlockMap); } + } else { + for (int i = 0; i < mLayerNum; i++) { + mDocsTables[reqID][i][idx] = blockID; + mBlocksMap[reqID][i][blockID] = idx; + } } +} - void GSAPrefetchEngineC::AddBlocksMap(int reqID, int idx, int blockID) - { - if (mBlocksMap.find(reqID) == mBlocksMap.end()) { - for (int i = 0; i < mLayerNum; ++i) { - std::map oneDocTable; - std::map oneBlockMap; - oneDocTable[idx] = blockID; - oneBlockMap[blockID] = idx; - mDocsTables[reqID].push_back(oneDocTable); - mBlocksMap[reqID].push_back(oneBlockMap); - } +void GSAPrefetchEngineC::DelBlocksMap(std::string reqID) +{ + mMutex.lock(); + mDelSeqIds.insert(reqID); + if (mIsPrefetchDone) { DelReqIDRun(); } + mMutex.unlock(); +} + +void GSAPrefetchEngineC::DelReqIDRun() +{ + for (auto it = mDelSeqIds.begin(); it != mDelSeqIds.end(); it++) { + if (mBlocksMap.find(*it) == mBlocksMap.end()) { + continue; } else { - for (int i = 0; i < mLayerNum; i++) { - mDocsTables[reqID][i][idx] = blockID; - mBlocksMap[reqID][i][blockID] = idx; - } + mBlocksMap.erase(*it); + mDocsTables.erase(*it); + mAllBlcoksHash.erase(*it); + mPromptLen.erase(*it); + std::cout << "Del reqID: " << *it << std::endl; } - } - - void GSAPrefetchEngineC::DelBlocksMap(int reqID) - { - if (mBlocksMap.find(reqID) == mBlocksMap.end()) { - return; + if (mPromptLen.find(*it) == mPromptLen.end()) { + continue; } else { - mBlocksMap.erase(reqID); - mDocsTables.erase(reqID); - std::cout << "Del reqID: " << reqID << std::endl; + mPromptLen.erase(*it); } } + mDelSeqIds.clear(); +} - void GSAPrefetchEngineC::PrintMap(int reqID, int i) - { - std::ostringstream oss; - oss << "Decode step: " << mDecodeStep << " Rnak: " << mRank << " reqID: " - << reqID << " layerID: " << i << "mDocsTables"; - for (auto it : mDocsTables[reqID][i]) { - oss << "(" << it.first << ", " << it.second << ")"; +void GSAPrefetchEngineC::PrintMap(std::string reqID, int i) +{ + std::ostringstream oss; + oss << "Decode step: " << mDecodeStep << " Rnak: " << mRank << " reqID: " << reqID + << " layerID: " << i << "mDocsTables"; + for (auto it : mDocsTables[reqID][i]) { oss << "(" << it.first << ", " << it.second << ")"; } + oss << "------\n"; + mLogger.log(LogLevel::INFO, oss.str().c_str()); + oss.str(""); + oss << "Decode step: " << mDecodeStep << " Rnak: " << mRank << " reqID: " << reqID + << " layerID: " << i << "mBlocksMap"; + for (auto it : mBlocksMap[reqID][i]) { oss << "(" << it.first << ", " << it.second << ")"; } + oss << "------\n"; + mLogger.log(LogLevel::INFO, oss.str().c_str()); + oss.str(""); +} + +void GSAPrefetchEngineC::GetHitAndMissBlock(PrefetchReqInfo oneBsInfo, + std::unordered_set& hitBlocks, + std::map& hitBlocksIdx, + std::vector& missIdxs) +{ + int topkLen = oneBsInfo.topkLen; + int layerID = oneBsInfo.layerID; + std::string reqID = oneBsInfo.reqID; + int topkIndex = oneBsInfo.topkIndex; + + std::ostringstream oss; + oss << "Decode step: " << mDecodeStep << " Rnak: " << mRank << " reqID: " << reqID + << " layerID: " << layerID << " topk len: " << topkLen << " topk: "; + for (int j = 0; j < topkLen; j++) { + int64_t item = 0; + if (mUseTopkIdxs.scalar_type() == torch::kInt32) { + item = mUseTopkIdxs[layerID][topkIndex][j].item(); + } else { + item = mUseTopkIdxs[layerID][topkIndex][j].item(); } - oss << "------\n"; - mLogger.log(LogLevel::INFO, oss.str().c_str()); - oss.str(""); - oss << "Decode step: " << mDecodeStep << " Rnak: " << mRank << " reqID: " - << reqID << " layerID: " << i << "mBlocksMap"; - for (auto it : mBlocksMap[reqID][i]) { - oss << "(" << it.first << ", " << it.second << ")"; + oss << item << " "; + if (mDocsTables[reqID][layerID].find(item) != mDocsTables[reqID][layerID].end()) { + int blockID = mDocsTables[reqID][layerID][item]; + hitBlocks.insert(blockID); + hitBlocksIdx.insert(std::make_pair(item, blockID)); + if (hitBlocks.size() == (topkLen - mExtraTopkLen)) { break; } + } else { + missIdxs.push_back(item); } - oss << "------\n"; - mLogger.log(LogLevel::INFO, oss.str().c_str()); - oss.str(""); } + oss << "------\n"; + mLogger.log(LogLevel::DEBUG, oss.str().c_str()); + oss.str(""); + if ((hitBlocks.size() + missIdxs.size()) != (uint32_t)topkLen && + hitBlocks.size() != (topkLen - mExtraTopkLen)) { + mLogger.log(LogLevel::ERROR, + "|KVCache Prefetch| Decode step: %u, Rank: %d, reqID: %s, layer: %d, hit size: " + "%lu, miss size: %lu , topkLen: %d, not equal error\n", + mDecodeStep, mRank, reqID, layerID, hitBlocks.size(), missIdxs.size(), topkLen); + PrintMap(reqID, layerID); + } +} - void GSAPrefetchEngineC::GetHitAndMissBlock(PrefetchReqInfo oneBsInfo, - std::unordered_set &hitBlocks, - std::map &hitBlocksIdx, - std::vector &missIdxs) - { - int topkLen = oneBsInfo.topkLen; - int layerID = oneBsInfo.layerID; - int reqID = oneBsInfo.reqID; - int topkIndex = oneBsInfo.topkIndex; - - for (int j = 0; j < topkLen; j++) { - int64_t item = 0; - if (mUseTopkIdxs.scalar_type() == torch::kInt32) { - item = mUseTopkIdxs[layerID][topkIndex][j].item(); - } else { - item = mUseTopkIdxs[layerID][topkIndex][j].item(); - } - if (mDocsTables[reqID][layerID].find(item) != mDocsTables[reqID][layerID].end()) { - int blockID = mDocsTables[reqID][layerID][item]; - hitBlocks.insert(blockID); - hitBlocksIdx.insert(std::make_pair(item, blockID)); - } else { - missIdxs.push_back(item); - } - } - if ((hitBlocks.size() + missIdxs.size()) != (uint32_t)topkLen) { - mLogger.log(LogLevel::ERROR, - "|KVCache Prefetch| Decode step: %u, Rank: %d, reqID: %d, layer: %d, not equal error\n", - mDecodeStep, mRank, reqID, layerID); - PrintMap(reqID, layerID); +void GSAPrefetchEngineC::RunPrefetchH2D(PrefetchReqInfo oneBsInfo, + std::unordered_set& hitBlocks, + std::map& hitBlocksIdx, + std::vector& missIdxs) +{ + int layerID = oneBsInfo.layerID; + std::string reqID = oneBsInfo.reqID; + uint32_t topkLen = oneBsInfo.topkLen; + int bsIndex = oneBsInfo.bsIndex; + + int oneFreeBlockLen = mFreeBlockLen[layerID][bsIndex].item(); + int* freeBlockPtr = mFreeBlock[layerID][bsIndex].data_ptr(); + std::vector oneFreeBlockTable; + + uint32_t index = 0; + int oneFreeBlockIndex = 0; + while (oneFreeBlockIndex < oneFreeBlockLen && index < missIdxs.size() && + hitBlocks.size() < (topkLen - mExtraTopkLen)) { + int oneFreeBlockID = freeBlockPtr[oneFreeBlockIndex]; + if (hitBlocks.find(oneFreeBlockID) != hitBlocks.end()) { + oneFreeBlockIndex += 1; + continue; + } else { + oneFreeBlockTable.push_back(oneFreeBlockID); + hitBlocks.insert(oneFreeBlockID); + hitBlocksIdx.insert(std::make_pair(missIdxs[index], oneFreeBlockID)); + index += 1; + oneFreeBlockIndex += 1; } - } - - void GSAPrefetchEngineC::RunPrefetchH2D(PrefetchReqInfo oneBsInfo, - std::unordered_set &hitBlocks, - std::map &hitBlocksIdx, - std::vector &missIdxs) - { - int layerID = oneBsInfo.layerID; - int reqID = oneBsInfo.reqID; - int topkIndex = oneBsInfo.topkIndex; - int bsIndex = oneBsInfo.bsIndex; - - int oneFreeBlockLen = mFreeBlockLen[layerID][bsIndex].item(); - int *freeBlockPtr = mFreeBlock[layerID][bsIndex].data_ptr(); - std::vector oneFreeBlockTable; - - uint32_t index = 0; + uint32_t loadLen = oneFreeBlockTable.size(); + missIdxs.erase(missIdxs.begin() + loadLen, missIdxs.end()); + allNeedLoadBlock[reqID][layerID] = oneFreeBlockTable; + allMissIdxs[reqID][layerID] = missIdxs; + LoadKVToHBM(oneFreeBlockTable, missIdxs, layerID, reqID); +} + +void GSAPrefetchEngineC::RunOneBsPrefetch(std::string reqID, int topkLen, int bsIndex, + int topkIndex) +{ +#pragma omp parallel for num_threads(16) proc_bind(master) + for (int i = 0; i < mLayerNum; i++) { + mLoadSuccessBlocks[i][bsIndex].fill_(0); + int* freeBlockPtr = mFreeBlock[i][bsIndex].data_ptr(); + std::unordered_set hitBlocks; + std::map hitBlocksIdx; + std::vector missIdxs; + PrefetchReqInfo oneBsInfo; + oneBsInfo.topkLen = topkLen; + oneBsInfo.reqID = reqID; + oneBsInfo.topkIndex = topkIndex; + oneBsInfo.bsIndex = bsIndex; + oneBsInfo.layerID = i; + GetHitAndMissBlock(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs); + if (missIdxs.size() != 0 && hitBlocksIdx.size() < (topkLen - mExtraTopkLen)) { + RunPrefetchH2D(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs); + } + int successIndex = 0; + for (auto it = hitBlocksIdx.begin(); it != hitBlocksIdx.end(); it++) { + mLoadSuccessBlocks[i][bsIndex][successIndex] = it->second; + successIndex += 1; + } int oneFreeBlockIndex = 0; - while(oneFreeBlockIndex < oneFreeBlockLen && index < missIdxs.size()) { - int oneFreeBlockID = freeBlockPtr[oneFreeBlockIndex]; - if (hitBlocks.find(oneFreeBlockID) != hitBlocks.end()) { - oneFreeBlockIndex += 1; + for (auto it = mDocsTables[reqID][i].begin(); it != mDocsTables[reqID][i].end(); it++) { + if (it->first >= mPromptLen[reqID]) { break; } + if (hitBlocksIdx.find(it->first) != hitBlocksIdx.end()) { continue; } else { - oneFreeBlockTable.push_back(oneFreeBlockID); - hitBlocks.insert(oneFreeBlockID); - hitBlocksIdx.insert(std::make_pair(missIdxs[index], oneFreeBlockID)); - index += 1; + freeBlockPtr[oneFreeBlockIndex] = it->second; oneFreeBlockIndex += 1; } } - allNeedLoadBlock[topkIndex][layerID] = oneFreeBlockTable; - allMissIdxs[topkIndex][layerID] = missIdxs; - LoadKVToHBM(oneFreeBlockTable, missIdxs, layerID, reqID); + mFreeBlockLen[i][bsIndex] = oneFreeBlockIndex; + mSuccessTableLen[i][bsIndex] = (int)(hitBlocks.size()); } +} - void GSAPrefetchEngineC::RunOneBsPrefetch(int reqID, - int topkLen, int bsIndex, int topkIndex) - { -#pragma omp parallel for num_threads(16) proc_bind(master) - for (int i = 0; i < mLayerNum; i++) { - mLoadSuccessBlocks[i][bsIndex].fill_(0); - int *freeBlockPtr = mFreeBlock[i][bsIndex].data_ptr(); - std::unordered_set hitBlocks; - std::map hitBlocksIdx; - std::vector missIdxs; - PrefetchReqInfo oneBsInfo; - oneBsInfo.topkLen = topkLen; - oneBsInfo.reqID = reqID; - oneBsInfo.topkIndex = topkIndex; - oneBsInfo.bsIndex = bsIndex; - oneBsInfo.layerID = i; - GetHitAndMissBlock(oneBsInfo, hitBlocks,hitBlocksIdx, missIdxs); - if (missIdxs.size() != 0) { - RunPrefetchH2D(oneBsInfo, hitBlocks,hitBlocksIdx, missIdxs); +void GSAPrefetchEngineC::LoadKVToHBM(std::vector loadNPUBlockIDs, std::vector missIdxs, + int layerID, std::string reqID) +{ + for (size_t i = 0; i < loadNPUBlockIDs.size(); i++) { + if (!mIsPythonLoad) { + if (mDelSeqIds.find(reqID) != mDelSeqIds.end()) { + mLogger.log(LogLevel::INFO, + "Decode step: %u, Rank: %d, reqID: %s, layer: %d, stop prefetch\n", + mDecodeStep, mRank, reqID.c_str(), layerID); + return; } - int successIndex = 0; - for (auto it = hitBlocksIdx.begin(); it != hitBlocksIdx.end(); it++) { - mLoadSuccessBlocks[i][bsIndex][successIndex] = it->second; - successIndex += 1; + while (mStopPrefetch) { std::this_thread::sleep_for(std::chrono::microseconds(2)); } + UC::Task task{UC::Task::Type::LOAD, UC::Task::Location::DEVICE, "NFS::S2D"}; + std::string blockId = mAllBlcoksHash[reqID][missIdxs[i]]; + size_t kOffset = GetOffsetNew(layerID, false); + size_t vOffset = GetOffsetNew(layerID, true); + if (!mUseMla) { + task.Append(blockId, kOffset, + reinterpret_cast( + mKvCaches[layerID][0][loadNPUBlockIDs[i]].data_ptr()), + mKVSzieBytes); + task.Append(blockId, vOffset, + reinterpret_cast( + mKvCaches[layerID][1][loadNPUBlockIDs[i]].data_ptr()), + mKVSzieBytes); + } else { + task.Append( + blockId, kOffset, + reinterpret_cast(mKvCaches[layerID][loadNPUBlockIDs[i]].data_ptr()), + mKVSzieBytes); } - int oneFreeBlockIndex = 0; - for (auto it = mBlocksMap[reqID][i].begin(); it != mBlocksMap[reqID][i].end(); it++) { - if (hitBlocks.find(it->first) != hitBlocks.end()) { - continue; - } else { - freeBlockPtr[oneFreeBlockIndex] = it->first; - oneFreeBlockIndex += 1; - } + size_t taskID = mStore->Submit(std::move(task)); + auto ret = mStore->Wait(taskID); + if (ret != 0) { + mLogger.log(LogLevel::ERROR, + "Decode step: %u, Rank: %d, reqID: %s, layer: %d, blockID: %lu, miss " + "idx: %u, load blockid: %u load k error\n", + mDecodeStep, mRank, reqID.c_str(), layerID, blockId, missIdxs[i], + loadNPUBlockIDs[i]); + return; } - mFreeBlockLen[i][bsIndex] = oneFreeBlockIndex; - mSuccessTableLen[i][bsIndex] = (int)(hitBlocks.size()); } - } - void GSAPrefetchEngineC::LoadKVToHBM(std::vector loadNPUBlockIDs, - std::vector missIdxs, int layerID, int reqID) - { - for (size_t i = 0; i < loadNPUBlockIDs.size(); i++) { - int oriIdx = mBlocksMap[reqID][layerID][loadNPUBlockIDs[i]]; - mBlocksMap[reqID][layerID][loadNPUBlockIDs[i]] = missIdxs[i]; - mDocsTables[reqID][layerID].erase(oriIdx); - mDocsTables[reqID][layerID][missIdxs[i]] = loadNPUBlockIDs[i]; - } + int oriIdx = mBlocksMap[reqID][layerID][loadNPUBlockIDs[i]]; + mBlocksMap[reqID][layerID][loadNPUBlockIDs[i]] = missIdxs[i]; + mDocsTables[reqID][layerID].erase(oriIdx); + mDocsTables[reqID][layerID][missIdxs[i]] = loadNPUBlockIDs[i]; } +} - void GSAPrefetchEngineC::RunAsyncPrefetchBs(std::vector &reqIDsInput, - std::vector &topkLensInput, - std::vector &bsIndexInput, int rank) - { - if (mRank == -1) { - mRank = rank; - } - if(mRank != 0) { - mLogger.SetLevel(LogLevel::WARNING); - mIsLog = false; +void GSAPrefetchEngineC::RunAsyncPrefetchBs(std::vector& reqIDsInput, + std::vector& topkLensInput, + std::vector& bsIndexInput, + std::vector& kvCaches, void* storePtr) +{ + if (mKVSzieBytes == 0) { + mTensorElemSize = kvCaches[0].element_size(); + if (mUseMla) { + mKVSzieBytes = kvCaches[0].element_size() * kvCaches[0][0].numel(); + } else { + mKVSzieBytes = kvCaches[0].element_size() * kvCaches[0][0][0].numel(); } - mLogger.log(LogLevel::INFO, - "Decode step: %u, |KVCache Prefetch| start async pretch batch size: %lu\n", - mDecodeStep, reqIDsInput.size()); - runBsLen = reqIDsInput.size(); - if (runBsLen > mMaxBs) { + if (storePtr == nullptr) { mLogger.log(LogLevel::ERROR, - "Decode step: %u, |KVCache Prefetch| runBsLen %u, maxBs: %d\n", - mDecodeStep, runBsLen, mMaxBs); + "Decode step: %u, |KVCache Prefetch| storePtr is nullptr error\n", + mDecodeStep); std::abort(); } - memcpy(mReqIdList, reqIDsInput.data(), sizeof(int) * runBsLen); - memcpy(mTopkLenList, topkLensInput.data(), sizeof(int) * runBsLen); - memcpy(mBsIndexList, bsIndexInput.data(), sizeof(int) * runBsLen); - mIsPrefetchDone = false; - mThreadPool->enqueue(MutliBSThreadFun, this); + mStore = static_cast*>(storePtr); + mLogger.log(LogLevel::INFO, + "Decode step: %u, |KVCache Prefetch| start mKVSzieBytes: %u, mTensorElemSize " + "%u, store %p\n", + mDecodeStep, mKVSzieBytes, mTensorElemSize, mStore); } - - void GSAPrefetchEngineC::SetBlockTableInfo(torch::Tensor &blockTables, torch::Tensor &blockLengths, - torch::Tensor &inputTopkBuf, int step) - { - mLoadSuccessBlocks = blockTables; - mSuccessTableLen = blockLengths; - mUseTopkIdxs = inputTopkBuf.clone(); - mDecodeStep = step; + mKvCaches = kvCaches; + mLogger.log(LogLevel::INFO, + "Decode step: %u, |KVCache Prefetch| start async pretch batch size: %lu\n", + mDecodeStep, reqIDsInput.size()); + runBsLen = reqIDsInput.size(); + if (runBsLen > mMaxBs) { + mLogger.log(LogLevel::ERROR, "Decode step: %u, |KVCache Prefetch| runBsLen %u, maxBs: %d\n", + mDecodeStep, runBsLen, mMaxBs); + std::abort(); } + mReqIdList.clear(); + mReqIdList.assign(reqIDsInput.begin(), reqIDsInput.end()); + memcpy(mTopkLenList, topkLensInput.data(), sizeof(int) * runBsLen); + memcpy(mBsIndexList, bsIndexInput.data(), sizeof(int) * runBsLen); + mMutex.lock(); + mIsPrefetchDone = false; + mMutex.unlock(); + if (mIsPythonLoad) { + MutliBSThreadFun(this); + } else { + mThreadPool->Enqueue(MutliBSThreadFun, this); + } +} +void GSAPrefetchEngineC::SetBlockTableInfo(torch::Tensor& blockTables, torch::Tensor& blockLengths, + torch::Tensor& inputTopkBuf, int step) +{ + mLoadSuccessBlocks = blockTables; + mSuccessTableLen = blockLengths; + mUseTopkIdxs = inputTopkBuf.clone(); + mDecodeStep = step; +} - int GSAPrefetchEngineC::CallPrefetchProcessFun() - { - auto start = std::chrono::high_resolution_clock::now(); - allNeedLoadBlock.clear(); - allNeedLoadBlock.resize(runBsLen, std::vector>(mLayerNum)); - allMissIdxs.clear(); - allMissIdxs.resize(runBsLen, std::vector>(mLayerNum)); - for (size_t i = 0; i < runBsLen; i++) { - if (mDocsTables.find(mReqIdList[i]) == mDocsTables.end() || mTopkLenList[i] <= 0) { - mLogger.log(LogLevel::ERROR, - "Decode step: %u, |KVCache Prefetch| topk len is zero: %d\n", - mDecodeStep, mTopkLenList[i]); - continue; - } - RunOneBsPrefetch(mReqIdList[i], mTopkLenList[i], mBsIndexList[i], i); +int GSAPrefetchEngineC::CallPrefetchProcessFun() +{ + auto start = std::chrono::high_resolution_clock::now(); + allNeedLoadBlock.clear(); + allMissIdxs.clear(); + for (size_t i = 0; i < runBsLen; i++) { + if (mDocsTables.find(mReqIdList[i]) == mDocsTables.end() || mTopkLenList[i] <= 0) { + mLogger.log(LogLevel::ERROR, + "Decode step: %u, |KVCache Prefetch| topk len is zero: %d\n", mDecodeStep, + mTopkLenList[i]); + continue; } - auto end = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end - start); - mLogger.log(LogLevel::INFO, - "Decode step: %u, |KVCache Prefetch| Finish async pretch cost: %lu\n", - mDecodeStep, duration.count()); - return 0; + allMissIdxs.insert({mReqIdList[i], std::vector>(mLayerNum)}); + allNeedLoadBlock.insert({mReqIdList[i], std::vector>(mLayerNum)}); + RunOneBsPrefetch(mReqIdList[i], mTopkLenList[i], mBsIndexList[i], i); } + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start); + mLogger.log(LogLevel::INFO, + "Decode step: %u, |KVCache Prefetch| Finish async pretch cost: %lu\n", mDecodeStep, + duration.count()); + return 0; +} - bool GSAPrefetchEngineC::GetPrefetchStatus() - { - return mIsPrefetchDone; - } +bool GSAPrefetchEngineC::GetPrefetchStatus() { return mIsPrefetchDone; } - void GSAPrefetchEngineC::SetPrefetchStatus(bool flag) - { - mIsPrefetchDone = flag; - } +void GSAPrefetchEngineC::SetPrefetchStatus(bool flag) +{ + mMutex.lock(); + mIsPrefetchDone = flag; + mMutex.unlock(); +} - std::vector>> GSAPrefetchEngineC::ObtainLoadBlocks() - { - return allNeedLoadBlock; - } +void GSAPrefetchEngineC::SetModelRunningStatus(bool flag) { mStopPrefetch = flag; } - std::vector>> GSAPrefetchEngineC::ObtainMissIdxs() - { - return allMissIdxs; - } +std::map>> GSAPrefetchEngineC::ObtainLoadBlocks() +{ + return allNeedLoadBlock; +} - std::map>> GSAPrefetchEngineC::ObtainBlocksMap() - { - return mBlocksMap; - } -} // namespace uc +std::map>> GSAPrefetchEngineC::ObtainMissIdxs() +{ + return allMissIdxs; +} + +std::map>> GSAPrefetchEngineC::ObtainBlocksMap() +{ + return mBlocksMap; +} + +std::map>> GSAPrefetchEngineC::ObtainDocsMap() +{ + return mDocsTables; +} +} // namespace ucmprefetch diff --git a/ucm/sparse/gsa/prefetch/src/pybinds.cpp b/ucm/sparse/gsa/prefetch/src/pybinds.cpp index b6d0e6260..25a1f5d5f 100644 --- a/ucm/sparse/gsa/prefetch/src/pybinds.cpp +++ b/ucm/sparse/gsa/prefetch/src/pybinds.cpp @@ -1,26 +1,29 @@ #pragma GCC diagnostic push -#include -#include #include +#include #include +#include #pragma GCC diagnostic pop #include "kvcache_pre.h" -namespace ucmprefetch{ - PYBIND11_MODULE(gsa_prefetch, m) - { - pybind11::class_(m, "GSAPrefetchEngineC") - .def(pybind11::init()) - .def("set_blocks_map", &ucmprefetch::GSAPrefetchEngineC::SetBlocksMap) - .def("add_blocks_map", &ucmprefetch::GSAPrefetchEngineC::AddBlocksMap) - .def("del_blocks_map", &ucmprefetch::GSAPrefetchEngineC::DelBlocksMap) - .def("run_async_prefetch_bs", &ucmprefetch::GSAPrefetchEngineC::RunAsyncPrefetchBs) - .def("set_blocks_table_info", &ucmprefetch::GSAPrefetchEngineC::SetBlockTableInfo) - .def("get_prefetch_status", &ucmprefetch::GSAPrefetchEngineC::GetPrefetchStatus) - .def("set_prefetch_status", &ucmprefetch::GSAPrefetchEngineC::SetPrefetchStatus) - .def("obtain_load_blocks", &ucmprefetch::GSAPrefetchEngineC::ObtainLoadBlocks) - .def("obtain_miss_idxs", &ucmprefetch::GSAPrefetchEngineC::ObtainMissIdxs) - .def("obtain_blocks_map", &ucmprefetch::GSAPrefetchEngineC::ObtainBlocksMap); - } +namespace ucmprefetch { +PYBIND11_MODULE(gsa_prefetch, m) +{ + pybind11::class_(m, "GSAPrefetchEngineC") + .def(pybind11::init&, bool, bool, int, int, int, bool>()) + .def("set_blocks_map", &ucmprefetch::GSAPrefetchEngineC::SetBlocksMap) + .def("set_blocks_map_multilayer", &ucmprefetch::GSAPrefetchEngineC::SetBlocksMapMultiLayer) + .def("add_blocks_map", &ucmprefetch::GSAPrefetchEngineC::AddBlocksMap) + .def("del_blocks_map", &ucmprefetch::GSAPrefetchEngineC::DelBlocksMap) + .def("run_async_prefetch_bs", &ucmprefetch::GSAPrefetchEngineC::RunAsyncPrefetchBs) + .def("set_blocks_table_info", &ucmprefetch::GSAPrefetchEngineC::SetBlockTableInfo) + .def("get_prefetch_status", &ucmprefetch::GSAPrefetchEngineC::GetPrefetchStatus) + .def("set_prefetch_status", &ucmprefetch::GSAPrefetchEngineC::SetPrefetchStatus) + .def("set_modelrunning_status", &ucmprefetch::GSAPrefetchEngineC::SetModelRunningStatus) + .def("obtain_load_blocks", &ucmprefetch::GSAPrefetchEngineC::ObtainLoadBlocks) + .def("obtain_miss_idxs", &ucmprefetch::GSAPrefetchEngineC::ObtainMissIdxs) + .def("obtain_docs_map", &ucmprefetch::GSAPrefetchEngineC::ObtainDocsMap) + .def("obtain_blocks_map", &ucmprefetch::GSAPrefetchEngineC::ObtainBlocksMap); } +} // namespace ucmprefetch diff --git a/ucm/sparse/kvcomp/CMakeLists.txt b/ucm/sparse/kvcomp/CMakeLists.txt index e69de29bb..fb32a3ec8 100644 --- a/ucm/sparse/kvcomp/CMakeLists.txt +++ b/ucm/sparse/kvcomp/CMakeLists.txt @@ -0,0 +1,50 @@ +if(BUILD_NUMA) + message(STATUS "Building numactl library...") + + set(NUMA_INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/numa_install) + FetchContent_Declare( + numactl + URL https://github.com/numactl/numactl/releases/download/v2.0.16/numactl-2.0.16.tar.gz + TLS_VERIFY OFF + ) + FetchContent_MakeAvailable(numactl) + if(NOT EXISTS "${NUMA_INSTALL_DIR}/lib/libnuma.so") + message(STATUS "Configuring numactl...") + execute_process( + COMMAND ./configure --prefix=${NUMA_INSTALL_DIR} + WORKING_DIRECTORY ${numactl_SOURCE_DIR} + RESULT_VARIABLE numa_configure_result + OUTPUT_VARIABLE numa_configure_output + ERROR_VARIABLE numa_configure_error + ) + if(NOT numa_configure_result EQUAL 0) + message(FATAL_ERROR "Failed to configure numactl. \n" + "Result: ${numa_configure_result}\n" + "STDOUT: ${numa_configure_output}\n" + "STDERR: ${numa_configure_error}\n") + endif() + + message(STATUS "Building and installing numactl...") + execute_process( + COMMAND make install -j8 + WORKING_DIRECTORY ${numactl_SOURCE_DIR} + RESULT_VARIABLE numa_install_result + OUTPUT_VARIABLE numa_install_output + ERROR_VARIABLE numa_install_error + ) + if(NOT numa_install_result EQUAL 0) + message(FATAL_ERROR "Failed to build and install numactl. \n" + "Result: ${numa_install_result}\n" + "STDOUT: ${numa_install_output}\n" + "STDERR: ${numa_install_error}\n") + endif() + else() + message(STATUS "Found already built libnuma. Skipping build.") + endif() + + add_definitions(-DNUMA_ENABLED) +else() + message(STATUS "Skipping numactl build...") +endif() + +add_subdirectory(hash_retrieval) diff --git a/ucm/sparse/kvcomp/README.md b/ucm/sparse/kvcomp/README.md index b010e7d91..76283551c 100644 --- a/ucm/sparse/kvcomp/README.md +++ b/ucm/sparse/kvcomp/README.md @@ -97,6 +97,7 @@ This design ensures both **efficiency** and **accuracy** by preserving essential KVComp is part of the UCM Sparse Attention module. For installation instructions, please refer to the [UCM's top-level README](../../../../README.md). Once UCM is installed, KVComp is naturally supported by running the following example python scripts. ```bash +export ENABLE_SPARSE=TRUE python ucm/sandbox/sparse/kvcomp/offline_inference_kvcomp.py ``` diff --git a/ucm/sparse/kvcomp/configs/kvcomp_deepseek_v2_lite_config.json b/ucm/sparse/kvcomp/configs/kvcomp_deepseek_v2_lite_config.json new file mode 100644 index 000000000..0ada4477d --- /dev/null +++ b/ucm/sparse/kvcomp/configs/kvcomp_deepseek_v2_lite_config.json @@ -0,0 +1,81 @@ +{ + "model_name": "DeepSeek/DeepSeek-V2-Lite-Chat", + "is_mla": true, + "hash_weight_type": "random", + "num_hidden_layers": 27, + "seq_len_threshhold": 2048, + "chunk_size": 128, + "chunk_repre_method": "max", + "head_dim": 576, + "hash_bits": 128, + "top_k_ratio_per_layer": [ + 1, + 1, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 1, + 1, + 1 + ], + "top_k_index_reuse": [ + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1 + ], + "must_select_blocks": [ + 0, + -2, + -1 + ], + "hash_weight": null, + "kv_lora_rank": 512, + "qk_rope_head_dim": 64, + "hash_bits_kv_lora": 512, + "hash_bits_qk_rope": 64, + "hash_weight_kv_lora": null, + "hash_weight_qk_rope": null +} \ No newline at end of file diff --git a/ucm/sparse/kvcomp/hash_encoder.py b/ucm/sparse/kvcomp/hash_encoder.py index 798704544..7546aa71e 100644 --- a/ucm/sparse/kvcomp/hash_encoder.py +++ b/ucm/sparse/kvcomp/hash_encoder.py @@ -31,6 +31,124 @@ logger = init_logger(__name__) +if hasattr(torch, "cuda") and torch.cuda.is_available(): + from vllm.triton_utils import tl, triton + + @triton.jit + def triton_hash_code_kernel( + x_ptr, + code_ptr, + pack_w_ptr, + hash_out_ptr, + M, + K, + N, + stride_xm, + stride_xk, + stride_codek, + stride_coden, + stride_pack_w, + stride_om, + stride_on, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # sample dimension + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # hash_rbits dimension + offs_k = tl.arange(0, BLOCK_K) # input_dim dimension + + # Matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x = tl.load( + x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk, + mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), + other=0.0, + ) + code = tl.load( + code_ptr + + offs_k[:, None] * stride_codek + + offs_n[None, :] * stride_coden, + mask=(offs_k[:, None] < K) & (offs_n[None, :] < N), + other=0.0, + ) + acc += tl.dot(x, code) + offs_k += BLOCK_K + + # Binarize and pack + bits = (acc > 0).to(tl.uint8) # Binarize + bits = tl.reshape(bits, (BLOCK_M, BLOCK_N // 8, 8)) # Reshape for packing + + # Load the packing weights (ensure it has the correct shape) + pack_w = tl.load(pack_w_ptr + tl.arange(0, 8) * stride_pack_w) + packed = tl.sum(bits * pack_w[None, None, :], axis=-1).to(tl.uint8) + + # Store results + offs_n = pid_n * (BLOCK_N // 8) + tl.arange(0, BLOCK_N // 8) + hash_out_ptrs = ( + hash_out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + ) + tl.store( + hash_out_ptrs, + packed, + mask=(offs_m[:, None] < M) & (offs_n[None, :] < (N // 8)), + ) + + def triton_hash_code(x, code, pack_weight): + input_dim = x.shape[-1] + samples = x.shape[0] + hash_bits = code.shape[-1] + assert (pack_weight.shape[0] == 8) and (hash_bits % 8 == 0) + hash_out = torch.empty( + (samples, hash_bits // 8), dtype=pack_weight.dtype, device=x.device + ) + + grid = lambda opts: ( + triton.cdiv(samples, opts["BLOCK_M"]), + triton.cdiv(input_dim, opts["BLOCK_N"]), + ) + + triton_hash_code_kernel[grid]( + x, + code, + pack_weight, + hash_out, + samples, + input_dim, + hash_bits, + x.stride(0), + x.stride(1), + code.stride(0), + code.stride(1), + pack_weight.stride(0), + hash_out.stride(0), + hash_out.stride(1), + BLOCK_M=32, + BLOCK_K=64, + BLOCK_N=16, + ) + + return hash_out.view(-1) # [samples * hash_numbers] + + +@torch.compile() +def torch_hash_code(x, code, pack_weight): + # [N, hash_bits] + x = x @ code + m = x.shape[:-1] + # [N, hash_bits] -- > [N, hash_bits // 8, 8] + x = (x > 0).to(torch.uint8).view(*m, -1, 8) + # 8bit -> 1bit + # binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8] + # then sum along the last dimension to get [N, hash_numbers] + x = torch.sum(x * pack_weight, dim=-1, dtype=torch.uint8) + x = x.view(-1) # [N * hash_numbers] + return x + class HashEncoder: """ @@ -62,13 +180,25 @@ def __init__( logger.warning("automatically using float16 for hash_weights now") self.dtype = torch.float16 - self.hash_weights = torch.normal( + if self.device.type == "cuda" and dtype == torch.bfloat16: + logger.warning("geqrf_cuda not implemented for BFloat16") + logger.warning("automatically using float32 for hash_weights now") + self.dtype = torch.float32 + + # Step 1: 随机高斯矩阵 + random_weights = torch.normal( mean=0, std=2, size=(self.input_dim, self.hash_bits), dtype=self.dtype, device=self.device, ) + # Step 2: QR分解 + Q, R = torch.linalg.qr(random_weights) + + # Step 3: 调整符号,保证Haar 分布 + d = torch.sign(torch.diag(R)) + self.hash_weights = Q * d if self.device.type == "cuda" or self.device.type == "cpu": self._init_bit_masks() @@ -93,8 +223,6 @@ def _init_bit_masks(self) -> None: self.bit_masks = torch.pow( 2, torch.arange(8, dtype=torch.uint8, device=self.device) ) - # shape (1, 1, 8) - self.bit_masks = self.bit_masks.unsqueeze(0).unsqueeze(0) def compute_hash(self, x: torch.Tensor) -> torch.Tensor: """ @@ -124,29 +252,24 @@ def compute_hash(self, x: torch.Tensor) -> torch.Tensor: if x_flat.dtype != self.dtype: x_flat = x_flat.to(self.dtype) - # [N, hash_bits] - xW = torch.matmul(x_flat, self.hash_weights) - - # [N * hash_bits] - xW_flat = xW.view(-1) - if self.device.type == "npu": + # [N, hash_bits] + xW = torch.matmul(x_flat, self.hash_weights) + # [N * hash_bits] + xW_flat = xW.view(-1) # [N*hash_numbers], where hash_numbers = hash_bits // 8 packed_codes_flat = torch_npu.npu_sign_bits_pack(xW_flat, size=1) - elif self.device.type == "cuda" or self.device.type == "cpu": - # (TODO) improve performance later on CUDA ops and CPU SIMD instructions - # [N, hash_bits] - projected = (xW > 0).to(torch.uint8) - # [N, hash_numbers, 8] - binary_codes = projected.view(-1, self.hash_numbers, 8) - - # binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8] - # then sum along the last dimension to get [N, hash_numbers] - packed_codes_flat = torch.sum( - binary_codes * self.bit_masks, dim=-1, dtype=torch.uint8 - ) # [N, hash_numbers] - packed_codes_flat = packed_codes_flat.view(-1) # [N * hash_numbers] + elif self.device.type == "cuda": + packed_codes_flat = triton_hash_code( + x_flat, self.hash_weights, self.bit_masks + ) # [N * hash_numbers] + + elif self.device.type == "cpu": + packed_codes_flat = torch_hash_code( + x_flat, self.hash_weights, self.bit_masks + ) # [N * hash_numbers] + else: raise ValueError(f"Unsupported device type: {self.device.type}") @@ -201,7 +324,7 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor: ) # expand last dim to 8 # (expanded & self.bit_masks) > 0 -> [N, hash_numbers, 8] - unpacked_bits = (expanded & self.bit_masks) > 0 + unpacked_bits = (expanded & self.bit_masks.unsqueeze(0).unsqueeze(0)) > 0 # 0 -> -1, 1 -> 1 unpacked_bits = unpacked_bits * 2 - 1 @@ -220,20 +343,22 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor: if __name__ == "__main__": + torch.manual_seed(42) + + print("test HashEncoder...") + dtype = torch.float16 if hasattr(torch, "npu") and torch.npu.is_available(): device = torch.device("npu:0") elif hasattr(torch, "cuda") and torch.cuda.is_available(): device = torch.device("cuda:0") + dtype = torch.float32 else: device = torch.device("cpu") print("Using device:", device) + encoder = HashEncoder(input_dim=8, hash_bits=8, dtype=dtype, device=device) - torch.manual_seed(42) - - encoder = HashEncoder(input_dim=8, hash_bits=8, dtype=torch.float16, device=device) - - x = torch.randn(2, 8, device=device, dtype=torch.float16) + x = torch.randn(2, 8, device=device, dtype=dtype) print("x:", x) hash_codes = encoder.compute_hash(x) @@ -250,3 +375,31 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor: print( f"hash_codes[1].item()={hash_codes[1].item()}, 8-bit binary form:{hash_codes[1].item():08b}" ) + + if hasattr(torch, "cuda") and torch.cuda.is_available(): + print("test cuda triton and torch hash code functions...") + x = torch.randn((1024, 512), device="cuda:0", dtype=torch.bfloat16) + code = torch.randn((512, 512), device="cuda:0", dtype=torch.bfloat16) + pack_weight = torch.tensor( + [128, 64, 32, 16, 8, 4, 2, 1], device="cuda:0", dtype=torch.uint8 + ) + + torch_output = torch_hash_code(x, code, pack_weight) + triton_output = triton_hash_code(x, code, pack_weight) + assert torch_output.shape == triton_output.shape + print(f"x_shape: {x.shape} code_shape: {code.shape}") + print("torch_output", torch_output) + print("triton_output", triton_output) + print( + f"The maximum difference between Torch and Triton is" + f" {torch.max(torch.abs(torch_output.to(torch.int32) - triton_output.to(torch.int32)))}" + ) + # benchmark + print( + "torch:", + triton.testing.do_bench(lambda: torch_hash_code(x, code, pack_weight)), + ) + print( + "triton:", + triton.testing.do_bench(lambda: triton_hash_code(x, code, pack_weight)), + ) diff --git a/ucm/sparse/kvcomp/hash_retrieval/CMakeLists.txt b/ucm/sparse/kvcomp/hash_retrieval/CMakeLists.txt new file mode 100644 index 000000000..5be9a6fe4 --- /dev/null +++ b/ucm/sparse/kvcomp/hash_retrieval/CMakeLists.txt @@ -0,0 +1,17 @@ +# 添加编译目标 +pybind11_add_module(hash_retrieval_backend cpy/hash_retrieval_backend.cpp) + +# 设置输出库的目录 +set_target_properties(hash_retrieval_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + +# 设置头文件目录,以确保 numaf.h 能找到 +target_include_directories(hash_retrieval_backend PUBLIC + ${NUMA_INSTALL_DIR}/include + ${Torch_INCLUDE_DIRS} +) + +# 链接所需的库 +target_link_libraries(hash_retrieval_backend PUBLIC + $<$:${NUMA_INSTALL_DIR}/lib/libnuma.so> + ${Torch_LIBRARIES} +) \ No newline at end of file diff --git a/ucm/sparse/kvcomp/hash_retrieval/__init__.py b/ucm/sparse/kvcomp/hash_retrieval/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp b/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp new file mode 100644 index 000000000..fdbe9d88e --- /dev/null +++ b/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp @@ -0,0 +1,444 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include +#include +#include // 用于UINT16_MAX +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef NUMA_ENABLED +#include +#include +#endif + +#if defined(__ARM_NEON) || defined(__ARM_NEON__) +#include +#elif defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) +#include // SSE/AVX +#include // POPCNT (SSE4.2) +#endif + +#define VEC_SIZE 16 + +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + +using vec16u = uint8x16_t; + +static inline vec16u vec_loadu16(const uint8_t* p) { return vld1q_u8(p); } + +static inline vec16u vec_xor(vec16u a, vec16u b) { return veorq_u8(a, b); } + +static inline uint16_t vec_sum_u8(vec16u v) +{ +#if defined(__aarch64__) || defined(_M_ARM64) + return vaddvq_u8(v); +#else + uint16x8_t s16 = vpaddlq_u8(v); + uint32x4_t s32 = vpaddlq_u16(s16); + uint64x2_t s64 = vpaddlq_u32(s32); + return (uint16_t)(vgetq_lane_u64(s64, 0) + vgetq_lane_u64(s64, 1)); +#endif +} + +static inline uint16_t vec_popcnt_xor_sum16(const uint8_t* a, const uint8_t* b) +{ + vec16u va = vec_loadu16(a); + vec16u vb = vec_loadu16(b); + vec16u vx = vec_xor(va, vb); + vec16u pc = vcntq_u8(vx); + return vec_sum_u8(pc); +} + +static inline uint16_t vec_popcnt_xor_sum16_vec(vec16u qa, const uint8_t* b) +{ + vec16u vb = vec_loadu16(b); + vec16u vx = vec_xor(qa, vb); + vec16u pc = vcntq_u8(vx); + return vec_sum_u8(pc); +} + +void print_uint8x16(uint8x16_t vec) +{ + uint8_t array[16]; + vst1q_u8(array, vec); + for (int i = 0; i < 16; ++i) { std::cout << static_cast(array[i]) << " "; } + std::cout << std::endl; +} + +#elif defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) + +using vec16u = __m128i; + +static inline vec16u vec_loadu16(const uint8_t* p) +{ + return _mm_loadu_si128(reinterpret_cast(p)); +} + +static inline vec16u vec_xor(vec16u a, vec16u b) { return _mm_xor_si128(a, b); } + +static inline uint16_t vec_popcnt_xor_sum16(const uint8_t* a, const uint8_t* b) +{ + __m128i va = _mm_loadu_si128(reinterpret_cast(a)); + __m128i vb = _mm_loadu_si128(reinterpret_cast(b)); + __m128i vx = _mm_xor_si128(va, vb); + + uint64_t lo, hi; +#if defined(__SSE4_1__) + lo = static_cast(_mm_extract_epi64(vx, 0)); + hi = static_cast(_mm_extract_epi64(vx, 1)); +#else + alignas(16) uint64_t tmp[2]; + _mm_storeu_si128(reinterpret_cast<__m128i*>(tmp), vx); + lo = tmp[0]; + hi = tmp[1]; +#endif + return (uint16_t)(__builtin_popcountll(lo) + __builtin_popcountll(hi)); +} + +static inline uint16_t vec_popcnt_xor_sum16_vec(vec16u qa, const uint8_t* b) +{ + __m128i vb = _mm_loadu_si128(reinterpret_cast(b)); + __m128i vx = _mm_xor_si128(qa, vb); + + uint64_t lo, hi; +#if defined(__SSE4_1__) + lo = static_cast(_mm_extract_epi64(vx, 0)); + hi = static_cast(_mm_extract_epi64(vx, 1)); +#else + alignas(16) uint64_t tmp[2]; + _mm_storeu_si128(reinterpret_cast<__m128i*>(tmp), vx); + lo = tmp[0]; + hi = tmp[1]; +#endif + return (uint16_t)(__builtin_popcountll(lo) + __builtin_popcountll(hi)); +} + +#else + +static inline uint16_t vec_popcnt_xor_sum16(const uint8_t* a, const uint8_t* b) +{ + uint16_t s = 0; + for (int i = 0; i < 16; ++i) s += __builtin_popcount((unsigned)(a[i] ^ b[i])); + return s; +} + +#endif + +namespace py = pybind11; + +class HashRetrievalWorkerBackend { +public: + HashRetrievalWorkerBackend(py::array_t data, py::dict cpu_idx_tbl) + : data_array_(data), stop_workers_(false), next_req_id_(0) + { + py::buffer_info info = data_array_.request(); + num_blocks_ = info.shape[0]; + block_size_ = info.shape[1]; + dim_ = info.shape[2]; + vec_per_dim_ = dim_ / VEC_SIZE; // data_每个值类型uint8_t,组成8*16_t进行simd加速 + tail_dim_ = dim_ % VEC_SIZE; + tail_start_ = vec_per_dim_ * VEC_SIZE; + data_ = static_cast(info.ptr); + + // Start worker threads + for (auto cpu_idx : cpu_idx_tbl) { + py::list core_ids = cpu_idx.second.cast(); + + for (size_t i = 0; i < core_ids.size(); ++i) { + int core_id = core_ids[i].cast(); + worker_threads_.emplace_back(&HashRetrievalWorkerBackend::worker_loop, this); + + // 核心绑定代码 + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(core_id, &cpuset); // 绑定每个线程到指定的核心 + + pthread_t thread = worker_threads_.back().native_handle(); + + // 设置 CPU 亲和性 + int rc = pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset); + if (rc != 0) { + std::cerr << "Error binding thread " << i << " to CPU core " << core_id + << std::endl; + } + +#ifdef NUMA_ENABLED + int numaId = cpu_idx.first.cast(); + // 设置内存亲和性 + unsigned long nodeMask = 1UL << numaId; + rc = set_mempolicy(MPOL_BIND, &nodeMask, sizeof(nodeMask) * 8); + if (rc != 0) { + std::cerr << "Error binding memory to NUMA node " << numaId << std::endl; + } +#endif + } + } + } + + ~HashRetrievalWorkerBackend() + { + { + std::lock_guard lock(mutex_); + stop_workers_ = true; + cond_.notify_all(); + } + for (auto& t : worker_threads_) t.join(); + } + + int submit(py::array_t query, int topk, py::array_t indexes) + { + py::buffer_info qinfo = query.request(); + py::buffer_info iinfo = indexes.request(); + if (qinfo.shape[1] != dim_) throw std::runtime_error("Query dim mismatch"); + if ((size_t)iinfo.shape[0] != (size_t)qinfo.shape[0]) + throw std::runtime_error("Query and indexes batch mismatch"); + + int req_id = next_req_id_.fetch_add(1); + + auto q = + std::vector((uint8_t*)qinfo.ptr, (uint8_t*)qinfo.ptr + qinfo.shape[0] * dim_); + + // Parse indexes to vector> + size_t n_requests = iinfo.shape[0], max_index_number = iinfo.shape[1]; + const int* idx_ptr = static_cast(iinfo.ptr); + std::vector> idxvec(n_requests); + for (size_t i = 0; i < n_requests; ++i) { + for (size_t j = 0; j < max_index_number; ++j) { + int index = idx_ptr[i * max_index_number + j]; + if (index != -1) idxvec[i].push_back(index); + } + } + + auto status = std::make_shared(); + { + std::lock_guard lock(mutex_); + requests_.emplace(Request{req_id, std::move(q), n_requests, topk, std::move(idxvec)}); + request_status_[req_id] = status; + } + cond_.notify_one(); + return req_id; + } + + bool poll(int req_id) + { + std::lock_guard lock(mutex_); + return results_.find(req_id) != results_.end(); + } + + void wait(int req_id) + { + std::shared_ptr s; + { + std::lock_guard lock(mutex_); + auto it = request_status_.find(req_id); + if (it == request_status_.end()) throw std::runtime_error("Bad req_id"); + s = it->second; + } + std::unique_lock lk2(s->m); + s->cv.wait(lk2, [&] { return s->done; }); + } + + py::dict get_result(int req_id) + { + std::lock_guard lock(mutex_); + auto it = results_.find(req_id); + if (it == results_.end()) throw std::runtime_error("Result not ready"); + + size_t batch_size = it->second.indices.size(); + size_t topk = batch_size > 0 ? it->second.indices[0].size() : 0; + py::array_t indices({batch_size, topk}); + py::array_t scores({batch_size, topk}); + + auto indices_ptr = static_cast(indices.request().ptr); + auto scores_ptr = static_cast(scores.request().ptr); + + for (size_t i = 0; i < batch_size; ++i) { + memcpy(indices_ptr + i * topk, it->second.indices[i].data(), topk * sizeof(int)); + memcpy(scores_ptr + i * topk, it->second.scores[i].data(), topk * sizeof(int)); + } + py::dict result; + result["indices"] = indices; + result["scores"] = scores; + results_.erase(it); + return result; + } + +private: + struct Request { + int req_id; + std::vector query; // Flattened [batch, dim] + size_t batch; + int topk; + std::vector> indexes; // Per-request index subset + }; + struct Result { + std::vector> indices; + std::vector> scores; + }; + + struct RequestStatus { + std::mutex m; + std::condition_variable cv; + bool done = false; + }; + + void worker_loop() + { + while (true) { + Request req; + { + std::unique_lock lock(mutex_); + cond_.wait(lock, [&] { return stop_workers_ || !requests_.empty(); }); + if (stop_workers_ && requests_.empty()) return; + req = std::move(requests_.front()); + requests_.pop(); + } + + Result res; + res.indices.resize(req.batch); + res.scores.resize(req.batch); + + // #pragma omp parallel for schedule(dynamic) + for (size_t b = 0; b < req.batch; ++b) { + const uint8_t* q_ptr = req.query.data() + b * dim_; + const auto& allowed = req.indexes[b]; + std::vector> heap; + heap.reserve(allowed.size()); + +#if defined(__ARM_NEON) || defined(__ARM_NEON__) || defined(__x86_64__) || defined(_M_X64) || \ + defined(__i386) || defined(_M_IX86) + // 1.预加载 query 向量 + vec16u q_vecs[vec_per_dim_]; // 存储query向量 + for (size_t v = 0; v < vec_per_dim_; ++v) { + q_vecs[v] = vec_loadu16(q_ptr + v * VEC_SIZE); + } +#endif + + // 2.遍历允许的索引 + for (auto idx : allowed) { + const uint8_t* base_idx_ptr = data_ + idx * block_size_ * dim_; + + int score = UINT16_MAX; // 初始化为最大值 + + // 3.内层向量化计算 + // #pragma omp parallel for + for (size_t t_idx = 0; t_idx < block_size_; ++t_idx) { + int sum = 0; + const uint8_t* k_base = base_idx_ptr + t_idx * dim_; + + // 计算每个向量的相似度 +#if defined(__ARM_NEON) || defined(__ARM_NEON__) || defined(__x86_64__) || defined(_M_X64) || \ + defined(__i386) || defined(_M_IX86) + for (size_t v = 0; v < vec_per_dim_; ++v) { + sum += vec_popcnt_xor_sum16_vec(q_vecs[v], k_base + v * VEC_SIZE); + } +#else + for (size_t v = 0; v < vec_per_dim_; ++v) { + sum += + vec_popcnt_xor_sum16(q_ptr + v * VEC_SIZE, k_base + v * VEC_SIZE); + } +#endif + if (tail_dim_ != 0) { + for (size_t t = 0; t < tail_dim_; ++t) { + uint8_t x = q_ptr[tail_start_ + t] ^ k_base[tail_start_ + t]; + sum += __builtin_popcount((unsigned)x); + } + } + + // 如果得分为0,则跳出循环 + if (sum < score) { + score = sum; + if (score == 0) { break; } + } + } + + // 将结果加入堆中 + heap.emplace_back(score, idx); + } + + // 获取当前TopK + int curr_topk = std::min((int)heap.size(), req.topk); + + // 对堆进行部分排序,获取TopK + std::partial_sort(heap.begin(), heap.begin() + curr_topk, heap.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); + + // 保存TopK结果 + for (int k = 0; k < curr_topk; ++k) { + res.scores[b].push_back(heap[k].first); + res.indices[b].push_back(heap[k].second); + } + } + + { + std::lock_guard lock(mutex_); + results_[req.req_id] = std::move(res); + auto s = request_status_[req.req_id]; + { + std::lock_guard lk2(s->m); + s->done = true; + } + s->cv.notify_all(); + } + } + } + + py::array_t data_array_; + const uint8_t* data_ = nullptr; + ssize_t dim_; + size_t num_blocks_, block_size_, vec_per_dim_, tail_dim_, tail_start_; + std::queue requests_; + std::unordered_map results_; + std::vector worker_threads_; + std::mutex mutex_; + std::condition_variable cond_; + std::unordered_map> request_status_; + bool stop_workers_; + std::atomic next_req_id_; +}; + +PYBIND11_MODULE(hash_retrieval_backend, m) +{ + py::class_(m, "HashRetrievalWorkerBackend") + .def(py::init, py::dict>()) + .def("submit", &HashRetrievalWorkerBackend::submit) + .def("poll", &HashRetrievalWorkerBackend::poll) + .def("get_result", &HashRetrievalWorkerBackend::get_result) + .def("wait", &HashRetrievalWorkerBackend::wait); +} diff --git a/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py b/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py new file mode 100644 index 000000000..5faf83dcc --- /dev/null +++ b/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py @@ -0,0 +1,119 @@ +import time +from collections import defaultdict + +import numpy as np +import torch + +from ucm.sparse.kvcomp.hash_encoder import HashEncoder +from ucm.sparse.kvcomp.hash_retrieval import hash_retrieval_backend +from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank + + +class HashRetrievalWorker: + # handle torch -> numpy && float16/bfloat16 -> float32. + def __init__(self, cpp_worker): + self.cpp_worker = cpp_worker + + @classmethod + def handle_input(cls, input): + if input.dtype != torch.uint8: + input = input.to(torch.uint8) + input = input.to("cpu", non_blocking=True) + return input + + def submit(self, query, topk, indexes): + q = self.handle_input(query) + req_id = self.cpp_worker.submit(q, topk, indexes) + return req_id + + def poll(self, req_id): + return self.cpp_worker.poll(req_id) # Returns True if ready + + def get_result(self, req_id): + return self.cpp_worker.get_result(req_id) + + def wait(self, req_id): + return self.cpp_worker.wait(req_id) + + +if __name__ == "__main__": + ################# data + batch_size = 2 + block_size = 2 + head_dim = 128 + head_num = 1 + dim = head_dim * head_num + kv_cache_blocks = 2560 + data = torch.rand(kv_cache_blocks, block_size, dim).to(torch.float32) + print("data created", data.shape) + + topk = 10 + search_blocks_range = 100 + tpot = 30 / 1000 + + indexes = np.arange(batch_size * search_blocks_range).reshape( + batch_size, search_blocks_range + ) + + query = torch.rand(batch_size, dim).to(torch.float32) + + hash_encoder = HashEncoder( + input_dim=dim, + hash_bits=dim, + dtype=torch.float32, + device=torch.device("cpu"), + ) + + hash_query = hash_encoder.compute_hash(query) + hash_key_cache = hash_encoder.compute_hash(data) + + ratio = 0.75 + total_tp_size = 4 + local_tp_rank = 0 + bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank( + total_tp_size, local_tp_rank, ratio=ratio + ) + + bind_info_dict = defaultdict(list) + for item in bind_info_list: + bind_info_dict[item[1]].append(item[0]) + bind_info_dict = dict(bind_info_dict) + + backend = hash_retrieval_backend.HashRetrievalWorkerBackend( + hash_key_cache, bind_info_dict + ) + worker = HashRetrievalWorker(backend) + + #################### cpp async version + req_id = worker.submit(hash_query, topk=topk, indexes=indexes) + + #################### LLM decode begin + time.sleep(tpot * 3) + #################### LLM decode done + + # Poll and get result (in a real program, you'd likely use asyncio or threading) + begin = time.time() + worker.wait(req_id) + result = worker.get_result(req_id) + print("cpp spent:", time.time() - begin) + cpp_indices = np.sort(result["indices"], 1) + print(f"cpp indices={cpp_indices}") + + ################### numpy version + unpacked_hash_query = hash_encoder._unpack_hash(hash_query) + unpacked_hash_key_cache = hash_encoder._unpack_hash(hash_key_cache) + begin = time.time() + data_indexed = unpacked_hash_key_cache[indexes.flatten()].reshape( + indexes.shape[0], indexes.shape[1], block_size, dim + ) + scores = torch.einsum("td, tnjd->tnj", unpacked_hash_query, data_indexed) + + block_scores_ret = torch.max(scores, dim=-1) + blocks_scores = block_scores_ret.values + + topk_ret = torch.topk(blocks_scores, topk, dim=-1) + topk_index = topk_ret.indices + topk_index = topk_index.sort(dim=-1).values + topk_index = indexes[np.arange(indexes.shape[0])[:, None], topk_index] + print("numpy spent: ", time.time() - begin) + print(f"numpy indices={topk_index}") diff --git a/ucm/sparse/kvcomp/kvcomp.py b/ucm/sparse/kvcomp/kvcomp.py index c1884b2ec..27fbe67d9 100644 --- a/ucm/sparse/kvcomp/kvcomp.py +++ b/ucm/sparse/kvcomp/kvcomp.py @@ -1,587 +1,299 @@ -""" -The MIT License - -Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - import math -import time +from collections import defaultdict from dataclasses import dataclass -from functools import wraps -from typing import Dict, List, Union +from typing import Any, Dict, List, Optional, Union +import numpy as np import torch from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.forward_context import ForwardContext -from vllm.sequence import SequenceStage -from vllm.v1.request import Request +from vllm.v1.request import Request, RequestStatus +from ucm.integration.vllm.ucm_connector import RequestHasher from ucm.logger import init_logger -from ucm.sandbox.sparse.kvcomp.hash_encoder import HashEncoder -from ucm.sandbox.sparse.kvcomp.kvcomp_config import KvCompConfig -from ucm.sparse.state import get_ucm_sparse - -logger = init_logger(__name__) - from ucm.sparse.base import ( INVALID_SLOT, UcmSparseBase, - UcmSparseMetadata, UcmSparseRole, ) +from ucm.sparse.esa.esa import ( + ESA, + ESASparseMetaData, + ReprePool, + ReqStatePerLayer, + get_sparse_range, +) +from ucm.sparse.kvcomp.hash_encoder import HashEncoder +from ucm.sparse.kvcomp.hash_retrieval import hash_retrieval_backend +from ucm.sparse.kvcomp.hash_retrieval.hash_retrieval_worker import HashRetrievalWorker +from ucm.sparse.kvcomp.kvcomp_config import KvCompConfig +from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank from ucm.sparse.state import get_ucm_sparse -from ucm.store.factory import UcmConnectorFactory from ucm.store.ucmstore import Task, UcmKVStoreBase +from ucm.utils import Config - -def stat(func): - @wraps(func) - def wrapper(*args, **kwargs): - wrapper.call_count += 1 - start = time.perf_counter_ns() - result = func(*args, **kwargs) - end = time.perf_counter_ns() - cost = end - start - wrapper.time_costs.append(cost) - return result - - wrapper.call_count = 0 - wrapper.time_costs = [] - return wrapper - +logger = init_logger(__name__) ReqType = Union[str, int] -HashType = Union[str, int] - -# TODO: add ESA specific config in kv_transfer_config -> extra_config -INIT_WINDOW_SZ = 1 -LOCAL_WINDOW_SZ = 2 -SPARSE_RATIO = 0.3 -RETRIEVAL_STRIDE = 4 - - -@dataclass -class ReqMeta: - request_id: ReqType - index_in_batch: int - num_prompt_tokens: int - num_output_tokens: int - num_scheduled_tokens: int - num_computed_tokens: int - num_sparsed_tokens: int - vllm_block_ids: list[int] - - @property - def step(self) -> int: - return self.num_output_tokens - - @property - def stage(self) -> SequenceStage: - return ( - SequenceStage.DECODE - if self.num_output_tokens > 0 - else SequenceStage.PREFILL - ) - @property - def is_last_chunk(self) -> bool: - return ( - self.num_computed_tokens + self.num_scheduled_tokens - >= self.num_prompt_tokens - ) +data = None -@dataclass -class KvCompSparseMetaData(UcmSparseMetadata): - requests: list[ReqMeta] - finished_req_ids: List[ReqType] - - def __init__(self): - self.requests = [] - self.finished_req_ids = [] - - def add_request( - self, - request_id: ReqType, - index_in_batch: int, - num_prompt_tokens: int, - num_output_tokens: int, - num_scheduled_tokens: int, - num_computed_tokens: int, - num_sparsed_tokens: int, - vllm_block_ids: list[int], - ) -> None: - meta = ReqMeta( - request_id=request_id, - index_in_batch=index_in_batch, - num_prompt_tokens=num_prompt_tokens, - num_output_tokens=num_output_tokens, - num_scheduled_tokens=num_scheduled_tokens, - num_computed_tokens=num_computed_tokens, - num_sparsed_tokens=num_sparsed_tokens, - vllm_block_ids=vllm_block_ids, - ) - self.requests.append(meta) - - -def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> int: - block_size, num_key_heads_per_tp, head_size = block_shape - k_min_data_block_size = block_size * num_key_heads_per_tp * head_size * precision - v_min_data_block_size = k_min_data_block_size if not is_mla else 0 - layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size - if is_mla: - k_offset = layer_size * layer_id - else: - k_offset = layer_size * layer_id + layer_size // tp_size * rank - v_offset = k_offset + k_min_data_block_size - return v_offset if is_v else k_offset - - -class ReqStatePerLayer: +class ReqStatePerLayerKvComp(ReqStatePerLayer): # handle single request per layer def __init__( self, - req_meta: ReqMeta, layer_name: str, rank: int, tp_size: int, store_instance: UcmKVStoreBase, + vllm_config: VllmConfig, + retrieval_worker: Optional[HashRetrievalWorker] = None, + repre_pool: Optional[ReprePool] = None, + esa_cfg: Optional[Dict[str, Any]] = None, ): - self.layer_name = layer_name - self.layer_id = int(layer_name.split(".")[2]) - self.block_repre: torch.Tensor = ( - None ## shape: blks, num_key_heads_per_tp, head_size + super().__init__( + layer_name, + rank, + tp_size, + store_instance, + vllm_config, + retrieval_worker, + repre_pool, ) - self.init_window: tuple[torch.Tensor, torch.Tensor] = None - self.local_window: tuple[torch.Tensor, torch.Tensor] = None - self.store_instance = store_instance - self.req_meta = req_meta - self.block_size = None - self.k_cache = None - self.v_cache = None - self.rank = rank - self.tp_size = tp_size - self.tasks: Dict[str, Task] = {} - self.init_window_sz = INIT_WINDOW_SZ - self.local_window_sz = LOCAL_WINDOW_SZ - - @classmethod - def req_state_hash(cls, req_id, layer_name): - return hash((req_id, layer_name)) - - @classmethod - def block_hash(cls, request_id, block_id): - return f"req_{request_id}_blk_{block_id}" - - @classmethod - def task_hash(cls, block_ids, store_type, tensor_type): - return hash((tuple(block_ids), store_type, tensor_type)) - - def update_meta(self, req_meta: ReqMeta, forward_context: ForwardContext): - self.req_meta = req_meta - - def retrieval(self, query: torch.Tensor, top_k: int): - if top_k >= self.block_repre.shape[0]: - n_blocks = self.block_repre.shape[0] - block_ids = list( - range(self.init_window_sz, n_blocks - self.local_window_sz + 1) - ) - block_hashes = [ - f"{self.block_hash(self.req_meta.request_id, id_)}" for id_ in block_ids - ] - return block_hashes + + self.esa_cfg = esa_cfg + # `retrieval_worker` 类型是 HashRetrievalWorker + self.retrieval_worker = retrieval_worker + + def extract_block_repre(self, vllm_block_ids): + ucm_sparse = get_ucm_sparse() + hash_encoder = ucm_sparse.hash_encoder + hashk_cache = hash_encoder.compute_hash(self.k_cache[vllm_block_ids]) + if self.is_mla: + hashk_cache = hashk_cache.unsqueeze(-2) + return hashk_cache + + def start_retrieval(self, batch_query, forward_context): + query_start_loc = self.req_meta.query_start_loc + query_len = self.req_meta.num_scheduled_tokens + query = batch_query[query_start_loc : query_start_loc + query_len] ntokens, num_q_heads, _ = query.shape if num_q_heads > self.num_key_heads: query = query.view(ntokens, self.num_key_heads, -1, self.head_size) query = query.mean(2) elif num_q_heads < self.num_key_heads: query = torch.repeat_interleave(query, self.num_key_heads // num_q_heads, 1) - - retrieval_start = self.init_window_sz - retrieval_end = self.block_repre.shape[0] - self.local_window_sz + 1 - block_repre_ = self.block_repre[retrieval_start:retrieval_end] - - if block_repre_.shape[0] == 0: - scores = torch.empty( - (block_repre_.shape[0]), dtype=query.dtype, device=query.device - ) - else: - ucm_sparse = get_ucm_sparse() - hash_encoder = ucm_sparse.hash_encoder - # query.shape [ntokens/BS, num_heads, head_size] - - # hash_query.shape [ntokens/BS, num_heads, hash_bits//8] - hash_query = hash_encoder.compute_hash(query) - # unpack_hash_query.shape [ntokens/BS, num_heads, hash_bits//8, 8] - unpack_hash_query = hash_encoder._unpack_hash(hash_query) - - # block_repre_.shape [n_blocks, block_size, num_kv_heads, head_size] - # unpack_hash_key_cache.shape [n_blocks, block_size, num_kv_heads, hash_bits//8, 8] - unpack_hash_key_cache = hash_encoder._unpack_hash(block_repre_) - - scores = torch.einsum( - "tid,njsd->tijsn", unpack_hash_query, unpack_hash_key_cache - ) - dims = tuple(range(scores.ndim - 1)) - - # [ntokens/BS, n_blocks] - scores = scores.sum(dim=dims) - - topk_ret = torch.topk(scores, top_k) - topk_index = topk_ret.indices - topk_index = ( - topk_index.sort().values - ) # TODO: remove this, don't need to sort in decode - block_ids = [id.item() + self.init_window_sz for id in topk_index] - block_hashes = [ - f"{self.block_hash(self.req_meta.request_id, id_)}" for id_ in block_ids - ] - return block_hashes - - def construct_init_and_local_window(self): - vllm_block_ids = self.req_meta.vllm_block_ids - # TODO: make sure we don't need to clone() - self.init_window = ( - self.k_cache[vllm_block_ids[: self.init_window_sz]], - self.v_cache[vllm_block_ids[: self.init_window_sz]], - ) - local_window_sz = min( - self.local_window_sz, len(vllm_block_ids[self.init_window_sz :]) - ) - if local_window_sz > 0: - self.local_window = ( - self.k_cache[vllm_block_ids[-local_window_sz:]], - self.v_cache[vllm_block_ids[-local_window_sz:]], - ) - - def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids): - fn = getattr(self.store_instance, transfer_type) - length = len(block_hashes) - block_shape = (self.block_size, self.num_key_heads, self.head_size) - precision = self.k_cache.untyped_storage().element_size() - # TODO: consider is_mla here - is_mla = False - offsets_k = [ - get_offset( - block_shape, - self.rank, - self.tp_size, - precision, - self.layer_id, - is_v=False, - is_mla=is_mla, - ) - ] * length - offsets_v = [ - get_offset( - block_shape, - self.rank, - self.tp_size, - precision, - self.layer_id, - is_v=True, - is_mla=is_mla, - ) - ] * length - key_src_tensors = [self.k_cache[id_] for id_ in vllm_block_ids] - value_src_tensors = [self.v_cache[id_] for id_ in vllm_block_ids] - task_k = fn(block_hashes, offsets_k, key_src_tensors) - task_v = fn(block_hashes, offsets_v, value_src_tensors) - task_k_hash = self.task_hash(block_hashes, transfer_type, "key") - self.tasks[task_k_hash] = task_k - task_v_hash = self.task_hash(block_hashes, transfer_type, "value") - self.tasks[task_v_hash] = task_v - - def extract_block_repre(self, vllm_block_ids): ucm_sparse = get_ucm_sparse() hash_encoder = ucm_sparse.hash_encoder - hashk_cache = hash_encoder.compute_hash(self.k_cache[vllm_block_ids]) - return hashk_cache + hash_query = hash_encoder.compute_hash(query) + query_flat = hash_query.reshape(query.shape[0], -1) + top_k = int(self.sparse_range * self.esa_cfg["sparse_ratio"]) + indexes = [self.slots] + self.retrieval_task = self.retrieval_worker.submit( + query_flat, topk=top_k, indexes=indexes + ) - def save_blocks(self, num_blocks_need_dump): - if num_blocks_need_dump <= 0: - return - vllm_block_ids = self.req_meta.vllm_block_ids - num_blocks_dumped = 0 if self.block_repre is None else self.block_repre.shape[0] - block_ids = list( - range(num_blocks_dumped, num_blocks_dumped + num_blocks_need_dump) + def block_repre_data(self): + self.sparse_range = get_sparse_range( + self.esa_cfg["init_window_sz"], + self.esa_cfg["local_window_sz"], + self.req_meta.num_prompt_tokens, + self.block_size, ) - block_hashes = [ - f"{self.block_hash(self.req_meta.request_id, id_)}" for id_ in block_ids + vllm_block_ids = self.req_meta.vllm_block_ids + # torch.save({"k": self.k_cache[vllm_block_ids].cpu(), "v": self.v_cache[vllm_block_ids].cpu()}, + # f"/home/heke/debug/{self.layer_id}.pkl") + vllm_block_ids_dump = vllm_block_ids[ + self.esa_cfg["init_window_sz"] : self.esa_cfg["init_window_sz"] + + self.sparse_range ] - if self.req_meta.stage == SequenceStage.PREFILL: - vllm_block_ids_dump = vllm_block_ids[ - num_blocks_dumped : num_blocks_dumped + num_blocks_need_dump - ] - else: - # TODO: handle spec_decode here - vllm_block_ids_dump = vllm_block_ids[-1:] - self.launch_transfer_task("dump", block_hashes, vllm_block_ids_dump) - repre = self.extract_block_repre(vllm_block_ids_dump) - # [n_blocks, num_kv_heads, block_size, hash_bits//8] - repre = repre.transpose(1, 2).contiguous() - # TODO: pre-allocate can speed up here - if self.block_repre is None: - self.block_repre = repre - else: - self.block_repre = torch.cat([self.block_repre, repre], dim=0) - - def maybe_register_kv_cache(self, forward_context: ForwardContext): - if self.block_size: - return - attn = forward_context.no_compile_layers[self.layer_name] - kv_cache = attn.kv_cache[forward_context.virtual_engine] - # TODO: consider is_mla here - self.k_cache = kv_cache[0] - self.v_cache = kv_cache[1] - self.block_size = self.k_cache.shape[1] - self.num_key_heads = self.k_cache.shape[2] - self.head_size = self.k_cache.shape[3] - - def attention_begin( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - forward_context: ForwardContext, - ) -> None: - if self.req_meta.step % RETRIEVAL_STRIDE != 1: - return - index_in_batch = self.req_meta.index_in_batch - if isinstance(forward_context.attn_metadata, dict): - attn_md = forward_context.attn_metadata[self.layer_name] + ######## 修改表征 + repre = self.extract_block_repre(vllm_block_ids_dump) + repre_flat = repre.reshape(repre.shape[0], repre.shape[1], -1) + new_slots = self.repre_pool.allocate(self.sparse_range) + og_len = len(self.slots) + for i, slot in enumerate(new_slots): + self.slots_to_relative_indexes[slot] = og_len + i + self.slots.extend(new_slots) + vals = repre_flat.to("cpu", non_blocking=True, dtype=torch.uint8) + data[self.layer_id][new_slots] = vals + ############## + + # NOTE: in Preemption, local_window_start != -self.esa_cfg['local_window_sz'] + local_window_start = self.esa_cfg["init_window_sz"] + self.sparse_range + + if not self.is_mla: + self.init_window = ( + self.k_cache[vllm_block_ids[: self.esa_cfg["init_window_sz"]]].clone(), + self.v_cache[vllm_block_ids[: self.esa_cfg["init_window_sz"]]].clone(), + ) + self.local_window = ( + self.k_cache[vllm_block_ids[local_window_start:]].clone(), + self.v_cache[vllm_block_ids[local_window_start:]].clone(), + ) else: - attn_md = forward_context.attn_metadata - query_start_loc = attn_md.query_start_loc[index_in_batch] - query_len = self.req_meta.num_scheduled_tokens - current_query = query[query_start_loc : query_start_loc + query_len] - - vllm_block_ids = self.req_meta.vllm_block_ids[ - self.init_window_sz : -self.local_window_sz - ] - self.wait_for_task_done() - self.prepare_init_and_local_window() # last dump task(possible) - # NOTE: sync style - topk_block_hashes = self.retrieval(current_query, len(vllm_block_ids)) - self.launch_transfer_task("load", topk_block_hashes, vllm_block_ids) - - self.wait_for_task_done() + self.init_window = self.k_cache[ + vllm_block_ids[: self.esa_cfg["init_window_sz"]] + ].clone() + self.local_window = self.k_cache[ + vllm_block_ids[local_window_start:] + ].clone() - # NOTE: Some sparse attention algorithms need to modify attn_metadata here - def prepare_init_and_local_window(self): - if self.req_meta.step != 1: - return - - vllm_block_ids = self.req_meta.vllm_block_ids - self.k_cache[vllm_block_ids[: self.init_window_sz]] = self.init_window[0] - self.v_cache[vllm_block_ids[: self.init_window_sz]] = self.init_window[1] - - if self.local_window is None: - return - - self.k_cache[vllm_block_ids[-self.local_window_sz :]] = self.local_window[0] - self.v_cache[vllm_block_ids[-self.local_window_sz :]] = self.local_window[1] - - def wait_for_task_done(self): - for task_hash, task in self.tasks.items(): - # TODO: handle exceptions here, refer to UcmKVConnector - ret = self.store_instance.wait(task) - self.tasks.clear() - - def attention_finished( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_output: torch.Tensor, - forward_context: ForwardContext, - ) -> None: - self.maybe_register_kv_cache(forward_context) - num_tokens_updated = ( - self.req_meta.num_computed_tokens + self.req_meta.num_scheduled_tokens - ) - num_blocks_dumped = 0 if self.block_repre is None else self.block_repre.shape[0] - num_full_blocks = num_tokens_updated // self.block_size - num_blocks_need_dump = num_full_blocks - num_blocks_dumped - self.save_blocks(num_blocks_need_dump) - if self.req_meta.stage == SequenceStage.PREFILL and self.req_meta.is_last_chunk: - self.construct_init_and_local_window() - self.wait_for_task_done() - - -class KvComp(UcmSparseBase): +class KvComp(ESA): # handle batch def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): - super().__init__(vllm_config, role) - self.req_states: dict[str, ReqStatePerLayer] = {} + UcmSparseBase.__init__(self, vllm_config, role) + self.req_states: dict[str, List[ReqStatePerLayerKvComp]] = {} self.rank = vllm_config.parallel_config.rank self.tp_size = vllm_config.parallel_config.tensor_parallel_size - self.block_size = vllm_config.cache_config.block_size + if role == UcmSparseRole.WORKER: + self.connector = get_kv_transfer_group().connector.store + else: + self.connector = None + self.total_num_hidden_layers = ( + vllm_config.model_config.hf_config.num_hidden_layers + ) + self.is_mla = vllm_config.model_config.is_deepseek_mla + self._sparse_metadata_prefill: ESASparseMetaData = ESASparseMetaData() + self._sparse_metadata_decode: ESASparseMetaData = ESASparseMetaData() + self._sparse_metadata: ESASparseMetaData = ESASparseMetaData() + self.esa_cfg = ( + Config(vllm_config.kv_transfer_config) + .get_config() + .get("ucm_sparse_config") + .get("KvComp") + ) - max_cache_size = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_config" - ]["max_cache_size"] - config = { - "max_cache_size": max_cache_size, - "device": self.rank, - "role": "worker", - } - self.connector = UcmConnectorFactory.create_connector("UcmDramStore", config) + self.block_size = vllm_config.cache_config.block_size + self.block_hashes: dict[int, dict[int, list[str]]] = {} + self.request_hasher = RequestHasher(vllm_config, 0) + self.num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + self.hashk_cache = None kvcomp_config_path = vllm_config.kv_transfer_config.kv_connector_extra_config[ "kvcomp_config_path" ] + self.kvcomp_config = KvCompConfig.from_json(kvcomp_config_path) - logger.info(f"read kvcomp config file: {kvcomp_config_path} ") + logger.info(f"read kvcomp config file : {kvcomp_config_path} ") assert ( - self.kvcomp_config.num_hidden_layers - == vllm_config.model_config.hf_text_config.num_hidden_layers + self.kvcomp_config.num_hidden_layers == self.total_num_hidden_layers ), f"kvcomp_config.num_hidden_layers {self.kvcomp_config.num_hidden_layers} \ - != vllm_config.model_config.hf_text_config.num_hidden_layers \ - {vllm_config.model_config.hf_text_config.num_hidden_layers}" - - dtype = vllm_config.model_config.dtype + != vllm_config.model_config.hf_text_config.num_hidden_layers \ + {self.total_num_hidden_layers}" if hasattr(torch, "npu") and torch.npu.is_available(): device = torch.device(f"npu:{self.rank}") - elif torch.cuda.is_available(): + elif hasattr(torch, "cuda") and torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") else: device = torch.device("cpu") + self.hash_encoder = HashEncoder( input_dim=self.kvcomp_config.head_dim, hash_bits=self.kvcomp_config.hash_bits, - dtype=dtype, + dtype=vllm_config.model_config.dtype, device=device, ) + self.device = device - # TODO: consider init self.is_mla here + global data - def attention_begin( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - layer_name: str, - forward_context: ForwardContext, - ) -> None: - for req_meta in self._sparse_metadata.requests: - req_state_hash = ReqStatePerLayer.req_state_hash( - req_meta.request_id, layer_name + if data is None: + parallel_config = vllm_config.parallel_config + num_slots = ( + vllm_config.model_config.max_model_len + * vllm_config.scheduler_config.max_num_seqs + // vllm_config.cache_config.block_size ) - if req_state_hash not in self.req_states: - self.req_states[req_state_hash] = ReqStatePerLayer( - req_meta, layer_name, self.rank, self.tp_size, self.connector - ) - req_state = self.req_states[req_state_hash] - req_state.update_meta(req_meta, forward_context) - req_state.attention_begin(query, key, value, forward_context) - - def attention_finished( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_output: torch.Tensor, - layer_name: str, - forward_context: ForwardContext, - ) -> None: - for req_meta in self._sparse_metadata.requests: - req_state_hash = ReqStatePerLayer.req_state_hash( - req_meta.request_id, layer_name + block_size = vllm_config.cache_config.block_size + dim = ( + vllm_config.model_config.get_num_kv_heads(parallel_config) + * self.kvcomp_config.hash_bits # 修改vllm_config.model_config.get_head_size()为hash_bits + // 8 ) - if req_state_hash not in self.req_states: - self.req_states[req_state_hash] = ReqStatePerLayer( - req_meta, layer_name, self.rank, self.tp_size, self.connector - ) - req_state = self.req_states[req_state_hash] - req_state.update_meta(req_meta, forward_context) - req_state.attention_finished( - query, key, value, attn_output, forward_context - ) - - def wait_all_task_done(self): - pass - - def execute_finished(self): - pass + data = [ + torch.empty((num_slots, block_size, dim), dtype=torch.uint8) + for _ in range(self.total_num_hidden_layers) + ] + self.layer_pools: list[ReprePool] = [ + ReprePool(num_slots) for _ in range(self.total_num_hidden_layers) + ] - def execute_finished(self): - pass + self.local_tp_rank = vllm_config.parallel_config.rank + self.total_tp_size = vllm_config.parallel_config.tensor_parallel_size + ratio = 0.75 - def build_sparse_meta( - self, - scheduler_output, - requests, - input_batch, - ) -> UcmSparseMetadata: - sparse_meta = KvCompSparseMetaData() - for ( - req_id, - num_scheduled_tokens, - ) in scheduler_output.num_scheduled_tokens.items(): - req_state = requests[req_id] - if len(req_state.prompt_token_ids) > self.block_size: - sparse_meta.add_request( - req_id, - input_batch.req_id_to_index[req_id], - len(req_state.prompt_token_ids), - len(req_state.output_token_ids), - num_scheduled_tokens, - req_state.num_computed_tokens, - scheduler_output.req_sparsed_slots[req_id], - req_state.block_ids[0], - ) - self._sparse_metadata = sparse_meta + bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank( + self.total_tp_size, self.local_tp_rank, ratio=ratio + ) - def request_begin(self, request_id: ReqType, prompt_token_ids: List[int]): - pass + bind_info_dict = defaultdict(list) + for item in bind_info_list: + bind_info_dict[item[1]].append(item[0]) + bind_info_dict = dict(bind_info_dict) - def request_finished_in_scheduler(self, request_id: ReqType): - pass + self.retrieval_workers: List[HashRetrievalWorker] = [] + for i in range(self.total_num_hidden_layers): + backend_src = data[i] + backend = hash_retrieval_backend.HashRetrievalWorkerBackend( + backend_src, bind_info_dict + ) + self.retrieval_workers.append(HashRetrievalWorker(backend)) - def request_finished_in_worker(self, request_id: ReqType): - pass + self.preempt_req_output_tokens: Dict[ReqType, int] = {} - def update_state_after_alloc(self, request: Request, num_blocks: int): - pass + def get_or_create_layerwise_req_state(self, req_meta, layer_name): + layer_id = int(layer_name.split(".")[2]) + if req_meta.is_preempt: + print( + f"preempt {req_meta.request_id}, layer_id: {layer_id}, {req_meta.num_output_tokens}" + ) + layer_state = self.req_states[req_meta.request_id][layer_id] + layer_state.repre_pool.free(layer_state.slots) + self.req_states[req_meta.request_id][layer_id] = None + if req_meta.request_id not in self.req_states: + if self.req_states.get(req_meta.request_id) is None: + self.req_states[req_meta.request_id] = [ + None + ] * self.total_num_hidden_layers + if self.req_states[req_meta.request_id][layer_id] is None: + self.req_states[req_meta.request_id][layer_id] = ReqStatePerLayerKvComp( + layer_name, + self.rank, + self.tp_size, + self.connector, + self._vllm_config, + self.retrieval_workers[layer_id], + self.layer_pools[layer_id], + self.esa_cfg, + ) + return self.req_states[req_meta.request_id][layer_id] - def estimate_num_slots_sparsed(self, request: Request) -> int: - if ( - request.num_output_tokens == 0 - or request.num_prompt_tokens < self.block_size - ): - return INVALID_SLOT - num_blocks = math.ceil(request.num_tokens / self.block_size) - mid_window_sz = int( - (num_blocks - INIT_WINDOW_SZ - LOCAL_WINDOW_SZ) * SPARSE_RATIO - ) - flaw = request.num_tokens % self.block_size - if flaw: - flaw = self.block_size - flaw - num_tokens_sparsed = ( - INIT_WINDOW_SZ + mid_window_sz + LOCAL_WINDOW_SZ - ) * self.block_size - flaw - return num_tokens_sparsed + def execute_begin(self, scheduler_output): + if self.hashk_cache is None: + print( + " ========================== initialize hashk cache ========================== " + ) + num_blocks = self._vllm_config.cache_config.num_gpu_blocks + self.hashk_cache = [ + torch.empty( + ( + num_blocks, + self.num_kv_heads, + self.block_size, + self.hash_encoder.hash_bits // 8, + ), + dtype=torch.uint8, + device=self.device, + ) + for _ in range(self.total_num_hidden_layers) + ] diff --git a/ucm/sparse/kvstar/multistep.py b/ucm/sparse/kvstar/multistep.py index 570d2defe..18ed4cb87 100644 --- a/ucm/sparse/kvstar/multistep.py +++ b/ucm/sparse/kvstar/multistep.py @@ -1,7 +1,7 @@ import enum import math from dataclasses import dataclass, field -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import torch from vllm.config import VllmConfig @@ -10,6 +10,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request +from ucm.integration.vllm.ucm_connector import RequestHasher from ucm.sparse.base import ( INVALID_SLOT, UcmSparseBase, @@ -17,8 +18,14 @@ UcmSparseRole, ) from ucm.sparse.kvstar.retrieve import kvstar_retrieve -from ucm.sparse.kvstar.utils import bind_cpus, block_hash_func, get_offset +from ucm.sparse.kvstar.utils import ( + block_hash_func, + compute_layer_offset, + compute_parent_block_hash, + get_bind_cpus_for_rank, +) from ucm.store.ucmstore import Task, UcmKVStoreBase +from ucm.utils import Config """ -------------------------------------------------------------------------------------- @@ -57,28 +64,28 @@ class ReqMeta: retrieval_stride: int = 8 block_hashes: list[str] = field(default_factory=list) - def set_block_hashes(self, token_ids): - block_hashes = [] - parent_block_hash_value = None - for start in range(0, len(token_ids), self.token_blk_size): - end = start + self.token_blk_size - block_token_ids = token_ids[start:end] - if len(block_token_ids) < self.token_blk_size: - break - curr_block_token_ids_tuple = tuple(block_token_ids) - block_hash = block_hash_func( - parent_block_hash_value, curr_block_token_ids_tuple - ) - block_hashes.append(str(block_hash)) - parent_block_hash_value = block_hash - return block_hashes - - @property - def req_block_hashes(self) -> list[str]: - if self.block_hashes: - return self.block_hashes - self.block_hashes = self.set_block_hashes(self.prompt_token_ids) - return self.block_hashes + # def set_block_hashes(self, token_ids): + # block_hashes = [] + # parent_block_hash_value = None + # for start in range(0, len(token_ids), self.token_blk_size): + # end = start + self.token_blk_size + # block_token_ids = token_ids[start:end] + # if len(block_token_ids) < self.token_blk_size: + # break + # curr_block_token_ids_tuple = tuple(block_token_ids) + # block_hash = block_hash_func( + # parent_block_hash_value, curr_block_token_ids_tuple + # ) + # block_hashes.append(str(block_hash)) + # parent_block_hash_value = block_hash + # return block_hashes + + # @property + # def req_block_hashes(self) -> list[str]: + # if self.block_hashes: + # return self.block_hashes + # self.block_hashes = self.set_block_hashes(self.prompt_token_ids) + # return self.block_hashes @property def step(self) -> int: @@ -153,6 +160,7 @@ def add_request( query_len: int, retrieval_stride: int, prompt_token_ids: list[int], + ucm_block_hashes: list[str], ) -> None: meta = ReqMeta( request_id=request_id, @@ -168,6 +176,7 @@ def add_request( query_start_loc=query_start_loc, query_len=query_len, retrieval_stride=retrieval_stride, + block_hashes=ucm_block_hashes, ) self.requests.append(meta) @@ -181,7 +190,6 @@ def __init__( rank: int, tp_size: int, store_instance: UcmKVStoreBase, - store_name: str, sparse_cfg, ): self.sparse_cfg = sparse_cfg @@ -193,7 +201,6 @@ def __init__( self.num_tokens = 0 # the number of all_tokens, prompt+output self.store_instance = store_instance - self.store_name = store_name self.req_meta = req_meta self.init_window: tuple[torch.Tensor, torch.Tensor] = None self.local_window: tuple[torch.Tensor, torch.Tensor] = None @@ -217,6 +224,10 @@ def __init__( self.num_blocks_dumped = 0 + self.layer_wise_pre_swap_area_block_hashes: Dict[int, str] = ( + {} + ) # key: block id, value: block hash id + @classmethod def block_hash(cls, request_id, block_id): return f"req_{request_id}_blk_{block_id}" @@ -348,6 +359,7 @@ def attention_begin( key: torch.Tensor, value: torch.Tensor, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: index_in_batch = self.req_meta.index_in_batch query_start_loc = self.req_meta.query_start_loc @@ -441,8 +453,40 @@ def load_retrieve_result_async(self, load_step, candidate_swap_vllm_block_ids): assert 0 retrieve_result_hash_list = self.step_group_retrieve_result.get( need_retrieve_record + ).copy() + fixed_origin_candidate_swap_vllm_block_ids = ( + candidate_swap_vllm_block_ids.copy() ) if need_retrieve_record != "prefill" or load_step == 1: + if len(self.layer_wise_pre_swap_area_block_hashes) == 0: + self.layer_wise_pre_swap_area_block_hashes = { + blk_id: blk_hash + for (blk_id, blk_hash) in zip( + candidate_swap_vllm_block_ids, retrieve_result_hash_list + ) + } + else: + already_matched_record = {} + for logic_blk_id in fixed_origin_candidate_swap_vllm_block_ids: + if ( + logic_blk_id in self.layer_wise_pre_swap_area_block_hashes + and self.layer_wise_pre_swap_area_block_hashes[logic_blk_id] + in retrieve_result_hash_list + ): + already_matched_record[logic_blk_id] = ( + self.layer_wise_pre_swap_area_block_hashes[logic_blk_id] + ) + candidate_swap_vllm_block_ids.remove(logic_blk_id) + retrieve_result_hash_list.remove( + already_matched_record[logic_blk_id] + ) + self.layer_wise_pre_swap_area_block_hashes = already_matched_record + for diff_blk_id, diff_blk_hash in zip( + candidate_swap_vllm_block_ids, retrieve_result_hash_list + ): + self.layer_wise_pre_swap_area_block_hashes[diff_blk_id] = ( + diff_blk_hash + ) if len(retrieve_result_hash_list) > 0: self.launch_transfer_task( "load", retrieve_result_hash_list, candidate_swap_vllm_block_ids @@ -507,6 +551,7 @@ def attention_finished( value: torch.Tensor, attn_output: torch.Tensor, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: if self.req_meta.stage != ReqStage.PREFILL: if ( @@ -539,7 +584,7 @@ def maybe_register_kv_cache(self, forward_context: ForwardContext): self.v_cache = kv_cache[1] self.block_size = self.k_cache.shape[1] self.num_key_heads = self.k_cache.shape[2] - self.block_hashes = self.req_meta.req_block_hashes + self.block_hashes = self.req_meta.block_hashes self.head_size = self.k_cache.shape[3] @classmethod @@ -556,29 +601,22 @@ def update_meta(self, req_meta: ReqMeta, forward_context: ForwardContext): def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids): fn = getattr(self.store_instance, transfer_type) length = len(block_hashes) - block_shape = (self.block_size, self.num_key_heads, self.head_size) precision = self.k_cache.storage().element_size() is_mla = False - block_shape = tuple(block_shape) + block_data_size = self.k_cache[0].numel() * precision offsets_k = [ - get_offset( - block_shape, - self.local_tp_rank, - self.total_tp_size, - precision, + compute_layer_offset( + block_data_size, self.layer_id, is_v=False, is_mla=is_mla, ) ] * length offsets_v = [ - get_offset( - block_shape, - self.local_tp_rank, - self.total_tp_size, - precision, + compute_layer_offset( + block_data_size, self.layer_id, is_v=True, is_mla=is_mla, @@ -614,38 +652,42 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self.total_num_hidden_layers = ( vllm_config.model_config.hf_config.num_hidden_layers ) + self.block_size = vllm_config.cache_config.block_size + self.block_hashes: dict[int, dict[int, list[str]]] = {} + self.rank = vllm_config.parallel_config.rank + self.is_mla = vllm_config.model_config.is_deepseek_mla + self.request_hasher = RequestHasher(vllm_config, 0) if self.role == UcmSparseRole.WORKER: ratio = 0.75 - numa_nodes_num, alloc_numa_ids, phy_cpu_core_per_numa = bind_cpus( + bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank( self.total_tp_size, self.local_tp_rank, ratio=ratio ) cpu_device = kvstar_retrieve.CPU param = kvstar_retrieve.SetupParam( cpuNumaIds=alloc_numa_ids, - physicalCorePerNuma=phy_cpu_core_per_numa, - allocRatio=ratio, - blkRepreSize=4096, + bindInfo=bind_info_list, deviceType=cpu_device, totalTpSize=self.total_tp_size, localRankId=self.local_tp_rank, ) kvstar_retrieve.Setup(param) - self.connector_name = ( - self._vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_name" - ] - ) - self.connector = get_kv_transfer_group().connector + # self.connector_name = ( + # self._vllm_config.kv_transfer_config.kv_connector_extra_config[ + # "ucm_connector_name" + # ] + # ) + self.connector = get_kv_transfer_group().connector.store else: self.connector = None assert self._vllm_config.kv_transfer_config is not None self.kvstar_multistep_cfg = ( - vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_config" - ]["KVStarMultiStep"] + Config(vllm_config.kv_transfer_config) + .get_config() + .get("ucm_sparse_config") + .get("KVStarMultiStep") ) print(f"kvstar_multistep_cfg: {self.kvstar_multistep_cfg}") @@ -665,7 +707,6 @@ def create_layerwise_req_state(self, req_meta, layer_name): self.local_tp_rank, self.total_tp_size, self.connector, - self.connector_name, self.kvstar_multistep_cfg, ) return self.req_states[req_meta.request_id][layer_id] @@ -690,6 +731,7 @@ def attention_begin( value: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: """ This is called at the beginning of "unified_attention". @@ -712,6 +754,7 @@ def attention_finished( attn_output: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: """ This is called at the end of "unified_attention". @@ -723,6 +766,44 @@ def attention_finished( query, key, value, attn_output, forward_context ) + def set_block_hashes(self, req_id, token_ids): + if req_id not in self.block_hashes: + self.block_hashes[req_id] = {} + + if self.rank in self.block_hashes[req_id]: + return + + self.block_hashes[req_id][self.rank] = [] + + parent_block_hash_value = compute_parent_block_hash( + self._vllm_config.model_config.model, + self._vllm_config.parallel_config.world_size, + self._vllm_config.model_config.dtype, + seed_rank=0, + ) + + for start in range(0, len(token_ids), self.block_size): + end = start + self.block_size + + block_token_ids = token_ids[start:end] + if len(block_token_ids) < self.block_size: + break + curr_block_token_ids_tuple = tuple(block_token_ids) + hash_value = self.request_hasher( + (parent_block_hash_value, curr_block_token_ids_tuple) + ) + + self.block_hashes[req_id][self.rank].append(str(hash_value)) + + parent_block_hash_value = hash_value + + if self.rank != 0 and not self.is_mla: + self.newqrequest_hasher = RequestHasher(self._vllm_config, self.rank) + for i, ucm_block_id in enumerate(self.block_hashes[req_id][self.rank]): + self.block_hashes[req_id][self.rank][i] = str( + self.newqrequest_hasher(ucm_block_id) + ) + def build_sparse_meta( self, scheduler_output, requests, input_batch, attn_metadata ) -> None: @@ -742,7 +823,7 @@ def build_sparse_meta( num_scheduled_tokens, ) in scheduler_output.num_scheduled_tokens.items(): req_state = requests[req_id] - + self.set_block_hashes(int(req_id), req_state.prompt_token_ids) q_start_loc = query_start_locs[input_batch.req_id_to_index[req_id]].item() q_len = ( query_start_locs[input_batch.req_id_to_index[req_id] + 1].item() @@ -764,6 +845,7 @@ def build_sparse_meta( q_len, self.kvstar_multistep_cfg["retrieval_stride"], req_state.prompt_token_ids, + self.block_hashes[int(req_id)][self.rank], ) self._sparse_metadata = sparse_meta @@ -813,9 +895,11 @@ def estimate_num_slots_sparsed(self, request: Request) -> int: estimate_num_slots_budget = num_blocks_this_step_budget * block_size return estimate_num_slots_budget - def allocate_slots( - self, request, num_slots_sparsed, coordinator, block_pool, kv_cache_groups - ): + def allocate_slots(self, kv_cache_manager, request, num_slots_sparsed): + coordinator = kv_cache_manager.coordinator + block_pool = kv_cache_manager.block_pool + kv_cache_groups = kv_cache_manager.kv_cache_config.kv_cache_groups + block_size = self._vllm_config.cache_config.block_size num_blocks_need = math.ceil(num_slots_sparsed / block_size) allocated_blocks = coordinator.get_blocks(request.request_id)[0] diff --git a/ucm/sparse/kvstar/retrieve/CMakeLists.txt b/ucm/sparse/kvstar/retrieve/CMakeLists.txt index 1e4584ff3..3f6777760 100644 --- a/ucm/sparse/kvstar/retrieve/CMakeLists.txt +++ b/ucm/sparse/kvstar/retrieve/CMakeLists.txt @@ -1,14 +1,8 @@ # auto detect cuda ------------------------ if($ENV{PLATFORM} STREQUAL "cuda") if(NOT DEFINED CMAKE_CUDA_COMPILER) - set(CUDA_HOME "$ENV{CUDA_HOME}") - if(NOT CUDA_HOME) - set(CUDA_HOME "/usr/local/cuda") - endif() - if(NOT EXISTS "${CUDA_HOME}") - message(FATAL_ERROR "CUDA_HOME directory does not exist: ${CUDA_HOME}") - endif() - set(CMAKE_CUDA_COMPILER "${CUDA_HOME}/bin/nvcc" CACHE FILEPATH "CUDA compiler" FORCE) + set(CUDA_ROOT "/usr/local/cuda/" CACHE PATH "Path to CUDA root directory") + set(CMAKE_CUDA_COMPILER ${CUDA_ROOT}/bin/nvcc) endif() enable_language(CUDA) set(CMAKE_CUDA_ARCHITECTURES 75 80 86 89 90) @@ -60,45 +54,53 @@ set(Torch_DIR ${PYTORCH_PATH}/share/cmake/Torch/) find_package(Torch REQUIRED) include_directories(${TORCH_INCLUDE_DIRS}) -set(NUMA_INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/numa_install) -FetchContent_Declare( - numactl - URL https://github.com/numactl/numactl/releases/download/v2.0.16/numactl-2.0.16.tar.gz - TLS_VERIFY OFF -) -FetchContent_MakeAvailable(numactl) -if(NOT EXISTS "${NUMA_INSTALL_DIR}/lib/libnuma.so") - message(STATUS "Configuring numactl...") - execute_process( - COMMAND ./configure --prefix=${NUMA_INSTALL_DIR} - WORKING_DIRECTORY ${numactl_SOURCE_DIR} - RESULT_VARIABLE numa_configure_result - OUTPUT_VARIABLE numa_configure_output - ERROR_VARIABLE numa_configure_error - ) - if(NOT numa_configure_result EQUAL 0) - message(FATAL_ERROR "Failed to configure numactl. \n" - "Result: ${numa_configure_result}\n" - "STDOUT: ${numa_configure_output}\n" - "STDERR: ${numa_configure_error}\n") - endif() +if(BUILD_NUMA) + message(STATUS "Building numactl library...") - message(STATUS "Building and installing numactl...") - execute_process( - COMMAND make install -j8 - WORKING_DIRECTORY ${numactl_SOURCE_DIR} - RESULT_VARIABLE numa_install_result - OUTPUT_VARIABLE numa_install_output - ERROR_VARIABLE numa_install_error + set(NUMA_INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/numa_install) + FetchContent_Declare( + numactl + URL https://github.com/numactl/numactl/releases/download/v2.0.16/numactl-2.0.16.tar.gz + TLS_VERIFY OFF ) - if(NOT numa_install_result EQUAL 0) - message(FATAL_ERROR "Failed to build and install numactl. \n" - "Result: ${numa_install_result}\n" - "STDOUT: ${numa_install_output}\n" - "STDERR: ${numa_install_error}\n") + FetchContent_MakeAvailable(numactl) + if(NOT EXISTS "${NUMA_INSTALL_DIR}/lib/libnuma.so") + message(STATUS "Configuring numactl...") + execute_process( + COMMAND ./configure --prefix=${NUMA_INSTALL_DIR} + WORKING_DIRECTORY ${numactl_SOURCE_DIR} + RESULT_VARIABLE numa_configure_result + OUTPUT_VARIABLE numa_configure_output + ERROR_VARIABLE numa_configure_error + ) + if(NOT numa_configure_result EQUAL 0) + message(FATAL_ERROR "Failed to configure numactl. \n" + "Result: ${numa_configure_result}\n" + "STDOUT: ${numa_configure_output}\n" + "STDERR: ${numa_configure_error}\n") + endif() + + message(STATUS "Building and installing numactl...") + execute_process( + COMMAND make install -j8 + WORKING_DIRECTORY ${numactl_SOURCE_DIR} + RESULT_VARIABLE numa_install_result + OUTPUT_VARIABLE numa_install_output + ERROR_VARIABLE numa_install_error + ) + if(NOT numa_install_result EQUAL 0) + message(FATAL_ERROR "Failed to build and install numactl. \n" + "Result: ${numa_install_result}\n" + "STDOUT: ${numa_install_output}\n" + "STDERR: ${numa_install_error}\n") + endif() + else() + message(STATUS "Found already built libnuma. Skipping build.") endif() + + add_definitions(-DNUMA_ENABLED) else() - message(STATUS "Found already built libnuma. Skipping build.") + message(STATUS "Skipping numactl build...") endif() add_subdirectory(core) diff --git a/ucm/sparse/kvstar/retrieve/core/CMakeLists.txt b/ucm/sparse/kvstar/retrieve/core/CMakeLists.txt index dc6b54129..876166e81 100644 --- a/ucm/sparse/kvstar/retrieve/core/CMakeLists.txt +++ b/ucm/sparse/kvstar/retrieve/core/CMakeLists.txt @@ -11,6 +11,6 @@ target_include_directories(kvstar_retrieve.core PUBLIC target_link_libraries(kvstar_retrieve.core PUBLIC spdlog::spdlog fmt::fmt - ${NUMA_INSTALL_DIR}/lib/libnuma.so + $<$:${NUMA_INSTALL_DIR}/lib/libnuma.so> ${Torch_LIBRARIES} ) diff --git a/ucm/sparse/kvstar/retrieve/core/api/kvstar_retrieve/kvstar_retrieve.cpp b/ucm/sparse/kvstar/retrieve/core/api/kvstar_retrieve/kvstar_retrieve.cpp index 847d923f7..cbc5ea6e0 100644 --- a/ucm/sparse/kvstar/retrieve/core/api/kvstar_retrieve/kvstar_retrieve.cpp +++ b/ucm/sparse/kvstar/retrieve/core/api/kvstar_retrieve/kvstar_retrieve.cpp @@ -7,30 +7,11 @@ #include "retrieve_task/retrieve_task_manager.h" namespace KVStar { -SetupParam::SetupParam(const std::vector& cpuNumaIds, const int physicalCorePerNuma, const float allocRatio, const size_t blkRepreSize, - const DeviceType deviceType, const int totalTpSize, const int localRankId) - : cpuNumaIds{cpuNumaIds}, physicalCorePerNuma{physicalCorePerNuma}, allocRatio{allocRatio}, blkRepreSize{blkRepreSize}, deviceType{deviceType}, +SetupParam::SetupParam(const std::vector& cpuNumaIds, const std::vector>& bindInfo, const DeviceType deviceType, const int totalTpSize, const int localRankId) + : cpuNumaIds{cpuNumaIds}, bindInfo{bindInfo}, deviceType{deviceType}, totalTpSize{totalTpSize}, localRankId{localRankId} { - - int coreNumPerNumaAlloc = static_cast(this->physicalCorePerNuma * this->allocRatio); - - this->perNumaCoreIds.clear(); - this->perNumaCoreIds.reserve(this->cpuNumaIds.size()); - - for (const int numaId : this->cpuNumaIds) { - int startCoreId = numaId * this->physicalCorePerNuma; - - std::vector curNumaCoreIdAlloc(coreNumPerNumaAlloc); - - std::iota(curNumaCoreIdAlloc.begin(), curNumaCoreIdAlloc.end(), startCoreId); - - this->perNumaCoreIds.push_back(curNumaCoreIdAlloc); - - KVSTAR_DEBUG("Alloc core ids {} in numa {}.", curNumaCoreIdAlloc, numaId); - } - - this->threadNum = static_cast(coreNumPerNumaAlloc * this->cpuNumaIds.size()); + this->threadNum = this->bindInfo.size(); KVSTAR_DEBUG("Successfully configured. Total threads = {}.", this->threadNum); } @@ -38,7 +19,7 @@ SetupParam::SetupParam(const std::vector& cpuNumaIds, const int physicalCor int32_t Setup(const SetupParam& param) { - auto status = Singleton::Instance()->Setup(param.threadNum, param.cpuNumaIds, param.perNumaCoreIds); + auto status = Singleton::Instance()->Setup(param.threadNum, param.bindInfo); if (status.Failure()) { KVSTAR_ERROR("Failed({}) to setup RetrieveTaskManager.", status); return status.Underlying(); @@ -53,4 +34,4 @@ int32_t Wait(const size_t taskId) { } -} \ No newline at end of file +} diff --git a/ucm/sparse/kvstar/retrieve/core/api/kvstar_retrieve/kvstar_retrieve.h b/ucm/sparse/kvstar/retrieve/core/api/kvstar_retrieve/kvstar_retrieve.h index be91bba55..cf28a9dfc 100644 --- a/ucm/sparse/kvstar/retrieve/core/api/kvstar_retrieve/kvstar_retrieve.h +++ b/ucm/sparse/kvstar/retrieve/core/api/kvstar_retrieve/kvstar_retrieve.h @@ -13,16 +13,13 @@ namespace KVStar { struct SetupParam { std::vector cpuNumaIds; - int physicalCorePerNuma; - float allocRatio; - size_t blkRepreSize; + std::vector> bindInfo; // coreId, numaId DeviceType deviceType; int totalTpSize; int localRankId; - std::vector> perNumaCoreIds; int threadNum; - SetupParam(const std::vector& cpuNumaIds, const int physicalCorePerNuma, const float allocRatio, const size_t blkRepreSize, + SetupParam(const std::vector& cpuNumaIds, const std::vector>& bindInfo, const DeviceType deviceType, const int totalTpSize, const int localRankId); }; @@ -36,4 +33,4 @@ int32_t Wait(const size_t taskId); -#endif //KVSTAR_RETRIEVE_CLIB_KVSTAR_RETRIEVE_H \ No newline at end of file +#endif //KVSTAR_RETRIEVE_CLIB_KVSTAR_RETRIEVE_H diff --git a/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_manager.cpp b/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_manager.cpp index 7af79e83e..36b86bbe0 100644 --- a/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_manager.cpp +++ b/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_manager.cpp @@ -2,47 +2,25 @@ #include "retrieve_task_manager.h" namespace KVStar { -Status RetrieveTaskManager::Setup(const size_t threadNum, const std::vector& cpuNumaIds, const std::vector>& bindCoreId) { +Status RetrieveTaskManager::Setup(const size_t threadNum, const std::vector>& bindInfo) { - const size_t numaNodeCount = cpuNumaIds.size(); - if (numaNodeCount == 0) { - KVSTAR_ERROR("Retrieve task manager get error numa id info {}.", cpuNumaIds); + if (threadNum != bindInfo.size()) { + KVSTAR_ERROR("Thread count ({}) does not match the size of bind-core-ID list ({}).", threadNum, bindInfo.size()); return Status::InvalidParam(); } - if (threadNum % numaNodeCount != 0) { - KVSTAR_ERROR("Retrieve task manager can not split threads into each numa, thread num {}, numa id info {}.", threadNum, cpuNumaIds); - return Status::InvalidParam(); - } - - if (bindCoreId.size() != numaNodeCount) { - KVSTAR_ERROR("Bind core ids {} can not match numa id info {}.", bindCoreId, cpuNumaIds); - return Status::InvalidParam(); - } - - const size_t threadsPerNuma = threadNum / numaNodeCount; - this->_queues.reserve(threadNum); for (size_t i = 0; i < threadNum; ++i) { - const size_t numaListIndex = i / threadsPerNuma; - - const size_t coreListIndex = i % threadsPerNuma; - - if (coreListIndex >= bindCoreId[numaListIndex].size()) { - KVSTAR_ERROR("Bind core ids {} can not alloc per numa need alloc threads num {}.", bindCoreId, threadsPerNuma); - return Status::InvalidParam(); - } - - const int targetNumaId = cpuNumaIds[numaListIndex]; - const int targetCoreId = bindCoreId[numaListIndex][coreListIndex]; + const int targetCoreId = bindInfo[i].first; + const int targetNumaId = bindInfo[i].second; auto& queue = this->_queues.emplace_back(std::make_unique()); auto status = queue->Setup(targetNumaId, targetCoreId, &this->_failureSet); if (status.Failure()) { - KVSTAR_ERROR("Init and setup thread id {} in pool failed.", i); + KVSTAR_ERROR("Init and setup thread id {} (to core {}) in pool failed.", i, targetCoreId); return status; } - KVSTAR_DEBUG("Init and setup thread id {} in pool success.", i); + KVSTAR_DEBUG("Init and setup thread id {} in pool to core {} success.", i, targetCoreId); } return Status::OK(); } @@ -106,4 +84,4 @@ Status RetrieveTaskManager::GetResult(size_t taskId, std::shared_ptr } -} \ No newline at end of file +} diff --git a/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_manager.h b/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_manager.h index a6bd0a254..67d2d0fef 100644 --- a/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_manager.h +++ b/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_manager.h @@ -10,7 +10,7 @@ namespace KVStar { class RetrieveTaskManager { public: - Status Setup(const size_t threadNum, const std::vector& cpuNumaIds, const std::vector>& bindCoreId); // 重要, 线程池拉起的入口 + Status Setup(const size_t threadNum, const std::vector>& bindInfo); Status SubmitSingleTask(RetrieveTask&&task, size_t &taskId); Status GetResult(size_t taskId, std::shared_ptr& result); @@ -36,4 +36,4 @@ class RetrieveTaskManager { -#endif //UCM_SPARSE_KVSTAR_RETRIEVE_RETRIEVE_TASK_MANAGER_H \ No newline at end of file +#endif //UCM_SPARSE_KVSTAR_RETRIEVE_RETRIEVE_TASK_MANAGER_H diff --git a/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_queue.cpp b/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_queue.cpp index 6ded55f0f..b504b1062 100644 --- a/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_queue.cpp +++ b/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_queue.cpp @@ -1,21 +1,26 @@ +#ifdef NUMA_ENABLED #include +#endif #include "retrieve_task_queue.h" #include "retrieve_task_runner.h" namespace KVStar { -RetrieveTaskQueue::~RetrieveTaskQueue() { +RetrieveTaskQueue::~RetrieveTaskQueue() +{ { std::unique_lock lk(this->_mutex); if (!this->_running) { return; } this->_running = false; } - if (this->_worker.joinable()){ + if (this->_worker.joinable()) { this->_cv.notify_all(); this->_worker.join(); } } -void RetrieveTaskQueue::Worker(const int numaId, const int bindCoreId, std::promise& started) { +void RetrieveTaskQueue::Worker(const int numaId, const int bindCoreId, + std::promise& started) +{ cpu_set_t cpuset; CPU_ZERO(&cpuset); CPU_SET(bindCoreId, &cpuset); @@ -27,6 +32,7 @@ void RetrieveTaskQueue::Worker(const int numaId, const int bindCoreId, std::prom return; } +#ifdef NUMA_ENABLED unsigned long nodemask = 1UL << numaId; rc = set_mempolicy(MPOL_BIND, &nodemask, sizeof(nodemask) * 8); if (rc != 0) { @@ -34,15 +40,17 @@ void RetrieveTaskQueue::Worker(const int numaId, const int bindCoreId, std::prom started.set_value(Status::OsApiError()); return; } +#endif - KVSTAR_DEBUG("Bind current thread {} to numa {} core {} and set memory affinity success.", thread, numaId, bindCoreId); + KVSTAR_DEBUG("Bind current thread {} to numa {} core {} and set memory affinity success.", + thread, numaId, bindCoreId); RetrieveTaskRunner runner; started.set_value(Status::OK()); Status status = Status::OK(); - for(;;){ + for (;;) { std::unique_lock lk(this->_mutex); this->_cv.wait(lk, [this] { return !this->_taskQ.empty() || !this->_running; }); if (!this->_running) { return; } @@ -56,22 +64,23 @@ void RetrieveTaskQueue::Worker(const int numaId, const int bindCoreId, std::prom if (!_failureSet->Exist(workItem.task.allocTaskId)) { if ((status = runner.Run(workItem.task, *workItem.result)).Failure()) { - KVSTAR_ERROR("Failed({}) to run retrieve task({}).", status, workItem.task.allocTaskId); + KVSTAR_ERROR("Failed({}) to run retrieve task({}).", status, + workItem.task.allocTaskId); this->_failureSet->Insert(workItem.task.allocTaskId); workItem.result->status = TaskStatus::FAILURE; } else { - KVSTAR_DEBUG("Process current task success, task id: {}.", workItem.task.allocTaskId); + KVSTAR_DEBUG("Process current task success, task id: {}.", + workItem.task.allocTaskId); workItem.result->status = TaskStatus::SUCCESS; } } workItem.task.waiter->Done(); } - } - -Status RetrieveTaskQueue::Setup(const int numaId, const int bindCoreId, RetrieveTaskSet* failureSet) { +Status RetrieveTaskQueue::Setup(const int numaId, const int bindCoreId, RetrieveTaskSet* failureSet) +{ this->_failureSet = failureSet; { std::unique_lock lk(this->_mutex); @@ -79,11 +88,12 @@ Status RetrieveTaskQueue::Setup(const int numaId, const int bindCoreId, Retrieve } std::promise started; auto fut = started.get_future(); - this->_worker = std::thread([&]{ this->Worker(numaId, bindCoreId, started); }); + this->_worker = std::thread([&] { this->Worker(numaId, bindCoreId, started); }); return fut.get(); } -void RetrieveTaskQueue::Push(WorkItem&& item) { +void RetrieveTaskQueue::Push(WorkItem&& item) +{ { std::unique_lock lk(this->_mutex); this->_taskQ.push_back(std::move(item)); @@ -91,5 +101,4 @@ void RetrieveTaskQueue::Push(WorkItem&& item) { this->_cv.notify_one(); } - -} \ No newline at end of file +} // namespace KVStar \ No newline at end of file diff --git a/ucm/sparse/kvstar/retrieve/py_intf/py_intf.cpp b/ucm/sparse/kvstar/retrieve/py_intf/py_intf.cpp index 4019262a0..151237bd2 100644 --- a/ucm/sparse/kvstar/retrieve/py_intf/py_intf.cpp +++ b/ucm/sparse/kvstar/retrieve/py_intf/py_intf.cpp @@ -106,23 +106,17 @@ PYBIND11_MODULE(kvstar_retrieve, module) py::class_(module, "SetupParam") .def(py::init&, - const int, - const float, - const size_t, + const std::vector>&, const KVStar::DeviceType, const int, const int>(), py::arg("cpuNumaIds"), - py::arg("physicalCorePerNuma"), - py::arg("allocRatio"), - py::arg("blkRepreSize"), + py::arg("bindInfo"), py::arg("deviceType"), py::arg("totalTpSize"), py::arg("localRankId")) .def_readwrite("cpuNumaIds", &KVStar::SetupParam::cpuNumaIds) - .def_readwrite("physicalCorePerNuma", &KVStar::SetupParam::physicalCorePerNuma) - .def_readwrite("allocRatio", &KVStar::SetupParam::allocRatio) - .def_readwrite("blkRepreSize", &KVStar::SetupParam::blkRepreSize) + .def_readwrite("bindInfo", &KVStar::SetupParam::bindInfo) .def_readwrite("deviceType", &KVStar::SetupParam::deviceType) .def_readwrite("totalTpSize", &KVStar::SetupParam::totalTpSize) .def_readwrite("localRankId", &KVStar::SetupParam::localRankId); @@ -131,4 +125,4 @@ PYBIND11_MODULE(kvstar_retrieve, module) module.def("AsyncRetrieveByCPU", &KVStar::AsyncRetrieveByCPU); module.def("Wait", &KVStar::Wait); module.def("GetTaskResult", &KVStar::GetTaskResult); -} \ No newline at end of file +} diff --git a/ucm/sparse/kvstar/utils.py b/ucm/sparse/kvstar/utils.py index 92f82b218..198b45fa7 100644 --- a/ucm/sparse/kvstar/utils.py +++ b/ucm/sparse/kvstar/utils.py @@ -1,3 +1,4 @@ +import collections import hashlib import pickle import subprocess @@ -18,6 +19,24 @@ def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> return v_offset if is_v else k_offset +@cache +def compute_layer_offset( + block_data_size: int, + layer_id: int, + is_v: bool, + is_mla: bool, +) -> int: + layer_data_size = block_data_size if is_mla else block_data_size * 2 + + k_offset = layer_data_size * layer_id + + if is_mla: + return k_offset + + v_offset = k_offset + block_data_size + return v_offset if is_v else k_offset + + @cache def md5(input) -> int: input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) @@ -33,6 +52,14 @@ def block_hash_func(parent_block_hash, curr_block_token_ids): return md5((parent_block_hash, curr_block_token_ids_tuple)) +@cache +def compute_parent_block_hash(model_name, world_size, dtype, seed_rank=0) -> int: + meta = f"{model_name}:{world_size}:{dtype}:{seed_rank}" + meta_bytes = meta.encode("utf-8") + h_seed = hashlib.md5(meta_bytes + b"UCM_HASH_SEED").digest() + return int.from_bytes(h_seed, byteorder="big") + + def execute_command(cmd_list): with subprocess.Popen( cmd_list, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE @@ -103,3 +130,114 @@ def bind_cpus(world_size, rank_id, ratio=0.5): print(f"cpu_core_alloc: {cpu_core_alloc}") return numa_nodes_num, alloc_numa_ids, phy_cpu_core_per_numa + + +def get_physical_core_topology(): + """ + use lscpu -e parse accurate cpu topology + return a dict, key: numa_id, value: physical core ids in this numa + """ + # topology[numa_id][core_id] = logical_cpu_id + # make sure each physical core only record once + topology = collections.defaultdict(dict) + + # execute lscpu -e, split as line + # e.g.: 36 0 0 0 0:0:0:0 yes 3700.0000 1000.0000 + lscpu_output = execute_command(["lscpu", "-e"]).strip().split("\n") + + # skip title + for line in lscpu_output[1:]: + parts = line.split() + if len(parts) < 4: + continue + + logical_cpu_id = int(parts[0]) + numa_id = int(parts[1]) + core_id = int(parts[3]) # physical core id + + if core_id not in topology[numa_id]: + topology[numa_id][core_id] = logical_cpu_id + + final_mapping = { + numa_id: list(sorted(cores.values())) for numa_id, cores in topology.items() + } + return final_mapping + + +def get_bind_cpus_for_rank(world_size, rank_id, ratio=1.0): + """ + for each rank, compute alloc numa id + + scenario: + 1. numa_num >= world_size, equal division numa for each rank + 2. numa_num < world_size, equal division total cores for each rank + """ + physical_core_map = get_physical_core_topology() + if not physical_core_map: + print("Could not determine CPU topology. Aborting bind.") + return [], [] + + print(f"Detected Physical Core Topology: {physical_core_map}") + + numa_nodes_num = len(physical_core_map) + sorted_numa_ids = sorted(physical_core_map.keys()) + + bind_info_list = [] + alloc_numa_ids = [] + + numas_per_rank = numa_nodes_num // world_size + + if numas_per_rank > 0: + print(f"Strategy: NUMA-level discard binding.") + + discarded_numa_count = numa_nodes_num % world_size + if discarded_numa_count > 0: + print( + f"Note: {discarded_numa_count} NUMA node(s) (IDs: {sorted_numa_ids[-discarded_numa_count:]}) will be unused to ensure fair distribution." + ) + + start_numa_idx = rank_id * numas_per_rank + end_numa_idx = start_numa_idx + numas_per_rank + + alloc_numa_ids = sorted_numa_ids[start_numa_idx:end_numa_idx] + + print(f"Rank {rank_id} allocated to NUMA nodes: {alloc_numa_ids}") + + for numa_id in alloc_numa_ids: + physical_cores_on_numa = physical_core_map.get(numa_id, []) + cores_to_take = int(len(physical_cores_on_numa) * ratio) + for core_id in physical_cores_on_numa[:cores_to_take]: + bind_info_list.append((core_id, numa_id)) + + else: + print( + f"Strategy: Fallback to uniform core distribution ({world_size} ranks > {numa_nodes_num} NUMA nodes)." + ) + + all_physical_cores_with_numa = [] + for numa_id in sorted_numa_ids: + for core_id in physical_core_map[numa_id]: + all_physical_cores_with_numa.append((core_id, numa_id)) + + total_physical_cores = len(all_physical_cores_with_numa) + cores_per_rank = total_physical_cores // world_size + if cores_per_rank == 0: + print( + f"Warning: Not enough physical cores ({total_physical_cores}) to assign at least one to each of the {world_size} ranks. Rank {rank_id} will not be bound to any core." + ) + return [], sorted_numa_ids + + start_core_idx = rank_id * cores_per_rank + end_core_idx = start_core_idx + cores_per_rank + + rank_core_share = all_physical_cores_with_numa[start_core_idx:end_core_idx] + cores_to_take = int(len(rank_core_share) * ratio) + bind_info_list = rank_core_share[:cores_to_take] + + alloc_numa_ids = sorted_numa_ids + + bind_info_list.sort() + print( + f"Rank {rank_id} will bind to {len(bind_info_list)} (CPU, NUMA) pairs: {bind_info_list}" + ) + return bind_info_list, alloc_numa_ids diff --git a/ucm/sparse/state.py b/ucm/sparse/state.py index a4e93c8db..a0f77a53b 100644 --- a/ucm/sparse/state.py +++ b/ucm/sparse/state.py @@ -11,6 +11,7 @@ from ucm.logger import init_logger from ucm.sparse.base import UcmSparseBase, UcmSparseRole from ucm.sparse.factory import UcmSparseFactory +from ucm.utils import Config if TYPE_CHECKING: from vllm.config import VllmConfig @@ -37,15 +38,12 @@ def ensure_ucm_sparse_initialized( return # Check if UCM sparse is enabled - if ( - "ucm_sparse_config" - not in vllm_config.kv_transfer_config.kv_connector_extra_config - ): + ucm_config = Config(vllm_config.kv_transfer_config) + ucm_sparse_config = ucm_config.get_config().get("ucm_sparse_config") + if not ucm_sparse_config: return - sparse_method_name = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_config" - ] + sparse_method_name = ucm_sparse_config if _UCM_SPARSE_AGENT is None: logger.info("Initializing UCM sparse agent with method: %s", sparse_method_name) diff --git a/ucm/sparse/utils.py b/ucm/sparse/utils.py index 168ae4777..358f71a37 100644 --- a/ucm/sparse/utils.py +++ b/ucm/sparse/utils.py @@ -10,7 +10,44 @@ CUDA_TOPK = False PTOPK_PREFETCH_ENABLE = False VLLM_CUDA_MEM_ALIGN_KV_CACHE = False -LOCAL_WINDOW_SZ = MIN_TOPK_LEN - 1 +INIT_WINDOW_SZ = 1 +NUM_PREFETCH_BLOCKS = 1 +NUM_GSA_BLOCKS = 1 + + +class GSAConfig: + def __init__(self): + self.block_size = DEFAULT_BLOCK_SIZE + self.init_windows_size = INIT_WINDOW_SZ + self.num_prefetch_blocks = NUM_PREFETCH_BLOCKS + self.min_topk_len = MIN_TOPK_LEN + self.max_topk_len = MAX_TOPK_LEN + + def set_config(self, block_szie): + self.block_size = block_szie + self.min_topk_len = math.ceil(MIN_TOPK_LEN * DEFAULT_BLOCK_SIZE / block_szie) + self.max_topk_len = math.ceil(MAX_TOPK_LEN * DEFAULT_BLOCK_SIZE / block_szie) + self.num_prefetch_blocks = math.ceil( + NUM_PREFETCH_BLOCKS * DEFAULT_BLOCK_SIZE / block_szie + ) + self.init_windows_size = math.ceil( + INIT_WINDOW_SZ * DEFAULT_BLOCK_SIZE / block_szie + ) + self.num_gsa_blocks = math.ceil( + NUM_GSA_BLOCKS * DEFAULT_BLOCK_SIZE / block_szie + ) + + def compute_topk_len(self, raw_seq_len): + topk_len = math.ceil(raw_seq_len * 0.3) + # topk_len = max(1, topk_len) + if topk_len < self.min_topk_len: + topk_len = min(self.min_topk_len, raw_seq_len) + elif topk_len > self.max_topk_len: + topk_len = self.max_topk_len + return topk_len + + +gsa_config = GSAConfig() def round_up(x: int, y: int) -> int: @@ -25,12 +62,3 @@ def align_to_256bytes(extent: int, dtype: torch.dtype) -> int: dtype_szie = get_type_size(dtype) eles_per_256bytes = 256 // dtype_szie return round_up(extent, eles_per_256bytes) - - -def compute_topk_len(raw_seq_len): - topk_len = int(raw_seq_len * 0.3) - if topk_len < MIN_TOPK_LEN: - topk_len = min(MIN_TOPK_LEN, raw_seq_len) - elif topk_len > MAX_TOPK_LEN: - topk_len = MAX_TOPK_LEN - return topk_len diff --git a/ucm/store/CMakeLists.txt b/ucm/store/CMakeLists.txt index 194b49b9a..c3825360d 100644 --- a/ucm/store/CMakeLists.txt +++ b/ucm/store/CMakeLists.txt @@ -1,12 +1,10 @@ -option(DOWNLOAD_DEPENDENCE "download dependence by cmake." ON) -set(LOGGER_BACKEND "spdlog" CACHE STRING "backend: spdlog or flux.") - include_directories(.) -add_subdirectory(vendor) add_subdirectory(infra) add_subdirectory(device) add_subdirectory(nfsstore) +add_subdirectory(pcstore) add_subdirectory(dramstore) add_subdirectory(localstore) add_subdirectory(mooncakestore) +add_subdirectory(task) add_subdirectory(test) diff --git a/ucm/store/device/CMakeLists.txt b/ucm/store/device/CMakeLists.txt index eef792ae9..e184705a9 100644 --- a/ucm/store/device/CMakeLists.txt +++ b/ucm/store/device/CMakeLists.txt @@ -1,10 +1,17 @@ if(RUNTIME_ENVIRONMENT STREQUAL "ascend") add_subdirectory(ascend) -endif() -if(RUNTIME_ENVIRONMENT STREQUAL "cuda") +elseif(RUNTIME_ENVIRONMENT STREQUAL "musa") + add_subdirectory(musa) +elseif(RUNTIME_ENVIRONMENT STREQUAL "cuda") add_subdirectory(cuda) -endif() -if(RUNTIME_ENVIRONMENT STREQUAL "simu") +elseif(RUNTIME_ENVIRONMENT STREQUAL "simu") add_subdirectory(simu) +else() + message(FATAL_ERROR "RUNTIME_ENVIRONMENT must be one of: ascend, musa, cuda, simu. Current value: ${RUNTIME_ENVIRONMENT}") +endif() + +if(TARGET storedevice) + target_include_directories(storedevice PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +else() + message(FATAL_ERROR "storedevice target was not created. Check RUNTIME_ENVIRONMENT setting and subdirectory CMakeLists.txt files.") endif() -target_include_directories(storedevice PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/ucm/store/device/ascend/ascend_device.cc b/ucm/store/device/ascend/ascend_device.cc index 7a40fa0fa..36062c596 100644 --- a/ucm/store/device/ascend/ascend_device.cc +++ b/ucm/store/device/ascend/ascend_device.cc @@ -91,6 +91,14 @@ class AscendDevice : public IBufferedDevice { } return Status::OK(); } + Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) override + { + return ASCEND_API(aclrtMemcpy, dst, count, src, count, ACL_MEMCPY_HOST_TO_DEVICE); + } + Status D2HSync(std::byte* dst, const std::byte* src, const size_t count) override + { + return ASCEND_API(aclrtMemcpy, dst, count, src, count, ACL_MEMCPY_DEVICE_TO_HOST); + } Status H2DAsync(std::byte* dst, const std::byte* src, const size_t count) override { return ASCEND_API(aclrtMemcpyAsync, dst, count, src, count, ACL_MEMCPY_HOST_TO_DEVICE, @@ -111,6 +119,25 @@ class AscendDevice : public IBufferedDevice { return ASCEND_API(aclrtLaunchCallback, Trampoline, (void*)c, ACL_CALLBACK_NO_BLOCK, this->stream_); } + Status Synchronized() override { return ASCEND_API(aclrtSynchronizeStream, this->stream_); } + Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number, + const size_t count) override + { + for (size_t i = 0; i < number; i++) { + auto status = this->H2DSync(dArr[i], hArr[i], count); + if (status.Failure()) { return status; } + } + return Status::OK(); + } + Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number, + const size_t count) override + { + for (size_t i = 0; i < number; i++) { + auto status = this->D2HSync(hArr[i], dArr[i], count); + if (status.Failure()) { return status; } + } + return Status::OK(); + } protected: std::shared_ptr MakeBuffer(const size_t size) override diff --git a/ucm/store/device/cuda/CMakeLists.txt b/ucm/store/device/cuda/CMakeLists.txt index bf43524db..fa0db292d 100644 --- a/ucm/store/device/cuda/CMakeLists.txt +++ b/ucm/store/device/cuda/CMakeLists.txt @@ -1,9 +1,10 @@ set(CUDA_ROOT "/usr/local/cuda/" CACHE PATH "Path to CUDA root directory") -add_library(Cuda::cudart UNKNOWN IMPORTED) -set_target_properties(Cuda::cudart PROPERTIES - INTERFACE_INCLUDE_DIRECTORIES "${CUDA_ROOT}/include" - IMPORTED_LOCATION "${CUDA_ROOT}/lib64/libcudart.so" +set(CMAKE_CUDA_COMPILER ${CUDA_ROOT}/bin/nvcc) +set(CMAKE_CUDA_ARCHITECTURES 75 80 86 89 90) +enable_language(CUDA) +add_library(storedevice STATIC cuda_device.cu) +target_link_libraries(storedevice PUBLIC storeinfra) +target_compile_options(storedevice PRIVATE + --diag-suppress=128 --diag-suppress=2417 --diag-suppress=2597 + -Wall -fPIC ) - -add_library(storedevice STATIC cuda_device.cc) -target_link_libraries(storedevice PUBLIC storeinfra Cuda::cudart) diff --git a/ucm/store/device/cuda/cuda_device.cu b/ucm/store/device/cuda/cuda_device.cu new file mode 100644 index 000000000..235b860cb --- /dev/null +++ b/ucm/store/device/cuda/cuda_device.cu @@ -0,0 +1,241 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include +#include "ibuffered_device.h" +#include "logger/logger.h" + +#define CUDA_TRANS_UNIT_SIZE (sizeof(uint64_t) * 2) +#define CUDA_TRANS_BLOCK_NUMBER (32) +#define CUDA_TRANS_BLOCK_SIZE (256) +#define CUDA_TRANS_THREAD_NUMBER (CUDA_TRANS_BLOCK_NUMBER * CUDA_TRANS_BLOCK_SIZE) + +inline __device__ void H2DUnit(uint8_t* __restrict__ dst, const volatile uint8_t* __restrict__ src) +{ + uint64_t a, b; + asm volatile("ld.global.cs.v2.u64 {%0, %1}, [%2];" : "=l"(a), "=l"(b) : "l"(src)); + asm volatile("st.global.cg.v2.u64 [%0], {%1, %2};" ::"l"(dst), "l"(a), "l"(b)); +} + +inline __device__ void D2HUnit(volatile uint8_t* __restrict__ dst, const uint8_t* __restrict__ src) +{ + uint64_t a, b; + asm volatile("ld.global.cs.v2.u64 {%0, %1}, [%2];" : "=l"(a), "=l"(b) : "l"(src)); + asm volatile("st.volatile.global.v2.u64 [%0], {%1, %2};" ::"l"(dst), "l"(a), "l"(b)); +} + +__global__ void H2DKernel(uintptr_t* dst, const volatile uintptr_t* src, size_t num, size_t size) +{ + auto length = num * size; + auto offset = (blockIdx.x * blockDim.x + threadIdx.x) * CUDA_TRANS_UNIT_SIZE; + while (offset + CUDA_TRANS_UNIT_SIZE <= length) { + auto idx = offset / size; + auto off = offset % size; + H2DUnit(((uint8_t*)dst[idx]) + off, ((const uint8_t*)src[idx]) + off); + offset += CUDA_TRANS_THREAD_NUMBER * CUDA_TRANS_UNIT_SIZE; + } +} + +__global__ void D2HKernel(volatile uintptr_t* dst, const uintptr_t* src, size_t num, size_t size) +{ + auto length = num * size; + auto offset = (blockIdx.x * blockDim.x + threadIdx.x) * CUDA_TRANS_UNIT_SIZE; + while (offset + CUDA_TRANS_UNIT_SIZE <= length) { + auto idx = offset / size; + auto off = offset % size; + D2HUnit(((uint8_t*)dst[idx]) + off, ((const uint8_t*)src[idx]) + off); + offset += CUDA_TRANS_THREAD_NUMBER * CUDA_TRANS_UNIT_SIZE; + } +} + +inline __host__ void H2DBatch(uintptr_t* dst, const volatile uintptr_t* src, size_t num, + size_t size, cudaStream_t stream) +{ + H2DKernel<<>>(dst, src, num, size); +} + +inline __host__ void D2HBatch(volatile uintptr_t* dst, const uintptr_t* src, size_t num, + size_t size, cudaStream_t stream) +{ + D2HKernel<<>>(dst, src, num, size); +} + +template <> +struct fmt::formatter : formatter { + auto format(cudaError_t err, format_context& ctx) const -> format_context::iterator + { + return formatter::format(err, ctx); + } +}; + +namespace UC { + +template +Status CudaApi(const char* caller, const char* file, const size_t line, const char* name, Api&& api, + Args&&... args) +{ + auto ret = std::invoke(api, args...); + if (ret != cudaSuccess) { + UC_ERROR("CUDA ERROR: api={}, code={}, err={}, caller={},{}:{}.", name, ret, + cudaGetErrorString(ret), caller, basename(file), line); + return Status::OsApiError(); + } + return Status::OK(); +} +#define CUDA_API(api, ...) CudaApi(__FUNCTION__, __FILE__, __LINE__, #api, api, __VA_ARGS__) + +class CudaDevice : public IBufferedDevice { + struct Closure { + std::function cb; + explicit Closure(std::function cb) : cb{cb} {} + }; + + static void Trampoline(cudaStream_t stream, cudaError_t ret, void* data) + { + (void)stream; + auto c = (Closure*)data; + c->cb(ret == cudaSuccess); + delete c; + } + static void* MakeDeviceArray(const void* hostArray[], const size_t number) + { + auto size = sizeof(void*) * number; + void* deviceArray = nullptr; + auto ret = cudaMalloc(&deviceArray, size); + if (ret != cudaSuccess) { + UC_ERROR("Failed({},{}) to alloc({}) on device.", ret, cudaGetErrorString(ret), size); + return nullptr; + } + if (CUDA_API(cudaMemcpy, deviceArray, hostArray, size, cudaMemcpyHostToDevice).Success()) { + return deviceArray; + } + ReleaseDeviceArray(deviceArray); + return nullptr; + } + static void ReleaseDeviceArray(void* deviceArray) { CUDA_API(cudaFree, deviceArray); } + +public: + CudaDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber) + : IBufferedDevice{deviceId, bufferSize, bufferNumber}, stream_{nullptr} + { + } + Status Setup() override + { + auto status = Status::OK(); + if ((status = CUDA_API(cudaSetDevice, this->deviceId)).Failure()) { return status; } + if ((status = IBufferedDevice::Setup()).Failure()) { return status; } + if ((status = CUDA_API(cudaStreamCreate, (cudaStream_t*)&this->stream_)).Failure()) { + return status; + } + return status; + } + virtual Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) override + { + return CUDA_API(cudaMemcpy, dst, src, count, cudaMemcpyHostToDevice); + } + virtual Status D2HSync(std::byte* dst, const std::byte* src, const size_t count) override + { + return CUDA_API(cudaMemcpy, dst, src, count, cudaMemcpyDeviceToHost); + } + Status H2DAsync(std::byte* dst, const std::byte* src, const size_t count) override + { + return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyHostToDevice, this->stream_); + } + Status D2HAsync(std::byte* dst, const std::byte* src, const size_t count) override + { + return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost, this->stream_); + } + Status AppendCallback(std::function cb) override + { + auto* c = new (std::nothrow) Closure(cb); + if (!c) { + UC_ERROR("Failed to make closure for append cb."); + return Status::OutOfMemory(); + } + auto status = CUDA_API(cudaStreamAddCallback, this->stream_, Trampoline, (void*)c, 0); + if (status.Failure()) { delete c; } + return status; + } + Status Synchronized() override { return CUDA_API(cudaStreamSynchronize, this->stream_); } + Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number, + const size_t count) override + { + auto src = MakeDeviceArray((const void**)hArr, number); + if (!src) { return Status::OutOfMemory(); } + auto dst = MakeDeviceArray((const void**)dArr, number); + if (!dst) { + ReleaseDeviceArray(src); + return Status::OutOfMemory(); + } + H2DBatch((uintptr_t*)dst, (const volatile uintptr_t*)src, number, count, this->stream_); + auto status = this->Synchronized(); + ReleaseDeviceArray(src); + ReleaseDeviceArray(dst); + return status; + } + Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number, + const size_t count) override + { + auto src = MakeDeviceArray((const void**)dArr, number); + if (!src) { return Status::OutOfMemory(); } + auto dst = MakeDeviceArray((const void**)hArr, number); + if (!dst) { + ReleaseDeviceArray(src); + return Status::OutOfMemory(); + } + D2HBatch((volatile uintptr_t*)dst, (const uintptr_t*)src, number, count, this->stream_); + auto status = this->Synchronized(); + ReleaseDeviceArray(src); + ReleaseDeviceArray(dst); + return status; + } + +protected: + std::shared_ptr MakeBuffer(const size_t size) override + { + std::byte* host = nullptr; + auto ret = cudaMallocHost((void**)&host, size); + if (ret != cudaSuccess) { + UC_ERROR("CUDA ERROR: api=cudaMallocHost, code={}.", ret); + return nullptr; + } + return std::shared_ptr(host, cudaFreeHost); + } + +private: + cudaStream_t stream_; +}; + +std::unique_ptr DeviceFactory::Make(const int32_t deviceId, const size_t bufferSize, + const size_t bufferNumber) +{ + try { + return std::make_unique(deviceId, bufferSize, bufferNumber); + } catch (const std::exception& e) { + UC_ERROR("Failed({}) to make cuda device({},{},{}).", e.what(), deviceId, bufferSize, + bufferNumber); + return nullptr; + } +} + +} // namespace UC diff --git a/ucm/store/device/ibuffered_device.h b/ucm/store/device/ibuffered_device.h index 532817056..a56ce67ac 100644 --- a/ucm/store/device/ibuffered_device.h +++ b/ucm/store/device/ibuffered_device.h @@ -25,11 +25,37 @@ #define UNIFIEDCACHE_IBUFFERED_DEVICE_H #include "idevice.h" -#include "thread/index_pool.h" namespace UC { class IBufferedDevice : public IDevice { + class LinearBuffer { + std::shared_ptr addr_{nullptr}; + size_t index_{0}; + size_t number_{0}; + size_t size_{0}; + + public: + void Setup(std::shared_ptr addr, const size_t number, const size_t size) + { + this->addr_ = addr; + this->number_ = number; + this->size_ = size; + this->Reset(); + } + void Reset() noexcept { this->index_ = 0; } + bool Full() const noexcept { return this->index_ == this->number_; } + bool Available(const size_t size) const noexcept { return this->size_ >= size; } + std::shared_ptr Get() noexcept + { + auto addr = this->addr_.get(); + auto buffer = addr + this->size_ * this->index_; + ++this->index_; + return std::shared_ptr(buffer, [](auto) {}); + } + }; + LinearBuffer buffer_; + public: IBufferedDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber) : IDevice{deviceId, bufferSize, bufferNumber} @@ -38,29 +64,21 @@ class IBufferedDevice : public IDevice { Status Setup() override { auto totalSize = this->bufferSize * this->bufferNumber; - this->_addr = this->MakeBuffer(totalSize); - if (!this->_addr) { return Status::OutOfMemory(); } - this->_indexPool.Setup(this->bufferNumber); + if (totalSize == 0) { return Status::OK(); } + auto addr = this->MakeBuffer(totalSize); + if (!addr) { return Status::OutOfMemory(); } + this->buffer_.Setup(addr, this->bufferNumber, this->bufferSize); return Status::OK(); } virtual std::shared_ptr GetBuffer(const size_t size) override { - auto idx = IndexPool::npos; - if (size <= this->bufferSize && (idx = this->_indexPool.Acquire()) != IndexPool::npos) { - auto ptr = this->_addr.get() + this->bufferSize * idx; - return std::shared_ptr( - ptr, [this, idx](std::byte*) { this->_indexPool.Release(idx); }); + if (this->buffer_.Full()) { + auto status = this->Synchronized(); + if (status.Failure()) { return nullptr; } + this->buffer_.Reset(); } - auto buffer = this->MakeBuffer(size); - if (buffer) { return buffer; } - auto host = (std::byte*)malloc(size); - if (host) { return std::shared_ptr(host, free); } - return nullptr; + return this->buffer_.Available(size) ? this->buffer_.Get() : this->MakeBuffer(size); } - -private: - std::shared_ptr _addr; - IndexPool _indexPool; }; } // namespace UC diff --git a/ucm/store/device/idevice.h b/ucm/store/device/idevice.h index 993fb5479..8670df3bf 100644 --- a/ucm/store/device/idevice.h +++ b/ucm/store/device/idevice.h @@ -39,9 +39,16 @@ class IDevice { virtual ~IDevice() = default; virtual Status Setup() = 0; virtual std::shared_ptr GetBuffer(const size_t size) = 0; + virtual Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) = 0; + virtual Status D2HSync(std::byte* dst, const std::byte* src, const size_t count) = 0; virtual Status H2DAsync(std::byte* dst, const std::byte* src, const size_t count) = 0; virtual Status D2HAsync(std::byte* dst, const std::byte* src, const size_t count) = 0; virtual Status AppendCallback(std::function cb) = 0; + virtual Status Synchronized() = 0; + virtual Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number, + const size_t count) = 0; + virtual Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number, + const size_t count) = 0; protected: virtual std::shared_ptr MakeBuffer(const size_t size) = 0; diff --git a/ucm/store/device/musa/CMakeLists.txt b/ucm/store/device/musa/CMakeLists.txt new file mode 100644 index 000000000..2e1ff3a75 --- /dev/null +++ b/ucm/store/device/musa/CMakeLists.txt @@ -0,0 +1,9 @@ +set(MUSA_ROOT "/usr/local/musa/" CACHE PATH "Path to MUSA root directory") +add_library(Musa::musart UNKNOWN IMPORTED) +set_target_properties(Musa::musart PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${MUSA_ROOT}/include" + IMPORTED_LOCATION "${MUSA_ROOT}/lib/libmusart.so" +) + +add_library(storedevice STATIC musa_device.cc) +target_link_libraries(storedevice PUBLIC storeinfra Musa::musart) diff --git a/ucm/store/device/cuda/cuda_device.cc b/ucm/store/device/musa/musa_device.cc similarity index 55% rename from ucm/store/device/cuda/cuda_device.cc rename to ucm/store/device/musa/musa_device.cc index 5e0c32865..66ad88d5c 100644 --- a/ucm/store/device/cuda/cuda_device.cc +++ b/ucm/store/device/musa/musa_device.cc @@ -1,7 +1,7 @@ /** * MIT License * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (c) 2025 MThreads Technologies Co., Ltd. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -21,13 +21,13 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#include +#include #include "ibuffered_device.h" #include "logger/logger.h" template <> -struct fmt::formatter : formatter { - auto format(cudaError_t err, format_context& ctx) const -> format_context::iterator +struct fmt::formatter : formatter { + auto format(musaError_t err, format_context& ctx) const -> format_context::iterator { return formatter::format(err, ctx); } @@ -36,57 +36,66 @@ struct fmt::formatter : formatter { namespace UC { template -Status CudaApi(const char* caller, const char* file, const size_t line, const char* name, Api&& api, +Status MusaApi(const char* caller, const char* file, const size_t line, const char* name, Api&& api, Args&&... args) { auto ret = api(args...); - if (ret != cudaSuccess) { - UC_ERROR("CUDA ERROR: api={}, code={}, err={}, caller={},{}:{}.", name, ret, - cudaGetErrorString(ret), caller, basename(file), line); + if (ret != musaSuccess) { + UC_ERROR("MUSA ERROR: api={}, code={}, err={}, caller={},{}:{}.", name, ret, + musaGetErrorString(ret), caller, basename(file), line); return Status::OsApiError(); } return Status::OK(); } -#define CUDA_API(api, ...) CudaApi(__FUNCTION__, __FILE__, __LINE__, #api, api, __VA_ARGS__) +#define MUSA_API(api, ...) MusaApi(__FUNCTION__, __FILE__, __LINE__, #api, api, __VA_ARGS__) -class CudaDevice : public IBufferedDevice { +class MusaDevice : public IBufferedDevice { struct Closure { std::function cb; explicit Closure(std::function cb) : cb{cb} {} }; - static void Trampoline(cudaStream_t stream, cudaError_t ret, void* data) + static void Trampoline(musaStream_t stream, musaError_t ret, void* data) { (void)stream; auto c = (Closure*)data; - c->cb(ret == cudaSuccess); + c->cb(ret == musaSuccess); delete c; } public: - CudaDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber) + MusaDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber) : IBufferedDevice{deviceId, bufferSize, bufferNumber}, stream_{nullptr} { } Status Setup() override { auto status = Status::OK(); - if ((status = CUDA_API(cudaSetDevice, this->deviceId)).Failure()) { return status; } + if ((status = MUSA_API(musaSetDevice, this->deviceId)).Failure()) { return status; } if ((status = IBufferedDevice::Setup()).Failure()) { return status; } - if ((status = CUDA_API(cudaStreamCreate, (cudaStream_t*)&this->stream_)).Failure()) { + if ((status = MUSA_API(musaStreamCreate, (musaStream_t*)&this->stream_)).Failure()) { return status; } return status; } + Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) override + { + return MUSA_API(musaMemcpy, dst, src, count, musaMemcpyHostToDevice); + } + Status D2HSync(std::byte* dst, const std::byte* src, const size_t count) override + { + return MUSA_API(musaMemcpy, dst, src, count, musaMemcpyDeviceToHost); + } + Status H2DAsync(std::byte* dst, const std::byte* src, const size_t count) override { - return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyHostToDevice, - (cudaStream_t)this->stream_); + return MUSA_API(musaMemcpyAsync, dst, src, count, musaMemcpyHostToDevice, + (musaStream_t)this->stream_); } Status D2HAsync(std::byte* dst, const std::byte* src, const size_t count) override { - return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost, - (cudaStream_t)this->stream_); + return MUSA_API(musaMemcpyAsync, dst, src, count, musaMemcpyDeviceToHost, + (musaStream_t)this->stream_); } Status AppendCallback(std::function cb) override { @@ -96,34 +105,56 @@ class CudaDevice : public IBufferedDevice { return Status::OutOfMemory(); } auto status = - CUDA_API(cudaStreamAddCallback, (cudaStream_t)this->stream_, Trampoline, (void*)c, 0); + MUSA_API(musaStreamAddCallback, (musaStream_t)this->stream_, Trampoline, (void*)c, 0); if (status.Failure()) { delete c; } return status; } + Status Synchronized() override { return MUSA_API(musaStreamSynchronize, this->stream_); } + + Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number, + const size_t count) override + { + for (size_t i = 0; i < number; i++) { + auto status = this->H2DSync(dArr[i], hArr[i], count); + if (status.Failure()) { return status; } + } + return Status::OK(); + } + Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number, + const size_t count) override + { + for (size_t i = 0; i < number; i++) { + auto status = this->D2HSync(hArr[i], dArr[i], count); + if (status.Failure()) { return status; } + } + return Status::OK(); + } + protected: std::shared_ptr MakeBuffer(const size_t size) override { std::byte* host = nullptr; - auto ret = cudaMallocHost((void**)&host, size); - if (ret != cudaSuccess) { - UC_ERROR("CUDA ERROR: api=cudaMallocHost, code={}.", ret); + auto ret = musaMallocHost((void**)&host, size); + if (ret != musaSuccess) { + UC_ERROR("MUSA ERROR: api=musaMallocHost, code={}.", ret); return nullptr; } - return std::shared_ptr(host, cudaFreeHost); + return std::shared_ptr(host, musaFreeHost); } private: - void* stream_; + musaStream_t stream_; + }; std::unique_ptr DeviceFactory::Make(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber) { try { - return std::make_unique(deviceId, bufferSize, bufferNumber); + return std::make_unique(deviceId, bufferSize, bufferNumber); } catch (const std::exception& e) { - UC_ERROR("Failed({}) to make cuda device({},{},{}).", e.what(), deviceId, bufferSize, + UC_ERROR("Failed({}) to make musa device({},{},{}).", e.what(), deviceId, bufferSize, bufferNumber); return nullptr; } diff --git a/ucm/store/device/musa/musa_device.mu b/ucm/store/device/musa/musa_device.mu new file mode 100644 index 000000000..81d46245a --- /dev/null +++ b/ucm/store/device/musa/musa_device.mu @@ -0,0 +1,241 @@ +/** + * MIT License + * + * Copyright (c) 2025 Mthreads Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include +#include "ibuffered_device.h" +#include "logger/logger.h" + +#define MUSA_TRANS_UNIT_SIZE (sizeof(uint64_t) * 2) +#define MUSA_TRANS_BLOCK_NUMBER (32) +#define MUSA_TRANS_BLOCK_SIZE (256) +#define MUSA_TRANS_THREAD_NUMBER (MUSA_TRANS_BLOCK_NUMBER * MUSA_TRANS_BLOCK_SIZE) + +inline __device__ void H2DUnit(uint8_t* __restrict__ dst, const volatile uint8_t* __restrict__ src) +{ + uint64_t a, b; + asm volatile("ld.global.cs.v2.u64 {%0, %1}, [%2];" : "=l"(a), "=l"(b) : "l"(src)); + asm volatile("st.global.cg.v2.u64 [%0], {%1, %2};" ::"l"(dst), "l"(a), "l"(b)); +} + +inline __device__ void D2HUnit(volatile uint8_t* __restrict__ dst, const uint8_t* __restrict__ src) +{ + uint64_t a, b; + asm volatile("ld.global.cs.v2.u64 {%0, %1}, [%2];" : "=l"(a), "=l"(b) : "l"(src)); + asm volatile("st.volatile.global.v2.u64 [%0], {%1, %2};" ::"l"(dst), "l"(a), "l"(b)); +} + +__global__ void H2DKernel(uintptr_t* dst, const volatile uintptr_t* src, size_t num, size_t size) +{ + auto length = num * size; + auto offset = (blockIdx.x * blockDim.x + threadIdx.x) * MUSA_TRANS_UNIT_SIZE; + while (offset + MUSA_TRANS_UNIT_SIZE <= length) { + auto idx = offset / size; + auto off = offset % size; + H2DUnit(((uint8_t*)dst[idx]) + off, ((const uint8_t*)src[idx]) + off); + offset += MUSA_TRANS_THREAD_NUMBER * MUSA_TRANS_UNIT_SIZE; + } +} + +__global__ void D2HKernel(volatile uintptr_t* dst, const uintptr_t* src, size_t num, size_t size) +{ + auto length = num * size; + auto offset = (blockIdx.x * blockDim.x + threadIdx.x) * MUSA_TRANS_UNIT_SIZE; + while (offset + MUSA_TRANS_UNIT_SIZE <= length) { + auto idx = offset / size; + auto off = offset % size; + D2HUnit(((uint8_t*)dst[idx]) + off, ((const uint8_t*)src[idx]) + off); + offset += MUSA_TRANS_THREAD_NUMBER * MUSA_TRANS_UNIT_SIZE; + } +} + +inline __host__ void H2DBatch(uintptr_t* dst, const volatile uintptr_t* src, size_t num, + size_t size, musaStream_t stream) +{ + H2DKernel<<>>(dst, src, num, size); +} + +inline __host__ void D2HBatch(volatile uintptr_t* dst, const uintptr_t* src, size_t num, + size_t size, musaStream_t stream) +{ + D2HKernel<<>>(dst, src, num, size); +} + +template <> +struct fmt::formatter : formatter { + auto format(musaError_t err, format_context& ctx) const -> format_context::iterator + { + return formatter::format(err, ctx); + } +}; + +namespace UC { + +template +Status MusaApi(const char* caller, const char* file, const size_t line, const char* name, Api&& api, + Args&&... args) +{ + auto ret = std::invoke(api, args...); + if (ret != musaSuccess) { + UC_ERROR("MUSA ERROR: api={}, code={}, err={}, caller={},{}:{}.", name, ret, + musaGetErrorString(ret), caller, basename(file), line); + return Status::OsApiError(); + } + return Status::OK(); +} +#define MUSA_API(api, ...) MusaApi(__FUNCTION__, __FILE__, __LINE__, #api, api, __VA_ARGS__) + +class MusaDevice : public IBufferedDevice { + struct Closure { + std::function cb; + explicit Closure(std::function cb) : cb{cb} {} + }; + + static void Trampoline(musaStream_t stream, musaError_t ret, void* data) + { + (void)stream; + auto c = (Closure*)data; + c->cb(ret == musaSuccess); + delete c; + } + static void* MakeDeviceArray(const void* hostArray[], const size_t number) + { + auto size = sizeof(void*) * number; + void* deviceArray = nullptr; + auto ret = musaMalloc(&deviceArray, size); + if (ret != musaSuccess) { + UC_ERROR("Failed({},{}) to alloc({}) on device.", ret, musaGetErrorString(ret), size); + return nullptr; + } + if (MUSA_API(musaMemcpy, deviceArray, hostArray, size, musaMemcpyHostToDevice).Success()) { + return deviceArray; + } + ReleaseDeviceArray(deviceArray); + return nullptr; + } + static void ReleaseDeviceArray(void* deviceArray) { MUSA_API(musaFree, deviceArray); } + +public: + MusaDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber) + : IBufferedDevice{deviceId, bufferSize, bufferNumber}, stream_{nullptr} + { + } + Status Setup() override + { + auto status = Status::OK(); + if ((status = MUSA_API(musaSetDevice, this->deviceId)).Failure()) { return status; } + if ((status = IBufferedDevice::Setup()).Failure()) { return status; } + if ((status = MUSA_API(musaStreamCreate, (musaStream_t*)&this->stream_)).Failure()) { + return status; + } + return status; + } + virtual Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) override + { + return MUSA_API(musaMemcpy, dst, src, count, musaMemcpyHostToDevice); + } + virtual Status D2HSync(std::byte* dst, const std::byte* src, const size_t count) override + { + return MUSA_API(musaMemcpy, dst, src, count, musaMemcpyDeviceToHost); + } + Status H2DAsync(std::byte* dst, const std::byte* src, const size_t count) override + { + return MUSA_API(musaMemcpyAsync, dst, src, count, musaMemcpyHostToDevice, this->stream_); + } + Status D2HAsync(std::byte* dst, const std::byte* src, const size_t count) override + { + return MUSA_API(musaMemcpyAsync, dst, src, count, musaMemcpyDeviceToHost, this->stream_); + } + Status AppendCallback(std::function cb) override + { + auto* c = new (std::nothrow) Closure(cb); + if (!c) { + UC_ERROR("Failed to make closure for append cb."); + return Status::OutOfMemory(); + } + auto status = MUSA_API(musaStreamAddCallback, this->stream_, Trampoline, (void*)c, 0); + if (status.Failure()) { delete c; } + return status; + } + Status Synchronized() override { return MUSA_API(musaStreamSynchronize, this->stream_); } + Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number, + const size_t count) override + { + auto src = MakeDeviceArray((const void**)hArr, number); + if (!src) { return Status::OutOfMemory(); } + auto dst = MakeDeviceArray((const void**)dArr, number); + if (!dst) { + ReleaseDeviceArray(src); + return Status::OutOfMemory(); + } + H2DBatch((uintptr_t*)dst, (const volatile uintptr_t*)src, number, count, this->stream_); + auto status = this->Synchronized(); + ReleaseDeviceArray(src); + ReleaseDeviceArray(dst); + return status; + } + Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number, + const size_t count) override + { + auto src = MakeDeviceArray((const void**)dArr, number); + if (!src) { return Status::OutOfMemory(); } + auto dst = MakeDeviceArray((const void**)hArr, number); + if (!dst) { + ReleaseDeviceArray(src); + return Status::OutOfMemory(); + } + D2HBatch((volatile uintptr_t*)dst, (const uintptr_t*)src, number, count, this->stream_); + auto status = this->Synchronized(); + ReleaseDeviceArray(src); + ReleaseDeviceArray(dst); + return status; + } + +protected: + std::shared_ptr MakeBuffer(const size_t size) override + { + std::byte* host = nullptr; + auto ret = musaMallocHost((void**)&host, size); + if (ret != musaSuccess) { + UC_ERROR("MUSA ERROR: api=musaMallocHost, code={}.", ret); + return nullptr; + } + return std::shared_ptr(host, musaFreeHost); + } + +private: + musaStream_t stream_; +}; + +std::unique_ptr DeviceFactory::Make(const int32_t deviceId, const size_t bufferSize, + const size_t bufferNumber) +{ + try { + return std::make_unique(deviceId, bufferSize, bufferNumber); + } catch (const std::exception& e) { + UC_ERROR("Failed({}) to make musa device({},{},{}).", e.what(), deviceId, bufferSize, + bufferNumber); + return nullptr; + } +} + +} // namespace UC diff --git a/ucm/store/device/simu/simu_device.cc b/ucm/store/device/simu/simu_device.cc index 5f3a4369d..a26b71744 100644 --- a/ucm/store/device/simu/simu_device.cc +++ b/ucm/store/device/simu/simu_device.cc @@ -23,6 +23,7 @@ * */ #include "ibuffered_device.h" #include "logger/logger.h" +#include "thread/latch.h" #include "thread/thread_pool.h" namespace UC { @@ -38,7 +39,19 @@ class SimuDevice : public IBufferedDevice { { auto status = IBufferedDevice::Setup(); if (status.Failure()) { return status; } - if (!this->backend_.Setup([](auto& task) { task(); })) { return Status::Error(); } + if (!this->backend_.SetWorkerFn([](auto& task, const auto&) { task(); }).Run()) { + return Status::Error(); + } + return Status::OK(); + } + Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) override + { + std::copy(src, src + count, dst); + return Status::OK(); + } + Status D2HSync(std::byte* dst, const std::byte* src, const size_t count) override + { + std::copy(src, src + count, dst); return Status::OK(); } Status H2DAsync(std::byte* dst, const std::byte* src, const size_t count) override @@ -64,6 +77,31 @@ class SimuDevice : public IBufferedDevice { this->backend_.Push([=] { cb(true); }); return Status::OK(); } + Status Synchronized() override + { + Latch waiter{1}; + this->backend_.Push([&] { waiter.Done(nullptr); }); + waiter.Wait(); + return Status::OK(); + } + Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number, + const size_t count) override + { + for (size_t i = 0; i < number; i++) { + auto status = this->H2DSync(dArr[i], hArr[i], count); + if (status.Failure()) { return status; } + } + return Status::OK(); + } + Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number, + const size_t count) override + { + for (size_t i = 0; i < number; i++) { + auto status = this->D2HSync(hArr[i], dArr[i], count); + if (status.Failure()) { return status; } + } + return Status::OK(); + } protected: std::shared_ptr MakeBuffer(const size_t size) override diff --git a/ucm/store/dramstore/CMakeLists.txt b/ucm/store/dramstore/CMakeLists.txt index 53c9bce17..e69de29bb 100644 --- a/ucm/store/dramstore/CMakeLists.txt +++ b/ucm/store/dramstore/CMakeLists.txt @@ -1,9 +0,0 @@ -file(GLOB_RECURSE UCMSTORE_DRAM_CC_SOURCE_FILES "./cc/*.cc") -add_library(dramstore STATIC ${UCMSTORE_DRAM_CC_SOURCE_FILES}) -target_include_directories(dramstore PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cc) -target_link_libraries(dramstore PUBLIC storeinfra) - -file(GLOB_RECURSE UCMSTORE_DRAM_CPY_SOURCE_FILES "./cpy/*.cc") -pybind11_add_module(ucmdramstore ${UCMSTORE_DRAM_CPY_SOURCE_FILES}) -target_link_libraries(ucmdramstore PRIVATE dramstore) -set_target_properties(ucmdramstore PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/ucm/store/dramstore/cpy/dramstore.py.cc b/ucm/store/dramstore/cpy/dramstore.py.cc deleted file mode 100644 index 635e9144a..000000000 --- a/ucm/store/dramstore/cpy/dramstore.py.cc +++ /dev/null @@ -1,121 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#include "api/dramstore.h" -#include - -namespace py = pybind11; - -namespace UC { - -class DRAMStorePy : public DRAMStore { -public: - void* CCStoreImpl() { return this; } - py::list AllocBatch(const py::list& blocks) - { - py::list results; - for (auto& block : blocks) { results.append(this->Alloc(block.cast())); } - return results; - } - py::list LookupBatch(const py::list& blocks) - { - py::list founds; - for (auto& block : blocks) { founds.append(this->Lookup(block.cast())); } - return founds; - } - void CommitBatch(const py::list& blocks, const bool success) - { - for (auto& block : blocks) { this->Commit(block.cast(), success); } - } - py::tuple CheckPy(const size_t task) - { - auto finish = false; - auto ret = this->Check(task, finish); - return py::make_tuple(ret, finish); - } - size_t Load(const py::list& blockIds, const py::list& offsets, const py::list& addresses, - const py::list& lengths) - { - return this->SubmitPy(blockIds, offsets, addresses, lengths, CCStore::Task::Type::LOAD, - CCStore::Task::Location::DEVICE, "DRAM::H2D"); - } - size_t Dump(const py::list& blockIds, const py::list& offsets, const py::list& addresses, - const py::list& lengths) - { - return this->SubmitPy(blockIds, offsets, addresses, lengths, CCStore::Task::Type::DUMP, - CCStore::Task::Location::DEVICE, "DRAM::D2H"); - } - -private: - size_t SubmitPy(const py::list& blockIds, const py::list& offsets, const py::list& addresses, - const py::list& lengths, const CCStore::Task::Type type, - const CCStore::Task::Location location, const std::string& brief) - { - CCStore::Task task{type, location, brief}; - auto blockId = blockIds.begin(); - auto offset = offsets.begin(); - auto address = addresses.begin(); - auto length = lengths.begin(); - while ((blockId != blockIds.end()) && (offset != offsets.end()) && - (address != addresses.end()) && (length != lengths.end())) { - auto ret = task.Append(blockId->cast(), offset->cast(), - address->cast(), length->cast()); - if (ret != 0) { return CCStore::invalidTaskId; } - blockId++; - offset++; - address++; - length++; - } - return this->Submit(std::move(task)); - } -}; - -} // namespace UC - -PYBIND11_MODULE(ucmdramstore, module) -{ - module.attr("project") = UCM_PROJECT_NAME; - module.attr("version") = UCM_PROJECT_VERSION; - module.attr("commit_id") = UCM_COMMIT_ID; - module.attr("build_type") = UCM_BUILD_TYPE; - auto store = py::class_(module, "DRAMStore"); - auto config = py::class_(store, "Config"); - config.def(py::init(), py::arg("ioSize"), py::arg("capacity")); - config.def_readwrite("ioSize", &UC::DRAMStorePy::Config::ioSize); - config.def_readwrite("capacity", &UC::DRAMStorePy::Config::capacity); - config.def_readwrite("deviceId", &UC::DRAMStorePy::Config::deviceId); - store.def(py::init<>()); - store.def("CCStoreImpl", &UC::DRAMStorePy::CCStoreImpl); - store.def("Setup", &UC::DRAMStorePy::Setup); - store.def("Alloc", py::overload_cast(&UC::DRAMStorePy::Alloc)); - store.def("AllocBatch", &UC::DRAMStorePy::AllocBatch); - store.def("Lookup", py::overload_cast(&UC::DRAMStorePy::Lookup)); - store.def("LookupBatch", &UC::DRAMStorePy::LookupBatch); - store.def("Load", &UC::DRAMStorePy::Load); - store.def("Dump", &UC::DRAMStorePy::Dump); - store.def("Wait", &UC::DRAMStorePy::Wait); - store.def("Check", &UC::DRAMStorePy::Check); - store.def("Commit", - py::overload_cast(&UC::DRAMStorePy::Commit)); - store.def("CommitBatch", &UC::DRAMStorePy::CommitBatch); -} diff --git a/ucm/store/dramstore/dramstore_connector.py b/ucm/store/dramstore/dramstore_connector.py index 17d174137..24f4306bc 100644 --- a/ucm/store/dramstore/dramstore_connector.py +++ b/ucm/store/dramstore/dramstore_connector.py @@ -37,6 +37,8 @@ if torch.cuda.is_available(): device = torch.cuda +elif hasattr(torch, "musa") and torch.musa.is_available(): + device = torch.musa elif hasattr(torch, "npu") and torch.npu.is_available(): device = torch.npu else: @@ -154,7 +156,6 @@ def dump( task.task_id = "-1" return task else: - device.current_stream().synchronize() stream = device.Stream() task.event = device.Event(enable_timing=True) with device.stream(stream): @@ -165,6 +166,46 @@ def dump( logger.debug(f"dump block {block_ids} finished.") return task + def fetch_data( + self, + block_ids: List[str], + offset: List[int], + dst_addr: List[int], + size: List[int], + ) -> Task: + """ + load kv cache data to device. + + Args: + block_ids (List[str]): vLLM block hash. + offset(List[int]): tp > 1 scene + dst_addr: List[int]: device tensor addr ptr. + size: List[int]: device tensor size. + Returns: + task(Task). + """ + pass + + def dump_data( + self, + block_ids: List[str], + offset: List[int], + src_addr: List[int], + size: List[int], + ) -> Task: + """ + dump kv cache data from device. + + Args: + block_ids (List[str]): vLLM block hash. + offset(List[int]): tp > 1 scene + src_addr: List[int]: device tensor addr ptr. + size: List[int]: device tensor size. + Returns: + task(Task). + """ + pass + def wait(self, task: DramTask) -> int: """ wait kv cache kv transfer task finished. diff --git a/ucm/store/factory.py b/ucm/store/factory.py index ac4c0569a..8b893cda3 100644 --- a/ucm/store/factory.py +++ b/ucm/store/factory.py @@ -63,6 +63,9 @@ def create_connector(cls, connector_name: str, config: dict) -> UcmKVStoreBase: UcmConnectorFactory.register_connector( "UcmNfsStore", "ucm.store.nfsstore.nfsstore_connector", "UcmNfsStore" ) +UcmConnectorFactory.register_connector( + "UcmPcStore", "ucm.store.pcstore.pcstore_connector", "UcmPcStore" +) UcmConnectorFactory.register_connector( "UcmMooncakeStore", "ucm.store.mooncakestore.mooncake_connector", diff --git a/ucm/store/infra/CMakeLists.txt b/ucm/store/infra/CMakeLists.txt index f3e0ce727..6bc8dc4a4 100644 --- a/ucm/store/infra/CMakeLists.txt +++ b/ucm/store/infra/CMakeLists.txt @@ -1,21 +1,11 @@ add_library(storeinfra STATIC) target_include_directories(storeinfra PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) file(GLOB_RECURSE UCMSTORE_COMMON_FILE_SOURCE_FILES "file/*.cc") -if(LOGGER_BACKEND STREQUAL "spdlog") - file(GLOB_RECURSE UCMSTORE_COMMON_LOGGER_SOURCE_FILES "logger/spdlog/*.cc") -endif() -if(LOGGER_BACKEND STREQUAL "flux") - file(GLOB_RECURSE UCMSTORE_COMMON_LOGGER_SOURCE_FILES "logger/flux/*.cc") -endif() -file(GLOB_RECURSE UCMSTORE_COMMON_STATUS_SOURCE_FILES "status/*.cc") -file(GLOB_RECURSE UCMSTORE_COMMON_TEMPLATE_SOURCE_FILES "template/*.cc") -file(GLOB_RECURSE UCMSTORE_COMMON_THREAD_SOURCE_FILES "thread/*.cc") target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_FILE_SOURCE_FILES}) -target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_LOGGER_SOURCE_FILES}) -target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_STATUS_SOURCE_FILES}) -target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_TEMPLATE_SOURCE_FILES}) -target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_THREAD_SOURCE_FILES}) -target_link_libraries(storeinfra PUBLIC fmt) -if(LOGGER_BACKEND STREQUAL "spdlog") - target_link_libraries(storeinfra PUBLIC spdlog) -endif() +target_link_libraries(storeinfra PUBLIC + infra_status + infra_logger + infra_template + infra_thread + infra_time +) diff --git a/ucm/store/infra/file/file.cc b/ucm/store/infra/file/file.cc index af3f5d725..8a52074f0 100644 --- a/ucm/store/infra/file/file.cc +++ b/ucm/store/infra/file/file.cc @@ -53,22 +53,37 @@ Status File::Access(const std::string& path, const int32_t mode) return FileImpl{path}.Access(mode); } +Status File::Stat(const std::string& path, IFile::FileStat& st) +{ + FileImpl file{path}; + auto status = file.Open(IFile::OpenFlag::READ_ONLY); + if (status.Failure()) { return status; } + status = file.Stat(st); + file.Close(); + return status; +} + Status File::Read(const std::string& path, const size_t offset, const size_t length, - uintptr_t address) + uintptr_t address, const bool directIo) { FileImpl file{path}; Status status = Status::OK(); - if ((status = file.Open(IFile::OpenFlag::READ_ONLY)).Failure()) { return status; } + auto flags = directIo ? IFile::OpenFlag::READ_ONLY | IFile::OpenFlag::DIRECT + : IFile::OpenFlag::READ_ONLY; + if ((status = file.Open(flags)).Failure()) { return status; } if ((status = file.Read((void*)address, length, offset)).Failure()) { return status; } return status; } Status File::Write(const std::string& path, const size_t offset, const size_t length, - const uintptr_t address) + const uintptr_t address, const bool directIo, const bool create) { FileImpl file{path}; Status status = Status::OK(); - if ((status = file.Open(IFile::OpenFlag::WRITE_ONLY)).Failure()) { return status; } + auto flags = IFile::OpenFlag::WRITE_ONLY; + if (directIo) { flags |= IFile::OpenFlag::DIRECT; } + if (create) { flags |= IFile::OpenFlag::CREATE; } + if ((status = file.Open(flags)).Failure()) { return status; } if ((status = file.Write((const void*)address, length, offset)).Failure()) { return status; } return status; } @@ -77,4 +92,6 @@ void File::MUnmap(void* addr, size_t size) { FileImpl{{}}.MUnmap(addr, size); } void File::ShmUnlink(const std::string& path) { FileImpl{path}.ShmUnlink(); } +void File::Remove(const std::string& path) { FileImpl{path}.Remove(); } + } // namespace UC diff --git a/ucm/store/infra/file/file.h b/ucm/store/infra/file/file.h index 70cf24cc3..086518e3a 100644 --- a/ucm/store/infra/file/file.h +++ b/ucm/store/infra/file/file.h @@ -36,12 +36,15 @@ class File { static Status RmDir(const std::string& path); static Status Rename(const std::string& path, const std::string& newName); static Status Access(const std::string& path, const int32_t mode); + static Status Stat(const std::string& path, IFile::FileStat& st); static Status Read(const std::string& path, const size_t offset, const size_t length, - uintptr_t address); + uintptr_t address, const bool directIo = false); static Status Write(const std::string& path, const size_t offset, const size_t length, - const uintptr_t address); + const uintptr_t address, const bool directIo = false, + const bool create = false); static void MUnmap(void* addr, size_t size); static void ShmUnlink(const std::string& path); + static void Remove(const std::string& path); }; } // namespace UC diff --git a/ucm/store/infra/file/ifile.h b/ucm/store/infra/file/ifile.h index 4ff05a9d7..74b77cba6 100644 --- a/ucm/store/infra/file/ifile.h +++ b/ucm/store/infra/file/ifile.h @@ -70,6 +70,7 @@ class IFile { virtual Status MMap(void*& addr, size_t size, bool write, bool read, bool shared) = 0; virtual void MUnmap(void* addr, size_t size) = 0; virtual void ShmUnlink() = 0; + virtual Status UpdateTime() = 0; private: std::string path_; diff --git a/ucm/store/infra/file/posix_file.cc b/ucm/store/infra/file/posix_file.cc index c24034646..bc697f392 100644 --- a/ucm/store/infra/file/posix_file.cc +++ b/ucm/store/infra/file/posix_file.cc @@ -25,12 +25,13 @@ #include #include #include +#include #include #include "logger/logger.h" namespace UC { -static constexpr auto NewFilePerm = (S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH); +static constexpr auto NewFilePerm = (S_IREAD | S_IWRITE | S_IRGRP | S_IROTH); PosixFile::~PosixFile() { this->Close(); } @@ -121,9 +122,7 @@ void PosixFile::Remove() auto ret = remove(this->Path().c_str()); auto eno = errno; if (ret != 0) { - if (eno == ENOENT) { - UC_WARN("Failed to remove file, path: {}, file not found.", this->Path()); - } + if (eno != ENOENT) { UC_WARN("Failed({},{}) to remove file({}).", ret, eno, this->Path()); } } } @@ -232,4 +231,15 @@ void PosixFile::ShmUnlink() } } +Status PosixFile::UpdateTime() +{ + auto ret = utime(this->Path().c_str(), nullptr); + auto eno = errno; + if (ret != 0) { + UC_ERROR("Failed({},{}) to update time file({}).", ret, eno, this->Path()); + return Status::OsApiError(); + } + return Status::OK(); +} + } // namespace UC diff --git a/ucm/store/infra/file/posix_file.h b/ucm/store/infra/file/posix_file.h index a401c0c5b..becbd28a3 100644 --- a/ucm/store/infra/file/posix_file.h +++ b/ucm/store/infra/file/posix_file.h @@ -47,6 +47,7 @@ class PosixFile : public IFile { Status MMap(void*& addr, size_t size, bool write, bool read, bool shared) override; void MUnmap(void* addr, size_t size) override; void ShmUnlink() override; + Status UpdateTime() override; private: int32_t handle_; diff --git a/ucm/store/infra/status/status.h b/ucm/store/infra/status/status.h deleted file mode 100644 index 8c373e124..000000000 --- a/ucm/store/infra/status/status.h +++ /dev/null @@ -1,129 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#ifndef UNIFIEDCACHE_STATUS_H -#define UNIFIEDCACHE_STATUS_H - -#include - -namespace UC { - -class Status { - enum class Code { -#define UC_MAKE_STATUS_CODE(i) (-50000 - (i)) - OK = 0, - ERROR = -1, - EPARAM = UC_MAKE_STATUS_CODE(0), - EOOM = UC_MAKE_STATUS_CODE(1), - EOSERROR = UC_MAKE_STATUS_CODE(2), - EDUPLICATE = UC_MAKE_STATUS_CODE(3), - ERETRY = UC_MAKE_STATUS_CODE(4), - ENOOBJ = UC_MAKE_STATUS_CODE(5), - ESERIALIZE = UC_MAKE_STATUS_CODE(6), - EDESERIALIZE = UC_MAKE_STATUS_CODE(7), - EUNSUPPORTED = UC_MAKE_STATUS_CODE(8), -#undef UC_MAKE_STATUS_CODE - }; - -public: - static Status& OK() - { - static Status s{Code::OK}; - return s; - } - static Status& Error() - { - static Status s{Code::ERROR}; - return s; - } - static Status& InvalidParam() - { - static Status s{Code::EPARAM}; - return s; - } - static Status& OutOfMemory() - { - static Status s{Code::EOOM}; - return s; - } - static Status& OsApiError() - { - static Status s{Code::EOSERROR}; - return s; - } - static Status& DuplicateKey() - { - static Status s{Code::EDUPLICATE}; - return s; - } - static Status& Retry() - { - static Status s{Code::ERETRY}; - return s; - } - static Status& NotFound() - { - static Status s{Code::ENOOBJ}; - return s; - } - static Status& SerializeFailed() - { - static Status s{Code::ESERIALIZE}; - return s; - } - static Status& DeserializeFailed() - { - static Status s{Code::EDESERIALIZE}; - return s; - } - static Status& Unsupported() - { - static Status s{Code::EUNSUPPORTED}; - return s; - } - -public: - Status(const Status& status) { this->code_ = status.code_; } - Status& operator=(const Status& status) - { - if (this != &status) { this->code_ = status.code_; } - return *this; - } - bool operator==(const Status& status) const { return this->code_ == status.code_; } - bool operator!=(const Status& status) const { return this->code_ != status.code_; } - int32_t Underlying() const { return static_cast(this->code_); } - bool Success() const { return this->code_ == Code::OK; } - bool Failure() const { return this->code_ != Code::OK; } - -private: - Status(const Code code) : code_{code} {} - -private: - Code code_; -}; - -inline int32_t format_as(const Status& status) { return status.Underlying(); } - -} // namespace UC - -#endif diff --git a/ucm/store/localstore/CMakeLists.txt b/ucm/store/localstore/CMakeLists.txt index 9b4993fcf..b6112e090 100644 --- a/ucm/store/localstore/CMakeLists.txt +++ b/ucm/store/localstore/CMakeLists.txt @@ -4,7 +4,7 @@ target_include_directories(localstore PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cc/api ${CMAKE_CURRENT_SOURCE_DIR}/cc/domain ) -target_link_libraries(localstore PUBLIC storeinfra) +target_link_libraries(localstore PUBLIC storeinfra storetask) file(GLOB_RECURSE UCMSTORE_LOCAL_CPY_SOURCE_FILES "./cpy/*.cc") pybind11_add_module(ucmlocalstore ${UCMSTORE_LOCAL_CPY_SOURCE_FILES}) diff --git a/ucm/store/localstore/cc/api/localstore.h b/ucm/store/localstore/cc/api/localstore.h index 25f0f0af4..edca1e0b1 100644 --- a/ucm/store/localstore/cc/api/localstore.h +++ b/ucm/store/localstore/cc/api/localstore.h @@ -28,7 +28,7 @@ namespace UC { -class LocalStore : public CCStore { +class LocalStore : public CCStore<> { public: struct Config { size_t ioSize; diff --git a/ucm/store/localstore/cpy/localstore.py.cc b/ucm/store/localstore/cpy/localstore.py.cc index 6bb3947ba..c067df231 100644 --- a/ucm/store/localstore/cpy/localstore.py.cc +++ b/ucm/store/localstore/cpy/localstore.py.cc @@ -56,31 +56,30 @@ class LocalStorePy : public LocalStore { size_t Load(const py::list& blockIds, const py::list& offsets, const py::list& addresses, const py::list& lengths) { - return this->SubmitPy(blockIds, offsets, addresses, lengths, CCStore::Task::Type::LOAD, - CCStore::Task::Location::DEVICE, "LOCAL::S2D"); + return this->SubmitPy(blockIds, offsets, addresses, lengths, Task::Type::LOAD, + Task::Location::DEVICE, "LOCAL::S2D"); } size_t Dump(const py::list& blockIds, const py::list& offsets, const py::list& addresses, const py::list& lengths) { - return this->SubmitPy(blockIds, offsets, addresses, lengths, CCStore::Task::Type::DUMP, - CCStore::Task::Location::DEVICE, "LOCAL::D2S"); + return this->SubmitPy(blockIds, offsets, addresses, lengths, Task::Type::DUMP, + Task::Location::DEVICE, "LOCAL::D2S"); } private: size_t SubmitPy(const py::list& blockIds, const py::list& offsets, const py::list& addresses, - const py::list& lengths, const CCStore::Task::Type type, - const CCStore::Task::Location location, const std::string& brief) + const py::list& lengths, Task::Type&& type, Task::Location&& location, + std::string&& brief) { - CCStore::Task task{type, location, brief}; + Task task{std::move(type), std::move(location), std::move(brief)}; auto blockId = blockIds.begin(); auto offset = offsets.begin(); auto address = addresses.begin(); auto length = lengths.begin(); while ((blockId != blockIds.end()) && (offset != offsets.end()) && (address != addresses.end()) && (length != lengths.end())) { - auto ret = task.Append(blockId->cast(), offset->cast(), - address->cast(), length->cast()); - if (ret != 0) { return CCStore::invalidTaskId; } + task.Append(blockId->cast(), offset->cast(), + address->cast(), length->cast()); blockId++; offset++; address++; diff --git a/ucm/store/mooncakestore/mooncake_connector.py b/ucm/store/mooncakestore/mooncake_connector.py index f595d93b2..706063245 100644 --- a/ucm/store/mooncakestore/mooncake_connector.py +++ b/ucm/store/mooncakestore/mooncake_connector.py @@ -259,6 +259,28 @@ async def _dump_impl( raise TypeError("Mooncake Store Put Type Error.") from err return 0 + def fetch_data( + self, + block_ids: List[str], + offset: List[int], + dst_addr: List[int], + size: List[int], + ) -> Task: + raise NotImplementedError( + "Method(fetch_data) not yet implemented in this version" + ) + + def dump_data( + self, + block_ids: List[str], + offset: List[int], + src_addr: List[int], + size: List[int], + ) -> Task: + raise NotImplementedError( + "Method(dump_data) not yet implemented in this version" + ) + def wait(self, task: Task) -> int: """ wait kv cache kv transfer task finished. diff --git a/ucm/store/nfsstore/CMakeLists.txt b/ucm/store/nfsstore/CMakeLists.txt index a4671e21f..62600ffbf 100644 --- a/ucm/store/nfsstore/CMakeLists.txt +++ b/ucm/store/nfsstore/CMakeLists.txt @@ -4,7 +4,7 @@ target_include_directories(nfsstore PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cc/api ${CMAKE_CURRENT_SOURCE_DIR}/cc/domain ) -target_link_libraries(nfsstore PUBLIC storeinfra storedevice) +target_link_libraries(nfsstore PUBLIC storeinfra storedevice storetask) file(GLOB_RECURSE UCMSTORE_NFS_CPY_SOURCE_FILES "./cpy/*.cc") pybind11_add_module(ucmnfsstore ${UCMSTORE_NFS_CPY_SOURCE_FILES}) diff --git a/ucm/store/nfsstore/cc/api/nfsstore.cc b/ucm/store/nfsstore/cc/api/nfsstore.cc index ee4ca55c1..47706fffa 100644 --- a/ucm/store/nfsstore/cc/api/nfsstore.cc +++ b/ucm/store/nfsstore/cc/api/nfsstore.cc @@ -25,7 +25,8 @@ #include #include "logger/logger.h" #include "space/space_manager.h" -#include "tsf_task/tsf_task_manager.h" +#include "trans/trans_manager.h" +#include "hotness/hotness_manager.h" namespace UC { @@ -33,7 +34,9 @@ class NFSStoreImpl : public NFSStore { public: int32_t Setup(const Config& config) { - auto status = this->spaceMgr_.Setup(config.storageBackends, config.kvcacheBlockSize); + auto status = this->spaceMgr_.Setup(config.storageBackends, config.kvcacheBlockSize, + config.tempDumpDirEnable, config.storageCapacity, + config.recycleEnable, config.recycleThresholdRatio); if (status.Failure()) { UC_ERROR("Failed({}) to setup SpaceManager.", status); return status.Underlying(); @@ -42,12 +45,19 @@ class NFSStoreImpl : public NFSStore { status = this->transMgr_.Setup(config.transferDeviceId, config.transferStreamNumber, config.transferIoSize, config.transferBufferNumber, - config.transferTimeoutMs, this->spaceMgr_.GetSpaceLayout()); + this->spaceMgr_.GetSpaceLayout(), config.transferTimeoutMs, config.transferIoDirect); if (status.Failure()) { UC_ERROR("Failed({}) to setup TsfTaskManager.", status); return status.Underlying(); } } + if (config.hotnessEnable) { + status = this->hotnessMgr_.Setup(config.hotnessInterval, this->spaceMgr_.GetSpaceLayout()); + if (status.Failure()) { + UC_ERROR("Failed({}) to setup HotnessManager.", status); + return status.Underlying(); + } + } this->ShowConfig(config); return Status::OK().Underlying(); } @@ -55,7 +65,12 @@ class NFSStoreImpl : public NFSStore { { return this->spaceMgr_.NewBlock(block).Underlying(); } - bool Lookup(const std::string& block) override { return this->spaceMgr_.LookupBlock(block); } + bool Lookup(const std::string& block) override + { + auto found = this->spaceMgr_.LookupBlock(block); + if (found) { this->hotnessMgr_.Visit(block); } + return found; + } void Commit(const std::string& block, const bool success) override { this->spaceMgr_.CommitBlock(block, success); @@ -78,17 +93,10 @@ class NFSStoreImpl : public NFSStore { } size_t Submit(Task&& task) override { - std::list tasks; - for (auto& shard : task.shards) { - tasks.push_back( - {task.type, task.location, shard.block, shard.offset, shard.address, task.size}); - } - size_t taskId; - return this->transMgr_ - .Submit(tasks, task.number * task.size, task.number, task.brief, taskId) - .Success() - ? taskId - : CCStore::invalidTaskId; + auto taskId = Task::invalid; + auto status = this->transMgr_.Submit(std::move(task), taskId); + if (status.Failure()) { taskId = Task::invalid; } + return taskId; } int32_t Wait(const size_t task) override { return this->transMgr_.Wait(task).Underlying(); } int32_t Check(const size_t task, bool& finish) override @@ -110,11 +118,19 @@ class NFSStoreImpl : public NFSStore { UC_INFO("Set UC::IOSize to {}.", config.transferIoSize); UC_INFO("Set UC::BufferNumber to {}.", config.transferBufferNumber); UC_INFO("Set UC::TimeoutMs to {}.", config.transferTimeoutMs); + UC_INFO("Set UC::TempDumpDirEnable to {}.", config.tempDumpDirEnable); + UC_INFO("Set UC::HotnessInterval to {}.", config.hotnessInterval); + UC_INFO("Set UC::HotnessEnable to {}.", config.hotnessEnable); + UC_INFO("Set UC::storageCapacity to {}.", config.storageCapacity); + UC_INFO("Set UC::RecycleEnable to {}.", config.recycleEnable); + UC_INFO("Set UC::RecycleThreshold to {}.", config.recycleThresholdRatio); + UC_INFO("Set UC::IoDirect to {}.", config.transferIoDirect); } private: SpaceManager spaceMgr_; - TsfTaskManager transMgr_; + TransManager transMgr_; + HotnessManager hotnessMgr_; }; int32_t NFSStore::Setup(const Config& config) diff --git a/ucm/store/nfsstore/cc/api/nfsstore.h b/ucm/store/nfsstore/cc/api/nfsstore.h index b6bd4ba6c..356aff2bb 100644 --- a/ucm/store/nfsstore/cc/api/nfsstore.h +++ b/ucm/store/nfsstore/cc/api/nfsstore.h @@ -29,7 +29,7 @@ namespace UC { -class NFSStore : public CCStore { +class NFSStore : public CCStore<> { public: struct Config { std::vector storageBackends; @@ -40,12 +40,22 @@ class NFSStore : public CCStore { size_t transferIoSize; size_t transferBufferNumber; size_t transferTimeoutMs; + bool tempDumpDirEnable; + bool hotnessEnable; + size_t hotnessInterval; + size_t storageCapacity; + bool recycleEnable; + float recycleThresholdRatio; + bool transferIoDirect; Config(const std::vector& storageBackends, const size_t kvcacheBlockSize, const bool transferEnable) : storageBackends{storageBackends}, kvcacheBlockSize{kvcacheBlockSize}, transferEnable{transferEnable}, transferDeviceId{-1}, transferStreamNumber{32}, - transferIoSize{262144}, transferBufferNumber{512}, transferTimeoutMs{30000} + transferIoSize{262144}, transferBufferNumber{512}, transferTimeoutMs{30000}, + tempDumpDirEnable{false}, hotnessEnable{true}, hotnessInterval{60}, + storageCapacity{0}, recycleEnable{true}, recycleThresholdRatio{0.7f}, + transferIoDirect{false} { } }; diff --git a/ucm/store/nfsstore/cc/domain/hotness/hotness_manager.h b/ucm/store/nfsstore/cc/domain/hotness/hotness_manager.h new file mode 100644 index 000000000..6572ee483 --- /dev/null +++ b/ucm/store/nfsstore/cc/domain/hotness/hotness_manager.h @@ -0,0 +1,75 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ + +#ifndef UNIFIEDCACHE_HOTNESS_MANAGER_H +#define UNIFIEDCACHE_HOTNESS_MANAGER_H + +#include +#include +#include "hotness_set.h" +#include "hotness_timer.h" +#include "logger/logger.h" + +namespace UC { + +class HotnessManager { +public: + Status Setup(const size_t interval, const SpaceLayout* spaceLayout) + { + this->hotnessTimer_.SetInterval(interval); + this->layout_ = spaceLayout; + this->setupSuccess_ = true; + return Status::OK(); + } + + void Visit(const std::string& blockId) + { + if (!this->setupSuccess_) { + return; + } + + this->hotnessSet_.Insert(blockId); + auto old = this->serviceRunning_.load(std::memory_order_acquire); + if (old) { return; } + if (this->serviceRunning_.compare_exchange_weak(old, true, std::memory_order_acq_rel)) { + auto updater = std::bind(&HotnessSet::UpdateHotness, &this->hotnessSet_, this->layout_); + if (this->hotnessTimer_.Start(std::move(updater)).Success()) { + UC_INFO("Space hotness service started."); + return; + } + this->serviceRunning_ = old; + } + } + +private: + bool setupSuccess_{false}; + std::atomic_bool serviceRunning_{false}; + const SpaceLayout* layout_; + HotnessSet hotnessSet_; + HotnessTimer hotnessTimer_; +}; + +} // namespace UC + +#endif \ No newline at end of file diff --git a/ucm/store/nfsstore/cc/domain/hotness/hotness_set.cc b/ucm/store/nfsstore/cc/domain/hotness/hotness_set.cc new file mode 100644 index 000000000..8bc739e99 --- /dev/null +++ b/ucm/store/nfsstore/cc/domain/hotness/hotness_set.cc @@ -0,0 +1,70 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ + +#include "hotness_set.h" +#include "logger/logger.h" +#include "file/file.h" +#include "template/singleton.h" +namespace UC { + +void HotnessSet::Insert(const std::string& blockId) +{ + std::lock_guard lg(this->mutex_); + this->pendingBlocks_.insert(blockId); +} + +void HotnessSet::UpdateHotness(const SpaceLayout* spaceLayout) +{ + std::unordered_set blocksToUpdate; + { + std::lock_guard lg(this->mutex_); + if (this->pendingBlocks_.empty()) { + return; + } + blocksToUpdate.swap(this->pendingBlocks_); + } + + size_t number = 0; + for (const std::string& blockId : blocksToUpdate) { + auto blockPath = spaceLayout->DataFilePath(blockId, false); + auto file = File::Make(blockPath); + if (!file) { + UC_WARN("Failed to make file({}), blockId({}).", blockPath, blockId); + continue; + } + auto status = file->UpdateTime(); + if (status.Failure()) { + UC_WARN("Failed({}) to update time({}), blockId({}).", status, blockPath, blockId); + continue; + } + number++; + } + if (blocksToUpdate.size() == number) { + UC_INFO("All blocks are hotness."); + } else { + UC_WARN("{} of {} blocks are hotness.", blocksToUpdate.size() - number, blocksToUpdate.size()); + } +} + +} // namespace UC \ No newline at end of file diff --git a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_queue.h b/ucm/store/nfsstore/cc/domain/hotness/hotness_set.h similarity index 60% rename from ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_queue.h rename to ucm/store/nfsstore/cc/domain/hotness/hotness_set.h index 30bb9565c..4bb8beca4 100644 --- a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_queue.h +++ b/ucm/store/nfsstore/cc/domain/hotness/hotness_set.h @@ -21,40 +21,27 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_TSF_TAKS_QUEUE_H -#define UNIFIEDCACHE_TSF_TAKS_QUEUE_H -#include "idevice.h" +#ifndef UNIFIEDCACHE_HOTNESS_SET_H +#define UNIFIEDCACHE_HOTNESS_SET_H + +#include +#include #include "space/space_layout.h" -#include "thread/thread_pool.h" -#include "tsf_task.h" -#include "tsf_task_set.h" namespace UC { -class TsfTaskQueue { +class HotnessSet { public: - Status Setup(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber, - TsfTaskSet* failureSet, const SpaceLayout* layout); - void Push(std::list& tasks); - -private: - void StreamOper(TsfTask& task); - void FileOper(TsfTask& task); - void H2D(TsfTask& task); - void D2H(TsfTask& task); - void H2S(TsfTask& task); - void S2H(TsfTask& task); - void Done(const TsfTask& task, bool success); + void Insert(const std::string& blockId); + void UpdateHotness(const SpaceLayout* spaceLayout); private: - ThreadPool _streamOper; - ThreadPool _fileOper; - std::unique_ptr _device; - TsfTaskSet* _failureSet; - const SpaceLayout* _layout; + std::mutex mutex_; + std::unordered_set pendingBlocks_; }; + } // namespace UC -#endif +#endif \ No newline at end of file diff --git a/ucm/store/nfsstore/cc/domain/hotness/hotness_timer.h b/ucm/store/nfsstore/cc/domain/hotness/hotness_timer.h new file mode 100644 index 000000000..d549788ec --- /dev/null +++ b/ucm/store/nfsstore/cc/domain/hotness/hotness_timer.h @@ -0,0 +1,56 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ + +#ifndef UNIFIEDCACHE_HOTNESS_TIMER_H +#define UNIFIEDCACHE_HOTNESS_TIMER_H +#include +#include +#include "logger/logger.h" +#include "template/timer.h" + +namespace UC { + +class HotnessTimer { +public: + void SetInterval(const size_t interval) { this->interval_ = std::chrono::seconds(interval); } + Status Start(std::function callable) + { + try { + this->timer_ = std::make_unique>>(this->interval_, + std::move(callable)); + } catch (const std::exception& e) { + UC_ERROR("Failed({}) to start hotness timer.", e.what()); + return Status::OutOfMemory(); + } + return this->timer_->Start() ? Status::OK() : Status::Error(); + } + +private: + std::chrono::seconds interval_; + std::unique_ptr>> timer_; +}; + +} // namespace UC + +#endif diff --git a/ucm/store/nfsstore/cc/domain/space/space_layout.h b/ucm/store/nfsstore/cc/domain/space/space_layout.h index 46ad5196a..3712abec0 100644 --- a/ucm/store/nfsstore/cc/domain/space/space_layout.h +++ b/ucm/store/nfsstore/cc/domain/space/space_layout.h @@ -24,6 +24,7 @@ #ifndef UNIFIEDCACHE_SPACE_LAYOUT_H #define UNIFIEDCACHE_SPACE_LAYOUT_H +#include #include #include #include "status/status.h" @@ -32,20 +33,16 @@ namespace UC { class SpaceLayout { public: - Status Setup(const std::vector& storageBackends); - std::string DataFileParent(const std::string& blockId) const; - std::string DataFilePath(const std::string& blockId, bool activated) const; - -private: - Status AddStorageBackend(const std::string& path); - Status AddFirstStorageBackend(const std::string& path); - Status AddSecondaryStorageBackend(const std::string& path); - std::string StorageBackend(const std::string& blockId) const; - std::vector RelativeRoots() const; - std::string DataFileRoot() const; - -private: - std::vector storageBackends_; + struct DataIterator; +public: + virtual ~SpaceLayout() = default; + virtual Status Setup(const std::vector& storageBackends) = 0; + virtual std::string DataFileParent(const std::string& blockId, bool activated) const = 0; + virtual std::string DataFilePath(const std::string& blockId, bool activated) const = 0; + virtual std::string ClusterPropertyFilePath() const = 0; + virtual std::shared_ptr CreateFilePathIterator() const = 0; + virtual std::string NextDataFilePath(std::shared_ptr iter) const = 0; + virtual bool IsActivatedFile(const std::string& filePath) const = 0; }; } // namespace UC diff --git a/ucm/store/nfsstore/cc/domain/space/space_manager.cc b/ucm/store/nfsstore/cc/domain/space/space_manager.cc index c74af531f..7b21041be 100644 --- a/ucm/store/nfsstore/cc/domain/space/space_manager.cc +++ b/ucm/store/nfsstore/cc/domain/space/space_manager.cc @@ -24,36 +24,71 @@ #include "space_manager.h" #include "file/file.h" #include "logger/logger.h" +#include "space_shard_temp_layout.h" +#include + +constexpr auto MIN_REUSE_BLOCK_AGE = 300; // 5 minutes namespace UC { -Status SpaceManager::Setup(const std::vector& storageBackends, const size_t blockSize) +std::unique_ptr MakeSpaceLayout(const bool tempDumpDirEnable) +{ + try { + if (tempDumpDirEnable) { return std::make_unique(); } + return std::make_unique(); + } catch (const std::exception& e) { + UC_ERROR("Failed({}) to make space layout object.", e.what()); + } + return nullptr; +} + +Status SpaceManager::Setup(const std::vector& storageBackends, const size_t blockSize, + const bool tempDumpDirEnable, const size_t storageCapacity, + const bool recycleEnable, const float recycleThresholdRatio) { if (blockSize == 0) { UC_ERROR("Invalid block size({}).", blockSize); return Status::InvalidParam(); } - auto status = this->layout_.Setup(storageBackends); + this->layout_ = MakeSpaceLayout(tempDumpDirEnable); + if (!this->layout_) { return Status::OutOfMemory(); } + auto status = this->layout_->Setup(storageBackends); + if (status.Failure()) { return status; } + status = this->property_.Setup(this->layout_->ClusterPropertyFilePath()); if (status.Failure()) { return status; } + if (recycleEnable && storageCapacity > 0) { + auto totalBlocks = storageCapacity / blockSize; + status = this->recycle_.Setup(this->GetSpaceLayout(), totalBlocks, [this] { + this->property_.DecreaseCapacity(this->blockSize_); + }); + if (status.Failure()) { return status; } + } + this->blockSize_ = blockSize; + this->capacity_ = storageCapacity; + this->recycleEnable_ = recycleEnable; + this->capacityRecycleThreshold_ = static_cast(storageCapacity * recycleThresholdRatio); return Status::OK(); } -Status SpaceManager::NewBlock(const std::string& blockId) const +Status SpaceManager::NewBlock(const std::string& blockId) { - auto parent = File::Make(this->layout_.DataFileParent(blockId)); - auto file = File::Make(this->layout_.DataFilePath(blockId, true)); + Status status = this->CapacityCheck(); + if (status.Failure()) { return status; } + constexpr auto activated = true; + auto parent = File::Make(this->layout_->DataFileParent(blockId, activated)); + auto file = File::Make(this->layout_->DataFilePath(blockId, activated)); if (!parent || !file) { UC_ERROR("Failed to new block({}).", blockId); return Status::OutOfMemory(); } - auto status = parent->MkDir(); + status = parent->MkDir(); if (status == Status::DuplicateKey()) { status = Status::OK(); } if (status.Failure()) { UC_ERROR("Failed({}) to new block({}).", status, blockId); return status; } - if ((File::Access(this->layout_.DataFilePath(blockId, false), IFile::AccessMode::EXIST)) + if ((File::Access(this->layout_->DataFilePath(blockId, false), IFile::AccessMode::EXIST)) .Success()) { status = Status::DuplicateKey(); UC_ERROR("Failed({}) to new block({}).", status, blockId); @@ -64,42 +99,65 @@ Status SpaceManager::NewBlock(const std::string& blockId) const if (status.Failure()) { if (status != Status::DuplicateKey()) { UC_ERROR("Failed({}) to new block({}).", status, blockId); + return status; + } + // Reuse the active block if it is not accessed within the last 5 minutes + status = file->Open(IFile::OpenFlag::READ_WRITE); + if (status.Failure()) { + UC_ERROR("Failed({}) to open file({}).", status, file->Path()); + return status; + } + IFile::FileStat st{}; + status = file->Stat(st); + if (status.Failure()) { + UC_ERROR("Failed({}) to stat file({}).", status, file->Path()); + return status; + } + const auto now = std::chrono::system_clock::now(); + const auto lastAccess = std::chrono::system_clock::from_time_t(st.st_atime); + if (now - lastAccess <= std::chrono::seconds(MIN_REUSE_BLOCK_AGE)) { + UC_ERROR("Block({}) is active, cannot reuse it.", blockId); + return Status::DuplicateKey(); } - return status; } + status = file->Truncate(this->blockSize_); if (status.Failure()) { UC_ERROR("Failed({}) to new block({}).", status, blockId); return status; } + this->property_.IncreaseCapacity(this->blockSize_); return Status::OK(); } -Status SpaceManager::CommitBlock(const std::string& blockId, bool success) const +Status SpaceManager::CommitBlock(const std::string& blockId, bool success) { - auto file = File::Make(this->layout_.DataFilePath(blockId, true)); - if (!file) { - UC_ERROR("Failed to {} block({}).", success ? "commit" : "cancel", blockId); - return Status::OutOfMemory(); - } - if (success) { - auto status = file->Rename(this->layout_.DataFilePath(blockId, false)); - if (status.Failure()) { UC_ERROR("Failed({}) to commit block({}).", status, blockId); } - return status; - } - auto parent = File::Make(this->layout_.DataFileParent(blockId)); - if (!parent) { - UC_ERROR("Failed to cancel block({}).", blockId); - return Status::OutOfMemory(); + const auto activatedParent = this->layout_->DataFileParent(blockId, true); + const auto activatedFile = this->layout_->DataFilePath(blockId, true); + const auto archivedParent = this->layout_->DataFileParent(blockId, false); + auto status = Status::OK(); + do { + if (!success) { break; } + if (archivedParent != activatedParent) { + status = File::MkDir(archivedParent); + if (status == Status::DuplicateKey()) { status = Status::OK(); } + if (status.Failure()) { break; } + } + const auto archivedFile = this->layout_->DataFilePath(blockId, false); + status = File::Rename(activatedFile, archivedFile); + } while (0); + File::Remove(activatedFile); + if (!success || archivedParent != activatedParent) { File::RmDir(activatedParent); } + if (status.Failure()) { + UC_ERROR("Failed({}) to {} block({}).", status, success ? "commit" : "cancel", blockId); } - file->Remove(); - parent->RmDir(); - return Status::OK(); + this->property_.DecreaseCapacity(this->blockSize_); + return status; } bool SpaceManager::LookupBlock(const std::string& blockId) const { - auto path = this->layout_.DataFilePath(blockId, false); + auto path = this->layout_->DataFilePath(blockId, false); auto file = File::Make(path); if (!file) { UC_ERROR("Failed to make file smart pointer, path: {}.", path); @@ -116,6 +174,22 @@ bool SpaceManager::LookupBlock(const std::string& blockId) const return true; } -const SpaceLayout* SpaceManager::GetSpaceLayout() const { return &this->layout_; } +const SpaceLayout* SpaceManager::GetSpaceLayout() const { return this->layout_.get(); } + +Status SpaceManager::CapacityCheck() +{ + if (this->capacity_ == 0) { return Status::OK(); } + + const size_t used = this->property_.GetCapacity(); + if (this->recycleEnable_ && used >= this->capacityRecycleThreshold_) { + this->recycle_.Trigger(); + } + if (used > this->capacity_ - this->blockSize_) { + UC_ERROR("Capacity is not enough, capacity: {}, current: {}, block size: {}.", + this->capacity_, used, this->blockSize_); + return Status::NoSpace(); + } + return Status::OK(); +} } // namespace UC diff --git a/ucm/store/nfsstore/cc/domain/space/space_manager.h b/ucm/store/nfsstore/cc/domain/space/space_manager.h index c63eb5108..e6690a95c 100644 --- a/ucm/store/nfsstore/cc/domain/space/space_manager.h +++ b/ucm/store/nfsstore/cc/domain/space/space_manager.h @@ -24,22 +24,34 @@ #ifndef UNIFIEDCACHE_SPACE_MANAGER_H #define UNIFIEDCACHE_SPACE_MANAGER_H +#include #include "space_layout.h" +#include "space_property.h" #include "status/status.h" +#include "space_recycle.h" namespace UC { class SpaceManager { public: - Status Setup(const std::vector& storageBackends, const size_t blockSize); - Status NewBlock(const std::string& blockId) const; - Status CommitBlock(const std::string& blockId, bool success = true) const; + Status Setup(const std::vector& storageBackends, const size_t blockSize, + const bool tempDumpDirEnable, const size_t storageCapacity = 0, + const bool recycleEnable = false, const float recycleThresholdRatio = 0.7f); + Status NewBlock(const std::string& blockId); + Status CommitBlock(const std::string& blockId, bool success = true); bool LookupBlock(const std::string& blockId) const; const SpaceLayout* GetSpaceLayout() const; private: - SpaceLayout layout_; + Status CapacityCheck(); +private: + std::unique_ptr layout_; + SpaceProperty property_; + SpaceRecycle recycle_; size_t blockSize_; + size_t capacity_; + bool recycleEnable_; + size_t capacityRecycleThreshold_; }; } // namespace UC diff --git a/ucm/store/nfsstore/cc/domain/space/space_property.cc b/ucm/store/nfsstore/cc/domain/space/space_property.cc new file mode 100644 index 000000000..2db7b71ec --- /dev/null +++ b/ucm/store/nfsstore/cc/domain/space/space_property.cc @@ -0,0 +1,136 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "space_property.h" +#include +#include +#include +#include "logger/logger.h" +#include "file/file.h" + +namespace UC { + +static constexpr uint32_t Magic = (('S' << 16) | ('p' << 8) | 1); +static constexpr size_t PropertySize = 256; + +struct Property { + std::atomic magic; + uint32_t padding; + std::atomic capacity; +}; +static_assert(sizeof(Property) <= PropertySize, "Property take too much space"); +static_assert(std::atomic::is_always_lock_free, "magic must be lock-free"); +static_assert(std::atomic::is_always_lock_free, "capacity must be lock-free"); + +inline auto PropertyPtr(void* addr) { return (Property*)addr; } + +SpaceProperty::~SpaceProperty() +{ + if (this->addr_) { + File::MUnmap(this->addr_, PropertySize); + } + this->addr_ = nullptr; +} + +Status SpaceProperty::Setup(const std::string& PropertyFilePath) +{ + auto file = File::Make(PropertyFilePath); + if (!file) { return Status::OutOfMemory(); } + auto flags = IFile::OpenFlag::CREATE | IFile::OpenFlag::EXCL | IFile::OpenFlag::READ_WRITE; + auto status = file->Open(flags); + if (status.Success()) { return this->InitShmProperty(file.get()); } + if (status == Status::DuplicateKey()) { return this->LoadShmProperty(file.get()); } + return status; +} + +void SpaceProperty::IncreaseCapacity(const size_t delta) +{ + PropertyPtr(this->addr_)->capacity += delta; +} + +void SpaceProperty::DecreaseCapacity(const size_t delta) +{ + auto property = PropertyPtr(this->addr_); + auto capacity = property->capacity.load(std::memory_order_acquire); + while (capacity > delta) { + if (property->capacity.compare_exchange_weak(capacity, capacity - delta, std::memory_order_acq_rel)) { + return; + } + capacity = property->capacity.load(std::memory_order_acquire); + } +} + +size_t SpaceProperty::GetCapacity() const +{ + return PropertyPtr(this->addr_)->capacity.load(std::memory_order_relaxed); +} + +Status SpaceProperty::InitShmProperty(IFile* shmPropertyFile) +{ + auto status = shmPropertyFile->Truncate(PropertySize); + if (status.Failure()) { return status; } + status = shmPropertyFile->MMap(this->addr_, PropertySize, true, true, true); + if (status.Failure()) { return status; } + std::fill_n((uint8_t*)this->addr_, PropertySize, 0); + auto property = PropertyPtr(this->addr_); + property->padding = 0; + property->capacity = 0; + property->magic = Magic; + return Status::OK(); +} + +Status SpaceProperty::LoadShmProperty(IFile* shmPropertyFile) +{ + auto status = shmPropertyFile->Open(IFile::OpenFlag::READ_WRITE); + if (status.Failure()) { return status; } + constexpr auto retryInterval = std::chrono::milliseconds(100); + constexpr auto maxTryTime = 100; + auto tryTime = 0; + IFile::FileStat stat; + do { + if (tryTime > maxTryTime) { + UC_ERROR("Shm file({}) not ready.", shmPropertyFile->Path()); + return Status::Retry(); + } + std::this_thread::sleep_for(retryInterval); + status = shmPropertyFile->Stat(stat); + if (status.Failure()) { return status; } + tryTime++; + } while (static_cast(stat.st_size) != PropertySize); + status = shmPropertyFile->MMap(this->addr_, PropertySize, true, true, true); + if (status.Failure()) { return status; } + auto property = PropertyPtr(this->addr_); + tryTime = 0; + do { + if (property->magic == Magic) { break; } + if (tryTime > maxTryTime) { + UC_ERROR("Shm file({}) not ready.", shmPropertyFile->Path()); + return Status::Retry(); + } + std::this_thread::sleep_for(retryInterval); + tryTime++; + } while (true); + return Status::OK(); +} + +} // namespace UC \ No newline at end of file diff --git a/ucm/store/nfsstore/cc/domain/space/space_property.h b/ucm/store/nfsstore/cc/domain/space/space_property.h new file mode 100644 index 000000000..b4e999081 --- /dev/null +++ b/ucm/store/nfsstore/cc/domain/space/space_property.h @@ -0,0 +1,51 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ + +#ifndef UNIFIEDCACHE_SPACE_PROPERTY_H +#define UNIFIEDCACHE_SPACE_PROPERTY_H + +#include "file/ifile.h" +#include "status/status.h" + +namespace UC { + +class SpaceProperty { +public: + ~SpaceProperty(); + Status Setup(const std::string& propertyFilePath); + void IncreaseCapacity(const size_t delta); + void DecreaseCapacity(const size_t delta); + size_t GetCapacity() const; + +private: + Status InitShmProperty(IFile* shmPropertyFile); + Status LoadShmProperty(IFile* shmPropertyFile); + +private: + void* addr_{nullptr}; +}; + +} // namespace UC + +#endif \ No newline at end of file diff --git a/ucm/store/nfsstore/cc/domain/space/space_recycle.cc b/ucm/store/nfsstore/cc/domain/space/space_recycle.cc new file mode 100644 index 000000000..0e47a7f7c --- /dev/null +++ b/ucm/store/nfsstore/cc/domain/space/space_recycle.cc @@ -0,0 +1,130 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ + +#include "space_recycle.h" +#include "logger/logger.h" +#include "file/file.h" +#include "template/topn_heap.h" + +namespace UC { + +constexpr float recyclePercent = 0.1f; /* recycle 10% of the capacity */ +constexpr uint32_t maxRecycleNum = 10240; /* max recycle num */ + +struct BlockInfo { + std::string path; + size_t timestamp; +}; + +struct CmpTimestamp { + bool operator()(const BlockInfo& lhs, const BlockInfo& rhs) const { + return lhs.timestamp > rhs.timestamp; + } +}; + +size_t GetFileTimestamp(const std::string& path) { + IFile::FileStat st; + if (File::Stat(path, st).Failure()) { return 0; } + return st.st_mtim.tv_sec; +} + +void RemoveBlockFile(const std::string& path) { + File::Remove(path); + auto pos = path.rfind('/'); + if (pos == std::string::npos) { return; } + auto parent = path.substr(0, pos); + File::RmDir(parent); +} + +void DoRecycle(const SpaceLayout* layout, const uint32_t recycleNum, + SpaceRecycle::RecycleOneBlockDone done) { + auto earliestHeap = std::make_unique>(recycleNum); + auto it = layout->CreateFilePathIterator(); + while (it) { + auto filePath = layout->NextDataFilePath(it); + if (filePath.empty()) { break; } + auto timestamp = GetFileTimestamp(filePath); + if (timestamp == 0) { continue; } + earliestHeap->Push({filePath, timestamp}); + } + while (!earliestHeap->Empty()) { + RemoveBlockFile(earliestHeap->Top().path); + if (done) { done(); } + earliestHeap->Pop(); + } +} +SpaceRecycle::~SpaceRecycle() { + { + std::lock_guard lock(this->mtx_); + this->stop_ = true; + this->cv_.notify_all(); + } + if (this->worker_.joinable()) { + this->worker_.join(); + } +} +Status SpaceRecycle::Setup(const SpaceLayout* layout, const size_t totalNumber, + RecycleOneBlockDone done) { + this->layout_ = layout; + this->recycleNum_ = totalNumber * recyclePercent; + if (this->recycleNum_ == 0) { + this->recycleNum_ = 1; + } + this->recycleOneBlockDone_ = done; + if (this->recycleNum_ > maxRecycleNum) { + this->recycleNum_ = maxRecycleNum; + } + return Status::OK(); +} + +void SpaceRecycle::Trigger() +{ + if (!this->serviceRunning_) { + this->worker_ = std::thread(&SpaceRecycle::Recycler, this); + } + std::lock_guard lock(this->mtx_); + if (!this->recycling_) { + this->recycling_ = true; + this->cv_.notify_all(); + } +} + +void SpaceRecycle::Recycler() +{ + this->serviceRunning_ = true; + UC_INFO("Space Recycle service start successfully."); + while (true) { + { + std::unique_lock lock(this->mtx_); + this->cv_.wait(lock, [this] { return this->stop_ || this->recycling_; }); + if (this->stop_) { break; } + } + DoRecycle(this->layout_, this->recycleNum_, this->recycleOneBlockDone_); + { + std::lock_guard lock(this->mtx_); + this->recycling_ = false; + } + } +} +} // namespace UC diff --git a/ucm/store/nfsstore/cc/domain/space/space_recycle.h b/ucm/store/nfsstore/cc/domain/space/space_recycle.h new file mode 100644 index 000000000..31e89d31b --- /dev/null +++ b/ucm/store/nfsstore/cc/domain/space/space_recycle.h @@ -0,0 +1,61 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_SPACE_RECYCLE_H +#define UNIFIEDCACHE_SPACE_RECYCLE_H + +#include +#include +#include +#include +#include +#include "space_layout.h" + +namespace UC { + +class SpaceRecycle { +public: + using RecycleOneBlockDone = std::function; + SpaceRecycle() = default; + SpaceRecycle(const SpaceRecycle&) = delete; + SpaceRecycle& operator=(const SpaceRecycle&) = delete; + ~SpaceRecycle(); + Status Setup(const SpaceLayout* layout, const size_t totalNumber, + RecycleOneBlockDone done); + void Trigger(); +private: + void Recycler(); +private: + bool stop_{false}; + bool recycling_{false}; + std::atomic_bool serviceRunning_{false}; + uint32_t recycleNum_{0}; + RecycleOneBlockDone recycleOneBlockDone_; + const SpaceLayout* layout_{nullptr}; + std::mutex mtx_; + std::condition_variable cv_; + std::thread worker_; +}; + +} // namespace UC +#endif \ No newline at end of file diff --git a/ucm/store/nfsstore/cc/domain/space/space_shard_layout.cc b/ucm/store/nfsstore/cc/domain/space/space_shard_layout.cc new file mode 100644 index 000000000..88f145665 --- /dev/null +++ b/ucm/store/nfsstore/cc/domain/space/space_shard_layout.cc @@ -0,0 +1,230 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "space_shard_layout.h" +#include +#include +#include +#include +#include "file/file.h" +#include "logger/logger.h" + +namespace UC { + +constexpr size_t blockIdSize = 16; +constexpr size_t nU64PerBlock = blockIdSize / sizeof(uint64_t); +using BlockId = std::array; +static_assert(sizeof(BlockId) == blockIdSize); + +const std::string activatedFileSuffix = "act"; +const std::string archivedFileSuffix = "dat"; + +inline auto OpenDir(const std::string& path) +{ + auto dir = ::opendir(path.c_str()); + auto eno = errno; + if (!dir) { UC_ERROR("Failed({}) to open dir({}).", eno, path); } + return dir; +} + +// Define SpaceLayout::DataIterator as an empty base class +struct SpaceLayout::DataIterator { + virtual ~DataIterator() = default; +}; + +struct SpaceShardLayout::DataIterator : public SpaceLayout::DataIterator { + const SpaceLayout* layout{nullptr}; + std::string root; + std::string current; + std::stack> stk; + ~DataIterator() + { + while (!this->stk.empty()) { + ::closedir(this->stk.top().first); + this->stk.pop(); + } + } + Status Setup(const SpaceLayout* layout, const std::string& root) { + this->layout = layout; + this->root = root; + auto dir = OpenDir(root); + if (!dir) { return Status::OsApiError(); } + this->stk.emplace(dir, root); + return Status::OK(); + } + Status Next() { + this->current.clear(); + while (!this->stk.empty()) { + auto entry = ::readdir64(this->stk.top().first); + if (entry == nullptr) { + ::closedir(this->stk.top().first); + this->stk.pop(); + continue; + } + std::string name{entry->d_name}; + if (name.front() == '.') { continue; } + if (this->layout->IsActivatedFile(name)) { continue; } + const auto& dir = this->stk.top().second; + auto fullpath = this->stk.top().second + "/" + name; + if (dir == this->root) { + auto sub = OpenDir(fullpath); + if (!sub) { return Status::OsApiError(); } + this->stk.emplace(sub, fullpath); + continue; + } + this->current = std::move(fullpath); + return Status::OK(); + } + return Status::NotFound(); + } +}; + +Status SpaceShardLayout::Setup(const std::vector& storageBackends) +{ + if (storageBackends.empty()) { + UC_ERROR("Empty backend list."); + return Status::InvalidParam(); + } + auto status = Status::OK(); + for (auto& path : storageBackends) { + if ((status = this->AddStorageBackend(path)).Failure()) { return status; } + } + return status; +} + +std::string SpaceShardLayout::DataFileParent(const std::string& blockId, bool activated) const +{ + uint64_t front, back; + this->ShardBlockId(blockId, front, back); + return fmt::format("{}{}/{:016x}", this->StorageBackend(blockId), this->DataFileRoot(), front); +} + +std::string SpaceShardLayout::DataFilePath(const std::string& blockId, bool activated) const +{ + uint64_t front, back; + this->ShardBlockId(blockId, front, back); + return fmt::format("{}{}/{:016x}/{:016x}.{}", this->StorageBackend(blockId), + this->DataFileRoot(), front, back, activated ? activatedFileSuffix : archivedFileSuffix); +} + +Status SpaceShardLayout::AddStorageBackend(const std::string& path) +{ + auto normalizedPath = path; + if (normalizedPath.back() != '/') { normalizedPath += '/'; } + auto status = Status::OK(); + if (this->storageBackends_.empty()) { + status = this->AddFirstStorageBackend(normalizedPath); + } else { + status = this->AddSecondaryStorageBackend(normalizedPath); + } + if (status.Failure()) { + UC_ERROR("Failed({}) to add storage backend({}).", status, normalizedPath); + } + return status; +} + +Status SpaceShardLayout::AddFirstStorageBackend(const std::string& path) +{ + for (const auto& root : this->RelativeRoots()) { + auto dir = File::Make(path + root); + if (!dir) { return Status::OutOfMemory(); } + auto status = dir->MkDir(); + if (status == Status::DuplicateKey()) { status = Status::OK(); } + if (status.Failure()) { return status; } + } + this->storageBackends_.emplace_back(path); + return Status::OK(); +} + +Status SpaceShardLayout::AddSecondaryStorageBackend(const std::string& path) +{ + auto iter = std::find(this->storageBackends_.begin(), this->storageBackends_.end(), path); + if (iter != this->storageBackends_.end()) { return Status::OK(); } + constexpr auto accessMode = IFile::AccessMode::READ | IFile::AccessMode::WRITE; + for (const auto& root : this->RelativeRoots()) { + auto dir = File::Make(path + root); + if (!dir) { return Status::OutOfMemory(); } + if (dir->Access(accessMode).Failure()) { return Status::InvalidParam(); } + } + this->storageBackends_.emplace_back(path); + return Status::OK(); +} + +std::string SpaceShardLayout::StorageBackend(const std::string& blockId) const +{ + static std::hash hasher; + return this->storageBackends_[hasher(blockId) % this->storageBackends_.size()]; +} + +std::vector SpaceShardLayout::RelativeRoots() const { + return { + this->DataFileRoot(), + this->ClusterFileRoot(), + }; +} + +std::string SpaceShardLayout::DataFileRoot() const { return "data"; } +std::string SpaceShardLayout::ClusterFileRoot() const { return "cluster"; } +void SpaceShardLayout::ShardBlockId(const std::string& blockId, uint64_t& front, + uint64_t& back) const +{ + auto id = static_cast(static_cast(blockId.data())); + front = id->front(); + back = id->back(); +} + +std::string SpaceShardLayout::StorageBackend() const { return this->storageBackends_.front(); } +std::string SpaceShardLayout::ClusterPropertyFilePath() const +{ + return fmt::format("{}{}/{}.bin", this->StorageBackend(), this->ClusterFileRoot(), "uc_property"); +} + +std::shared_ptr SpaceShardLayout::CreateFilePathIterator() const +{ + auto dataRoot = this->StorageBackend() + this->DataFileRoot(); + std::shared_ptr iter = nullptr; + try { + iter = std::make_shared(); + } catch (const std::exception& e) { + UC_ERROR("Failed to create data iterator: {}", e.what()); + return nullptr; + } + if (iter->Setup(this, dataRoot).Failure()) { + return nullptr; + } + return std::dynamic_pointer_cast(iter); +} + +std::string SpaceShardLayout::NextDataFilePath(std::shared_ptr iter) const +{ + auto shard_iter = std::dynamic_pointer_cast(iter); + if (!shard_iter) { return std::string{}; } + if (shard_iter->Next().Failure()) { return std::string{}; } + return shard_iter->current; +} + +bool SpaceShardLayout::IsActivatedFile(const std::string& filePath) const +{ + return std::equal(activatedFileSuffix.rbegin(), activatedFileSuffix.rend(), filePath.rbegin()); +} +} // namespace UC diff --git a/ucm/store/nfsstore/cc/domain/space/space_shard_layout.h b/ucm/store/nfsstore/cc/domain/space/space_shard_layout.h new file mode 100644 index 000000000..60dd4820d --- /dev/null +++ b/ucm/store/nfsstore/cc/domain/space/space_shard_layout.h @@ -0,0 +1,58 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_SPACE_SHARD_LAYOUT_H +#define UNIFIEDCACHE_SPACE_SHARD_LAYOUT_H + +#include "space_layout.h" + +namespace UC { + +class SpaceShardLayout : public SpaceLayout { +public: + struct DataIterator; +public: + Status Setup(const std::vector& storageBackends) override; + std::string DataFileParent(const std::string& blockId, bool activated) const override; + std::string DataFilePath(const std::string& blockId, bool activated) const override; + std::string ClusterPropertyFilePath() const override; + std::shared_ptr CreateFilePathIterator() const override; + std::string NextDataFilePath(std::shared_ptr iter) const override; + bool IsActivatedFile(const std::string& filePath) const override; + +protected: + virtual std::vector RelativeRoots() const; + virtual Status AddStorageBackend(const std::string& path); + virtual Status AddFirstStorageBackend(const std::string& path); + virtual Status AddSecondaryStorageBackend(const std::string& path); + virtual std::string StorageBackend(const std::string& blockId) const; + virtual std::string DataFileRoot() const; + virtual std::string ClusterFileRoot() const; + virtual std::string StorageBackend() const; + virtual void ShardBlockId(const std::string& blockId, uint64_t& front, uint64_t& back) const; + std::vector storageBackends_; +}; + +} // namespace UC + +#endif diff --git a/ucm/store/nfsstore/cc/domain/space/space_shard_temp_layout.cc b/ucm/store/nfsstore/cc/domain/space/space_shard_temp_layout.cc new file mode 100644 index 000000000..f42e1eb5a --- /dev/null +++ b/ucm/store/nfsstore/cc/domain/space/space_shard_temp_layout.cc @@ -0,0 +1,56 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "space_shard_temp_layout.h" +#include + +namespace UC { + +std::string SpaceShardTempLayout::DataFileParent(const std::string& blockId, bool activated) const +{ + if (!activated) { return SpaceShardLayout::DataFileParent(blockId, activated); } + uint64_t front, back; + this->ShardBlockId(blockId, front, back); + return fmt::format("{}{}/{:016x}", this->StorageBackend(blockId), this->TempDataFileRoot(), + front); +} + +std::string SpaceShardTempLayout::DataFilePath(const std::string& blockId, bool activated) const +{ + if (!activated) { return SpaceShardLayout::DataFilePath(blockId, activated); } + uint64_t front, back; + this->ShardBlockId(blockId, front, back); + return fmt::format("{}{}/{:016x}/{:016x}.dat", this->StorageBackend(blockId), + this->TempDataFileRoot(), front, back); +} + +std::vector SpaceShardTempLayout::RelativeRoots() const +{ + auto roots = SpaceShardLayout::RelativeRoots(); + roots.push_back(this->TempDataFileRoot()); + return roots; +} + +std::string SpaceShardTempLayout::TempDataFileRoot() const { return "temp"; } + +} // namespace UC diff --git a/ucm/store/nfsstore/cc/domain/space/space_shard_temp_layout.h b/ucm/store/nfsstore/cc/domain/space/space_shard_temp_layout.h new file mode 100644 index 000000000..a63bd468e --- /dev/null +++ b/ucm/store/nfsstore/cc/domain/space/space_shard_temp_layout.h @@ -0,0 +1,43 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_SPACE_SHARD_TEMP_LAYOUT_H +#define UNIFIEDCACHE_SPACE_SHARD_TEMP_LAYOUT_H + +#include "space_shard_layout.h" + +namespace UC { + +class SpaceShardTempLayout : public SpaceShardLayout { +public: + std::string DataFileParent(const std::string& blockId, bool activated) const override; + std::string DataFilePath(const std::string& blockId, bool activated) const override; + +protected: + std::vector RelativeRoots() const override; + virtual std::string TempDataFileRoot() const; +}; + +} // namespace UC + +#endif diff --git a/ucm/store/nfsstore/cc/domain/trans/posix_queue.cc b/ucm/store/nfsstore/cc/domain/trans/posix_queue.cc new file mode 100644 index 000000000..0b525cb13 --- /dev/null +++ b/ucm/store/nfsstore/cc/domain/trans/posix_queue.cc @@ -0,0 +1,141 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "posix_queue.h" +#include "file/file.h" + +namespace UC { + +template +bool IsAligned(const T value) +{ + static constexpr size_t alignment = 4096; + static constexpr size_t alignMask = alignment - 1; + return (value & alignMask) == 0; +} + +Status PosixQueue::Setup(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber, + TaskSet* failureSet, const SpaceLayout* layout, const size_t timeoutMs, bool useDirect) +{ + this->deviceId_ = deviceId; + this->bufferSize_ = bufferSize; + this->bufferNumber_ = bufferNumber; + this->failureSet_ = failureSet; + this->layout_ = layout; + this->useDirect_ = useDirect; + auto success = + this->backend_.SetWorkerInitFn([this](auto& device) { return this->Init(device); }) + .SetWorkerFn([this](auto& shard, const auto& device) { this->Work(shard, device); }) + .SetWorkerExitFn([this](auto& device) { this->Exit(device); }) + .Run(); + return success ? Status::OK() : Status::Error(); +} + +void PosixQueue::Push(std::list& shards) noexcept { this->backend_.Push(shards); } + +bool PosixQueue::Init(Device& device) +{ + if (this->deviceId_ < 0) { return true; } + device = DeviceFactory::Make(this->deviceId_, this->bufferSize_, this->bufferNumber_); + if (!device) { return false; } + return device->Setup().Success(); +} + +void PosixQueue::Exit(Device& device) { device.reset(); } + +void PosixQueue::Work(Task::Shard& shard, const Device& device) +{ + if (this->failureSet_->Contains(shard.owner)) { + this->Done(shard, device, true); + return; + } + auto status = Status::OK(); + if (shard.location == Task::Location::DEVICE) { + if (shard.type == Task::Type::DUMP) { + status = this->D2S(shard, device); + } else { + status = this->S2D(shard, device); + } + } else { + if (shard.type == Task::Type::DUMP) { + status = this->H2S(shard); + } else { + status = this->S2H(shard); + } + } + this->Done(shard, device, status.Success()); +} + +void PosixQueue::Done(Task::Shard& shard, const Device& device, const bool success) +{ + if (!success) { this->failureSet_->Insert(shard.owner); } + if (!shard.done) { return; } + if (device) { + if (device->Synchronized().Failure()) { this->failureSet_->Insert(shard.owner); } + } + shard.done(); +} + +Status PosixQueue::D2S(Task::Shard& shard, const Device& device) +{ + shard.buffer = device->GetBuffer(shard.length); + if (!shard.buffer) { + UC_ERROR("Out of memory({}).", shard.length); + return Status::OutOfMemory(); + } + auto hub = shard.buffer.get(); + auto status = device->D2HSync((std::byte*)hub, (std::byte*)shard.address, shard.length); + if (status.Failure()) { return status; } + auto path = this->layout_->DataFilePath(shard.block, true); + return File::Write(path, shard.offset, shard.length, (uintptr_t)hub, useDirect_); +} + +Status PosixQueue::S2D(Task::Shard& shard, const Device& device) +{ + shard.buffer = device->GetBuffer(shard.length); + if (!shard.buffer) { + UC_ERROR("Out of memory({}).", shard.length); + return Status::OutOfMemory(); + } + auto hub = shard.buffer.get(); + auto path = this->layout_->DataFilePath(shard.block, false); + auto status = File::Read(path, shard.offset, shard.length, (uintptr_t)hub, useDirect_); + if (status.Failure()) { return status; } + return device->H2DAsync((std::byte*)shard.address, (std::byte*)hub, shard.length); +} + +Status PosixQueue::H2S(Task::Shard& shard) +{ + auto path = this->layout_->DataFilePath(shard.block, true); + auto aligned = IsAligned(shard.offset) && IsAligned(shard.length) && IsAligned(shard.address); + return File::Write(path, shard.offset, shard.length, shard.address, aligned); +} + +Status PosixQueue::S2H(Task::Shard& shard) +{ + auto path = this->layout_->DataFilePath(shard.block, false); + auto aligned = IsAligned(shard.offset) && IsAligned(shard.length) && IsAligned(shard.address); + return File::Read(path, shard.offset, shard.length, shard.address, aligned); +} + +} // namespace UC diff --git a/ucm/store/dramstore/cc/api/dramstore.cc b/ucm/store/nfsstore/cc/domain/trans/posix_queue.h similarity index 50% rename from ucm/store/dramstore/cc/api/dramstore.cc rename to ucm/store/nfsstore/cc/domain/trans/posix_queue.h index 56b4350f3..0e2dc17f8 100644 --- a/ucm/store/dramstore/cc/api/dramstore.cc +++ b/ucm/store/nfsstore/cc/domain/trans/posix_queue.h @@ -21,41 +21,44 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#include "dramstore.h" -#include "logger/logger.h" +#ifndef UNIFIEDCACHE_POSIX_QUEUE_H +#define UNIFIEDCACHE_POSIX_QUEUE_H + +#include "device/idevice.h" +#include "space/space_layout.h" #include "status/status.h" +#include "task_queue.h" +#include "task_set.h" +#include "thread/thread_pool.h" namespace UC { -class DRAMStoreImpl : public DRAMStore { +class PosixQueue : public TaskQueue { + using Device = std::unique_ptr; + int32_t deviceId_{-1}; + size_t bufferSize_{0}; + size_t bufferNumber_{0}; + TaskSet* failureSet_{nullptr}; + const SpaceLayout* layout_{nullptr}; + bool useDirect_{false}; + ThreadPool backend_{}; + public: - int32_t Setup(const size_t ioSize, const size_t capacity, const int32_t deviceId) { return -1; } - int32_t Alloc(const std::string& block) override { return -1; } - bool Lookup(const std::string& block) override { return false; } - void Commit(const std::string& block, const bool success) override {} - std::list Alloc(const std::list& blocks) override - { - return std::list(); - } - std::list Lookup(const std::list& blocks) override - { - return std::list(); - } - void Commit(const std::list& blocks, const bool success) override {} - size_t Submit(Task&& task) override { return 0; } - int32_t Wait(const size_t task) override { return -1; } - int32_t Check(const size_t task, bool& finish) override { return -1; } -}; + Status Setup(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber, + TaskSet* failureSet, const SpaceLayout* layout, const size_t timeoutMs, bool useDirect = false); + void Push(std::list& shards) noexcept override; -int32_t DRAMStore::Setup(const Config& config) -{ - auto impl = new (std::nothrow) DRAMStoreImpl(); - if (!impl) { - UC_ERROR("Out of memory."); - return Status::OutOfMemory().Underlying(); - } - this->impl_ = impl; - return impl->Setup(config.ioSize, config.capacity, config.deviceId); -} +private: + bool Init(Device& device); + void Exit(Device& device); + void Work(Task::Shard& shard, const Device& device); + void Done(Task::Shard& shard, const Device& device, const bool success); + Status D2S(Task::Shard& shard, const Device& device); + Status S2D(Task::Shard& shard, const Device& device); + Status H2S(Task::Shard& shard); + Status S2H(Task::Shard& shard); +}; } // namespace UC + +#endif diff --git a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task.h b/ucm/store/nfsstore/cc/domain/trans/trans_manager.h similarity index 59% rename from ucm/store/nfsstore/cc/domain/tsf_task/tsf_task.h rename to ucm/store/nfsstore/cc/domain/trans/trans_manager.h index b3f8cba40..06748042e 100644 --- a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task.h +++ b/ucm/store/nfsstore/cc/domain/trans/trans_manager.h @@ -21,39 +21,30 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_TSF_TASK_H -#define UNIFIEDCACHE_TSF_TASK_H +#ifndef UNIFIEDCACHE_TRANS_MANAGER_H +#define UNIFIEDCACHE_TRANS_MANAGER_H -#include "ucmstore.h" -#include "tsf_task_waiter.h" +#include "posix_queue.h" +#include "task_manager.h" namespace UC { -class TsfTask { +class TransManager : public TaskManager { public: - using Type = CCStore::Task::Type; - using Location = CCStore::Task::Location; - -public: - TsfTask(const Type type, const Location location, const std::string& blockId, - const size_t offset, const uintptr_t address, const size_t length) - : type{type}, location{location}, blockId{blockId}, offset{offset}, address{address}, - length{length}, owner{0}, waiter{nullptr}, hub{nullptr} + Status Setup(const int32_t deviceId, const size_t streamNumber, const size_t ioSize, + const size_t bufferNumber, const SpaceLayout* layout, const size_t timeoutMs, bool useDirect = false) { + this->timeoutMs_ = timeoutMs; + auto status = Status::OK(); + for (size_t i = 0; i < streamNumber; i++) { + auto q = std::make_shared(); + status = + q->Setup(deviceId, ioSize, bufferNumber, &this->failureSet_, layout, timeoutMs, useDirect); + if (status.Failure()) { break; } + this->queues_.emplace_back(std::move(q)); + } + return status; } - TsfTask() : TsfTask{Type::DUMP, Location::HOST, {}, 0, 0, 0} {} - -public: - Type type; - Location location; - std::string blockId; - size_t offset; - uintptr_t address; - size_t length; - - size_t owner; - std::shared_ptr waiter; - std::shared_ptr hub; }; } // namespace UC diff --git a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_manager.cc b/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_manager.cc deleted file mode 100644 index a74aad05e..000000000 --- a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_manager.cc +++ /dev/null @@ -1,107 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#include "tsf_task_manager.h" - -namespace UC { - -Status TsfTaskManager::Setup(const int32_t deviceId, const size_t streamNumber, - const size_t bufferSize, const size_t bufferNumber, - const size_t timeoutMs, const SpaceLayout* layout) -{ - this->_queues.reserve(streamNumber); - for (size_t i = 0; i < streamNumber; ++i) { - auto& queue = this->_queues.emplace_back(std::make_unique()); - auto status = queue->Setup(deviceId, bufferSize, bufferNumber, &this->_failureSet, layout); - if (status.Failure()) { return status; } - } - this->_timeoutMs = timeoutMs; - return Status::OK(); -} - -Status TsfTaskManager::Submit(std::list& tasks, const size_t size, const size_t number, - const std::string& brief, size_t& taskId) -{ - std::unique_lock lk(this->_mutex); - taskId = ++this->_taskIdSeed; - auto [iter, success] = this->_waiters.emplace( - taskId, std::make_shared(taskId, size, number, brief)); - if (!success) { return Status::OutOfMemory(); } - std::vector> lists; - this->Dispatch(tasks, lists, taskId, iter->second); - for (size_t i = 0; i < lists.size(); i++) { - if (lists[i].empty()) { continue; } - this->_queues[this->_qIdx]->Push(lists[i]); - this->_qIdx = (this->_qIdx + 1) % this->_queues.size(); - } - return Status::OK(); -} - -Status TsfTaskManager::Wait(const size_t taskId) -{ - std::shared_ptr waiter = nullptr; - { - std::unique_lock lk(this->_mutex); - auto iter = this->_waiters.find(taskId); - if (iter == this->_waiters.end()) { return Status::NotFound(); } - waiter = iter->second; - this->_waiters.erase(iter); - } - if (!waiter->Wait(this->_timeoutMs)) { - this->_failureSet.Insert(taskId); - waiter->Wait(); - } - bool failure = this->_failureSet.Contains(taskId); - this->_failureSet.Remove(taskId); - if (failure) { UC_ERROR("Transfer task({}) failed.", taskId); } - return failure ? Status::Error() : Status::OK(); -} - -Status TsfTaskManager::Check(const size_t taskId, bool& finish) -{ - std::lock_guard lk(this->_mutex); - auto iter = this->_waiters.find(taskId); - if (iter == this->_waiters.end()) { return Status::NotFound(); } - finish = iter->second->Finish(); - return Status::OK(); -} - -void TsfTaskManager::Dispatch(std::list& tasks, std::vector>& targets, - const size_t taskId, std::shared_ptr waiter) const -{ - auto qNumber = this->_queues.size(); - auto index = size_t(0); - targets.resize(qNumber); - auto it = tasks.begin(); - while (it != tasks.end()) { - auto next = std::next(it); - it->owner = taskId; - it->waiter = waiter; - auto& target = targets[index % qNumber]; - target.splice(target.end(), tasks, it); - index++; - it = next; - } -} - -} // namespace UC diff --git a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_queue.cc b/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_queue.cc deleted file mode 100644 index b51760567..000000000 --- a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_queue.cc +++ /dev/null @@ -1,202 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#include "tsf_task_queue.h" -#include "file/file.h" - -namespace UC { - -#define UC_TASK_ERROR(s, t) \ - do { \ - UC_ERROR("Failed({}) to run task({},{},{},{}).", (s), (t).owner, (t).blockId, (t).offset, \ - (t).length); \ - } while (0) - -Status TsfTaskQueue::Setup(const int32_t deviceId, const size_t bufferSize, - const size_t bufferNumber, TsfTaskSet* failureSet, - const SpaceLayout* layout) -{ - this->_failureSet = failureSet; - this->_layout = layout; - if (deviceId >= 0) { - this->_device = DeviceFactory::Make(deviceId, bufferSize, bufferNumber); - if (!this->_device) { return Status::OutOfMemory(); } - if (!this->_streamOper.Setup([this](TsfTask& task) { this->StreamOper(task); }, - [this] { return this->_device->Setup().Success(); })) { - return Status::Error(); - } - } - if (!this->_fileOper.Setup([this](TsfTask& task) { this->FileOper(task); })) { - return Status::Error(); - } - return Status::OK(); -} - -void TsfTaskQueue::Push(std::list& tasks) -{ - auto& front = tasks.front(); - if (front.location == TsfTask::Location::HOST || front.type == TsfTask::Type::LOAD) { - this->_fileOper.Push(tasks); - } else { - this->_streamOper.Push(tasks); - } -} - -void TsfTaskQueue::StreamOper(TsfTask& task) -{ - if (this->_failureSet->Contains(task.owner)) { - this->Done(task, false); - return; - } - if (task.type == TsfTask::Type::LOAD) { - this->H2D(task); - } else { - this->D2H(task); - } -} - -void TsfTaskQueue::FileOper(TsfTask& task) -{ - if (this->_failureSet->Contains(task.owner)) { - this->Done(task, false); - return; - } - if (task.type == TsfTask::Type::LOAD) { - this->S2H(task); - } else { - this->H2S(task); - } -} - -void TsfTaskQueue::H2D(TsfTask& task) -{ - auto status = this->_device->H2DAsync((std::byte*)task.address, task.hub.get(), task.length); - if (status.Failure()) { - UC_TASK_ERROR(status, task); - this->Done(task, false); - return; - } - status = this->_device->AppendCallback([this, task](bool success) mutable { - if (!success) { UC_TASK_ERROR(Status::Error(), task); } - this->Done(task, success); - }); - if (status.Failure()) { - UC_TASK_ERROR(status, task); - this->Done(task, false); - return; - } -} - -void TsfTaskQueue::D2H(TsfTask& task) -{ - task.hub = this->_device->GetBuffer(task.length); - if (!task.hub) { - UC_TASK_ERROR(Status::OutOfMemory(), task); - this->Done(task, false); - return; - } - auto status = this->_device->D2HAsync(task.hub.get(), (std::byte*)task.address, task.length); - if (status.Failure()) { - UC_TASK_ERROR(status, task); - this->Done(task, false); - return; - } - status = this->_device->AppendCallback([this, task](bool success) mutable { - if (success) { - this->_fileOper.Push(std::move(task)); - } else { - UC_TASK_ERROR(Status::Error(), task); - this->Done(task, false); - return; - } - }); - if (status.Failure()) { - UC_TASK_ERROR(status, task); - this->Done(task, false); - return; - } -} - -void TsfTaskQueue::H2S(TsfTask& task) -{ - const void* src = - task.location == TsfTask::Location::HOST ? (void*)task.address : task.hub.get(); - auto path = this->_layout->DataFilePath(task.blockId, true); - auto status = Status::OK(); - do { - auto file = File::Make(path); - if (!file) { - status = Status::OutOfMemory(); - break; - } - if ((status = file->Open(IFile::OpenFlag::WRITE_ONLY)).Failure()) { break; } - if ((status = file->Write(src, task.length, task.offset)).Failure()) { break; } - } while (0); - if (status.Failure()) { - UC_TASK_ERROR(status, task); - this->Done(task, false); - return; - } - this->Done(task, true); -} - -void TsfTaskQueue::S2H(TsfTask& task) -{ - auto path = this->_layout->DataFilePath(task.blockId, false); - auto status = Status::OK(); - do { - auto dst = (void*)task.address; - if (task.location == TsfTask::Location::DEVICE) { - if (!(task.hub = this->_device->GetBuffer(task.length))) { - status = Status::OutOfMemory(); - break; - } - dst = task.hub.get(); - } - auto file = File::Make(path); - if (!file) { - status = Status::OutOfMemory(); - break; - } - if ((status = file->Open(IFile::OpenFlag::READ_ONLY)).Failure()) { break; } - if ((status = file->Read(dst, task.length, task.offset)).Failure()) { break; } - } while (0); - if (status.Failure()) { - UC_TASK_ERROR(status, task); - this->Done(task, false); - return; - } - if (task.location == TsfTask::Location::HOST) { - this->Done(task, true); - return; - } - this->_streamOper.Push(std::move(task)); -} - -void TsfTaskQueue::Done(const TsfTask& task, bool success) -{ - if (!success) { this->_failureSet->Insert(task.owner); } - task.waiter->Done(); -} - -} // namespace UC diff --git a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_waiter.h b/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_waiter.h deleted file mode 100644 index 00614187c..000000000 --- a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_waiter.h +++ /dev/null @@ -1,84 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#ifndef UNIFIEDCACHE_TSF_TASK_WAITER_H -#define UNIFIEDCACHE_TSF_TASK_WAITER_H - -#include "logger/logger.h" -#include "thread/latch.h" -#include "time/stopwatch.h" - -namespace UC { - -class TsfTaskWaiter : public Latch { -public: - TsfTaskWaiter(const size_t id, const size_t size, const size_t number, const std::string& brief) - : Latch{number}, id_{id}, size_{size}, number_{number}, brief_{brief} - { - } - void Done() - { - if (Latch::Done() == 0) { - auto elapsed = this->sw_.Elapsed().count(); - UC_DEBUG("Task({},{},{},{}) finished, elapsed={:.06f}s, bw={:.06f}GB/s.", this->id_, - this->brief_, this->number_, this->size_, elapsed, - this->size_ / elapsed / (1ULL << 30)); - this->Notify(); - } - } - using Latch::Wait; - bool Wait(const size_t timeoutMs) - { - if (timeoutMs == 0) { - this->Wait(); - return true; - } - auto finish = false; - { - std::unique_lock lk(this->mutex_); - if (this->counter_ == 0) { return true; } - auto elapsed = (size_t)this->sw_.ElapsedMs().count(); - if (elapsed < timeoutMs) { - finish = this->cv_.wait_for(lk, std::chrono::milliseconds(timeoutMs - elapsed), - [this] { return this->counter_ == 0; }); - } - } - if (!finish) { - UC_WARN("Task({},{},{},{}) timeout, elapsed={:.06f}s.", this->id_, this->brief_, - this->number_, this->size_, this->sw_.Elapsed().count()); - } - return finish; - } - bool Finish() { return this->counter_ == 0; } - -private: - size_t id_; - size_t size_; - size_t number_; - std::string brief_; - StopWatch sw_; -}; - -} // namespace UC - -#endif // UNIFIEDCACHE_TSF_TASK_WAITER_H diff --git a/ucm/store/nfsstore/cpy/nfsstore.py.cc b/ucm/store/nfsstore/cpy/nfsstore.py.cc index 1471f798b..0d24dc105 100644 --- a/ucm/store/nfsstore/cpy/nfsstore.py.cc +++ b/ucm/store/nfsstore/cpy/nfsstore.py.cc @@ -57,43 +57,42 @@ class NFSStorePy : public NFSStore { size_t LoadToDevice(const py::list& blockIds, const py::list& offsets, const py::list& addresses, const py::list& lengths) { - return this->SubmitPy(blockIds, offsets, addresses, lengths, CCStore::Task::Type::LOAD, - CCStore::Task::Location::DEVICE, "NFS::S2D"); + return this->SubmitPy(blockIds, offsets, addresses, lengths, Task::Type::LOAD, + Task::Location::DEVICE, "NFS::S2D"); } size_t LoadToHost(const py::list& blockIds, const py::list& offsets, const py::list& addresses, const py::list& lengths) { - return this->SubmitPy(blockIds, offsets, addresses, lengths, CCStore::Task::Type::LOAD, - CCStore::Task::Location::HOST, "NFS::S2H"); + return this->SubmitPy(blockIds, offsets, addresses, lengths, Task::Type::LOAD, + Task::Location::HOST, "NFS::S2H"); } size_t DumpFromDevice(const py::list& blockIds, const py::list& offsets, const py::list& addresses, const py::list& lengths) { - return this->SubmitPy(blockIds, offsets, addresses, lengths, CCStore::Task::Type::DUMP, - CCStore::Task::Location::DEVICE, "NFS::D2S"); + return this->SubmitPy(blockIds, offsets, addresses, lengths, Task::Type::DUMP, + Task::Location::DEVICE, "NFS::D2S"); } size_t DumpFromHost(const py::list& blockIds, const py::list& offsets, const py::list& addresses, const py::list& lengths) { - return this->SubmitPy(blockIds, offsets, addresses, lengths, CCStore::Task::Type::DUMP, - CCStore::Task::Location::HOST, "NFS::H2S"); + return this->SubmitPy(blockIds, offsets, addresses, lengths, Task::Type::DUMP, + Task::Location::HOST, "NFS::H2S"); } private: size_t SubmitPy(const py::list& blockIds, const py::list& offsets, const py::list& addresses, - const py::list& lengths, const CCStore::Task::Type type, - const CCStore::Task::Location location, const std::string& brief) + const py::list& lengths, Task::Type&& type, Task::Location&& location, + std::string&& brief) { - CCStore::Task task{type, location, brief}; + Task task{std::move(type), std::move(location), std::move(brief)}; auto blockId = blockIds.begin(); auto offset = offsets.begin(); auto address = addresses.begin(); auto length = lengths.begin(); while ((blockId != blockIds.end()) && (offset != offsets.end()) && (address != addresses.end()) && (length != lengths.end())) { - auto ret = task.Append(blockId->cast(), offset->cast(), - address->cast(), length->cast()); - if (ret != 0) { return CCStore::invalidTaskId; } + task.Append(blockId->cast(), offset->cast(), + address->cast(), length->cast()); blockId++; offset++; address++; @@ -121,7 +120,15 @@ PYBIND11_MODULE(ucmnfsstore, module) config.def_readwrite("transferDeviceId", &UC::NFSStorePy::Config::transferDeviceId); config.def_readwrite("transferStreamNumber", &UC::NFSStorePy::Config::transferStreamNumber); config.def_readwrite("transferIoSize", &UC::NFSStorePy::Config::transferIoSize); + config.def_readwrite("transferIoDirect", &UC::NFSStorePy::Config::transferIoDirect); config.def_readwrite("transferBufferNumber", &UC::NFSStorePy::Config::transferBufferNumber); + config.def_readwrite("transferTimeoutMs", &UC::NFSStorePy::Config::transferTimeoutMs); + config.def_readwrite("tempDumpDirEnable", &UC::NFSStorePy::Config::tempDumpDirEnable); + config.def_readwrite("hotnessEnable", &UC::NFSStorePy::Config::hotnessEnable); + config.def_readwrite("hotnessInterval", &UC::NFSStorePy::Config::hotnessInterval); + config.def_readwrite("storageCapacity", &UC::NFSStorePy::Config::storageCapacity); + config.def_readwrite("recycleEnable", &UC::NFSStorePy::Config::recycleEnable); + config.def_readwrite("recycleThresholdRatio", &UC::NFSStorePy::Config::recycleThresholdRatio); store.def(py::init<>()); store.def("CCStoreImpl", &UC::NFSStorePy::CCStoreImpl); store.def("Setup", &UC::NFSStorePy::Setup); diff --git a/ucm/store/nfsstore/nfsstore_connector.py b/ucm/store/nfsstore/nfsstore_connector.py index b1176c1bc..c21c686a7 100644 --- a/ucm/store/nfsstore/nfsstore_connector.py +++ b/ucm/store/nfsstore/nfsstore_connector.py @@ -51,6 +51,19 @@ def __init__(self, config: Dict): if transfer_enable: param.transferDeviceId = config["device"] param.transferIoSize = config["io_size"] + param.transferIoDirect = config.get("use_direct", False) + param.transferStreamNumber = config.get("stream_number", 32) + param.transferBufferNumber = config.get("buffer_number", 512) + # NOTE: compatible with legacy nfsstore lib + if hasattr(param, "storage_capacity"): + param.storageCapacity = config.get("storage_capacity", 0) + if hasattr(param, "recycle_enable"): + param.recycleEnable = ( + True if config.get("recycle_enable", 0) == 1 else False + ) + if param.recycleEnable: + param.recycleThresholdRatio = config.get("recycle_threshold_ratio", 0.7) + ret = self.store.Setup(param) if ret != 0: msg = f"Failed to initialize ucmnfsstore, errcode: {ret}." @@ -88,6 +101,26 @@ def dump( ) return NfsTask(task_id=task_id) + def fetch_data( + self, + block_ids: List[str], + offset: List[int], + dst_addr: List[int], + size: List[int], + ) -> Task: + task_id = self.store.LoadToDevice(block_ids, offset, dst_addr, size) + return NfsTask(task_id=task_id) + + def dump_data( + self, + block_ids: List[str], + offset: List[int], + src_addr: List[int], + size: List[int], + ) -> Task: + task_id = self.store.DumpFromDevice(block_ids, offset, src_addr, size) + return NfsTask(task_id=task_id) + def wait(self, task: Task) -> int: return self.store.Wait(task.task_id) diff --git a/ucm/store/pcstore/CMakeLists.txt b/ucm/store/pcstore/CMakeLists.txt new file mode 100644 index 000000000..3511bd72b --- /dev/null +++ b/ucm/store/pcstore/CMakeLists.txt @@ -0,0 +1,12 @@ +file(GLOB_RECURSE UCMSTORE_PC_CC_SOURCE_FILES "./cc/*.cc") +add_library(pcstore STATIC ${UCMSTORE_PC_CC_SOURCE_FILES}) +target_include_directories(pcstore PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/cc/api + ${CMAKE_CURRENT_SOURCE_DIR}/cc/domain +) +target_link_libraries(pcstore PUBLIC trans storeinfra) + +file(GLOB_RECURSE UCMSTORE_PC_CPY_SOURCE_FILES "./cpy/*.cc") +pybind11_add_module(ucmpcstore ${UCMSTORE_PC_CPY_SOURCE_FILES}) +target_link_libraries(ucmpcstore PRIVATE pcstore) +set_target_properties(ucmpcstore PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/ucm/store/pcstore/__init__.py b/ucm/store/pcstore/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ucm/store/pcstore/cc/api/pcstore.cc b/ucm/store/pcstore/cc/api/pcstore.cc new file mode 100644 index 000000000..a58b29325 --- /dev/null +++ b/ucm/store/pcstore/cc/api/pcstore.cc @@ -0,0 +1,114 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "pcstore.h" +#include +#include "logger/logger.h" +#include "space/space_manager.h" +#include "status/status.h" +#include "trans/trans_manager.h" + +namespace UC { + +class PcStoreImpl : public PcStore { +public: + int32_t Setup(const Config& config) + { + auto status = this->spaceMgr_.Setup(config.storageBackends, config.kvcacheBlockSize); + if (status.Failure()) { return status.Underlying(); } + if (config.transferEnable) { + status = this->transMgr_.Setup( + config.transferLocalRankSize, config.transferDeviceId, config.transferStreamNumber, + config.kvcacheBlockSize, config.transferIoSize, config.transferIoDirect, + config.transferBufferNumber, this->spaceMgr_.GetSpaceLayout(), + config.transferTimeoutMs); + if (status.Failure()) { return status.Underlying(); } + } + this->ShowConfig(config); + return Status::OK().Underlying(); + } + int32_t Alloc(const std::string& block) override { return Status::OK().Underlying(); } + bool Lookup(const std::string& block) override { return this->spaceMgr_.LookupBlock(block); } + void Commit(const std::string& block, const bool success) override {} + std::list Alloc(const std::list& blocks) override + { + std::list results; + for (const auto& block : blocks) { results.emplace_back(this->Alloc(block)); } + return results; + } + std::list Lookup(const std::list& blocks) override + { + std::list founds; + for (const auto& block : blocks) { founds.emplace_back(this->Lookup(block)); } + return founds; + } + void Commit(const std::list& blocks, const bool success) override {} + size_t Submit(TransTask&& task) override + { + auto taskId = TransTask::invalid; + auto status = this->transMgr_.Submit(std::move(task), taskId); + if (status.Failure()) { taskId = TransTask::invalid; } + return taskId; + } + int32_t Wait(const size_t task) override { return this->transMgr_.Wait(task).Underlying(); } + int32_t Check(const size_t task, bool& finish) override + { + return this->transMgr_.Check(task, finish).Underlying(); + } + +private: + void ShowConfig(const Config& config) + { + std::string buildType = UCM_BUILD_TYPE; + if (buildType.empty()) { buildType = "Release"; } + UC_INFO("PcStore-{}({}).", UCM_COMMIT_ID, buildType); + UC_INFO("Set UC::StorageBackends to {}.", config.storageBackends); + UC_INFO("Set UC::BlockSize to {}.", config.kvcacheBlockSize); + UC_INFO("Set UC::TransferEnable to {}.", config.transferEnable); + if (!config.transferEnable) { return; } + UC_INFO("Set UC::IoSize to {}.", config.transferIoSize); + UC_INFO("Set UC::IoDirect to {}.", config.transferIoDirect); + UC_INFO("Set UC::LocalRankSize to {}.", config.transferLocalRankSize); + UC_INFO("Set UC::DeviceId to {}.", config.transferDeviceId); + UC_INFO("Set UC::StreamNumber to {}.", config.transferStreamNumber); + UC_INFO("Set UC::BufferNumber to {}.", config.transferBufferNumber); + UC_INFO("Set UC::TimeoutMs to {}.", config.transferTimeoutMs); + } + +private: + SpaceManager spaceMgr_; + TransManager transMgr_; +}; + +int32_t PcStore::Setup(const Config& config) +{ + auto impl = new (std::nothrow) PcStoreImpl(); + if (!impl) { + UC_ERROR("Out of memory."); + return Status::OutOfMemory().Underlying(); + } + this->impl_ = impl; + return impl->Setup(config); +} + +} // namespace UC diff --git a/ucm/store/dramstore/cc/api/dramstore.h b/ucm/store/pcstore/cc/api/pcstore.h similarity index 71% rename from ucm/store/dramstore/cc/api/dramstore.h rename to ucm/store/pcstore/cc/api/pcstore.h index 1dc975736..d655160f1 100644 --- a/ucm/store/dramstore/cc/api/dramstore.h +++ b/ucm/store/pcstore/cc/api/pcstore.h @@ -21,28 +21,38 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_DRAMSTORE_H -#define UNIFIEDCACHE_DRAMSTORE_H +#ifndef UNIFIEDCACHE_PCSTORE_H +#define UNIFIEDCACHE_PCSTORE_H +#include "trans/trans_task.h" #include "ucmstore.h" namespace UC { -class DRAMStore : public CCStore { +class PcStore : CCStore { public: struct Config { - size_t ioSize; - size_t capacity; - int32_t deviceId; - Config(const size_t ioSize, const size_t capacity) - : ioSize{ioSize}, capacity{capacity}, deviceId{-1} + std::vector storageBackends; + size_t kvcacheBlockSize; + bool transferEnable; + size_t transferIoSize{262144}; + bool transferIoDirect{false}; + size_t transferLocalRankSize{1}; + int32_t transferDeviceId{-1}; + size_t transferStreamNumber{8}; + size_t transferBufferNumber{4096}; + size_t transferTimeoutMs{30000}; + + Config(const std::vector& storageBackends, const size_t kvcacheBlockSize, + const bool transferEnable) + : storageBackends{storageBackends}, kvcacheBlockSize{kvcacheBlockSize}, + transferEnable{transferEnable} { } }; public: - DRAMStore() : impl_{nullptr} {} - ~DRAMStore() override + ~PcStore() override { if (this->impl_) { delete this->impl_; } } @@ -65,7 +75,7 @@ class DRAMStore : public CCStore { { this->impl_->Commit(blocks, success); } - size_t Submit(Task&& task) override { return this->impl_->Submit(std::move(task)); } + size_t Submit(TransTask&& task) override { return this->impl_->Submit(std::move(task)); } int32_t Wait(const size_t task) override { return this->impl_->Wait(task); } int32_t Check(const size_t task, bool& finish) override { @@ -73,7 +83,7 @@ class DRAMStore : public CCStore { } private: - DRAMStore* impl_; + PcStore* impl_{nullptr}; }; } // namespace UC diff --git a/ucm/store/nfsstore/cc/domain/space/space_layout.cc b/ucm/store/pcstore/cc/domain/space/space_layout.cc similarity index 75% rename from ucm/store/nfsstore/cc/domain/space/space_layout.cc rename to ucm/store/pcstore/cc/domain/space/space_layout.cc index dd4ff1ac2..06cee30d0 100644 --- a/ucm/store/nfsstore/cc/domain/space/space_layout.cc +++ b/ucm/store/pcstore/cc/domain/space/space_layout.cc @@ -29,11 +29,6 @@ namespace UC { -constexpr size_t blockIdSize = 16; -constexpr size_t nU64PerBlock = blockIdSize / sizeof(uint64_t); -using BlockId = std::array; -static_assert(sizeof(BlockId) == blockIdSize); - Status SpaceLayout::Setup(const std::vector& storageBackends) { if (storageBackends.empty()) { @@ -47,18 +42,27 @@ Status SpaceLayout::Setup(const std::vector& storageBackends) return status; } -std::string SpaceLayout::DataFileParent(const std::string& blockId) const +std::string SpaceLayout::DataFilePath(const std::string& blockId, bool activated) const { - auto id = static_cast(static_cast(blockId.data())); - return fmt::format("{}{}/{:016x}", this->StorageBackend(blockId), this->DataFileRoot(), - id->front()); + const auto& backend = StorageBackend(blockId); + const auto& dir = activated ? TempFileRoot() : DataFileRoot(); + uint64_t front, back; + ShardBlockId(blockId, front, back); + return fmt::format("{}{}{:016x}{:016x}", backend, dir, front, back); } -std::string SpaceLayout::DataFilePath(const std::string& blockId, bool activated) const +Status SpaceLayout::Commit(const std::string& blockId, bool success) const { - auto id = static_cast(static_cast(blockId.data())); - return fmt::format("{}{}/{:016x}/{:016x}.{}", this->StorageBackend(blockId), - this->DataFileRoot(), id->front(), id->back(), activated ? "act" : "dat"); + const auto& activated = this->DataFilePath(blockId, true); + const auto& archived = this->DataFilePath(blockId, false); + if (success) { return File::Rename(activated, archived); } + File::Remove(activated); + return Status::OK(); +} + +std::vector SpaceLayout::RelativeRoots() const +{ + return {DataFileRoot(), TempFileRoot()}; } Status SpaceLayout::AddStorageBackend(const std::string& path) @@ -110,8 +114,19 @@ std::string SpaceLayout::StorageBackend(const std::string& blockId) const return this->storageBackends_[hasher(blockId) % this->storageBackends_.size()]; } -std::vector SpaceLayout::RelativeRoots() const { return {this->DataFileRoot()}; } +std::string SpaceLayout::DataFileRoot() const { return "data/"; } + +std::string SpaceLayout::TempFileRoot() const { return "temp/"; } -std::string SpaceLayout::DataFileRoot() const { return "data"; } +void SpaceLayout::ShardBlockId(const std::string& blockId, uint64_t& front, uint64_t& back) const +{ + constexpr size_t blockIdSize = 16; + constexpr size_t nU64PerBlock = blockIdSize / sizeof(uint64_t); + using BlockId = std::array; + static_assert(sizeof(BlockId) == blockIdSize); + auto id = static_cast(static_cast(blockId.data())); + front = id->front(); + back = id->back(); +} } // namespace UC diff --git a/ucm/store/pcstore/cc/domain/space/space_layout.h b/ucm/store/pcstore/cc/domain/space/space_layout.h new file mode 100644 index 000000000..a9a7a4d79 --- /dev/null +++ b/ucm/store/pcstore/cc/domain/space/space_layout.h @@ -0,0 +1,55 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_SPACE_LAYOUT_H +#define UNIFIEDCACHE_SPACE_LAYOUT_H + +#include +#include +#include "status/status.h" + +namespace UC { + +class SpaceLayout { +public: + Status Setup(const std::vector& storageBackends); + std::string DataFilePath(const std::string& blockId, bool activated) const; + Status Commit(const std::string& blockId, bool success) const; + +private: + std::vector RelativeRoots() const; + Status AddStorageBackend(const std::string& path); + Status AddFirstStorageBackend(const std::string& path); + Status AddSecondaryStorageBackend(const std::string& path); + std::string StorageBackend(const std::string& blockId) const; + std::string DataFileRoot() const; + std::string TempFileRoot() const; + void ShardBlockId(const std::string& blockId, uint64_t& front, uint64_t& back) const; + +private: + std::vector storageBackends_; +}; + +} // namespace UC + +#endif diff --git a/ucm/store/pcstore/cc/domain/space/space_manager.cc b/ucm/store/pcstore/cc/domain/space/space_manager.cc new file mode 100644 index 000000000..556646f46 --- /dev/null +++ b/ucm/store/pcstore/cc/domain/space/space_manager.cc @@ -0,0 +1,86 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "space_manager.h" +#include +#include "file/file.h" +#include "logger/logger.h" + +namespace UC { + +Status SpaceManager::Setup(const std::vector& storageBackends, const size_t blockSize) +{ + if (blockSize == 0) { + UC_ERROR("Invalid block size({}).", blockSize); + return Status::InvalidParam(); + } + auto status = this->layout_.Setup(storageBackends); + if (status.Failure()) { return status; } + this->blockSize_ = blockSize; + return Status::OK(); +} + +Status SpaceManager::NewBlock(const std::string& blockId) +{ + const auto& activated = this->layout_.DataFilePath(blockId, true); + const auto& archived = this->layout_.DataFilePath(blockId, false); + if (File::Access(archived, IFile::AccessMode::EXIST).Success()) { + return Status::DuplicateKey(); + } + auto file = File::Make(activated); + if (!file) { return Status::OutOfMemory(); } + auto mode = IFile::OpenFlag::CREATE | IFile::OpenFlag::EXCL | IFile::OpenFlag::READ_WRITE; + auto s = file->Open(mode); + if (s.Failure()) { + if (s != Status::DuplicateKey()) { return s; } + mode = IFile::OpenFlag::READ_WRITE; + if ((s = file->Open(mode)).Failure()) { return s; } + IFile::FileStat st; + if ((s = file->Stat(st)).Failure()) { return s; } + const auto now = std::chrono::system_clock::now(); + const auto mtime = std::chrono::system_clock::from_time_t(st.st_mtime); + constexpr auto reuseBlockAge = std::chrono::seconds(300); + if (now - mtime <= reuseBlockAge) { return Status::DuplicateKey(); } + } + return file->Truncate(this->blockSize_); +} + +Status SpaceManager::CommitBlock(const std::string& blockId, bool success) +{ + return this->layout_.Commit(blockId, success); +} + +bool SpaceManager::LookupBlock(const std::string& blockId) const +{ + const auto& path = this->layout_.DataFilePath(blockId, false); + constexpr auto mode = + IFile::AccessMode::EXIST | IFile::AccessMode::READ | IFile::AccessMode::WRITE; + auto s = File::Access(path, mode); + if (s.Failure()) { + if (s != Status::NotFound()) { UC_ERROR("Failed({}) to access file({}).", s, path); } + return false; + } + return true; +} + +} // namespace UC diff --git a/ucm/store/pcstore/cc/domain/space/space_manager.h b/ucm/store/pcstore/cc/domain/space/space_manager.h new file mode 100644 index 000000000..4656c952f --- /dev/null +++ b/ucm/store/pcstore/cc/domain/space/space_manager.h @@ -0,0 +1,46 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_SPACE_MANAGER_H +#define UNIFIEDCACHE_SPACE_MANAGER_H + +#include "space_layout.h" + +namespace UC { + +class SpaceManager { +public: + Status Setup(const std::vector& storageBackends, const size_t blockSize); + Status NewBlock(const std::string& blockId); + Status CommitBlock(const std::string& blockId, bool success); + bool LookupBlock(const std::string& blockId) const; + const SpaceLayout* GetSpaceLayout() const { return &this->layout_; } + +private: + SpaceLayout layout_; + size_t blockSize_; +}; + +} // namespace UC + +#endif diff --git a/ucm/store/pcstore/cc/domain/trans/share_buffer.cc b/ucm/store/pcstore/cc/domain/trans/share_buffer.cc new file mode 100644 index 000000000..35e9ae3da --- /dev/null +++ b/ucm/store/pcstore/cc/domain/trans/share_buffer.cc @@ -0,0 +1,308 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "share_buffer.h" +#include +#include +#include +#include +#include "file/file.h" +#include "logger/logger.h" +#include "trans/buffer.h" + +namespace UC { + +static constexpr int32_t SHARE_BUFFER_MAGIC = (('S' << 16) | ('b' << 8) | 1); + +struct ShareMutex { + pthread_mutex_t mutex; + ~ShareMutex() = delete; + void Init() + { + pthread_mutexattr_t attr; + pthread_mutexattr_init(&attr); + pthread_mutexattr_setpshared(&attr, PTHREAD_PROCESS_SHARED); + pthread_mutexattr_setrobust(&attr, PTHREAD_MUTEX_ROBUST); + pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_ADAPTIVE_NP); + pthread_mutex_init(&mutex, &attr); + pthread_mutexattr_destroy(&attr); + } + void Lock() { pthread_mutex_lock(&mutex); } + void Unlock() { pthread_mutex_unlock(&mutex); } +}; + +struct ShareLock { + pthread_spinlock_t lock; + ~ShareLock() = delete; + void Init() { pthread_spin_init(&lock, PTHREAD_PROCESS_SHARED); } + void Lock() { pthread_spin_lock(&lock); } + void Unlock() { pthread_spin_unlock(&lock); } +}; + +struct ShareBlockId { + uint64_t lo{0}; + uint64_t hi{0}; + void Set(const std::string& block) + { + auto data = static_cast((const void*)block.data()); + lo = data[0]; + hi = data[1]; + } + void Reset() { lo = hi = 0; } + bool Used() const { return lo != 0 || hi != 0; } + bool operator==(const std::string& block) const + { + auto data = static_cast((const void*)block.data()); + return lo == data[0] && hi == data[1]; + } +}; + +enum class ShareBlockStatus { INIT, LOADING, LOADED, FAILURE }; + +struct ShareBlockHeader { + ShareBlockId id; + ShareLock mutex; + int32_t ref; + ShareBlockStatus status; + size_t offset; + void* Data() { return reinterpret_cast(this) + offset; } +}; + +struct ShareBufferHeader { + ShareMutex mutex; + std::atomic magic; + int32_t ref; + size_t blockSize; + size_t blockNumber; + ShareBlockHeader headers[0]; +}; + +inline std::string GenShareBufferName(const size_t blockSize, const size_t blockNumber, + const bool ioDirect, const size_t nSharer) +{ + return fmt::format("uc.buf-{}-{}-{}-{:04x}", blockSize, blockNumber, ioDirect, nSharer); +} + +Status ShareBuffer::Setup(const size_t blockSize, const size_t blockNumber, const bool ioDirect, + const size_t nSharer) +{ + this->blockSize_ = blockSize; + this->blockNumber_ = blockNumber; + this->ioDirect_ = ioDirect; + this->nSharer_ = nSharer; + this->addr_ = nullptr; + this->shmName_ = GenShareBufferName(blockSize, blockNumber, ioDirect, nSharer); + auto file = File::Make(this->shmName_); + if (!file) { return Status::OutOfMemory(); } + auto flags = IFile::OpenFlag::CREATE | IFile::OpenFlag::EXCL | IFile::OpenFlag::READ_WRITE; + auto s = file->ShmOpen(flags); + if (s.Success()) { return this->InitShmBuffer(file.get()); } + if (s == Status::DuplicateKey()) { return this->LoadShmBuffer(file.get()); } + return s; +} + +ShareBuffer::~ShareBuffer() +{ + if (!this->addr_) { return; } + auto bufferHeader = (ShareBufferHeader*)this->addr_; + bufferHeader->mutex.Lock(); + auto ref = (--bufferHeader->ref); + bufferHeader->mutex.Unlock(); + void* dataAddr = static_cast(this->addr_) + this->DataOffset(); + Trans::Buffer::UnregisterHostBuffer(dataAddr); + const auto shmSize = this->ShmSize(); + File::MUnmap(this->addr_, shmSize); + if (ref == 0) { File::ShmUnlink(this->shmName_); } +} + +std::shared_ptr ShareBuffer::MakeReader(const std::string& block, + const std::string& path) +{ + auto index = this->AcquireBlock(block); + try { + void* addr = this->BlockAt(index); + return std::shared_ptr( + new Reader{block, path, blockSize_, ioDirect_, nSharer_, addr}, + [this, index](auto) { this->ReleaseBlock(index); }); + } catch (...) { + this->ReleaseBlock(index); + UC_ERROR("Failed to create reader."); + return nullptr; + } +} + +size_t ShareBuffer::DataOffset() const +{ + static const auto pageSize = sysconf(_SC_PAGESIZE); + auto headerSize = sizeof(ShareBufferHeader) + sizeof(ShareBlockHeader) * this->blockNumber_; + return (headerSize + pageSize - 1) & ~(pageSize - 1); +} + +size_t ShareBuffer::ShmSize() const +{ + return this->DataOffset() + this->blockSize_ * this->blockNumber_; +} + +Status ShareBuffer::InitShmBuffer(IFile* file) +{ + const auto shmSize = this->ShmSize(); + auto s = file->Truncate(shmSize); + if (s.Failure()) { return s; } + s = file->MMap(this->addr_, shmSize, true, true, true); + if (s.Failure()) { return s; } + auto bufferHeader = (ShareBufferHeader*)this->addr_; + bufferHeader->magic = 1; + bufferHeader->mutex.Init(); + bufferHeader->ref = this->nSharer_; + bufferHeader->blockSize = this->blockSize_; + bufferHeader->blockNumber = this->blockNumber_; + const auto dataOffset = this->DataOffset(); + for (size_t i = 0; i < this->blockNumber_; i++) { + bufferHeader->headers[i].id.Reset(); + bufferHeader->headers[i].mutex.Init(); + bufferHeader->headers[i].ref = 0; + bufferHeader->headers[i].status = ShareBlockStatus::INIT; + const auto headerOffset = sizeof(ShareBufferHeader) + sizeof(ShareBlockHeader) * i; + bufferHeader->headers[i].offset = dataOffset + this->blockSize_ * i - headerOffset; + } + bufferHeader->magic = SHARE_BUFFER_MAGIC; + void* dataAddr = static_cast(this->addr_) + dataOffset; + auto dataSize = shmSize - dataOffset; + auto status = Trans::Buffer::RegisterHostBuffer(dataAddr, dataSize); + if (status.Success()) { return Status::OK(); } + UC_ERROR("Failed({}) to regitster host buffer({}).", status.ToString(), dataSize); + return Status::Error(); +} + +Status ShareBuffer::LoadShmBuffer(IFile* file) +{ + auto s = file->ShmOpen(IFile::OpenFlag::READ_WRITE); + if (s.Failure()) { return s; } + const auto shmSize = this->ShmSize(); + s = file->Truncate(shmSize); + if (s.Failure()) { return s; } + s = file->MMap(this->addr_, shmSize, true, true, true); + if (s.Failure()) { return s; } + auto bufferHeader = (ShareBufferHeader*)this->addr_; + constexpr auto retryInterval = std::chrono::milliseconds(100); + constexpr auto maxTryTime = 100; + auto tryTime = 0; + do { + if (bufferHeader->magic == SHARE_BUFFER_MAGIC) { break; } + if (tryTime > maxTryTime) { + UC_ERROR("Shm file({}) not ready.", file->Path()); + return Status::Retry(); + } + std::this_thread::sleep_for(retryInterval); + tryTime++; + } while (true); + const auto dataOffset = this->DataOffset(); + void* dataAddr = static_cast(this->addr_) + dataOffset; + auto dataSize = shmSize - dataOffset; + auto status = Trans::Buffer::RegisterHostBuffer(dataAddr, dataSize); + if (status.Success()) { return Status::OK(); } + UC_ERROR("Failed({}) to regitster host buffer({}).", status.ToString(), dataSize); + return Status::Error(); +} + +size_t ShareBuffer::AcquireBlock(const std::string& block) +{ + static std::hash hasher{}; + auto pos = hasher(block) % this->blockNumber_; + auto bufferHeader = (ShareBufferHeader*)this->addr_; + auto reusedIdx = this->blockNumber_; + bufferHeader->mutex.Lock(); + for (size_t i = 0;; i++) { + if (!bufferHeader->headers[pos].id.Used()) { + if (reusedIdx == this->blockNumber_) { reusedIdx = pos; } + break; + } + if (bufferHeader->headers[pos].id == block) { + reusedIdx = pos; + break; + } + if (bufferHeader->headers[pos].ref <= 0) { + if (reusedIdx == this->blockNumber_) { reusedIdx = pos; } + } + pos = (pos + 1) % this->blockNumber_; + if (i == this->blockNumber_) { + UC_WARN("Buffer({}) used out.", this->blockNumber_); + i = 0; + } + } + auto blockHeader = bufferHeader->headers + reusedIdx; + blockHeader->mutex.Lock(); + if (blockHeader->ref <= 0) { + blockHeader->id.Set(block); + blockHeader->ref = this->nSharer_; + blockHeader->status = ShareBlockStatus::INIT; + } + blockHeader->mutex.Unlock(); + bufferHeader->mutex.Unlock(); + return reusedIdx; +} + +void ShareBuffer::ReleaseBlock(const size_t index) +{ + auto bufferHeader = (ShareBufferHeader*)this->addr_; + bufferHeader->headers[index].mutex.Lock(); + bufferHeader->headers[index].ref--; + bufferHeader->headers[index].mutex.Unlock(); +} + +void* ShareBuffer::BlockAt(const size_t index) +{ + auto bufferHeader = (ShareBufferHeader*)this->addr_; + return bufferHeader->headers + index; +} + +Status ShareBuffer::Reader::Ready4Read() +{ + auto header = (ShareBlockHeader*)this->addr_; + if (header->status == ShareBlockStatus::LOADED) { return Status::OK(); } + if (header->status == ShareBlockStatus::FAILURE) { return Status::Error(); } + if (header->status == ShareBlockStatus::LOADING) { return Status::Retry(); } + auto loading = false; + header->mutex.Lock(); + if (header->status == ShareBlockStatus::INIT) { + header->status = ShareBlockStatus::LOADING; + loading = true; + } + header->mutex.Unlock(); + if (!loading) { return Status::Retry(); } + auto s = File::Read(this->path_, 0, this->length_, this->GetData(), this->ioDirect_); + if (s.Success()) { + header->status = ShareBlockStatus::LOADED; + return Status::OK(); + } + header->status = ShareBlockStatus::FAILURE; + return s; +} + +uintptr_t ShareBuffer::Reader::GetData() +{ + auto header = (ShareBlockHeader*)this->addr_; + return (uintptr_t)header->Data(); +} + +} // namespace UC diff --git a/ucm/store/pcstore/cc/domain/trans/share_buffer.h b/ucm/store/pcstore/cc/domain/trans/share_buffer.h new file mode 100644 index 000000000..3fce7a87c --- /dev/null +++ b/ucm/store/pcstore/cc/domain/trans/share_buffer.h @@ -0,0 +1,85 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_SHARE_BUFFER_H +#define UNIFIEDCACHE_SHARE_BUFFER_H + +#include +#include +#include +#include "file/ifile.h" +#include "status/status.h" + +namespace UC { + +class ShareBuffer { +public: + class Reader { + std::string block_; + std::string path_; + size_t length_; + bool ioDirect_; + size_t nSharer_; + void* addr_; + + public: + Status Ready4Read(); + uintptr_t GetData(); + + private: + Reader(const std::string& block, const std::string& path, const size_t length, + const bool ioDirect, const size_t nSharer, void* addr) + : block_{block}, path_{path}, length_{length}, ioDirect_{ioDirect}, nSharer_{nSharer}, + addr_{addr} + { + } + friend class ShareBuffer; + }; + +public: + Status Setup(const size_t blockSize, const size_t blockNumber, const bool ioDirect, + const size_t nSharer); + ~ShareBuffer(); + std::shared_ptr MakeReader(const std::string& block, const std::string& path); + +private: + size_t DataOffset() const; + size_t ShmSize() const; + Status InitShmBuffer(IFile* file); + Status LoadShmBuffer(IFile* file); + size_t AcquireBlock(const std::string& block); + void ReleaseBlock(const size_t index); + void* BlockAt(const size_t index); + +private: + size_t blockSize_; + size_t blockNumber_; + bool ioDirect_; + size_t nSharer_; + std::string shmName_; + void* addr_; +}; + +} // namespace UC + +#endif diff --git a/ucm/store/pcstore/cc/domain/trans/trans_manager.cc b/ucm/store/pcstore/cc/domain/trans/trans_manager.cc new file mode 100644 index 000000000..68f6a6ec7 --- /dev/null +++ b/ucm/store/pcstore/cc/domain/trans/trans_manager.cc @@ -0,0 +1,118 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "trans_manager.h" +#include "logger/logger.h" + +namespace UC { + +Status TransManager::Setup(const size_t rankSize, const int32_t deviceId, const size_t streamNumber, + const size_t blockSize, const size_t ioSize, const bool ioDirect, + const size_t bufferNumber, const SpaceLayout* layout, + const size_t timeoutMs) +{ + auto s = Status::OK(); + if (rankSize > 1) { + s = this->shareQueue_.Setup(rankSize, deviceId, streamNumber, blockSize, ioSize, ioDirect, + bufferNumber, layout, &this->failureSet_); + if (s.Failure()) { return s; } + } + s = this->queue_.Setup(deviceId, streamNumber, blockSize, ioSize, ioDirect, bufferNumber, + layout, &this->failureSet_); + if (s.Failure()) { return s; } + this->rankSize_ = rankSize; + this->timeoutMs_ = timeoutMs; + return Status::OK(); +} + +Status TransManager::Submit(TransTask task, size_t& taskId) noexcept +{ + taskId = task.id; + const auto taskStr = task.Str(); + const auto blockNumber = task.GroupNumber(); + TaskPtr taskPtr = nullptr; + WaiterPtr waiterPtr = nullptr; + try { + taskPtr = std::make_shared(std::move(task)); + waiterPtr = std::make_shared(blockNumber, taskPtr->startTp); + } catch (const std::exception& e) { + UC_ERROR("Failed({}) to submit task({}).", e.what(), taskStr); + return Status::OutOfMemory(); + } + std::unique_lock lg(mutex_); + const auto& [iter, success] = tasks_.emplace(taskId, std::make_pair(taskPtr, waiterPtr)); + if (!success) { + UC_ERROR("Failed to submit task({}).", taskStr); + return Status::OutOfMemory(); + } + lg.unlock(); + if (this->rankSize_ > 1 && iter->second.first->type == TransTask::Type::LOAD) { + this->shareQueue_.Dispatch(iter->second.first, iter->second.second); + return Status::OK(); + } + this->queue_.Dispatch(iter->second.first, iter->second.second); + return Status::OK(); +} + +Status TransManager::Wait(const size_t taskId) noexcept +{ + TaskPtr task = nullptr; + WaiterPtr waiter = nullptr; + { + std::lock_guard lg(mutex_); + auto iter = tasks_.find(taskId); + if (iter == tasks_.end()) { + UC_ERROR("Not found task by id({}).", taskId); + return Status::NotFound(); + } + task = iter->second.first; + waiter = iter->second.second; + tasks_.erase(iter); + } + if (!waiter->Wait(timeoutMs_)) { + UC_ERROR("Task({}) timeout({}).", task->Str(), timeoutMs_); + failureSet_.Insert(taskId); + waiter->Wait(); + } + auto failure = failureSet_.Contains(taskId); + if (failure) { + failureSet_.Remove(taskId); + UC_ERROR("Task({}) failed.", task->Str()); + return Status::Error(); + } + return Status::OK(); +} + +Status TransManager::Check(const size_t taskId, bool& finish) noexcept +{ + std::lock_guard lg(mutex_); + auto iter = tasks_.find(taskId); + if (iter == tasks_.end()) { + UC_ERROR("Not found task by id({}).", taskId); + return Status::NotFound(); + } + finish = iter->second.second->Finish(); + return Status::OK(); +} + +} // namespace UC diff --git a/ucm/store/pcstore/cc/domain/trans/trans_manager.h b/ucm/store/pcstore/cc/domain/trans/trans_manager.h new file mode 100644 index 000000000..54a74bdb1 --- /dev/null +++ b/ucm/store/pcstore/cc/domain/trans/trans_manager.h @@ -0,0 +1,56 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_MANAGER_H +#define UNIFIEDCACHE_TRANS_MANAGER_H + +#include "trans_queue.h" +#include "trans_share_queue.h" + +namespace UC { + +class TransManager { +public: + Status Setup(const size_t rankSize, const int32_t deviceId, const size_t streamNumber, + const size_t blockSize, const size_t ioSize, const bool ioDirect, + const size_t bufferNumber, const SpaceLayout* layout, const size_t timeoutMs); + Status Submit(TransTask task, size_t& taskId) noexcept; + Status Wait(const size_t taskId) noexcept; + Status Check(const size_t taskId, bool& finish) noexcept; + +private: + using TaskPtr = std::shared_ptr; + using WaiterPtr = std::shared_ptr; + using TaskPair = std::pair; + TransShareQueue shareQueue_; + TransQueue queue_; + size_t rankSize_; + size_t timeoutMs_; + std::mutex mutex_; + std::unordered_map tasks_; + TaskSet failureSet_; +}; + +} // namespace UC + +#endif diff --git a/ucm/store/pcstore/cc/domain/trans/trans_queue.cc b/ucm/store/pcstore/cc/domain/trans/trans_queue.cc new file mode 100644 index 000000000..83b8ce458 --- /dev/null +++ b/ucm/store/pcstore/cc/domain/trans/trans_queue.cc @@ -0,0 +1,174 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "trans_queue.h" +#include "file/file.h" +#include "logger/logger.h" +#include "trans/device.h" + +namespace UC { + +void TransQueue::DeviceWorker(BlockTask&& task) +{ + if (this->failureSet_->Contains(task.owner)) { + task.done(false); + return; + } + auto number = task.shards.size(); + auto size = this->ioSize_; + auto done = task.done; + auto devPtrs = (void**)task.shards.data(); + auto hostPtr = task.buffer.get(); + auto s = Status::OK(); + if (task.type == TransTask::Type::LOAD) { + s = stream_->HostToDevice(hostPtr, devPtrs, size, number); + } else { + s = stream_->DeviceToHost(devPtrs, hostPtr, size, number); + if (s.Success()) { this->filePool_.Push(std::move(task)); } + } + if (s.Failure()) { this->failureSet_->Insert(task.owner); } + done(s.Success()); + return; +} + +void TransQueue::FileWorker(BlockTask&& task) +{ + if (this->failureSet_->Contains(task.owner)) { + task.done(false); + return; + } + auto hostPtr = (uintptr_t)task.buffer.get(); + auto length = this->ioSize_ * task.shards.size(); + if (task.type == TransTask::Type::DUMP) { + const auto& path = this->layout_->DataFilePath(task.block, true); + auto s = File::Write(path, 0, length, hostPtr, this->ioDirect_, true); + this->layout_->Commit(task.block, s.Success()); + return; + } + const auto& path = this->layout_->DataFilePath(task.block, false); + auto s = File::Read(path, 0, length, hostPtr, this->ioDirect_); + if (s.Success()) { + this->devPool_.Push(std::move(task)); + return; + } + this->failureSet_->Insert(task.owner); + task.done(false); +} + +Status TransQueue::Setup(const int32_t deviceId, const size_t streamNumber, const size_t blockSize, + const size_t ioSize, const bool ioDirect, const size_t bufferNumber, + const SpaceLayout* layout, TaskSet* failureSet_) +{ + Trans::Device device; + auto ts = device.Setup(deviceId); + if (ts.Failure()) { + UC_ERROR("Failed({}) to set context on device({}).", ts.ToString(), deviceId); + return Status::Error(); + } + buffer_ = device.MakeBuffer(); + stream_ = device.MakeStream(); + if (!buffer_ || !stream_) { + UC_ERROR("Failed to make buffer and stream on device({}).", deviceId); + return Status::Error(); + } + ts = buffer_->MakeHostBuffers(blockSize, bufferNumber); + if (ts.Failure()) { + UC_ERROR("Failed({}) to make host buffer({},{}).", ts.ToString(), blockSize, bufferNumber); + return Status::Error(); + } + auto success = + this->devPool_.SetWorkerFn([this](auto t, auto) { this->DeviceWorker(std::move(t)); }) + .Run(); + if (!success) { return Status::Error(); } + success = this->filePool_.SetWorkerFn([this](auto t, auto) { this->FileWorker(std::move(t)); }) + .SetNWorker(streamNumber) + .Run(); + if (!success) { return Status::Error(); } + this->layout_ = layout; + this->ioSize_ = ioSize; + this->ioDirect_ = ioDirect; + this->failureSet_ = failureSet_; + return Status::OK(); +} + +void TransQueue::Dispatch(TaskPtr task, WaiterPtr waiter) +{ + if (task->type == TransTask::Type::DUMP) { + this->DispatchDump(task, waiter); + return; + } + task->ForEachGroup( + [task, waiter, this](const std::string& block, std::vector& shards) { + BlockTask blockTask; + blockTask.owner = task->id; + blockTask.block = block; + blockTask.type = task->type; + auto bufferSize = this->ioSize_ * shards.size(); + std::swap(blockTask.shards, shards); + blockTask.buffer = buffer_->GetHostBuffer(bufferSize); + blockTask.done = [task, waiter, ioSize = this->ioSize_](bool success) { + if (!success) { + waiter->Done(nullptr); + } else { + waiter->Done([task, ioSize] { UC_DEBUG("{}", task->Epilog(ioSize)); }); + } + }; + if (task->type == TransTask::Type::DUMP) { + this->devPool_.Push(std::move(blockTask)); + } else { + this->filePool_.Push(std::move(blockTask)); + } + }); +} + +void TransQueue::DispatchDump(TaskPtr task, WaiterPtr waiter) +{ + std::vector blocks; + blocks.reserve(task->GroupNumber()); + task->ForEachGroup( + [task, &blocks, this](const std::string& block, std::vector& shards) { + BlockTask blockTask; + blockTask.owner = task->id; + blockTask.block = block; + blockTask.type = task->type; + auto bufferSize = this->ioSize_ * shards.size(); + blockTask.buffer = buffer_->GetHostBuffer(bufferSize); + std::swap(blockTask.shards, shards); + auto device = (void**)blockTask.shards.data(); + auto host = blockTask.buffer.get(); + stream_->DeviceToHostAsync(device, host, this->ioSize_, blockTask.shards.size()); + blocks.push_back(std::move(blockTask)); + }); + auto s = stream_->Synchronized(); + if (s.Failure()) { this->failureSet_->Insert(task->id); } + for (auto&& block : blocks) { + if (s.Failure()) { + waiter->Done(nullptr); + return; + } + this->filePool_.Push(std::move(block)); + waiter->Done([task, ioSize = this->ioSize_] { UC_DEBUG("{}", task->Epilog(ioSize)); }); + } +} + +} // namespace UC diff --git a/ucm/store/pcstore/cc/domain/trans/trans_queue.h b/ucm/store/pcstore/cc/domain/trans/trans_queue.h new file mode 100644 index 000000000..d61f1b9bf --- /dev/null +++ b/ucm/store/pcstore/cc/domain/trans/trans_queue.h @@ -0,0 +1,71 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRAN_QUEUE_H +#define UNIFIEDCACHE_TRAN_QUEUE_H + +#include "space/space_layout.h" +#include "task/task_set.h" +#include "task/task_waiter.h" +#include "thread/thread_pool.h" +#include "trans/buffer.h" +#include "trans/stream.h" +#include "trans_task.h" + +namespace UC { + +class TransQueue { + using TaskPtr = std::shared_ptr; + using WaiterPtr = std::shared_ptr; + struct BlockTask { + size_t owner; + std::string block; + TransTask::Type type; + std::vector shards; + std::shared_ptr buffer; + std::function done; + }; + void DeviceWorker(BlockTask&& task); + void FileWorker(BlockTask&& task); + +public: + Status Setup(const int32_t deviceId, const size_t streamNumber, const size_t blockSize, + const size_t ioSize, const bool ioDirect, const size_t bufferNumber, + const SpaceLayout* layout, TaskSet* failureSet_); + void Dispatch(TaskPtr task, WaiterPtr waiter); + void DispatchDump(TaskPtr task, WaiterPtr waiter); + +private: + std::unique_ptr buffer_{nullptr}; + std::unique_ptr stream_{nullptr}; + const SpaceLayout* layout_; + size_t ioSize_; + bool ioDirect_; + ThreadPool devPool_; + ThreadPool filePool_; + TaskSet* failureSet_; +}; + +} // namespace UC + +#endif diff --git a/ucm/store/pcstore/cc/domain/trans/trans_share_queue.cc b/ucm/store/pcstore/cc/domain/trans/trans_share_queue.cc new file mode 100644 index 000000000..c43d16a85 --- /dev/null +++ b/ucm/store/pcstore/cc/domain/trans/trans_share_queue.cc @@ -0,0 +1,169 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "trans_share_queue.h" +#include "logger/logger.h" +#include "trans/device.h" + +namespace UC { + +TransShareQueue::~TransShareQueue() +{ + { + std::lock_guard lg(this->mutex_); + this->stop_ = true; + this->cv_.notify_all(); + } + for (auto& w : this->threads_) { + if (w.joinable()) { w.join(); } + } +} + +Status TransShareQueue::Setup(const size_t nSharer, const int32_t deviceId, + const size_t streamNumber, const size_t blockSize, + const size_t ioSize, const bool ioDirect, const size_t bufferNumber, + const SpaceLayout* layout, TaskSet* failureSet) +{ + this->deviceId_ = deviceId; + this->streamNumber_ = streamNumber; + this->ioSize_ = ioSize; + this->layout_ = layout; + this->failureSet_ = failureSet; + auto status = this->buffer_.Setup(blockSize, bufferNumber, ioDirect, nSharer); + if (status.Failure()) { return status; } + std::list> start(streamNumber); + std::list> fut; + for (auto& s : start) { + fut.push_back(s.get_future()); + this->threads_.emplace_back([&] { this->WorkerLoop(s); }); + } + for (auto& f : fut) { + if (status.Failure()) { break; } + status = f.get(); + } + return status; +} + +void TransShareQueue::Dispatch(TaskPtr task, WaiterPtr waiter) +{ + std::lock_guard lg(this->mutex_); + task->ForEachGroup( + [task, waiter, this](const std::string& block, std::vector& shards) { + BlockTask blockTask; + blockTask.reader = + this->buffer_.MakeReader(block, this->layout_->DataFilePath(block, false)); + blockTask.owner = task->id; + std::swap(blockTask.shards, shards); + blockTask.done = [task, waiter, ioSize = this->ioSize_](bool success) { + if (!success) { + waiter->Done(nullptr); + } else { + waiter->Done([task, ioSize] { UC_DEBUG("{}", task->Epilog(ioSize)); }); + } + }; + this->wait_.push_back(blockTask); + }); + this->cv_.notify_all(); +} + +void TransShareQueue::WorkerLoop(std::promise& status) +{ + Trans::Device device; + auto s = device.Setup(deviceId_); + if (s.Failure()) { + UC_ERROR("Failed({}) to set context on device({}).", s.ToString(), deviceId_); + status.set_value(Status::Error()); + return; + } + auto stream = device.MakeStream(); + if (!stream) { + UC_ERROR("Failed to create stream on device({}).", deviceId_); + status.set_value(Status::Error()); + return; + } + status.set_value(Status::OK()); + while (!stop_) { Worker(*stream); } +} + +void TransShareQueue::Worker(Trans::Stream& stream) +{ + std::unique_lock ul{this->mutex_}; + if (this->load_.empty() && this->wait_.empty()) { + this->cv_.wait( + ul, [this] { return this->stop_ || !this->load_.empty() || !this->wait_.empty(); }); + } + if (this->stop_) { return; } + for (auto iter = this->load_.begin(); iter != this->load_.end(); iter++) { + auto s = iter->reader->Ready4Read(); + if (s != Status::Retry()) { + auto task = std::move(*iter); + this->load_.erase(iter); + ul.unlock(); + this->HandleReadyTask(s, task, stream); + return; + } + } + if (this->load_.size() >= this->streamNumber_) { return; } + if (this->wait_.empty()) { return; } + auto task = std::move(this->wait_.front()); + this->wait_.pop_front(); + ul.unlock(); + this->HandleLoadTask(task, stream); +} + +void TransShareQueue::HandleReadyTask(Status s, BlockTask& task, Trans::Stream& stream) +{ + if (this->failureSet_->Contains(task.owner)) { + task.done(false); + return; + } + if (s.Success()) { + auto host = (void*)task.reader->GetData(); + auto device = (void**)task.shards.data(); + auto status = stream.HostToDeviceAsync(host, device, this->ioSize_, task.shards.size()); + if (status.Failure()) [[unlikely]] { + UC_ERROR("Failed({}) to copy data from host to device.", status.ToString()); + s = Status::Error(); + } + } + if (s.Failure()) { this->failureSet_->Insert(task.owner); } + task.done(s.Success()); +} + +void TransShareQueue::HandleLoadTask(BlockTask& task, Trans::Stream& stream) +{ + if (this->failureSet_->Contains(task.owner)) { + task.done(false); + return; + } + auto s = task.reader->Ready4Read(); + if (s == Status::Retry()) { + std::lock_guard lg{this->mutex_}; + this->load_.push_back(task); + this->cv_.notify_one(); + return; + } + this->HandleReadyTask(s, task, stream); +} + +} // namespace UC diff --git a/ucm/store/pcstore/cc/domain/trans/trans_share_queue.h b/ucm/store/pcstore/cc/domain/trans/trans_share_queue.h new file mode 100644 index 000000000..7c40b0542 --- /dev/null +++ b/ucm/store/pcstore/cc/domain/trans/trans_share_queue.h @@ -0,0 +1,78 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_SHARE_QUEUE_H +#define UNIFIEDCACHE_TRANS_SHARE_QUEUE_H + +#include +#include +#include +#include +#include "share_buffer.h" +#include "space/space_layout.h" +#include "task/task_set.h" +#include "task/task_waiter.h" +#include "trans/stream.h" +#include "trans_task.h" + +namespace UC { + +class TransShareQueue { + using TaskPtr = std::shared_ptr; + using WaiterPtr = std::shared_ptr; + struct BlockTask { + std::shared_ptr reader; + size_t owner; + std::vector shards; + std::function done; + }; + int32_t deviceId_; + size_t streamNumber_; + size_t ioSize_; + ShareBuffer buffer_; + const SpaceLayout* layout_; + TaskSet* failureSet_; + std::atomic_bool stop_{false}; + std::mutex mutex_; + std::condition_variable cv_; + std::list load_; + std::list wait_; + std::list threads_; + +public: + ~TransShareQueue(); + Status Setup(const size_t nSharer, const int32_t deviceId, const size_t streamNumber, + const size_t blockSize, const size_t ioSize, const bool ioDirect, + const size_t bufferNumber, const SpaceLayout* layout, TaskSet* failureSet); + void Dispatch(TaskPtr task, WaiterPtr waiter); + +private: + void WorkerLoop(std::promise& status); + void Worker(Trans::Stream& stream); + void HandleReadyTask(Status s, BlockTask& task, Trans::Stream& stream); + void HandleLoadTask(BlockTask& task, Trans::Stream& stream); +}; + +} // namespace UC + +#endif diff --git a/ucm/store/pcstore/cc/domain/trans/trans_task.h b/ucm/store/pcstore/cc/domain/trans/trans_task.h new file mode 100644 index 000000000..8fcb48fba --- /dev/null +++ b/ucm/store/pcstore/cc/domain/trans/trans_task.h @@ -0,0 +1,89 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TRANS_TASK_H +#define UNIFIEDCACHE_TRANS_TASK_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace UC { + +class TransTask { + static size_t NextId() noexcept + { + static std::atomic id{invalid + 1}; + return id.fetch_add(1, std::memory_order_relaxed); + }; + static double NowTp() noexcept + { + auto now = std::chrono::steady_clock::now().time_since_epoch(); + return std::chrono::duration(now).count(); + } + +public: + enum class Type { DUMP, LOAD }; + size_t id; + Type type; + double startTp{0}; + static constexpr auto invalid = std::numeric_limits::min(); + TransTask(Type&& type, std::string&& brief) + : id{NextId()}, type{std::move(type)}, startTp{NowTp()}, brief_{std::move(brief)} + { + } + void Append(const std::string& block, const uintptr_t address) + { + grouped_[block].push_back(address); + number_++; + } + auto Str() const noexcept { return fmt::format("{},{},{}", id, brief_, number_); } + size_t GroupNumber() const { return grouped_.size(); } + void ForEachGroup(std::function&)> fn) + { + for (auto& [block, shards] : grouped_) { fn(block, shards); } + } + auto Epilog(const size_t ioSize) const noexcept + { + auto total = ioSize * number_; + auto costs = NowTp() - startTp; + auto bw = double(total) / costs / 1e9; + return fmt::format("Task({},{},{},{}) finished, costs={:.06f}s, bw={:.06f}GB/s.", id, + brief_, number_, total, costs, bw); + } + +private: + std::string brief_; + size_t number_{0}; + std::unordered_map> grouped_; +}; + +} // namespace UC + +#endif diff --git a/ucm/store/pcstore/cpy/pcstore.py.cc b/ucm/store/pcstore/cpy/pcstore.py.cc new file mode 100644 index 000000000..6ed45f9a0 --- /dev/null +++ b/ucm/store/pcstore/cpy/pcstore.py.cc @@ -0,0 +1,117 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "pcstore.h" +#include +#include + +namespace py = pybind11; + +namespace UC { + +class PcStorePy : public PcStore { +public: + void* CCStoreImpl() { return this; } + py::list AllocBatch(const py::list& blocks) + { + py::list results; + for (auto& block : blocks) { results.append(this->Alloc(block.cast())); } + return results; + } + py::list LookupBatch(const py::list& blocks) + { + py::list founds; + for (auto& block : blocks) { founds.append(this->Lookup(block.cast())); } + return founds; + } + void CommitBatch(const py::list& blocks, const bool success) + { + for (auto& block : blocks) { this->Commit(block.cast(), success); } + } + py::tuple CheckPy(const size_t task) + { + auto finish = false; + auto ret = this->Check(task, finish); + return py::make_tuple(ret, finish); + } + size_t LoadToDevice(const py::list& blockIds, const py::list& addresses) + { + return this->SubmitPy(blockIds, addresses, TransTask::Type::LOAD, "PC::S2D"); + } + size_t DumpFromDevice(const py::list& blockIds, const py::list& addresses) + { + return this->SubmitPy(blockIds, addresses, TransTask::Type::DUMP, "PC::D2S"); + } + +private: + size_t SubmitPy(const py::list& blockIds, const py::list& addresses, TransTask::Type&& type, + std::string&& brief) + { + TransTask task{std::move(type), std::move(brief)}; + auto blockId = blockIds.begin(); + auto address = addresses.begin(); + while ((blockId != blockIds.end()) && (address != addresses.end())) { + task.Append(blockId->cast(), address->cast()); + blockId++; + address++; + } + return this->Submit(std::move(task)); + } +}; + +} // namespace UC + +PYBIND11_MODULE(ucmpcstore, module) +{ + module.attr("project") = UCM_PROJECT_NAME; + module.attr("version") = UCM_PROJECT_VERSION; + module.attr("commit_id") = UCM_COMMIT_ID; + module.attr("build_type") = UCM_BUILD_TYPE; + auto store = py::class_(module, "PcStore"); + auto config = py::class_(store, "Config"); + config.def(py::init&, const size_t, const bool>(), + py::arg("storageBackends"), py::arg("kvcacheBlockSize"), py::arg("transferEnable")); + config.def_readwrite("storageBackends", &UC::PcStorePy::Config::storageBackends); + config.def_readwrite("kvcacheBlockSize", &UC::PcStorePy::Config::kvcacheBlockSize); + config.def_readwrite("transferEnable", &UC::PcStorePy::Config::transferEnable); + config.def_readwrite("transferIoDirect", &UC::PcStorePy::Config::transferIoDirect); + config.def_readwrite("transferLocalRankSize", &UC::PcStorePy::Config::transferLocalRankSize); + config.def_readwrite("transferDeviceId", &UC::PcStorePy::Config::transferDeviceId); + config.def_readwrite("transferStreamNumber", &UC::PcStorePy::Config::transferStreamNumber); + config.def_readwrite("transferIoSize", &UC::PcStorePy::Config::transferIoSize); + config.def_readwrite("transferBufferNumber", &UC::PcStorePy::Config::transferBufferNumber); + config.def_readwrite("transferTimeoutMs", &UC::PcStorePy::Config::transferTimeoutMs); + store.def(py::init<>()); + store.def("CCStoreImpl", &UC::PcStorePy::CCStoreImpl); + store.def("Setup", &UC::PcStorePy::Setup); + store.def("Alloc", py::overload_cast(&UC::PcStorePy::Alloc)); + store.def("AllocBatch", &UC::PcStorePy::AllocBatch); + store.def("Lookup", py::overload_cast(&UC::PcStorePy::Lookup)); + store.def("LookupBatch", &UC::PcStorePy::LookupBatch); + store.def("LoadToDevice", &UC::PcStorePy::LoadToDevice); + store.def("DumpFromDevice", &UC::PcStorePy::DumpFromDevice); + store.def("Wait", &UC::PcStorePy::Wait); + store.def("Check", &UC::PcStorePy::CheckPy); + store.def("Commit", py::overload_cast(&UC::PcStorePy::Commit)); + store.def("CommitBatch", &UC::PcStorePy::CommitBatch); +} diff --git a/ucm/store/pcstore/pcstore_connector.py b/ucm/store/pcstore/pcstore_connector.py new file mode 100644 index 000000000..e9e0d46dc --- /dev/null +++ b/ucm/store/pcstore/pcstore_connector.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import torch + +from ucm.store.pcstore import ucmpcstore +from ucm.store.ucmstore import Task, UcmKVStoreBase + + +@dataclass +class NfsTask(Task): + task_id: int + + +class UcmPcStore(UcmKVStoreBase): + def __init__(self, config: Dict): + super().__init__(config) + self.store = ucmpcstore.PcStore() + storage_backends = [ + path for path in config["storage_backends"].split(":") if path + ] + block_size = int(config["kv_block_size"]) + transfer_enable = True if config["role"] == "worker" else False + param = ucmpcstore.PcStore.Config(storage_backends, block_size, transfer_enable) + if transfer_enable: + param.transferDeviceId = config["device"] + param.transferIoSize = config["io_size"] + param.transferIoDirect = config.get("use_direct", False) + param.transferStreamNumber = config.get("stream_number", 8) + param.transferBufferNumber = config.get("buffer_number", 4096) + param.transferLocalRankSize = config.get("local_rank_size", 8) + ret = self.store.Setup(param) + if ret != 0: + msg = f"Failed to initialize ucmpcstore, errcode: {ret}." + raise RuntimeError(msg) + + def cc_store(self) -> int: + return self.store.CCStoreImpl() + + def create(self, block_ids: List[str]) -> List[int]: + return self.store.AllocBatch(block_ids) + + def lookup(self, block_ids: List[str]) -> List[bool]: + return self.store.LookupBatch(block_ids) + + def prefetch(self, block_ids: List[str]) -> None: + pass + + def load( + self, block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor] + ) -> Task: + dst_tensor_ptr = [t.data_ptr() for t in dst_tensor] + task_id = self.store.LoadToDevice(block_ids, dst_tensor_ptr) + return NfsTask(task_id=task_id) + + def dump( + self, block_ids: List[str], offset: List[int], src_tensor: List[torch.Tensor] + ) -> Task: + src_tensor_ptr = [t.data_ptr() for t in src_tensor] + task_id = self.store.DumpFromDevice(block_ids, src_tensor_ptr) + return NfsTask(task_id=task_id) + + def fetch_data( + self, + block_ids: List[str], + offset: List[int], + dst_addr: List[int], + size: List[int], + ) -> Task: + pass + + def dump_data( + self, + block_ids: List[str], + offset: List[int], + src_addr: List[int], + size: List[int], + ) -> Task: + pass + + def wait(self, task: Task) -> int: + return self.store.Wait(task.task_id) + + def commit(self, block_ids: List[str], is_success: bool = True) -> None: + self.store.CommitBatch(block_ids, is_success) + + def check(self, task: Task) -> Tuple[int, bool]: + return self.store.Check(task.task_id) diff --git a/ucm/store/task/CMakeLists.txt b/ucm/store/task/CMakeLists.txt new file mode 100644 index 000000000..543e78bc7 --- /dev/null +++ b/ucm/store/task/CMakeLists.txt @@ -0,0 +1,3 @@ +add_library(storetask STATIC task_manager.cc) +target_include_directories(storetask PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(storetask PUBLIC storeinfra) diff --git a/ucm/store/task/task_manager.cc b/ucm/store/task/task_manager.cc new file mode 100644 index 000000000..38c780ccb --- /dev/null +++ b/ucm/store/task/task_manager.cc @@ -0,0 +1,98 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "task_manager.h" + +namespace UC { + +Status TaskManager::Submit(Task&& task, size_t& taskId) noexcept +{ + taskId = task.Id(); + const auto taskStr = task.Str(); + TaskPtr taskPtr = nullptr; + WaiterPtr waiterPtr = nullptr; + try { + taskPtr = std::make_shared(std::move(task)); + waiterPtr = std::make_shared(0, task.StartTp()); + } catch (const std::exception& e) { + UC_ERROR("Failed({}) to submit task({}).", e.what(), taskStr); + return Status::OutOfMemory(); + } + std::lock_guard lg(mutex_); + const auto& [iter, success] = + tasks_.emplace(taskId, std::make_pair(std::move(taskPtr), std::move(waiterPtr))); + if (!success) { + UC_ERROR("Failed to submit task({}).", taskStr); + return Status::OutOfMemory(); + } + auto shards = iter->second.first->Split(queues_.size(), iter->second.second); + for (auto& shard : shards) { + auto& q = queues_[qIndex_++]; + if (qIndex_ == queues_.size()) { qIndex_ = 0; } + q->Push(shard); + } + return Status::OK(); +} + +Status TaskManager::Wait(const size_t taskId) noexcept +{ + TaskPtr task = nullptr; + WaiterPtr waiter = nullptr; + { + std::lock_guard lg(mutex_); + auto iter = tasks_.find(taskId); + if (iter == tasks_.end()) { + UC_ERROR("Not found task by id({}).", taskId); + return Status::NotFound(); + } + task = iter->second.first; + waiter = iter->second.second; + tasks_.erase(iter); + } + if (!waiter->Wait(timeoutMs_)) { + UC_ERROR("Task({}) timeout({}).", task->Str(), timeoutMs_); + failureSet_.Insert(taskId); + waiter->Wait(); + } + auto failure = failureSet_.Contains(taskId); + if (failure) { + failureSet_.Remove(taskId); + UC_ERROR("Task({}) failed.", task->Str()); + return Status::Error(); + } + return Status::OK(); +} + +Status TaskManager::Check(const size_t taskId, bool& finish) noexcept +{ + std::lock_guard lg(mutex_); + auto iter = tasks_.find(taskId); + if (iter == tasks_.end()) { + UC_ERROR("Not found task by id({}).", taskId); + return Status::NotFound(); + } + finish = iter->second.second->Finish(); + return Status::OK(); +} + +} // namespace UC diff --git a/ucm/store/task/task_manager.h b/ucm/store/task/task_manager.h new file mode 100644 index 000000000..513e3d65c --- /dev/null +++ b/ucm/store/task/task_manager.h @@ -0,0 +1,57 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TASK_MANAGER_H +#define UNIFIEDCACHE_TASK_MANAGER_H + +#include +#include "status/status.h" +#include "task_queue.h" +#include "task_set.h" + +namespace UC { + +class TaskManager { + using TaskPtr = std::shared_ptr; + using WaiterPtr = std::shared_ptr; + using TaskPair = std::pair; + using QueuePtr = std::shared_ptr; + +public: + virtual ~TaskManager() = default; + virtual Status Submit(Task&& task, size_t& taskId) noexcept; + virtual Status Wait(const size_t taskId) noexcept; + virtual Status Check(const size_t taskId, bool& finish) noexcept; + +protected: + std::mutex mutex_; + std::unordered_map tasks_; + size_t qIndex_{0}; + std::vector queues_; + size_t timeoutMs_{0}; + TaskSet failureSet_; +}; + +} // namespace UC + +#endif diff --git a/ucm/store/task/task_queue.h b/ucm/store/task/task_queue.h new file mode 100644 index 000000000..5e6ed467a --- /dev/null +++ b/ucm/store/task/task_queue.h @@ -0,0 +1,39 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TASK_QUEUE_H +#define UNIFIEDCACHE_TASK_QUEUE_H + +#include "task_shard.h" + +namespace UC { + +class TaskQueue { +public: + virtual ~TaskQueue() = default; + virtual void Push(std::list& shards) noexcept = 0; +}; + +} // namespace UC + +#endif diff --git a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_set.h b/ucm/store/task/task_set.h similarity index 91% rename from ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_set.h rename to ucm/store/task/task_set.h index 8c522c614..d2a79cb76 100644 --- a/ucm/store/nfsstore/cc/domain/tsf_task/tsf_task_set.h +++ b/ucm/store/task/task_set.h @@ -21,14 +21,14 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_TSF_TASK_SET_H -#define UNIFIEDCACHE_TSF_TASK_SET_H +#ifndef UNIFIEDCACHE_TASK_SET_H +#define UNIFIEDCACHE_TASK_SET_H #include "template/hashset.h" namespace UC { -class TsfTaskSet : public HashSet {}; +class TaskSet : public HashSet {}; } // namespace UC diff --git a/ucm/store/task/task_shard.h b/ucm/store/task/task_shard.h new file mode 100644 index 000000000..2f71738fa --- /dev/null +++ b/ucm/store/task/task_shard.h @@ -0,0 +1,149 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_TASK_SHARD_H +#define UNIFIEDCACHE_TASK_SHARD_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "logger/logger.h" +#include "task_waiter.h" + +namespace UC { + +class Task { +public: + enum class Type { DUMP, LOAD }; + enum class Location { HOST, DEVICE }; + struct Shard { + Type type; + Location location; + std::string block; + size_t offset; + uintptr_t address; + size_t length; + size_t owner; + std::shared_ptr buffer; + std::function done; + Shard(const Type type, const Location location, const std::string& block, + const size_t offset, const uintptr_t address, const size_t length, const size_t owner) + : type{type}, location{location}, block{block}, offset{offset}, address{address}, + length{length}, owner{owner}, buffer{nullptr}, done{nullptr} + { + } + Shard(const Shard&) = delete; + Shard& operator=(const Shard&) = delete; + Shard& operator=(Shard&& s) noexcept + { + if (this != &s) { + this->type = s.type; + this->location = s.location; + this->block = std::move(s.block); + this->offset = s.offset; + this->address = s.address; + this->length = s.length; + this->owner = s.owner; + this->buffer = std::move(s.buffer); + this->done = std::move(s.done); + } + return *this; + } + Shard(Shard&& s) noexcept { *this = std::move(s); } + }; + static constexpr auto invalid = std::numeric_limits::min(); + Task(Type&& type, Location&& location, std::string&& brief) + : id_{NextId()}, type_{type}, location_{location}, brief_{std::move(brief)}, number_{0}, + size_{0}, startTp_{NowTp()}, execTp_{0.f} + { + } + auto Id() const noexcept { return id_; } + auto StartTp() const noexcept { return startTp_; } + auto Str() const noexcept { return fmt::format("{},{},{},{}", id_, brief_, number_, size_); } + void Append(const std::string& block, const size_t offset, const uintptr_t address, + const size_t length) + { + shards_.emplace_back(type_, location_, block, offset, address, length, id_); + number_++; + size_ += length; + } + std::vector> Split(const size_t n, std::shared_ptr waiter) + { + auto num = std::min(n, number_); + std::vector> out(num); + waiter->Set(num); + auto base = number_ / num; + auto rem = number_ % num; + auto it = shards_.cbegin(); + for (size_t i = 0; i < num; i++) { + auto next = std::next(it, base + (i < rem ? 1 : 0)); + out[i].splice(out[i].end(), shards_, it, next); + out[i].back().done = [waiter, this] { + waiter->Done([this] { UC_DEBUG("Task({}) finished, {}.", Str(), Stat()); }); + }; + it = next; + } + this->execTp_ = NowTp(); + return out; + } + +private: + static size_t NextId() noexcept + { + static std::atomic id{invalid + 1}; + return id.fetch_add(1, std::memory_order_relaxed); + }; + static double NowTp() noexcept + { + auto now = std::chrono::steady_clock::now().time_since_epoch(); + return std::chrono::duration(now).count(); + } + std::string Stat() const noexcept + { + auto wait = execTp_ - startTp_; + auto exec = NowTp() - execTp_; + auto bw = size_ / exec / 1024 / 1024 / 1024; + return fmt::format("wait={:.06f}s, exec={:.06f}s, bw={:.06f}GB/s", wait, exec, bw); + } + +private: + size_t id_; + Type type_; + Location location_; + std::string brief_; + std::list shards_; + size_t number_; + size_t size_; + double startTp_; + double execTp_; +}; + +} // namespace UC + +#endif diff --git a/ucm/store/task/task_waiter.h b/ucm/store/task/task_waiter.h new file mode 100644 index 000000000..96358b8db --- /dev/null +++ b/ucm/store/task/task_waiter.h @@ -0,0 +1,68 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_ITASK_WAITER_H +#define UNIFIEDCACHE_ITASK_WAITER_H + +#include +#include "thread/latch.h" + +namespace UC { + +class TaskWaiter : public Latch { +protected: + double startTp_; + +public: + TaskWaiter(const size_t expected, const double startTp) : Latch{expected}, startTp_{startTp} {} + virtual ~TaskWaiter() = default; + virtual void Set(const size_t expected) noexcept { this->counter_.store(expected); } + using Latch::Wait; + virtual bool Wait(const size_t timeoutMs) noexcept + { + if (timeoutMs == 0) { + this->Wait(); + return true; + } + std::unique_lock ul(this->mutex_); + if (this->counter_ == 0) { return true; } + auto elapsed = std::chrono::duration(NowTp() - startTp_); + auto elapsedMs = std::chrono::duration_cast(elapsed); + auto timeMs = std::chrono::milliseconds(timeoutMs); + if (timeMs <= elapsedMs) { return false; } + auto remainMs = timeMs - elapsedMs; + return this->cv_.wait_for(ul, remainMs, [this] { return this->counter_ == 0; }); + } + virtual bool Finish() noexcept { return this->counter_ == 0; } + +private: + static double NowTp() noexcept + { + auto now = std::chrono::steady_clock::now().time_since_epoch(); + return std::chrono::duration(now).count(); + } +}; + +} // namespace UC + +#endif diff --git a/ucm/store/test/CMakeLists.txt b/ucm/store/test/CMakeLists.txt index 859c185c4..0c4974efd 100644 --- a/ucm/store/test/CMakeLists.txt +++ b/ucm/store/test/CMakeLists.txt @@ -4,7 +4,7 @@ if(BUILD_UNIT_TESTS) add_executable(ucmstore.test ${UCMSTORE_TEST_SOURCE_FILES}) target_include_directories(ucmstore.test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/case) target_link_libraries(ucmstore.test PRIVATE - dramstore nfsstore localstore storeinfra storedevice + nfsstore localstore storeinfra storedevice gtest_main gtest mockcpp ) gtest_discover_tests(ucmstore.test) diff --git a/ucm/store/test/case/nfsstore/hotness_test.cc b/ucm/store/test/case/nfsstore/hotness_test.cc new file mode 100644 index 000000000..21cc83cdd --- /dev/null +++ b/ucm/store/test/case/nfsstore/hotness_test.cc @@ -0,0 +1,54 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ + +#include +#include +#include "hotness/hotness_set.h" +#include "hotness/hotness_timer.h" +#include "cmn/path_base.h" +#include "file/file.h" +#include "space/space_manager.h" + +class UCHotnessTest : public UC::PathBase {}; + +TEST_F(UCHotnessTest, UpdateHotness) +{ + UC::SpaceManager mgr; + ASSERT_EQ(mgr.Setup({this->Path()}, 1024 * 1024, false), UC::Status::OK()); + + std::string block1 = "a1b2c3d4e5f6789012345678901234ab"; + ASSERT_EQ(mgr.NewBlock(block1), UC::Status::OK()); + ASSERT_EQ(mgr.CommitBlock(block1), UC::Status::OK()); + + UC::HotnessSet hotness_set; + hotness_set.Insert(block1); + auto space_layout = mgr.GetSpaceLayout(); + auto path = space_layout->DataFilePath(block1, false); + auto currentTime = std::filesystem::last_write_time(path); + std::filesystem::last_write_time(path, currentTime - std::chrono::seconds(2)); + auto lastTime = std::filesystem::last_write_time(path); + hotness_set.UpdateHotness(space_layout); + auto newTime = std::filesystem::last_write_time(path); + ASSERT_GT(newTime, lastTime); +} \ No newline at end of file diff --git a/ucm/store/test/case/nfsstore/space_manager_test.cc b/ucm/store/test/case/nfsstore/space_manager_test.cc index cd68ac6f9..958f6464a 100644 --- a/ucm/store/test/case/nfsstore/space_manager_test.cc +++ b/ucm/store/test/case/nfsstore/space_manager_test.cc @@ -21,15 +21,17 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ +#include #include "cmn/path_base.h" #include "space/space_manager.h" +#include "file/file.h" class UCSpaceManagerTest : public UC::PathBase {}; TEST_F(UCSpaceManagerTest, NewBlockTwice) { UC::SpaceManager spaceMgr; - ASSERT_EQ(spaceMgr.Setup({this->Path()}, 1024 * 1024), UC::Status::OK()); + ASSERT_EQ(spaceMgr.Setup({this->Path()}, 1024 * 1024, false), UC::Status::OK()); const std::string block1 = "block1"; ASSERT_FALSE(spaceMgr.LookupBlock(block1)); ASSERT_EQ(spaceMgr.NewBlock(block1), UC::Status::OK()); @@ -39,3 +41,85 @@ TEST_F(UCSpaceManagerTest, NewBlockTwice) ASSERT_TRUE(spaceMgr.LookupBlock(block1)); ASSERT_EQ(spaceMgr.NewBlock(block1), UC::Status::DuplicateKey()); } + +TEST_F(UCSpaceManagerTest, NewBlockTwiceWithTempDir) +{ + UC::SpaceManager spaceMgr; + ASSERT_EQ(spaceMgr.Setup({this->Path()}, 1024 * 1024, true), UC::Status::OK()); + const std::string block1 = "block1"; + ASSERT_FALSE(spaceMgr.LookupBlock(block1)); + ASSERT_EQ(spaceMgr.NewBlock(block1), UC::Status::OK()); + ASSERT_FALSE(spaceMgr.LookupBlock(block1)); + ASSERT_EQ(spaceMgr.NewBlock(block1), UC::Status::DuplicateKey()); + ASSERT_EQ(spaceMgr.CommitBlock(block1), UC::Status::OK()); + ASSERT_TRUE(spaceMgr.LookupBlock(block1)); + ASSERT_EQ(spaceMgr.NewBlock(block1), UC::Status::DuplicateKey()); +} + +TEST_F(UCSpaceManagerTest, CreateBlockWhenNoSpace) +{ + UC::SpaceManager spaceMgr; + size_t blockSize = 1024 * 1024; + size_t capacity = blockSize; + ASSERT_EQ(spaceMgr.Setup({this->Path()}, blockSize, false, capacity), UC::Status::OK()); + ASSERT_EQ(spaceMgr.NewBlock("block3"), UC::Status::OK()); + ASSERT_EQ(spaceMgr.NewBlock("block4"), UC::Status::NoSpace()); +} + +TEST_F(UCSpaceManagerTest, IterAllBlockFile) +{ + constexpr size_t blockSize = 1024 * 1024; + constexpr size_t capacity = blockSize * 1024; + UC::SpaceManager spaceMgr; + ASSERT_EQ(spaceMgr.Setup({this->Path()}, blockSize, false, capacity), UC::Status::OK()); + const std::string block1 = "a1b2c3d4e5f6789012345678901234ab"; + const std::string block2 = "a2b2c3d4e5f6789012345678901234ab"; + const std::string block3 = "a3b2c3d4e5f6789012345678901234ab"; + ASSERT_EQ(spaceMgr.NewBlock(block1), UC::Status::OK()); + ASSERT_EQ(spaceMgr.NewBlock(block2), UC::Status::OK()); + ASSERT_EQ(spaceMgr.NewBlock(block3), UC::Status::OK()); + auto layout = spaceMgr.GetSpaceLayout(); + auto iter = layout->CreateFilePathIterator(); + size_t count = 0; + while (!layout->NextDataFilePath(iter).empty()) { count++; } + ASSERT_EQ(count, 0); + ASSERT_EQ(spaceMgr.CommitBlock(block1), UC::Status::OK()); + ASSERT_EQ(spaceMgr.CommitBlock(block2), UC::Status::OK()); + ASSERT_EQ(spaceMgr.CommitBlock(block3), UC::Status::OK()); + iter = layout->CreateFilePathIterator(); + count = 0; + while (!layout->NextDataFilePath(iter).empty()) { count++; } + ASSERT_EQ(count, 3); +} + +TEST_F(UCSpaceManagerTest, NewBlockReuseIfActiveAccessedLongAgo) +{ + UC::SpaceManager spaceMgr; + constexpr size_t blockSize = 1024 * 1024; + constexpr size_t capacity = blockSize * 1024; + ASSERT_EQ(spaceMgr.Setup({this->Path()}, blockSize, false, capacity), UC::Status::OK()); + const auto* layout = spaceMgr.GetSpaceLayout(); + ASSERT_NE(layout, nullptr); + + const std::string block1 = "a1b2c3d4e5f6789012345678901234ab"; + auto parent = UC::File::Make(layout->DataFileParent(block1, /*activated=*/true)); + ASSERT_NE(parent, nullptr); + ASSERT_EQ(parent->MkDir(), UC::Status::OK()); + + const auto activePath = layout->DataFilePath(block1, /*activated=*/true); + auto activeFile = UC::File::Make(activePath); + ASSERT_NE(activeFile, nullptr); + ASSERT_EQ(activeFile->Open(UC::IFile::OpenFlag::CREATE | UC::IFile::OpenFlag::READ_WRITE), UC::Status::OK()); + activeFile->Close(); + + // NewBlock should return DuplicateKey because the file is recent + ASSERT_EQ(spaceMgr.NewBlock(block1), UC::Status::DuplicateKey()); + + // Set atime to 10 minutes ago so it is not considered recent + struct utimbuf newTime; + auto tp = time(nullptr) - 600; + newTime.modtime = tp; + newTime.actime = tp; + utime(activePath.c_str(), &newTime); + ASSERT_EQ(spaceMgr.NewBlock(block1), UC::Status::OK()); +} \ No newline at end of file diff --git a/ucm/store/test/case/nfsstore/space_property_test.cc b/ucm/store/test/case/nfsstore/space_property_test.cc new file mode 100644 index 000000000..5b5beb619 --- /dev/null +++ b/ucm/store/test/case/nfsstore/space_property_test.cc @@ -0,0 +1,61 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "cmn/path_base.h" +#include "space/space_property.h" +#include "space/space_layout.h" +#include "space/space_shard_layout.h" +#include "space/space_manager.h" + +class UCSpacePropertyTest : public UC::PathBase {}; + +/* +* check the persistence of property +*/ +TEST_F(UCSpacePropertyTest, CapacityPersistence) +{ + size_t blocksize = 1024 * 1024; + UC::SpaceManager spaceMgr; + ASSERT_EQ(spaceMgr.Setup({this->Path()}, blocksize, false, blocksize * 5), UC::Status::OK()); + const UC::SpaceLayout* layout = spaceMgr.GetSpaceLayout(); + + const std::string path = layout->ClusterPropertyFilePath(); + + UC::SpaceProperty spaceProperty; + ASSERT_EQ(spaceProperty.Setup(path), UC::Status::OK()); + ASSERT_EQ(spaceProperty.GetCapacity(), 0); + + spaceProperty.IncreaseCapacity(blocksize * 2); + ASSERT_EQ(spaceProperty.GetCapacity(), blocksize*2); + + UC::SpaceProperty spaceProperty2; + ASSERT_EQ(spaceProperty2.Setup(path), UC::Status::OK()); + ASSERT_EQ(spaceProperty2.GetCapacity(), blocksize*2); + + spaceProperty2.DecreaseCapacity(blocksize); + ASSERT_EQ(spaceProperty2.GetCapacity(), blocksize); + + UC::SpaceProperty spaceProperty3; + ASSERT_EQ(spaceProperty3.Setup(path), UC::Status::OK()); + ASSERT_EQ(spaceProperty3.GetCapacity(), blocksize); + } \ No newline at end of file diff --git a/ucm/store/test/case/nfsstore/space_recycle_test.cc b/ucm/store/test/case/nfsstore/space_recycle_test.cc new file mode 100644 index 000000000..a9fb862f7 --- /dev/null +++ b/ucm/store/test/case/nfsstore/space_recycle_test.cc @@ -0,0 +1,124 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ + +#include +#include "cmn/path_base.h" +#include "file/file.h" +#include "space/space_recycle.h" +#include "space/space_manager.h" +#include "thread/latch.h" + +namespace UC { + +void DoRecycle(const SpaceLayout* layout, const uint32_t recycleNum, + SpaceRecycle::RecycleOneBlockDone done); +} + +class UCSpaceRecycleTest : public UC::PathBase { +protected: + using OpenFlag = UC::IFile::OpenFlag; + using AccessMode = UC::IFile::AccessMode; + + void NewBlock(const UC::SpaceLayout* layout, const std::string& id) + { + std::string parent = layout->DataFileParent(id, false); + UC::File::MkDir(parent); + std::string path = layout->DataFilePath(id, false); + auto f = UC::File::Make(path); + f->Open(OpenFlag::CREATE | OpenFlag::READ_WRITE); + } + + bool ExistBlock(const UC::SpaceLayout* layout, const std::string& id) + { + std::string path = layout->DataFilePath(id, false); + return UC::File::Access(path, AccessMode::EXIST).Success(); + } + + void UpdateBlock(const UC::SpaceLayout* layout, const std::string& id) + { + struct utimbuf newTime; + auto tp = time(nullptr) + 3600; + newTime.modtime = tp; + newTime.actime = tp; + std::string path = layout->DataFilePath(id, false); + utime(path.c_str(), &newTime); + } +}; + +TEST_F(UCSpaceRecycleTest, TriggerRecycle) +{ + size_t blocksize = 1024 * 1024; + UC::SpaceManager spaceMgr; + ASSERT_EQ(spaceMgr.Setup({this->Path()}, blocksize, false, blocksize * 5), UC::Status::OK()); + const UC::SpaceLayout* layout = spaceMgr.GetSpaceLayout(); + std::string block1 = "a1b2c3d4e5f6789012345678901234ab"; + NewBlock(layout, block1); + ASSERT_TRUE(ExistBlock(layout, block1)); + + std::string block2 = "a2b2c3d4e5f6789012345678901234ab"; + NewBlock(layout, block2); + ASSERT_TRUE(ExistBlock(layout, block2)); + + UpdateBlock(layout, block1); + UC::SpaceRecycle recycle; + UC::Latch waiter{1}; + + ASSERT_TRUE(recycle.Setup(layout, 10, [&waiter] { waiter.Done([]{}); }).Success()); + recycle.Trigger(); + waiter.Wait(); + EXPECT_TRUE(ExistBlock(layout, block1)); + EXPECT_FALSE(ExistBlock(layout, block2)); +} + +TEST_F(UCSpaceRecycleTest, DoRecycle) +{ + size_t blocksize = 1024 * 1024; + UC::SpaceManager spaceMgr; + ASSERT_EQ(spaceMgr.Setup({this->Path()}, blocksize, false, blocksize * 5), UC::Status::OK()); + const UC::SpaceLayout* layout = spaceMgr.GetSpaceLayout(); + std::string recycleBlocks[] = { + "a1b2c3d4e5f6789012345678901234ab", + "a2b2c3d4e5f6789012345678901234ab", + "a3b2c3d4e5f6789012345678901234ab" + }; + std::string remainBlocks[] = { + "b1b2c3d4e5f6789012345678901234ab", + "b2b2c3d4e5f6789012345678901234ab", + "b3b2c3d4e5f6789012345678901234ab" + }; + for (auto &id: remainBlocks) + { + NewBlock(layout, id); + ASSERT_TRUE(ExistBlock(layout, id)); + } + for (auto &id: recycleBlocks) + { + NewBlock(layout, id); + ASSERT_TRUE(ExistBlock(layout, id)); + } + for (auto &id: remainBlocks) { UpdateBlock(layout, id); } + UC::DoRecycle(layout, 3, nullptr); + for (auto &id: remainBlocks) { EXPECT_TRUE(ExistBlock(layout, id)); } + for (auto &id: recycleBlocks) { EXPECT_FALSE(ExistBlock(layout, id)); } +} \ No newline at end of file diff --git a/ucm/store/test/e2e/dramstore_embed_and_fetch.py b/ucm/store/test/e2e/dramstore_embed_and_fetch.py new file mode 100644 index 000000000..4f9acda19 --- /dev/null +++ b/ucm/store/test/e2e/dramstore_embed_and_fetch.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +import os +import secrets +from typing import List + +import torch + +from ucm.store.dramstore.dramstore_connector import UcmDramStore +from ucm.store.ucmstore import UcmKVStoreBase + + +def setup_store( + capacity, block_size, stream_number, device_id, timeout_ms +) -> UcmKVStoreBase: + config = {} + config["capacity"] = capacity + config["kv_block_size"] = block_size + config["stream_number"] = stream_number + config["device_id"] = device_id + config["timeout_ms"] = timeout_ms + return UcmDramStore(config) + + +def make_buffers( + block_number, device_id, batch_size, block_dim, block_len, block_layer +): + hashes = [secrets.token_hex(16) for _ in range(block_number)] + tensors = [ + [ + torch.rand( + [block_dim, block_len], + dtype=torch.bfloat16, + device="cuda:{}".format(device_id), + ) + for _ in range(block_layer) + ] + for _ in range(batch_size) + ] + return hashes, tensors + + +def embed(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]): + results = store.create(hashes) + assert sum(results) == 0 + block_ids = [] + offsets = [] + layers = [] + for hash_id, block in zip(hashes, tensors): + offset = 0 + for layer in block: + block_ids.append(hash_id) + offsets.append(offset) + layers.append(layer) + offset += layer.untyped_storage().size() + task = store.dump(block_ids, offsets, layers) + assert task.task_id > 0 + ret = store.wait(task) + assert ret == 0 + store.commit(hashes, True) + + +def fetch(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]): + founds = store.lookup(hashes) + for found in founds: + assert found + block_ids = [] + offsets = [] + layers = [] + for hash_id, block in zip(hashes, tensors): + offset = 0 + for layer in block: + block_ids.append(hash_id) + offsets.append(offset) + layers.append(layer) + offset += layer.untyped_storage().size() + task = store.load(block_ids, offsets, layers) + assert task.task_id > 0 + ret = store.wait(task) + assert ret == 0 + + +def main(): + block_number = 4096 + device_id = 1 + block_dim = 576 + block_len = 128 + block_elem_size = 2 + block_layer = 61 + io_size = block_dim * block_len * block_elem_size + block_size = io_size * block_layer + batch_size = 256 + stream_number = 10 + timeout_ms = 1000000 + capacity = block_number * block_size * 2 + batch_number = 64 + + store = setup_store(capacity, block_size, stream_number, device_id, timeout_ms) + hashes, tensors = make_buffers( + block_number, device_id, batch_size, block_dim, block_len, block_layer + ) + total_batches = (block_number + batch_size - 1) // batch_size + + for batch in range(total_batches): + start = batch_size * batch + end = min(start + batch_size, block_number) + embed(store, hashes[start:end], tensors) + + _, new_tensors = make_buffers( + block_number, device_id, batch_size, block_dim, block_len, block_layer + ) + for batch in range(total_batches): + start = batch_size * batch + end = start + batch_size + fetch(store, hashes[start:end], new_tensors) + + +if __name__ == "__main__": + os.environ["UC_LOGGER_LEVEL"] = "debug" + main() diff --git a/ucm/store/test/e2e/nfsstore_embed.py b/ucm/store/test/e2e/nfsstore_embed.py index 4f295811f..0b6e2fc57 100644 --- a/ucm/store/test/e2e/nfsstore_embed.py +++ b/ucm/store/test/e2e/nfsstore_embed.py @@ -27,7 +27,6 @@ from typing import List import torch -import torch_npu from ucm.store.nfsstore.nfsstore_connector import UcmNfsStore from ucm.store.ucmstore import UcmKVStoreBase @@ -52,7 +51,7 @@ def make_buffers( torch.rand( [block_dim, block_len], dtype=torch.bfloat16, - device="npu:{}".format(device_id), + device="cuda:{}".format(device_id), ) for _ in range(block_layer) ] @@ -81,6 +80,39 @@ def embed(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Ten store.commit(hashes, True) +def fetch(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]): + founds = store.lookup(hashes) + for found in founds: + assert found + block_ids = [] + offsets = [] + layers = [] + for hash_id, block in zip(hashes, tensors): + offset = 0 + for layer in block: + block_ids.append(hash_id) + offsets.append(offset) + layers.append(layer) + offset += layer.untyped_storage().size() + task = store.load(block_ids, offsets, layers) + assert task.task_id > 0 + ret = store.wait(task) + assert ret == 0 + + +def cmp_and_print_diff(a, b, rtol=0.0, atol=0.0): + for r, (row_a, row_b) in enumerate(zip(a, b)): + for c, (ta, tb) in enumerate(zip(row_a, row_b)): + if not torch.allclose(ta, tb, rtol=rtol, atol=atol): + mask = ~torch.isclose(ta, tb, rtol=rtol, atol=atol) + diff_a = ta[mask].cpu() + diff_b = tb[mask].cpu() + print(f"DIFF at [{r}][{c}] total {mask.sum().item()} element(s)") + print(" a val:", diff_a.flatten()) + print(" b val:", diff_b.flatten()) + assert False + + def store_all_hashes(hashes): kvcache_block_hashes_file = "kvcache_block_hashes.txt" current_directory = os.path.dirname(__file__) @@ -109,7 +141,10 @@ def main(): for batch in range(total_batches): start = batch_size * batch end = min(start + batch_size, block_number) + tensors2 = [[torch.empty_like(t) for t in row] for row in tensors] embed(store, hashes[start:end], tensors) + fetch(store, hashes[start:end], tensors2) + cmp_and_print_diff(tensors, tensors2) store_all_hashes(hashes) diff --git a/ucm/store/test/e2e/nfsstore_embed_fetch.py b/ucm/store/test/e2e/nfsstore_embed_fetch.py new file mode 100644 index 000000000..1132afa50 --- /dev/null +++ b/ucm/store/test/e2e/nfsstore_embed_fetch.py @@ -0,0 +1,373 @@ +# -*- coding: utf-8 -*- +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +import csv +import math +import os +import secrets +import time +from typing import Dict, List, Tuple + +import torch + +from ucm.store.nfsstore.nfsstore_connector import UcmNfsStore +from ucm.store.ucmstore import UcmKVStoreBase + + +def setup( + storage_backends, + block_size, + device_id, + io_size, + transferStreamNumber, + transferIoDirect, +) -> UcmKVStoreBase: + config = { + "storage_backends": storage_backends, + "kv_block_size": block_size, + "role": "worker", + "device": device_id, + "io_size": io_size, + "transferStreamNumber": transferStreamNumber, + "transferIoDirect": transferIoDirect, + } + return UcmNfsStore(config) + + +def make_aligned_tensor(shape, dtype, device, alignment=4096): + numl = math.prod(shape) + dtype_size = torch.tensor(1, dtype=dtype).element_size() + total_byters = numl * dtype_size + + padded_bytes = total_byters + alignment + storage = torch.ByteTensor(padded_bytes).to(device) + + ptr = storage.data_ptr() + offset = ptr % alignment + if offset != 0: + aligned_ptr = ptr + (alignment - offset) + else: + aligned_ptr = ptr + + aligned_storage = storage[(aligned_ptr - ptr) :].view(dtype) + tensor = aligned_storage[:numl].view(shape) + tensor.storage_ref = storage + return tensor + + +def make_buffers( + block_number, device_id, batch_size, head_dim, block_len, block_layer, num_head, kv +): + hashes = [secrets.token_hex(16) for _ in range(block_number)] + kv_caches = {} + for i in range(block_layer): + kv_caches[i] = make_aligned_tensor( + [kv, block_number, block_len, num_head, head_dim], + dtype=torch.float16, + device=f"cuda:{device_id}", + ) + return hashes, kv_caches + + +def store_all_hashes(hashes: List[str]): + file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt") + with open(file_path, "w", encoding="utf-8") as f: + for h in hashes: + f.write(h + "\n") + + +def load_hashes_from_file() -> List[str]: + file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt") + if not os.path.exists(file_path): + return [] + with open(file_path, "r", encoding="utf-8") as f: + return [line.strip() for line in f.readlines()] + + +def embed( + store: UcmKVStoreBase, + hashes: List[str], + kvcaches: Dict[int, torch.Tensor], + mla: bool, +): + start_time = time.perf_counter() + + total_block_ids, total_offsets, total_tensors = [], [], [] + total_size = 0 + + for i, hash_val in enumerate(hashes): + offset = 0 + for layer_id, kv_layer in kvcaches.items(): + k_tensor = kv_layer[0][i] # kv=1 + total_tensors.append(k_tensor) + total_block_ids.append(hash_val) + total_offsets.append(offset) + sz = k_tensor.numel() * k_tensor.element_size() + offset += sz + total_size += sz + + if not mla: + v_tensor = kv_layer[1][i] + total_tensors.append(v_tensor) + total_block_ids.append(hash_val) + total_offsets.append(offset) + sz = v_tensor.numel() * v_tensor.element_size() + offset += sz + total_size += sz + + task = store.dump(total_block_ids, total_offsets, total_tensors) + store.wait(task) + + elapsed_time = time.perf_counter() - start_time + throughput_gbps = (total_size / (1024**3)) / elapsed_time if elapsed_time > 0 else 0 + + print( + f"WRITE: Data Size={(total_size / (1024 ** 3)):.4f} GB, Time={elapsed_time:.4f} s, " + f"Speed={throughput_gbps:.4f} GB/s" + ) + + return total_size, elapsed_time, throughput_gbps + + +def fetch( + store: UcmKVStoreBase, + hashes: List[str], + kvcaches: Dict[int, torch.Tensor], + mla: bool, +): + start_time = time.perf_counter() + + founds = store.lookup(hashes) + for f in founds: + assert f, "Cache block miss detected" + + block_ids, offsets, tensors = [], [], [] + total_size = 0 + + for i, hash_val in enumerate(hashes): + offset = 0 + for layer_id, kv_layer in kvcaches.items(): + k_tensor = kv_layer[0][i] # kv=1 + block_ids.append(hash_val) + offsets.append(offset) + tensors.append(k_tensor) + sz = k_tensor.numel() * k_tensor.element_size() + offset += sz + total_size += sz + + if not mla: + v_tensor = kv_layer[1][i] + block_ids.append(hash_val) + offsets.append(offset) + tensors.append(v_tensor) + sz = v_tensor.numel() * v_tensor.element_size() + offset += sz + total_size += sz + + task = store.load(block_ids, offsets, tensors) + ret = store.wait(task) + assert ret == 0, "Load operation failed" + + elapsed_time = time.perf_counter() - start_time + throughput_gbps = (total_size / (1024**3)) / elapsed_time if elapsed_time > 0 else 0 + + print( + f"READ: Data Size={(total_size / (1024 ** 3)):.4f} GB, Time={elapsed_time:.4f} s, " + f"Speed={throughput_gbps:.4f} GB/s" + ) + + return total_size, elapsed_time, throughput_gbps + + +def run( + storage_backends: str, + device_id: int, + repeat: int, + num_head: int, + block_len: int, + transferStreamNumber: int, + num_tokens: int, + block_layer: int, + head_size: int, + block_elem_size: int, + kv: int, + mla: bool, + transferIoDirect: bool, + operation_mode: str = "both", # "write_only", "read_only", or "both" +) -> Tuple[float, float, float, float, float, float]: + """ + Run a single test with given parameters and return performance metrics. + + Returns: + Tuple of (avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size) + """ + + block_dim = head_size * num_head + io_size = block_dim * block_len * block_elem_size + block_size = io_size * block_layer + batch_size = int(num_tokens / block_len) + real_blocks = batch_size + 10 + + w_bw_list, r_bw_list = [], [] + w_time_list, r_time_list = [], [] + w_size_sum, r_size_sum = 0.0, 0.0 + + store = setup( + storage_backends, + block_size, + device_id, + io_size, + transferStreamNumber, + transferIoDirect, + ) + + for r in range(repeat): + print(f"\n--- Round {r+1} ---") + + if operation_mode in ["write_only", "both"]: + hashes, kvcaches = make_buffers( + real_blocks, + device_id, + batch_size, + head_size, + block_len, + block_layer, + num_head, + kv, + ) + + results = store.create(hashes[:batch_size]) + assert sum(results) == 0, "Create operation failed" + + w_size, w_time, w_bw = embed( + store, + hashes[:batch_size], + kvcaches, + mla, + ) + store.commit(hashes[:batch_size], True) + + if r == 0: + store_all_hashes(hashes[:batch_size]) + + if r != 0: + w_bw_list.append(w_bw) + w_time_list.append(w_time) + w_size_sum += w_size + + if operation_mode == "write_only": + del kvcaches, hashes + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif hasattr(torch, "npu") and torch.npu.is_available(): + torch.npu.empty_cache() + + if operation_mode in ["read_only", "both"]: + if operation_mode == "read_only": + saved_hashes = load_hashes_from_file() + if not saved_hashes: + raise RuntimeError("No saved hashes found for read operation") + + _, kvcaches = make_buffers( + real_blocks, + device_id, + batch_size, + head_size, + block_len, + block_layer, + num_head, + kv, + ) + + r_size, r_time, r_bw = fetch( + store, + saved_hashes[:batch_size], + kvcaches, + mla, + ) + else: + r_size, r_time, r_bw = fetch( + store, + hashes[:batch_size], + kvcaches, + mla, + ) + + if r != 0: + r_bw_list.append(r_bw) + r_time_list.append(r_time) + r_size_sum += r_size + + if operation_mode == "read_only": + del kvcaches + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif hasattr(torch, "npu") and torch.npu.is_available(): + torch.npu.empty_cache() + else: + del kvcaches, hashes + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif hasattr(torch, "npu") and torch.npu.is_available(): + torch.npu.empty_cache() + + del store + avg_w_bw = sum(w_bw_list) / len(w_bw_list) if w_bw_list else 0.0 + avg_r_bw = sum(r_bw_list) / len(r_bw_list) if r_bw_list else 0.0 + avg_w_time = sum(w_time_list) / len(w_time_list) if w_time_list else 0.0 + avg_r_time = sum(r_time_list) / len(r_time_list) if r_time_list else 0.0 + avg_w_size = w_size_sum / (1024**3) / len(w_time_list) if w_time_list else 0.0 + avg_r_size = r_size_sum / (1024**3) / len(r_time_list) if r_time_list else 0.0 + + return avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size + + +if __name__ == "__main__": + os.environ["UC_LOGGER_LEVEL"] = "debug" + + try: + result = run( + storage_backends=".", + device_id=1, + repeat=1, + num_head=1, + block_len=128, + transferStreamNumber=32, + num_tokens=4096, + block_layer=61, + head_size=576, + block_elem_size=2, + kv=1, + mla=True, + transferIoDirect=False, + operation_mode="both", + ) + + avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size = result + + except Exception as e: + print(f"Error: {e}") + import traceback + + traceback.print_exc() diff --git a/ucm/store/test/e2e/nfsstore_embed_fetch_run.py b/ucm/store/test/e2e/nfsstore_embed_fetch_run.py new file mode 100644 index 000000000..04415067f --- /dev/null +++ b/ucm/store/test/e2e/nfsstore_embed_fetch_run.py @@ -0,0 +1,215 @@ +# -*- coding: utf-8 -*- +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +import csv +import multiprocessing +import os +from typing import List + +from nfsstore_embed_fetch import run + + +def run_wrapper(result_queue, *args): + try: + result = run(*args) + result_queue.put(("success", result)) + except Exception as e: + result_queue.put(("error", str(e))) + + +def get_user_input(prompt, default=None): + if default is not None: + user_input = input(f"{prompt} (default: {default}): ").strip() + return user_input if user_input else default + else: + return input(f"{prompt}: ").strip() + + +def main(): + + try: + multiprocessing.set_start_method("spawn", force=True) + except RuntimeError: + pass + + storage_backends = "." + device_id = 1 + repeat = 3 # This parameter must be greater than 1; the results from the first round of testing are not included in the bandwidth calculation. + num_tokens_list = [2048, 4096, 8192, 16384, 32768] + transferStreamNumbers = [32, 64, 128] + + print("1. Model Selection:") + print(" 1 - QwQ-32B") + print(" 2 - deepseek-v3") + model_choice = get_user_input("Please select model", "1") + mla = True if model_choice == "2" else False + + print("\n2. GDS Transfer:") + print(" 1 - Disable IoDirect (default)") + print(" 2 - Enable IoDirect") + transferIoDirect = get_user_input("Please select Direct IO mode", "1") + transferIoDirect = False if transferIoDirect == "1" else True + + print("\n3. Operation Mode:") + print(" 1 - Read/Write Test (default)") + print(" 2 - Write Only Test") + print(" 3 - Read Only Test") + op_choice = get_user_input("Please select operation mode", "1") + operation_mode_map = {"1": "both", "2": "write_only", "3": "read_only"} + operation_mode = operation_mode_map.get(op_choice, "both") + + if mla: + block_lens = [64, 128] + block_layer = 61 + head_size = 576 + block_elem_size = 2 + kv = 1 + model_name = "deepseek-v3" + num_head_list = [1] + else: + block_lens = [128, 256] + block_layer = 64 + head_size = 128 + block_elem_size = 2 + kv = 2 + model_name = "QwQ-32B" + num_head_list = [1, 2, 4, 8] + + SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + csv_file = os.path.join(SCRIPT_DIR, "embed_fetch_result.csv") + need_header = not os.path.exists(csv_file) + + os.makedirs(SCRIPT_DIR, exist_ok=True) + + with open(csv_file, "a", newline="", encoding="utf-8") as csv_fp: + writer = csv.writer(csv_fp) + + if need_header: + writer.writerow( + [ + "Model", + "Sequence Length", + "Batch Size", + "Layers", + "Element Size", + "KV", + "Num Head", + "Block Size", + "Stream Number", + "IO Count", + "IO Size(B)", + "Total Size(GB)", + "Write Avg Time(s)", + "Write Avg Bandwidth(GB/s)", + "Read Avg Time(s)", + "Read Avg Bandwidth(GB/s)", + ] + ) + + for num_head in num_head_list: + for block_len in block_lens: + for transferStreamNumber in transferStreamNumbers: + block_dim = head_size * num_head + io_size = block_dim * block_len * block_elem_size + + for num_tokens in num_tokens_list: + sep = "=" * 60 + print( + f"\n{sep}\n= num_head={num_head} | num_tokens={num_tokens:>6} | Repeat {repeat} times =\n{sep}\n" + ) + + batch_size = int(num_tokens / block_len) + io_num = int(num_tokens / block_len * block_layer) + + result_queue = multiprocessing.Queue() + + process = multiprocessing.Process( + target=run_wrapper, + args=( + result_queue, + storage_backends, + device_id, + repeat, + num_head, + block_len, + transferStreamNumber, + num_tokens, + block_layer, + head_size, + block_elem_size, + kv, + mla, + transferIoDirect, + operation_mode, + ), + ) + + process.start() + process.join() + + status, result = result_queue.get() + if status == "error": + raise Exception(f"Error in subprocess: {result}") + + ( + avg_w_size, + avg_w_time, + avg_w_bw, + avg_r_time, + avg_r_bw, + avg_r_size, + ) = result + + writer.writerow( + [ + model_name, + num_tokens, + batch_size, + block_layer, + block_elem_size, + kv, + num_head, + block_len, + transferStreamNumber, + io_num, + io_size, + f"{avg_w_size:.4f}", + f"{avg_w_time:.4f}", + f"{avg_w_bw:.4f}", + f"{avg_r_time:.4f}", + f"{avg_r_bw:.4f}", + ] + ) + csv_fp.flush() + + print( + f"WRITE COMPLETE for num_head={num_head}, num_tokens={num_tokens}" + ) + + print("\n" + "=" * 60 + "\n= All combinations tested =\n" + "=" * 60 + "\n") + + +if __name__ == "__main__": + os.environ["UC_LOGGER_LEVEL"] = "debug" + main() diff --git a/ucm/store/test/e2e/nfsstore_fetch.py b/ucm/store/test/e2e/nfsstore_fetch.py index 11100ce7d..e18ddcb97 100644 --- a/ucm/store/test/e2e/nfsstore_fetch.py +++ b/ucm/store/test/e2e/nfsstore_fetch.py @@ -27,7 +27,6 @@ from typing import List import torch -import torch_npu from ucm.store.nfsstore.nfsstore_connector import UcmNfsStore from ucm.store.ucmstore import UcmKVStoreBase @@ -62,7 +61,7 @@ def make_buffers(device_id, batch_size, block_dim, block_len, block_layer): torch.rand( [block_dim, block_len], dtype=torch.bfloat16, - device="npu:{}".format(device_id), + device="cuda:{}".format(device_id), ) for _ in range(block_layer) ] diff --git a/ucm/store/test/e2e/pcstore_embed.py b/ucm/store/test/e2e/pcstore_embed.py new file mode 100644 index 000000000..da3e9de86 --- /dev/null +++ b/ucm/store/test/e2e/pcstore_embed.py @@ -0,0 +1,155 @@ +# -*- coding: utf-8 -*- +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +import os +import secrets +import time +from typing import List + +import torch + +from ucm.store.pcstore.pcstore_connector import UcmPcStore +from ucm.store.ucmstore import UcmKVStoreBase + + +def setup_store(storage_backends, block_size, device_id, io_size) -> UcmKVStoreBase: + config = {} + config["storage_backends"] = storage_backends + config["kv_block_size"] = block_size + config["role"] = "worker" + config["device"] = device_id + config["io_size"] = io_size + return UcmPcStore(config) + + +def make_buffers( + block_number, device_id, batch_size, block_dim, block_len, block_layer +): + hashes = [secrets.token_hex(16) for _ in range(block_number)] + tensors = [ + [ + torch.rand( + [block_dim, block_len], + dtype=torch.bfloat16, + device="cuda:{}".format(device_id), + ) + for _ in range(block_layer) + ] + for _ in range(batch_size) + ] + return hashes, tensors + + +def embed(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]): + results = store.create(hashes) + assert sum(results) == 0 + block_ids = [] + offsets = [] + layers = [] + for hash_id, block in zip(hashes, tensors): + offset = 0 + for layer in block: + block_ids.append(hash_id) + offsets.append(offset) + layers.append(layer) + offset += layer.untyped_storage().size() + task = store.dump(block_ids, offsets, layers) + assert task.task_id > 0 + ret = store.wait(task) + assert ret == 0 + store.commit(hashes, True) + + +def fetch(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]): + founds = store.lookup(hashes) + for found in founds: + assert found + block_ids = [] + offsets = [] + layers = [] + for hash_id, block in zip(hashes, tensors): + offset = 0 + for layer in block: + block_ids.append(hash_id) + offsets.append(offset) + layers.append(layer) + offset += layer.untyped_storage().size() + task = store.load(block_ids, offsets, layers) + assert task.task_id > 0 + ret = store.wait(task) + assert ret == 0 + + +def cmp_and_print_diff(a, b, rtol=0.0, atol=0.0): + for r, (row_a, row_b) in enumerate(zip(a, b)): + for c, (ta, tb) in enumerate(zip(row_a, row_b)): + if not torch.allclose(ta, tb, rtol=rtol, atol=atol): + mask = ~torch.isclose(ta, tb, rtol=rtol, atol=atol) + diff_a = ta[mask].cpu() + diff_b = tb[mask].cpu() + print(f"DIFF at [{r}][{c}] total {mask.sum().item()} element(s)") + print(" a val:", diff_a.flatten()) + print(" b val:", diff_b.flatten()) + assert False + + +def store_all_hashes(hashes): + kvcache_block_hashes_file = "kvcache_block_hashes.txt" + current_directory = os.path.dirname(__file__) + file_path = os.path.join(current_directory, kvcache_block_hashes_file) + with open(file_path, "w", encoding="utf-8") as file: + for hs in hashes: + file.write(hs + "\n") + + +def main(): + storage_backends = "." + block_number = 4096 + device_id = 1 + block_dim = 576 + block_len = 64 + block_elem_size = 2 + block_layer = 61 + io_size = block_dim * block_len * block_elem_size + block_size = io_size * block_layer + batch_size = 64 + store = setup_store(storage_backends, block_size, device_id, io_size) + hashes, tensors = make_buffers( + block_number, device_id, batch_size, block_dim, block_len, block_layer + ) + total_batches = (block_number + batch_size - 1) // batch_size + for batch in range(total_batches): + start = batch_size * batch + end = min(start + batch_size, block_number) + tensors2 = [[torch.empty_like(t) for t in row] for row in tensors] + embed(store, hashes[start:end], tensors) + time.sleep(1) + fetch(store, hashes[start:end], tensors2) + cmp_and_print_diff(tensors, tensors2) + store_all_hashes(hashes) + + +if __name__ == "__main__": + os.environ["UC_LOGGER_LEVEL"] = "debug" + main() diff --git a/ucm/store/test/e2e/pcstore_fetch.py b/ucm/store/test/e2e/pcstore_fetch.py new file mode 100644 index 000000000..6299d387d --- /dev/null +++ b/ucm/store/test/e2e/pcstore_fetch.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +import os +import random +from typing import List + +import torch + +from ucm.store.pcstore.pcstore_connector import UcmPcStore +from ucm.store.ucmstore import UcmKVStoreBase + + +def setup_store(storage_backends, block_size, device_id, io_size) -> UcmKVStoreBase: + config = {} + config["storage_backends"] = storage_backends + config["kv_block_size"] = block_size + config["role"] = "worker" + config["device"] = device_id + config["io_size"] = io_size + return UcmPcStore(config) + + +def get_hashes(batch_size, batch_number): + kvcache_block_hashes_file = "kvcache_block_hashes.txt" + current_directory = os.path.dirname(__file__) + file_path = os.path.join(current_directory, kvcache_block_hashes_file) + with open(file_path, "r", encoding="utf-8") as file: + lines = file.readlines() + total = [line.strip() for line in lines] + hashes = [] + for _ in range(batch_number): + hashes.extend(random.sample(total, batch_size)) + return hashes + + +def make_buffers(device_id, batch_size, block_dim, block_len, block_layer): + tensors = [ + [ + torch.rand( + [block_dim, block_len], + dtype=torch.bfloat16, + device="cuda:{}".format(device_id), + ) + for _ in range(block_layer) + ] + for _ in range(batch_size) + ] + return tensors + + +def fetch(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]): + founds = store.lookup(hashes) + for found in founds: + assert found + block_ids = [] + offsets = [] + layers = [] + for hash_id, block in zip(hashes, tensors): + offset = 0 + for layer in block: + block_ids.append(hash_id) + offsets.append(offset) + layers.append(layer) + offset += layer.untyped_storage().size() + task = store.load(block_ids, offsets, layers) + assert task.task_id > 0 + ret = store.wait(task) + assert ret == 0 + + +def main(): + storage_backends = "." + device_id = 1 + block_dim = 576 + block_len = 64 + block_elem_size = 2 + block_layer = 61 + io_size = block_dim * block_len * block_elem_size + block_size = io_size * block_layer + batch_size = 64 + batch_number = 128 + store = setup_store(storage_backends, block_size, device_id, io_size) + hashes = get_hashes(batch_size, batch_number) + tensors = make_buffers(device_id, batch_size, block_dim, block_len, block_layer) + for batch in range(batch_number): + start = batch_size * batch + end = start + batch_size + fetch(store, hashes[start:end], tensors) + + +if __name__ == "__main__": + os.environ["UC_LOGGER_LEVEL"] = "debug" + main() diff --git a/ucm/store/ucmstore.h b/ucm/store/ucmstore.h index 625e83ec9..bd841de83 100644 --- a/ucm/store/ucmstore.h +++ b/ucm/store/ucmstore.h @@ -24,49 +24,15 @@ #ifndef UNIFIEDCACHE_STORE_H #define UNIFIEDCACHE_STORE_H -#include -#include -#include +#include "task/task_shard.h" namespace UC { +template class CCStore { using BlockId = std::string; using TaskId = size_t; -public: - static constexpr TaskId invalidTaskId = 0; - class Task { - public: - struct Shard { - size_t index; - BlockId block; - size_t offset; - uintptr_t address; - }; - enum class Type { DUMP, LOAD }; - enum class Location { HOST, DEVICE }; - Type type; - Location location; - std::string brief; - size_t number; - size_t size; - std::list shards; - Task(const Type type, const Location location, const std::string& brief) - : type{type}, location{location}, brief{brief}, number{0}, size{0} - { - } - int32_t Append(const BlockId& block, const size_t offset, const uintptr_t address, - const size_t length) - { - if (this->number == 0) { this->size = length; } - if (this->size != length) { return -1; } - this->shards.emplace_back({this->number, block, offset, address}); - this->number++; - return 0; - } - }; - public: virtual ~CCStore() = default; virtual int32_t Alloc(const BlockId& block) = 0; @@ -75,7 +41,7 @@ class CCStore { virtual std::list Alloc(const std::list& blocks) = 0; virtual std::list Lookup(const std::list& blocks) = 0; virtual void Commit(const std::list& blocks, const bool success) = 0; - virtual TaskId Submit(Task&& task) = 0; + virtual TaskId Submit(T&& task) = 0; virtual int32_t Wait(const TaskId task) = 0; virtual int32_t Check(const TaskId task, bool& finish) = 0; }; diff --git a/ucm/store/ucmstore.py b/ucm/store/ucmstore.py index 1b0f7e125..f473bab54 100644 --- a/ucm/store/ucmstore.py +++ b/ucm/store/ucmstore.py @@ -113,7 +113,7 @@ def dump( self, block_ids: List[str], offset: List[int], src_tensor: List[torch.Tensor] ) -> Task: """ - dump kv cache to device. + dump kv cache from device. Args: block_ids (List[str]): vLLM block hash. @@ -124,6 +124,48 @@ def dump( """ pass + @abstractmethod + def fetch_data( + self, + block_ids: List[str], + offset: List[int], + dst_addr: List[int], + size: List[int], + ) -> Task: + """ + load kv cache data to device. + + Args: + block_ids (List[str]): vLLM block hash. + offset(List[int]): tp > 1 scene + dst_addr: List[int]: device tensor addr ptr. + size: List[int]: device tensor size. + Returns: + task(Task). + """ + pass + + @abstractmethod + def dump_data( + self, + block_ids: List[str], + offset: List[int], + src_addr: List[int], + size: List[int], + ) -> Task: + """ + dump kv cache data from device. + + Args: + block_ids (List[str]): vLLM block hash. + offset(List[int]): tp > 1 scene + src_addr: List[int]: device tensor addr ptr. + size: List[int]: device tensor size. + Returns: + task(Task). + """ + pass + @abstractmethod def wait(self, task: Task) -> int: """ diff --git a/ucm/store/vendor/CMakeLists.txt b/ucm/store/vendor/CMakeLists.txt deleted file mode 100644 index 15f11539b..000000000 --- a/ucm/store/vendor/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -function(EnableDept name url tag) - if(DOWNLOAD_DEPENDENCE) - FetchContent_Declare(${name} GIT_REPOSITORY ${url} GIT_TAG ${tag} GIT_SHALLOW TRUE) - FetchContent_MakeAvailable(${name}) - else() - add_subdirectory(${name}) - endif() -endfunction() - -include(FetchContent) -EnableDept(fmt https://github.com/fmtlib/fmt.git 11.2.0) -if(LOGGER_BACKEND STREQUAL "spdlog") - EnableDept(spdlog https://github.com/gabime/spdlog.git v1.15.3) -endif() -EnableDept(pybind11 https://github.com/pybind/pybind11.git v2.13.6) diff --git a/ucm/utils.py b/ucm/utils.py new file mode 100644 index 000000000..bf07f6b84 --- /dev/null +++ b/ucm/utils.py @@ -0,0 +1,90 @@ +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +from typing import Any, Dict + +import yaml + +from ucm.logger import init_logger + +logger = init_logger(__name__) + + +class Config: + def __init__(self, kv_transfer_config: Any): + self.kv_transfer_config = kv_transfer_config + self.config: Dict[str, Any] = {} + self._load_config() + + def load_ucm_config_from_yaml(self, file_path: str) -> Dict[str, Any]: + if not file_path: + logger.warning("No UCM config file path provided.") + return {} + + try: + with open(file_path, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) or {} + if not isinstance(config, dict): + logger.warning( + f"Config file {file_path} does not contain a dictionary. " + "Returning empty config." + ) + return {} + logger.info(f"Loaded UCM config from {file_path}") + return config + except FileNotFoundError: + logger.error(f"UCM config file not found: {file_path}") + return {} + except yaml.YAMLError as e: + logger.error(f"Failed to parse YAML config file {file_path}: {e}") + return {} + + def _load_config(self) -> None: + has_extra_config = ( + self.kv_transfer_config is not None + and hasattr(self.kv_transfer_config, "kv_connector_extra_config") + and self.kv_transfer_config.kv_connector_extra_config is not None + ) + if not has_extra_config: + self.config = self._get_default_config() + else: + extra_config = self.kv_transfer_config.kv_connector_extra_config + if "UCM_CONFIG_FILE" in extra_config: + config_file = extra_config["UCM_CONFIG_FILE"] + self.config = self.load_ucm_config_from_yaml(config_file) + else: + if extra_config == {}: + self.config = self._get_default_config() + else: + self.config = dict(extra_config) + logger.info("Using kv_connector_extra_config from terminal input") + + def _get_default_config(self) -> Dict[str, Any]: + config = {"ucm_connector_name": "UcmDramStore"} + logger.warning(f"No UCM config provided, using default configuration {config}") + return config + + def get_config(self) -> Dict[str, Any]: + logger.info(f"Using UCM with config: {self.config}") + return self.config