diff --git a/.github/workflows/unifiedcache_test.yml b/.github/workflows/unifiedcache_test.yml index aa0dee2ad..ead608ba8 100644 --- a/.github/workflows/unifiedcache_test.yml +++ b/.github/workflows/unifiedcache_test.yml @@ -18,40 +18,3 @@ jobs: call-lint: uses: ./.github/workflows/pre-commit.yml - - unit-test: - needs: call-lint - name: Run Unittests - runs-on: ubuntu-latest - steps: - - name: Free disk space - run: | - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf "/usr/local/share/boost" - sudo rm -rf "$AGENT_TOOLSDIRECTORY" - docker system prune -af - df -h - - - name: Checkout unified-cache-management repo - uses: actions/checkout@v4 - - - name: Run unit test inside vLLM container - run: | - docker run --rm \ - -e VLLM_USE_PRECOMPILED=1 \ - -e PLATFORM=cuda \ - -v ${{ github.workspace }}:/workspace/unified-cache-management \ - -w /workspace/unified-cache-management \ - --entrypoint /bin/bash \ - vllm/vllm-openai:v0.9.2 \ - -c " - set -euo pipefail - pip install -v -e . --no-build-isolation - cd \$(pip show vllm | grep Location | awk '{print \$2}') && - git apply /workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-pc.patch - git apply /workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-aggre.patch - git apply /workspace/unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch - cd /workspace/unified-cache-management - python3 -m unittest discover -s test - " diff --git a/docs/source/getting-started/installation_gpu.md b/docs/source/getting-started/installation_gpu.md index eaf1d3d05..4a0cda2f7 100644 --- a/docs/source/getting-started/installation_gpu.md +++ b/docs/source/getting-started/installation_gpu.md @@ -51,7 +51,7 @@ export PLATFORM=cuda pip install -v -e . --no-build-isolation ``` -**Note:** Patches are now applied automatically via dynamic patching when you import the unified-cache-management package. You no longer need to manually apply patches using `git apply`. The patches are automatically applied when you use the `UnifiedCacheConnectorV1` connector. +**Note:** Patches are now applied automatically via dynamic patching when you import the unified-cache-management package. You no longer need to manually apply patches using `git apply`. The patches are automatically applied when you use the `UCMConnector` connector. ## Setup from docker diff --git a/docs/source/user-guide/pd-disaggregation/1p1d.md b/docs/source/user-guide/pd-disaggregation/1p1d.md index fb3f4d056..c2bdb47e2 100644 --- a/docs/source/user-guide/pd-disaggregation/1p1d.md +++ b/docs/source/user-guide/pd-disaggregation/1p1d.md @@ -26,8 +26,8 @@ vllm serve /home/models/Qwen2.5-7B-Instruct \ --block-size 128 \ --kv-transfer-config \ '{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", "kv_role": "kv_producer", "kv_connector_extra_config": { "ucm_connector_name": "UcmNfsStore", @@ -55,8 +55,8 @@ vllm serve /home/models/Qwen2.5-7B-Instruct \ --block-size 128 \ --kv-transfer-config \ '{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", "kv_role": "kv_consumer", "kv_connector_extra_config": { "ucm_connector_name": "UcmNfsStore", diff --git a/docs/source/user-guide/pd-disaggregation/npgd.md b/docs/source/user-guide/pd-disaggregation/npgd.md index c4919779a..05bbc0823 100644 --- a/docs/source/user-guide/pd-disaggregation/npgd.md +++ b/docs/source/user-guide/pd-disaggregation/npgd.md @@ -33,8 +33,8 @@ vllm serve /home/models/Qwen2.5-7B-Instruct \ --dtype bfloat16 \ --kv-transfer-config \ '{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", "kv_role": "kv_producer", "kv_connector_extra_config": { "ucm_connector_name": "UcmNfsStore", @@ -63,8 +63,8 @@ vllm serve /home/models/Qwen2.5-7B-Instruct \ --dtype bfloat16 \ --kv-transfer-config \ '{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", "kv_role": "kv_consumer", "kv_connector_extra_config": { "ucm_connector_name": "UcmNfsStore", diff --git a/docs/source/user-guide/pd-disaggregation/xpyd.md b/docs/source/user-guide/pd-disaggregation/xpyd.md index a57ab5d2f..b21f19ada 100644 --- a/docs/source/user-guide/pd-disaggregation/xpyd.md +++ b/docs/source/user-guide/pd-disaggregation/xpyd.md @@ -26,8 +26,8 @@ vllm serve /home/models/Qwen2.5-7B-Instruct \ --block-size 128 \ --kv-transfer-config \ '{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", "kv_role": "kv_producer", "kv_connector_extra_config": { "ucm_connector_name": "UcmNfsStore", @@ -54,8 +54,8 @@ vllm serve /home/models/Qwen2.5-7B-Instruct \ --block-size 128 \ --kv-transfer-config \ '{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", "kv_role": "kv_producer", "kv_connector_extra_config": { "ucm_connector_name": "UcmNfsStore", @@ -83,8 +83,8 @@ vllm serve /home/models/Qwen2.5-7B-Instruct \ --block-size 128 \ --kv-transfer-config \ '{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", "kv_role": "kv_consumer", "kv_connector_extra_config": { "ucm_connector_name": "UcmNfsStore", @@ -110,8 +110,8 @@ vllm serve /home/models/Qwen2.5-7B-Instruct \ --block-size 128 \ --kv-transfer-config \ '{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", "kv_role": "kv_consumer", "kv_connector_extra_config": { "ucm_connector_name": "UcmNfsStore", diff --git a/docs/source/user-guide/prefix-cache/nfs_store.md b/docs/source/user-guide/prefix-cache/nfs_store.md index 741fcedf7..cd0800eeb 100644 --- a/docs/source/user-guide/prefix-cache/nfs_store.md +++ b/docs/source/user-guide/prefix-cache/nfs_store.md @@ -135,8 +135,8 @@ vllm serve /home/models/Qwen2.5-14B-Instruct \ --port 7800 \ --kv-transfer-config \ '{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", "kv_role": "kv_both", "kv_connector_extra_config": {"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} }' diff --git a/docs/source/user-guide/sparse-attention/gsa.md b/docs/source/user-guide/sparse-attention/gsa.md index 5a96287a3..601046d13 100644 --- a/docs/source/user-guide/sparse-attention/gsa.md +++ b/docs/source/user-guide/sparse-attention/gsa.md @@ -88,7 +88,7 @@ Similar to UCM's `offline_inference_esa.py` examples. We only need to specify `u ... ktc = KVTransferConfig( kv_connector=name, - kv_connector_module_path="ucm.integration.vllm.uc_connector", + kv_connector_module_path="ucm.integration.vllm.ucm_connector", kv_role="kv_both", kv_connector_extra_config={ "ucm_connector_name": "UcmNfsStore", @@ -121,7 +121,7 @@ vllm serve /home/models/DeepSeek-R1-Distill-Qwen-32B \ --kv-transfer-config \ '{ "kv_connector": name, - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", "kv_role": "kv_both", "kv_connector_extra_config": { "ucm_connector_name": "UcmNfsStore", diff --git a/examples/ucm_config_example.yaml b/examples/ucm_config_example.yaml index 32c357e1b..c0c3534ae 100644 --- a/examples/ucm_config_example.yaml +++ b/examples/ucm_config_example.yaml @@ -32,7 +32,7 @@ load_only_first_rank: false # GSA: {} -# Whether to use layerwise loading/saving (optional, default: True for UnifiedCacheConnectorV1) +# Whether to use layerwise loading/saving (optional, default: True for UCMConnector) # use_layerwise: true # hit_ratio: 0.9 diff --git a/test/test_uc_connector.py b/test/test_uc_connector.py deleted file mode 100644 index 0c2261d87..000000000 --- a/test/test_uc_connector.py +++ /dev/null @@ -1,588 +0,0 @@ -# -# MIT License -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# - -import random -import secrets -import unittest -from collections import defaultdict -from typing import List, Union -from unittest.mock import MagicMock, Mock, patch - -import torch -from vllm.multimodal.inputs import MultiModalKwargs -from vllm.sampling_params import SamplingParams -from vllm.v1.core.kv_cache_manager import KVCacheBlocks -from vllm.v1.request import Request - -from ucm.integration.vllm.uc_connector import ( - BlockOperation, - ReqMeta, - RequestBlockInfo, - UCConnectorV1Metadata, - UnifiedCacheConnectorV1, -) -from ucm.store.ucmstore import Task, UcmKVStoreBase - - -def make_request( - request_id, prompt_token_ids, mm_positions=None, mm_hashes=None, cache_salt=None -): - if mm_positions is None: - multi_model_inputs = None - else: - multi_model_inputs = [MultiModalKwargs({})] * len(mm_positions) - - return Request( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_inputs=multi_model_inputs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17), - pooling_params=None, - eos_token_id=100, - arrival_time=0, - lora_request=None, - cache_salt=cache_salt, - ) - - -class TestUCConnector(unittest.TestCase): - - @classmethod - def setUpClass(cls): - print("===> Before all tests (setUpClass)") - - @classmethod - def tearDownClass(cls): - print("===> Before all tests (tearDownClass)") - - def setUp(self): - self.block_number = 4 - self.block_size = 8 - self.num_layers = 48 - self.total_blocks_num = 40 - self.total_tp_size = 2 - self.kv_caches = {} - for i in range(self.num_layers): - layer_name = f"model.layers.{i}.self_attn.attn" - kv_tensor = torch.rand( - (2, self.total_blocks_num, self.block_size, 4, 8), dtype=torch.bfloat16 - ) - self.kv_caches[layer_name] = kv_tensor - - def init_uc( - self, mock_connector, metadata=Mock(), use_layerwise=True - ) -> UnifiedCacheConnectorV1: - with patch.object(UnifiedCacheConnectorV1, "__init__", return_value=None): - ucconnector = UnifiedCacheConnectorV1(None, None) - ucconnector.block_size = self.block_size - ucconnector.use_layerwise = use_layerwise - ucconnector.kv_caches = self.kv_caches - ucconnector.rank = 1 - ucconnector.is_mla = False - ucconnector.connector = mock_connector - ucconnector.request_block_infos: dict[str, RequestBlockInfo] = {} - ucconnector.dump_tasks: dict[str, dict[str, List[Task]]] = {} - ucconnector.total_tp_size = self.total_tp_size - ucconnector._connector_metadata = metadata - ucconnector.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict( - dict - ) - ucconnector._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {} - ucconnector._load_failed_reqs: set[str] = set() - ucconnector._load_req_to_blocks: dict[str, set[int]] = {} - ucconnector.num_layers = 48 - ucconnector.is_mla = False - return ucconnector - - def test_get_num_new_matched_tokens_hit_all_on_storage(self): - mock_connector = Mock(spec=UcmKVStoreBase) - - def mock_lookup(block_hashes: List[str]) -> List[bool]: - return [True] * self.block_number - - mock_connector.lookup.side_effect = mock_lookup - ucconnector = self.init_uc(mock_connector) - - random.seed(20250704) - request1 = make_request( - request_id=1, - prompt_token_ids=random.sample( - range(0, 10000), self.block_number * self.block_size - ), - mm_positions=None, - mm_hashes=None, - ) - - # all block dumped in ssd, external_tokens equals to full tokens num - self.block_size - all_tokens_len = len(request1.all_token_ids) - external_tokens, _ = ucconnector.get_num_new_matched_tokens(request1, 0) - self.assertEqual(external_tokens, all_tokens_len - self.block_size) - self.assertEqual( - ucconnector.request_block_infos[request1.request_id].block_operations, - [ - BlockOperation.LOAD, - BlockOperation.LOAD, - BlockOperation.LOAD, - BlockOperation.NONE, - ], - ) - - def test_get_num_new_matched_tokens_partial_hit(self): - mock_connector = Mock(spec=UcmKVStoreBase) - - def mock_lookup(block_hashes: List[str]) -> List[bool]: - return [True, False, True, False] - - mock_connector.lookup.side_effect = mock_lookup - ucconnector = self.init_uc(mock_connector) - - random.seed(20250704) - request1 = make_request( - request_id=1, - prompt_token_ids=random.sample( - range(0, 10000), self.block_number * self.block_size - ), - mm_positions=None, - mm_hashes=None, - ) - - # all block dumped in ssd, external_tokens equals to full tokens num - self.block_size - all_tokens_len = len(request1.all_token_ids) - external_tokens, _ = ucconnector.get_num_new_matched_tokens(request1, 0) - self.assertEqual(external_tokens, self.block_size) - self.assertEqual( - ucconnector.request_block_infos[request1.request_id].block_operations, - [ - BlockOperation.LOAD, - BlockOperation.NONE, - BlockOperation.NONE, - BlockOperation.NONE, - ], - ) - - def test_get_num_new_matched_tokens_partial_hit_with_preftxcache(self): - mock_connector = Mock(spec=UcmKVStoreBase) - - def mock_lookup(block_hashes: List[str]) -> List[bool]: - return [True, True, False] - - mock_connector.lookup.side_effect = mock_lookup - ucconnector = self.init_uc(mock_connector) - - random.seed(20250704) - request1 = make_request( - request_id=1, - prompt_token_ids=random.sample( - range(0, 10000), self.block_number * self.block_size - ), - mm_positions=None, - mm_hashes=None, - ) - - # 1 block hit on hbm cache, 2 block hit on ssd, 1 blocks miss - external_tokens, _ = ucconnector.get_num_new_matched_tokens( - request1, self.block_size - ) - self.assertEqual(external_tokens, 2 * self.block_size) - self.assertEqual( - ucconnector.request_block_infos[request1.request_id].start_position, 1 - ) - self.assertEqual( - ucconnector.request_block_infos[request1.request_id].block_operations, - [ - BlockOperation.NONE, - BlockOperation.LOAD, - BlockOperation.LOAD, - BlockOperation.NONE, - ], - ) - - def test_get_num_new_matched_tokens_partial_hit_with_load_async(self): - mock_connector = Mock(spec=UcmKVStoreBase) - - def mock_lookup(block_hashes: List[str]) -> List[bool]: - return [True, True, False] - - mock_connector.lookup.side_effect = mock_lookup - ucconnector = self.init_uc(mock_connector) - ucconnector.kv_role = "kv_consumer" - - random.seed(20250704) - request1 = make_request( - request_id=1, - prompt_token_ids=random.sample( - range(0, 10000), self.block_number * self.block_size - ), - mm_positions=None, - mm_hashes=None, - ) - if request1.kv_transfer_params is None: - request1.kv_transfer_params = {} - request1.kv_transfer_params["load_async"] = True - - # 1 block hit on hbm cache, 2 block hit on ssd, 1 blocks miss - external_tokens, load_async = ucconnector.get_num_new_matched_tokens( - request1, self.block_size - ) - self.assertEqual(external_tokens, 2 * self.block_size) - self.assertEqual( - ucconnector.request_block_infos[request1.request_id].start_position, 1 - ) - self.assertEqual(request1.kv_transfer_params["load_async"], False) - self.assertEqual( - ucconnector.request_block_infos[request1.request_id].block_operations, - [ - BlockOperation.NONE, - BlockOperation.LOAD, - BlockOperation.LOAD, - BlockOperation.NONE, - ], - ) - - def test_update_state_after_alloc_create_success(self): - mock_connector = Mock(spec=UcmKVStoreBase) - - def mock_create(block_hashes: List[str]) -> List[int]: - return [0] - - mock_connector.create.side_effect = mock_create - ucconnector = self.init_uc(mock_connector) - request1 = make_request( - request_id=1, - prompt_token_ids=random.sample( - range(0, 10000), self.block_number * self.block_size - ), - mm_positions=None, - mm_hashes=None, - ) - ucconnector.request_block_infos[request1.request_id] = RequestBlockInfo( - block_hashes=[secrets.token_hex(8) for _ in range(self.block_number)], - block_operations=[ - BlockOperation.NONE, - BlockOperation.LOAD, - BlockOperation.LOAD, - BlockOperation.NONE, - ], - start_position=1, - ) - - vllm_blocks = Mock() - vllm_blocks.get_unhashed_block_ids.return_value = [0, 1, 2, 3] - ucconnector.update_state_after_alloc(request1, vllm_blocks, 2 * self.block_size) - self.assertEqual( - ucconnector.request_block_infos[request1.request_id].block_operations, - [ - BlockOperation.NONE, - BlockOperation.LOAD, - BlockOperation.LOAD, - BlockOperation.DUMP, - ], - ) - self.assertEqual( - ucconnector.request_block_infos[request1.request_id].start_position, 0 - ) - - def test_update_state_after_alloc_with_load_async(self): - mock_connector = Mock(spec=UcmKVStoreBase) - - ucconnector = self.init_uc(mock_connector) - request1 = make_request( - request_id=1, - prompt_token_ids=random.sample( - range(0, 10000), self.block_number * self.block_size - ), - mm_positions=None, - mm_hashes=None, - ) - ucconnector.request_block_infos[request1.request_id] = RequestBlockInfo( - block_hashes=[secrets.token_hex(8) for _ in range(self.block_number)], - block_operations=[ - BlockOperation.NONE, - BlockOperation.LOAD, - BlockOperation.LOAD, - BlockOperation.NONE, - ], - start_position=1, - ) - ucconnector._need_load_reqs[request1.request_id] = [] - vllm_blocks = Mock() - vllm_blocks.get_unhashed_block_ids.return_value = [1, 2, 3] - ucconnector.update_state_after_alloc(request1, vllm_blocks, 2 * self.block_size) - self.assertEqual( - ucconnector.request_block_infos[request1.request_id].block_operations, - [ - BlockOperation.NONE, - BlockOperation.LOAD, - BlockOperation.LOAD, - BlockOperation.NONE, - ], - ) - self.assertEqual( - ucconnector.request_block_infos[request1.request_id].start_position, 1 - ) - - def test_build_connector_meta(self): - mock_connector = Mock(spec=UcmKVStoreBase) - - ucconnector = self.init_uc(mock_connector) - request1 = make_request( - request_id=1, - prompt_token_ids=random.sample( - range(0, 10000), self.block_number * self.block_size - ), - mm_positions=None, - mm_hashes=None, - ) - request2 = make_request( - request_id=2, - prompt_token_ids=random.sample( - range(0, 10000), self.block_number * self.block_size - ), - mm_positions=None, - mm_hashes=None, - ) - ucconnector.request_block_infos[request1.request_id] = RequestBlockInfo( - block_hashes=[secrets.token_hex(8) for _ in range(self.block_number)], - block_operations=[ - BlockOperation.NONE, - BlockOperation.LOAD, - BlockOperation.LOAD, - BlockOperation.DUMP, - ], - start_position=0, - ) - ucconnector.request_block_infos[request2.request_id] = RequestBlockInfo( - block_hashes=[secrets.token_hex(8) for _ in range(self.block_number)], - block_operations=[ - BlockOperation.NONE, - BlockOperation.LOAD, - BlockOperation.LOAD, - BlockOperation.DUMP, - ], - start_position=2, - ) - - from types import SimpleNamespace - - scheduled_new_reqs = [ - SimpleNamespace(req_id=request1.request_id, block_ids=[[1, 2, 3, 4]]), - ] - - scheduled_cached_reqs = SimpleNamespace( - req_ids=[request2.request_id], - new_block_ids=[[[5, 6]]], - ) - - scheduler_output = SimpleNamespace( - scheduled_new_reqs=scheduled_new_reqs, - scheduled_cached_reqs=scheduled_cached_reqs, - ) - - meta = ucconnector.build_connector_meta(scheduler_output) - self.assertIsInstance(meta, UCConnectorV1Metadata) - new_req_meta = meta.requests[0] - self.assertEqual( - new_req_meta.load_blocks, - [ - ( - ucconnector.request_block_infos[request1.request_id].block_hashes[ - 1 - ], - 2, - ), - ( - ucconnector.request_block_infos[request1.request_id].block_hashes[ - 2 - ], - 3, - ), - ], - ) - self.assertEqual( - new_req_meta.dump_blocks, - [(ucconnector.request_block_infos[request1.request_id].block_hashes[3], 4)], - ) - cache_req_meta = meta.requests[1] - self.assertEqual( - cache_req_meta.load_blocks, - [(ucconnector.request_block_infos[request2.request_id].block_hashes[2], 5)], - ) - self.assertEqual( - cache_req_meta.dump_blocks, - [(ucconnector.request_block_infos[request2.request_id].block_hashes[3], 6)], - ) - - def test_get_num_new_matched_tokens_no_hit(self): - mock_connector = Mock(spec=UcmKVStoreBase) - - def mock_lookup(blocks: List[str]) -> List[bool]: - return [False] * self.block_number - - mock_connector.lookup.side_effect = mock_lookup - ucconnector = self.init_uc(mock_connector) - - random.seed(20250704) - request1 = make_request( - request_id=1, - prompt_token_ids=random.sample( - range(0, 10000), self.block_number * self.block_size - ), - mm_positions=None, - mm_hashes=None, - ) - - external_tokens, _ = ucconnector.get_num_new_matched_tokens(request1, 0) - self.assertEqual(external_tokens, 0) - - def test_get_num_new_matched_tokens_invalid_para(self): - with patch.object(UnifiedCacheConnectorV1, "__init__", return_value=None): - ucconnector = UnifiedCacheConnectorV1(None, None) - ucconnector.block_size = self.block_size - - request1 = make_request( - request_id=1, - prompt_token_ids=random.sample( - range(0, 10000), self.block_number * self.block_size - ), - mm_positions=None, - mm_hashes=None, - ) - - # passing invalid params - with self.assertRaises(AssertionError): - external_tokens, _ = ucconnector.get_num_new_matched_tokens( - request1, self.block_size + 1 - ) - - def test_wait_for_save_not_layerwise_success(self): - req_meta1 = MagicMock(spec=ReqMeta) - req_meta1.request_id = "req1" - req_meta1.dump_blocks = [ - (secrets.token_hex(8), i) for i in range(self.block_number) - ] - - metadata = UCConnectorV1Metadata() - metadata.requests = [req_meta1] - - mock_connector = Mock(spec=UcmKVStoreBase) - - def mock_dump( - block_ids: List[str], offset: List[int], src_tensor: List[torch.Tensor] - ) -> Task: - assert len(offset) == len(src_tensor) == len(block_ids) - return Task() - - def mock_wait(task: Task) -> int: - return 0 - - mock_connector.dump.side_effect = mock_dump - mock_connector.wait.side_effect = mock_wait - ucconnector = self.init_uc( - mock_connector, metadata=metadata, use_layerwise=False - ) - ucconnector.wait_for_save() - - def test_wait_for_save_not_layerwise_invalid_para(self): - with patch.object(UnifiedCacheConnectorV1, "__init__", return_value=None): - ucconnector = UnifiedCacheConnectorV1(None, None) - ucconnector.block_size = self.block_size - ucconnector.use_layerwise = False - ucconnector._connector_metadata = Mock() - ucconnector.is_mla = False - - with self.assertRaises(AssertionError): - ucconnector.wait_for_save() - - def test_start_load_kv_not_layerwise_success(self): - req_meta1 = MagicMock(spec=ReqMeta) - req_meta1.request_id = "req1" - req_meta1.load_blocks = [ - (secrets.token_hex(8), i) for i in range(self.block_number) - ] - req_meta1.load_async = False - - metadata = UCConnectorV1Metadata() - metadata.requests = [req_meta1] - - mock_connector = Mock(spec=UcmKVStoreBase) - - def mock_load( - block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor] - ) -> Task: - assert len(offset) == len(dst_tensor) == len(block_ids) - return Task() - - def mock_wait(task: Task) -> int: - return 0 - - mock_connector.load.side_effect = mock_load - mock_connector.wait.side_effect = mock_wait - - ucconnector = self.init_uc( - mock_connector, metadata=metadata, use_layerwise=False - ) - forward_context = Mock() - ucconnector.start_load_kv(forward_context) - assert mock_connector.load.call_count == 1 - - def test_start_load_kv_invalid_para(self): - with patch.object(UnifiedCacheConnectorV1, "__init__", return_value=None): - ucconnector = UnifiedCacheConnectorV1(None, None) - ucconnector.block_size = self.block_size - ucconnector._connector_metadata = Mock() - - forward_context = Mock() - with self.assertRaises(AssertionError): - ucconnector.start_load_kv(forward_context) - - def test_start_load_kv_layerwise_success(self): - req_meta1 = MagicMock(spec=ReqMeta) - req_meta1.request_id = "req1" - req_meta1.load_blocks = [ - (secrets.token_hex(8), i) for i in range(self.block_number) - ] - req_meta1.load_async = False - - metadata = UCConnectorV1Metadata() - metadata.requests = [req_meta1] - - mock_connector = Mock(spec=UcmKVStoreBase) - - def mock_load( - block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor] - ) -> Task: - assert len(offset) == len(dst_tensor) == len(block_ids) - return Task() - - mock_connector.load.side_effect = mock_load - ucconnector = self.init_uc(mock_connector, metadata=metadata) - forward_context = Mock() - ucconnector.start_load_kv(forward_context) - assert mock_connector.load.call_count == self.num_layers - - -if __name__ == "__main__": - unittest.main() diff --git a/ucm/__init__.py b/ucm/__init__.py index 8052a3998..12890cf92 100644 --- a/ucm/__init__.py +++ b/ucm/__init__.py @@ -1,4 +1,3 @@ -from ucm.integration.vllm.uc_connector import UnifiedCacheConnectorV1 from ucm.integration.vllm.ucm_connector import UCMConnector try: diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py deleted file mode 100644 index c8317007b..000000000 --- a/ucm/integration/vllm/uc_connector.py +++ /dev/null @@ -1,804 +0,0 @@ -# -# MIT License -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# -# Adapted from lmcache/lmcache/integration/vllm/vllm_v1_adapter.py -# -import hashlib -import pickle -from collections import defaultdict -from dataclasses import dataclass, field -from enum import Enum -from typing import TYPE_CHECKING, Any, List, Optional, Union - -import torch -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, - KVConnectorMetadata, - KVConnectorRole, -) -from vllm.distributed.parallel_state import get_world_group -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.request import Request, RequestStatus - -from ucm.logger import init_logger -from ucm.store.factory import UcmConnectorFactory -from ucm.store.ucmstore import Task -from ucm.utils import Config - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata - from vllm.forward_context import ForwardContext - from vllm.v1.core.kv_cache_manager import KVCacheBlocks - -logger = init_logger(__name__) - - -class BlockOperation(Enum): - NONE = "none" - LOAD = "load" - DUMP = "dump" - - -@dataclass -class RequestBlockInfo: - # Hash values for all blocks - block_hashes: list[str] = field(default_factory=list) - # Operation type for each block - block_operations: list[BlockOperation] = field(default_factory=list) - # Next block position to process - start_position: int = 0 - - -@dataclass -class ReqMeta: - request_id: str - # list[(block_hash, vllm_block_id)] - load_blocks: list[tuple[str, int]] = field(default_factory=list) - # list[(block_hash, vllm_block_id)] - dump_blocks: list[tuple[str, int]] = field(default_factory=list) - # Whether use load_async - load_async: bool = False - - -@dataclass -class UCConnectorV1Metadata(KVConnectorMetadata): - requests: list[ReqMeta] = field(default_factory=list) - - -class UnifiedCacheConnectorV1(KVConnectorBase_V1): - - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) - self.block_size = vllm_config.cache_config.block_size - self.use_layerwise = False - self.kv_caches: dict[str, torch.Tensor] = {} - self.total_tp_size = vllm_config.parallel_config.tensor_parallel_size - self.rank = ( - -1 if role == KVConnectorRole.SCHEDULER else get_world_group().local_rank - ) - self.request_block_infos: dict[str, RequestBlockInfo] = {} - # dump tasks record request -> block -> list[task] - self.dump_tasks: dict[str, dict[str, List[Task]]] = {} - self.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict(dict) - self.is_mla = self._vllm_config.model_config.is_deepseek_mla - self.num_layers = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config - ) - self.element_size = vllm_config.model_config.dtype.itemsize - self.kv_role = vllm_config.kv_transfer_config.kv_role - self._need_load_reqs: dict[str, Union[list[int], Task]] = {} - self._load_failed_reqs: set[str] = set() - self._load_req_to_blocks: dict[str, set[int]] = {} - self.num_head = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config - ) - self.head_size = vllm_config.model_config.get_head_size() - ucm_config = Config(vllm_config.kv_transfer_config) - launch_config = ucm_config.get_config() - if "ucm_connector_name" in launch_config: - name = launch_config.get("ucm_connector_name") - config = launch_config.get("ucm_connector_config") or {} - config["device"] = self.rank - config["role"] = ( - "scheduler" if role == KVConnectorRole.SCHEDULER else "worker" - ) - config_base = self.block_size * self.element_size * self.head_size - config["kv_block_size"] = ( - config_base - * self.num_layers - * (1 if self.is_mla else self.num_head * self.total_tp_size * 2) - ) - config["io_size"] = config_base * (1 if self.is_mla else self.num_head) - logger.info( - "kv_block_size = %d, io_size = %d,", - config["kv_block_size"], - config["io_size"], - ) - logger.info("init UCConnectorImpl, connector: %s", name) - self.connector = UcmConnectorFactory.create_connector(name, config) - else: - raise TypeError(f"no storage connector.") - if ( - self._vllm_config.kv_transfer_config is not None - and "use_layerwise" - in self._vllm_config.kv_transfer_config.kv_connector_extra_config - ): - self.use_layerwise = ( - self._vllm_config.kv_transfer_config.kv_connector_extra_config[ - "use_layerwise" - ] - ) - - def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"): - for layer_name in forward_context.no_compile_layers: - attn_layer = forward_context.no_compile_layers[layer_name] - if not hasattr(attn_layer, "kv_cache"): - logger.debug("The layer %s does not have kv_cache, skip it", layer_name) - continue - - if layer_name not in self.kv_caches: - self.kv_caches[layer_name] = attn_layer.kv_cache[ - forward_context.virtual_engine - ] - - def DataOffset(self, kv_layer, rank, layer_id, is_v): - # Non-MLA scene: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size) - # MLA scene: one layer shape is (num_blocks, block_size, head_size) - # Element size - elem_size = kv_layer[0].element_size() - logger.debug( - f"total_tp_size = {self.total_tp_size},\n" f"element size = {elem_size}." - ) - # One block size - k_min_data_block_size = ( - kv_layer[0][0].numel() if not self.is_mla else kv_layer[0].numel() - ) * elem_size - v_min_data_block_size = ( - kv_layer[1][0].numel() if not self.is_mla else 0 - ) * elem_size - # When tp > 1 layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size - layer_size = (k_min_data_block_size + v_min_data_block_size) * ( - self.total_tp_size if not self.is_mla else 1 - ) - if is_v: - # Offset of v = Offset of k + k_min_data_block_size - return int( - self.DataOffset(kv_layer, rank, layer_id, False) + k_min_data_block_size - ) - if self.is_mla: - return int(layer_size * layer_id) - else: - # Offset of k = layer_size * layer_id + layer_size / tp_size * current rank - return int( - layer_size * layer_id + layer_size / self.total_tp_size * self.rank - ) - - def get_tensor_and_offset_layerwise( - self, vllm_block_ids: List[int], kv_layer: torch.Tensor, layer_name: str - ) -> tuple[List[torch.Tensor], List[int]]: - k_tensors = [] - k_offsets = [] - v_tensors = [] - v_offsets = [] - layer_id = self._extract_layer_index(layer_name) - - for blk_id in vllm_block_ids: - k_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, False) - if self.is_mla: - k_tensors.append(kv_layer[blk_id]) - else: - k_tensors.append(kv_layer[0][blk_id]) - k_offsets.append(k_data_offset) - if not self.is_mla: - v_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, True) - v_tensors.append(kv_layer[1][blk_id]) - v_offsets.append(v_data_offset) - return k_tensors + v_tensors, k_offsets + v_offsets - - # ============================== - # Worker-side methods - # ============================== - def clear_connector_metadata(self) -> None: - """Clear the connector metadata. - - This function should be called by the model runner every time - after the model execution. - """ - self._load_failed_reqs.clear() - self._load_req_to_blocks.clear() - super().clear_connector_metadata() - - def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: - """ - Start loading the KV cache from the connector to vLLM's paged - KV buffer. This is called from the forward context before the - forward pass to enable async loading during model execution. - - Args: - forward_context (ForwardContext): the forward context. - **kwargs: additional arguments for the load operation - - Note: - The number of elements in kv_caches and layer_names should be - the same. - - """ - metadata = self._get_connector_metadata() - assert isinstance(metadata, UCConnectorV1Metadata) - - if len(self.kv_caches) == 0: - self._init_kv_caches_from_forward_context(forward_context) - if len(list(self.kv_caches.values())[0]) == 2: - self.is_mla = False - - self.layerwise_load_tasks.clear() - self.current_layer = 0 - need_load_tasks: dict[str, Task] = {} - for request in metadata.requests: - if not request.load_blocks: - continue - - storage_block_ids = [block[0] for block in request.load_blocks] - vllm_block_ids = [block[1] for block in request.load_blocks] - self._load_req_to_blocks.setdefault(request.request_id, set()).update( - vllm_block_ids - ) - is_load_async = request.load_async - total_offsets = [] - total_tensors = [] - storage_block_ids = storage_block_ids * (1 if self.is_mla else 2) - for layer_name, kv_layer in self.kv_caches.items(): - tensors, offsets = self.get_tensor_and_offset_layerwise( - vllm_block_ids, kv_layer, layer_name - ) - if self.use_layerwise and not is_load_async: - task_id = self.connector.load(storage_block_ids, offsets, tensors) - self.layerwise_load_tasks[request.request_id][layer_name] = task_id - continue - else: - total_offsets.extend(offsets) - total_tensors.extend(tensors) - if total_offsets and total_tensors: - storage_block_ids = storage_block_ids * self.num_layers - task_id = self.connector.load( - storage_block_ids, total_offsets, total_tensors - ) - if is_load_async: - self._need_load_reqs[request.request_id] = task_id - else: - need_load_tasks[request.request_id] = task_id - for req_id, task_id in need_load_tasks.items(): - if self.connector.wait(task_id) != 0: - self._load_failed_reqs.add(req_id) - logger.error(f"Failed to load blocks for req {req_id}") - - def wait_for_layer_load(self, layer_name: str) -> None: - """ - Block until the KV for a specific layer is loaded into vLLM's - paged buffer. This is called from within attention layer to ensure - async copying from start_load_kv is complete. - - This interface will be useful for layer-by-layer pipelining. - - Args: - layer_name: the name of that layer - """ - if not self.use_layerwise: - return - if self.layerwise_load_tasks: - logger.debug(f"Waiting for layer {self.current_layer} to be loaded") - - if self.current_layer >= self.num_layers: - return - - for request_id, layer_to_task in self.layerwise_load_tasks.items(): - if request_id in self._load_failed_reqs: - continue - task = layer_to_task[layer_name] - if self.connector.wait(task) != 0: - self._load_failed_reqs.add(request_id) - logger.error( - f"Failed to load block for request {request_id} on layer {layer_name}" - ) - continue - logger.debug(f"Load tasks for {request_id} on layer {layer_name} finished.") - - def save_kv_layer( - self, - layer_name: str, - kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", - **kwargs, - ) -> None: - """ - Start saving the a layer of KV cache from vLLM's paged buffer - to the connector. This is called from within attention layer to - enable async copying during execution. - - Args: - layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current - layer in vLLM. - attn_metadata (AttentionMetadata): the attention metadata. - **kwargs: additional arguments for the save operation. - """ - if self.is_mla and self.rank != 0: - return - self.current_layer += 1 - if hasattr(self, "kv_role") and self.kv_role == "kv_consumer": - return - - if not self.use_layerwise: - return - - if self.current_layer > self.num_layers: - return - - metadata = self._get_connector_metadata() - assert isinstance(metadata, UCConnectorV1Metadata) - - for request in metadata.requests: - if not request.dump_blocks or request.load_async: - continue - - # Extract storage block IDs and vLLM block IDs from dump_blocks, same for load_blocks - # dump_blocks format: [(block_hash, vllm_block_id), ...] - # Note: block_hash is the storage_block_id - # Example: [("hash_123", 5), ("hash_456", 8), ("hash_789", 12)] - # ["hash_123", "hash_456", "hash_789"] - storage_block_ids = [block[0] for block in request.dump_blocks] - vllm_block_ids = [block[1] for block in request.dump_blocks] # [5, 8, 12] - blocks_len = len(storage_block_ids) - tensors, offsets = self.get_tensor_and_offset_layerwise( - vllm_block_ids, kv_layer, layer_name - ) - - if kv_layer[0].device.type == "npu": - torch.npu.current_stream().synchronize() - elif kv_layer[0].device.type == "cuda": - torch.cuda.current_stream().synchronize() - elif kv_layer[0].device.type == "musa": - torch.musa.current_stream().synchronize() - - for block_id, offset, tensor in zip( - storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] - ): - task = self.connector.dump([block_id], [offset], [tensor]) - self.dump_tasks.setdefault(request.request_id, {}).setdefault( - block_id, [] - ).append(task) - if not self.is_mla: - for block_id, offset, tensor in zip( - storage_block_ids, offsets[blocks_len:], tensors[blocks_len:] - ): - task = self.connector.dump([block_id], [offset], [tensor]) - self.dump_tasks.setdefault(request.request_id, {}).setdefault( - block_id, [] - ).append(task) - - def wait_for_save(self) -> Optional[dict[str, list[str]]]: - """ - Block until all the save operations is done. This is called - as the forward context exits to ensure that the async saving - from save_kv_layer is complete before finishing the forward. - - This prevents overwrites of paged KV buffer before saving done. - """ - if hasattr(self, "kv_role") and self.kv_role == "kv_consumer": - return - if self.is_mla and self.rank != 0: - return - # request id -> succeed dumped blocks - success_dumped_blocks: dict[str, list[str]] = {} - - def wait_for_tasks(): - for request_id, block_dump_tasks in self.dump_tasks.items(): - for block_id, dump_tasks in block_dump_tasks.items(): - if any(self.connector.wait(task) != 0 for task in dump_tasks): - continue - success_dumped_blocks.setdefault(request_id, []).append(block_id) - - metadata = self._get_connector_metadata() - assert isinstance(metadata, UCConnectorV1Metadata) - if self.use_layerwise: - wait_for_tasks() - # clear dump_tasks for all request - self.dump_tasks.clear() - return success_dumped_blocks if success_dumped_blocks else None - - req_to_dump_blocks: dict[str, list[str]] = {} - need_dump_tasks: dict[str, Task] = {} - for request in metadata.requests: - if not request.dump_blocks: - continue - - storage_block_ids = [block[0] for block in request.dump_blocks] - vllm_block_ids = [block[1] for block in request.dump_blocks] - req_to_dump_blocks[request.request_id] = storage_block_ids - total_offsets = [] - total_tensors = [] - total_block_ids = ( - storage_block_ids * (1 if self.is_mla else 2) * self.num_layers - ) - for layer_name, kv_layer in self.kv_caches.items(): - tensors, offsets = self.get_tensor_and_offset_layerwise( - vllm_block_ids, kv_layer, layer_name - ) - total_offsets.extend(offsets) - total_tensors.extend(tensors) - task_id = self.connector.dump(total_block_ids, total_offsets, total_tensors) - need_dump_tasks[request.request_id] = task_id - - for req_id, task_id in need_dump_tasks.items(): - if self.connector.wait(task_id) != 0: - logger.error(f"Failed to dump blocks for req {request.request_id}") - else: - success_dumped_blocks[req_id] = req_to_dump_blocks[req_id] - return success_dumped_blocks if success_dumped_blocks else None - - def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: - """Get the finished recving and sending requests.""" - done_recving: set[str] = set() - for req_id, task in self._need_load_reqs.items(): - if req_id in self._load_failed_reqs: - done_recving.add(req_id) - continue - ret, finish = self.connector.check(task) - if ret != 0: - logger.error( - f"Task {task} failed, check return {ret} for request {req_id}" - ) - self._load_failed_reqs.add(req_id) - elif not finish: - continue - elif (wret := self.connector.wait(task)) != 0: - logger.error( - f"Task {task} failed, wait return {wret} for request {req_id}" - ) - self._load_failed_reqs.add(req_id) - done_recving.add(req_id) - - # remove the finished requests - for req_id in list(done_recving): - self._need_load_reqs.pop(req_id, None) - - return None, done_recving - - # ============================== - # Scheduler-side methods - # ============================== - def get_num_new_matched_tokens( - self, - request: "Request", - num_computed_tokens: int, - ) -> tuple[int, bool]: - """ - Get number of new tokens that can be loaded from the - external KV cache beyond the num_computed_tokens. - - Args: - request (Request): the request object. - num_computed_tokens (int): the number of locally - computed tokens for this request - - Returns: - the number of tokens that can be loaded from the - external KV cache beyond what is already computed. - """ - logger.info(f"get_num_new_matched_tokens request {request.request_id}.") - - if request.status == RequestStatus.PREEMPTED: - logger.info(f"Handle preempted request {request.request_id}.") - self.request_finished(request, []) - - def md5(input) -> int: - input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - md5_bytes = hashlib.md5(input_bytes).digest() - return int.from_bytes(md5_bytes, byteorder="big") - - def hash_request_tokens( - hash_function: Any, block_size: int, request: "Request" - ) -> list[str]: - token_ids = request.all_token_ids - - ret = [] - parent_block_hash_value = None - for start in range(0, len(token_ids), block_size): - end = start + block_size - block_token_ids = token_ids[start:end] - # Do not hash the block if it is not full. - if len(block_token_ids) < block_size: - break - - if not parent_block_hash_value: - parent_block_hash_value = md5("UCMHASHSEED") - - block_token_ids_tuple = tuple(block_token_ids) - hash_value = hash_function( - (parent_block_hash_value, block_token_ids_tuple) - ) - parent_block_hash_value = hash_value - ret.append(str(hash_value)) - - return ret - - assert num_computed_tokens % self.block_size == 0 - block_hashes = hash_request_tokens(md5, self.block_size, request) - if not block_hashes: - logger.debug("Maybe tokens too short to load.") - return 0, False - - # Calculate start position (exclude blocks already in HBM) - start_position = num_computed_tokens // self.block_size - - block_operations = [BlockOperation.NONE] * len(block_hashes) - - remain_hashes = block_hashes[start_position:] - if not remain_hashes: - # All blocks are in HBM - return 0, False - - lookup_results = self.connector.lookup(remain_hashes) - - # Find the longest continuous match from the beginning - num_lookup_hits = 0 - for i, hit in enumerate(lookup_results): - if hit: - num_lookup_hits += 1 - block_operations[start_position + i] = BlockOperation.LOAD - else: - # TODO we will fix hole match later - break - logger.info( - f"num_total_blocks: {len(block_hashes)}, " - f"num_lookup_hits on hbm: {start_position}, " - f"num_lookup_hits on storage except hbm: {num_lookup_hits}" - ) - - # Load async when Decode instance need to load - if hasattr(self, "kv_role") and self.kv_role == "kv_consumer": - # Only trigger 1 asynchronous KV transfer per request. - if ( - request.kv_transfer_params - and request.kv_transfer_params["load_async"] == False - ) or num_lookup_hits == 0: - return 0, False - request.kv_transfer_params = request.kv_transfer_params or {} - request.kv_transfer_params["load_async"] = False - self.request_block_infos[request.request_id] = RequestBlockInfo( - block_hashes=block_hashes, - block_operations=block_operations, - start_position=start_position, - ) - self._need_load_reqs[request.request_id] = [] - return num_lookup_hits * self.block_size, True - - # When all the tokens are cached in ssd or hbm, - # we need to recompute the last token. This if condition will be removed - # once vLLM's scheduler provides a better solution in the future. - if (num_lookup_hits + start_position) * self.block_size == len( - request.all_token_ids - ): - num_lookup_hits -= 1 - block_operations[-1] = BlockOperation.NONE - - self.request_block_infos[request.request_id] = RequestBlockInfo( - block_hashes=block_hashes, - block_operations=block_operations, - start_position=start_position, - ) - - return num_lookup_hits * self.block_size, False - - def update_state_after_alloc( - self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int - ): - """ - Update KVConnector state after block allocation. - """ - if request.request_id in self._need_load_reqs: - local_block_ids = ( - # since we use unhashed blocks, so we don't need to reset start_position - blocks.get_unhashed_block_ids() - if num_external_tokens > 0 - else [] - ) - self._need_load_reqs[request.request_id] = local_block_ids - return - - request_block_info = self.request_block_infos.get(request.request_id, None) - if request_block_info: - start_position = request_block_info.start_position - block_operations = request_block_info.block_operations - block_hashes = request_block_info.block_hashes - start_create_pos = start_position + num_external_tokens // self.block_size - remaining_hashes = block_hashes[start_create_pos:] - if remaining_hashes: - create_results = self.connector.create(remaining_hashes) - if any(ret != 0 for ret in create_results): - logger.warning(f"\ncreate_results on storage: {create_results}\n") - for j, ret in enumerate(create_results): - idx = start_create_pos + j - block_operations[idx] = ( - BlockOperation.DUMP if ret == 0 else BlockOperation.NONE - ) - # set start_position to 0, so that we can process from the beginning - request_block_info.start_position = 0 - - def build_connector_meta( - self, scheduler_output: SchedulerOutput - ) -> KVConnectorMetadata: - """ - Build the connector metadata for this step. - - This function should NOT modify fields in the scheduler_output. - Also, calling this function will reset the state of the connector. - - Args: - scheduler_output (SchedulerOutput): the scheduler output object. - """ - meta = UCConnectorV1Metadata() - - for req_id, block_ids in self._need_load_reqs.items(): - block_info = self.request_block_infos.get(req_id) - if block_info: - load_blocks, dump_blocks = self._extract_blocks(block_ids, block_info) - meta.requests.append( - ReqMeta( - request_id=req_id, - load_blocks=load_blocks, - dump_blocks=dump_blocks, - load_async=True, - ) - ) - self._need_load_reqs.clear() - - for new_req in scheduler_output.scheduled_new_reqs: - req_id = new_req.req_id - vllm_block_ids = new_req.block_ids[0] - - block_info = self.request_block_infos.get(req_id) - if block_info: - load_blocks, dump_blocks = self._extract_blocks( - vllm_block_ids, block_info - ) - if load_blocks or dump_blocks: - meta.requests.append( - ReqMeta( - request_id=req_id, - load_blocks=load_blocks, - dump_blocks=dump_blocks, - ) - ) - - # Process cached requests using iterator - cached_request_data = scheduler_output.scheduled_cached_reqs - - # Adapted for vllm 0.9.1, 0.9.2 and later versions - def get_requests(): - # 0.9.1 - if isinstance(cached_request_data, list): - return [ - ( - request_data.req_id, - request_data.new_block_ids, - ) - for request_data in cached_request_data - ] - # >= 0.9.2 - else: - return [ - ( - req_id, - cached_request_data.new_block_ids[i], - ) - for i, req_id in enumerate(cached_request_data.req_ids) - ] - - # When prompt tokens > max_num_batched_tokens, request of running requests may need to save - for req_id, new_block_ids in get_requests(): - block_info = self.request_block_infos.get(req_id) - if block_info: - load_blocks, dump_blocks = self._extract_blocks( - new_block_ids[0], block_info - ) - if load_blocks or dump_blocks: - meta.requests.append( - ReqMeta( - request_id=req_id, - load_blocks=load_blocks, - dump_blocks=dump_blocks, - ) - ) - - return meta - - def request_finished( - self, - request: "Request", - block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: - block_info = self.request_block_infos.pop(request.request_id, None) - if block_info is not None: - cancel_blocks = [ - block_info.block_hashes[i] - for i, op in enumerate(block_info.block_operations) - if op == BlockOperation.DUMP - and hasattr(request, "succeed_dumped_blocks") - and block_info.block_hashes[i] not in request.succeed_dumped_blocks - ] - if cancel_blocks: - logger.debug(f"commit {cancel_blocks} to False.") - self.connector.commit(cancel_blocks, False) - if hasattr(request, "succeed_dumped_blocks"): - request.succeed_dumped_blocks.clear() - return False, None - - def _extract_blocks( - self, vllm_block_ids: list[int], block_info: RequestBlockInfo - ) -> tuple[list[tuple[str, int]], list[tuple[str, int]]]: - """ - Extract blocks that need load and dump, block_info.start_position - is the next block position to process, only return blocks that need - processing, NONE blocks are ignored. - """ - start_pos = block_info.start_position - - if start_pos >= len(block_info.block_operations): - return [], [] - - process_length = min( - len(block_info.block_operations) - start_pos, len(vllm_block_ids) - ) - ops = block_info.block_operations[start_pos : start_pos + process_length] - hashes = block_info.block_hashes[start_pos : start_pos + process_length] - vllm_ids = vllm_block_ids[:process_length] - - load_blocks = [] - dump_blocks = [] - for op, hash, vllm_id in zip(ops, hashes, vllm_ids): - if op == BlockOperation.LOAD: - load_blocks.append((hash, vllm_id)) - elif op == BlockOperation.DUMP: - dump_blocks.append((hash, vllm_id)) - - block_info.start_position += process_length - return load_blocks, dump_blocks - - def get_block_ids_with_load_errors(self) -> set[int]: - invalid_block_ids: set[int] = set() - for req_id in self._load_failed_reqs: - if req_id in self._load_req_to_blocks: - invalid_block_ids.update(self._load_req_to_blocks[req_id]) - return invalid_block_ids - - @staticmethod - def _extract_layer_index(layer_name: str) -> Optional[int]: - """ - Extract the layer index from the layer name. - """ - for chunk in layer_name.split("."): - if chunk.isdigit(): - return int(chunk) - return None diff --git a/ucm/sparse/kvcomp/offline_inference_kvcomp.py b/ucm/sparse/kvcomp/offline_inference_kvcomp.py index ce08870c1..b722d1b71 100644 --- a/ucm/sparse/kvcomp/offline_inference_kvcomp.py +++ b/ucm/sparse/kvcomp/offline_inference_kvcomp.py @@ -76,8 +76,8 @@ def print_output( def main(): - module_path = "ucm.integration.vllm.uc_connector" - name = "UnifiedCacheConnectorV1" + module_path = "ucm.integration.vllm.ucm_connector" + name = "UCMConnector" model = os.getenv("MODEL_PATH", "/data/models/Qwen3-4B") setup_environment_variables()