diff --git a/docs/source/getting-started/quickstart_vllm.md b/docs/source/getting-started/quickstart_vllm.md index 729bce9b8..bb068c371 100644 --- a/docs/source/getting-started/quickstart_vllm.md +++ b/docs/source/getting-started/quickstart_vllm.md @@ -77,6 +77,33 @@ Download the pre-built `vllm/vllm-openai:v0.9.2` docker image and build unified- pip install -v -e . --no-build-isolation ``` +3. Apply vLLM Integration Patches (Required) + + To enable Unified Cache Management (UCM) integration with vLLM, you must **manually apply the corresponding vLLM patch**. + + You may directly navigate to the vLLM source directory: + ```bash + cd + ``` + Apply the patch that matches your development needs: + + - Full UCM integration (recommended): + ```bash + git apply unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch + ``` + + - Sparse attention only: + ```bash + git apply unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch + ``` + + - ReRoPE support only: + ```bash + git apply unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt-rerope.patch + ``` + + Choose the patch according to your development needs. + If you are working on **sparse attention** or **ReRoPE** independently, applying only the corresponding patch is sufficient. ### Option 3: Install by pip diff --git a/docs/source/getting-started/quickstart_vllm_ascend.md b/docs/source/getting-started/quickstart_vllm_ascend.md index 0538806a8..339a6f726 100644 --- a/docs/source/getting-started/quickstart_vllm_ascend.md +++ b/docs/source/getting-started/quickstart_vllm_ascend.md @@ -12,7 +12,7 @@ We offer 3 options to install UCM. ### Option 1: Build from source -Follow commands below to install unified-cache-management from source code: +1、Follow commands below to install unified-cache-management from source code: **Note:** The sparse module was not compiled by default. To enable it, set the environment variable `export ENABLE_SPARSE=TRUE` before you build. ```bash # Replace with the branch or tag name needed @@ -23,6 +23,31 @@ pip install -v -e . --no-build-isolation cd .. ``` +2、Apply vLLM and vLLM-Ascend Integration Patches (Required) +To enable Unified Cache Management (UCM) integration, you need to apply patches to both vLLM and vLLM-Ascend source trees. + +**Step 1:** Apply the vLLM Patch + +First, apply the standard vLLM integration patch in the vLLM source directory: + +```bash +cd +git apply unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch +``` + +**Step 2:** Apply the vLLM-Ascend Patch + +Then, switch to the vLLM-Ascend source directory and apply the Ascend-specific patch: + +```bash +cd +git apply unified-cache-management/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch +``` + +**Note:** + The ReRoPE algorithm is not supported on Ascend at the moment. + Only the standard UCM integration is applicable for vLLM-Ascend. + ### Option 2: Install by pip Install by pip or find the pre-build wheels on [Pypi](https://pypi.org/project/uc-manager/). diff --git a/examples/offline_inference_kvcomphbm.py b/examples/offline_inference_kvcomphbm.py index a860dd74a..fc5d5142d 100644 --- a/examples/offline_inference_kvcomphbm.py +++ b/examples/offline_inference_kvcomphbm.py @@ -77,7 +77,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str): }, } ], - "ucm_sparse_config": {"GSA": {}}, + "ucm_sparse_config": {"KvCompOnDevice": {}}, }, ) diff --git a/examples/ucm_config_example.yaml b/examples/ucm_config_example.yaml index 77690b266..b89523861 100644 --- a/examples/ucm_config_example.yaml +++ b/examples/ucm_config_example.yaml @@ -31,8 +31,7 @@ load_only_first_rank: false # Or for GSA: # GSA: {} # Or for KvCompOnDevice: - # KvCompOnDevice: - # "kvcompOnDevice_config_path": "workspace/unified-cache-management/ucm/sparse/kvcomp/configs/kvcomp_qwen3_32B_config.json" + # KvCompOnDevice: {} # Whether to use layerwise loading/saving (optional, default: True for UCMConnector) diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-aggre.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-aggre.patch deleted file mode 100644 index 5f8df381c..000000000 --- a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-aggre.patch +++ /dev/null @@ -1,753 +0,0 @@ -From 6e2c814bb3b3a74ca56149b44d6a0b2017b91136 Mon Sep 17 00:00:00 2001 -From: harrisonyhq -Date: Tue, 4 Nov 2025 23:32:10 -0800 -Subject: [PATCH 2/3] [Patch1] Patch for load failure and aggregate - ---- - .../kv_transfer/kv_connector/utils.py | 113 +++++++++++ - .../kv_transfer/kv_connector/v1/base.py | 9 + - .../kv_connector/v1/multi_connector.py | 6 + - vllm/v1/core/block_pool.py | 2 +- - vllm/v1/core/sched/output.py | 2 + - vllm/v1/core/sched/scheduler.py | 184 ++++++++++++++++-- - vllm/v1/core/single_type_kv_cache_manager.py | 3 + - vllm/v1/executor/multiproc_executor.py | 30 ++- - vllm/v1/outputs.py | 6 +- - vllm/v1/worker/gpu_input_batch.py | 14 ++ - vllm/v1/worker/gpu_model_runner.py | 28 ++- - vllm/v1/worker/gpu_worker.py | 23 ++- - 12 files changed, 397 insertions(+), 23 deletions(-) - -diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py -index 5cbc8ca31..b63bf5965 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/utils.py -+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py -@@ -5,10 +5,16 @@ KV cache helper for store. - """ - import torch - -+from collections import defaultdict -+from collections.abc import Sequence -+from concurrent.futures import CancelledError, Future -+from typing import Optional, cast -+ - import vllm.envs as envs - from vllm import _custom_ops as ops - from vllm.config import VllmConfig, get_current_vllm_config - from vllm.logger import init_logger -+from vllm.v1.outputs import ModelRunnerOutput - - logger = init_logger(__name__) - -@@ -107,3 +113,110 @@ def get_kv_connector_cache_layout(): - "layout to HND for better xfer performance.") - return "HND" - return "NHD" -+ -+ -+class KVOutputAggregator: -+ """Utility class to aggregate the output of all workers into a single -+ output corresponding to Rank 0 for scheduler.""" -+ -+ def __init__(self, world_size: int): -+ # Complete transfer tracker. Used by to track finished requests -+ # [req_id -> n_finished_workers] -+ self._recv_remaining_count = defaultdict[str, int](lambda: world_size) -+ self._send_remaining_count = defaultdict[str, int](lambda: world_size) -+ self._dump_remaining_count = defaultdict[str, int](lambda: world_size) -+ -+ def aggregate(self, -+ outputs: list[ModelRunnerOutput], -+ output_rank: int = 0) -> ModelRunnerOutput: -+ # aggregate finished_sending, finished_recving from all workers -+ -+ def update_finished_set(req_ids: Optional[set[str]], -+ remaining_count_dict: dict[str, int], -+ finished_set: set[str]) -> None: -+ for req_id in req_ids or (): -+ new_count = remaining_count_dict[req_id] - 1 -+ if new_count == 0: -+ finished_set.add(req_id) -+ del remaining_count_dict[req_id] -+ else: -+ remaining_count_dict[req_id] = new_count -+ -+ def update_finished_list(req_ids: Optional[dict[str, list[str]]], -+ remaining_count_dict: dict[str, int], -+ finished_list: dict[str, list[str]]) -> None: -+ for req_id, succeed_dump_blocks in (req_ids or {}).items(): -+ if req_id not in finished_list: -+ finished_list[req_id] = [] -+ for blk_id in succeed_dump_blocks: -+ new_count = remaining_count_dict[blk_id] - 1 -+ if new_count == 0: -+ finished_list[req_id].append(blk_id) -+ del remaining_count_dict[blk_id] -+ else: -+ remaining_count_dict[blk_id] = new_count -+ -+ finished_sending = set[str]() -+ finished_recving = set[str]() -+ invalid_block_ids = set[int]() -+ finished_dumping: dict[str, list[str]] = {} -+ for output in outputs: -+ update_finished_set(output.finished_sending, -+ self._send_remaining_count, finished_sending) -+ update_finished_set(output.finished_recving, -+ self._recv_remaining_count, finished_recving) -+ update_finished_list(output.finished_dumping, -+ self._dump_remaining_count, finished_dumping) -+ if output.invalid_block_ids: -+ invalid_block_ids |= output.invalid_block_ids -+ -+ # select output of the worker specified by output_rank -+ output = outputs[output_rank] -+ -+ # set the aggregated finished_sending / finished_recving -+ # if output.finished_sending/recving is not empty, but the other ranks -+ # still have unfinished send/recv, we want to set the aggregated -+ # finished_sending/recving to None until all ranks have finished -+ # send/recv -+ output.finished_sending = finished_sending if finished_sending else None -+ output.finished_recving = finished_recving if finished_recving else None -+ output.finished_dumping = finished_dumping if finished_dumping else None -+ output.invalid_block_ids = invalid_block_ids or None -+ -+ return output -+ -+ def async_aggregate(self, -+ output_futures: Sequence[Future[ModelRunnerOutput]], -+ output_rank: int = 0) -> Future[ModelRunnerOutput]: -+ """Takes a list of futures and returns a single future which resolves -+ to the respective list of outputs.""" -+ result_future: Future[ModelRunnerOutput] = Future() -+ -+ outputs: list[Optional[ModelRunnerOutput]] = [None -+ ] * len(output_futures) -+ -+ def make_callback(idx): -+ -+ def callback(fut): -+ if result_future.done(): -+ return -+ -+ try: -+ outputs[idx] = fut.result() -+ except CancelledError: -+ result_future.cancel() -+ except Exception as e: -+ result_future.set_exception(e) -+ -+ # this check assumes io_thread_pool uses a single thread -+ if all(outputs): -+ result_future.set_result( -+ self.aggregate(cast(list[ModelRunnerOutput], outputs), -+ output_rank)) -+ -+ return callback -+ -+ for i, output_future in enumerate(output_futures): -+ output_future.add_done_callback(make_callback(i)) -+ -+ return result_future -diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py -index f80b5eba2..39d8fa389 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py -+++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py -@@ -201,6 +201,15 @@ class KVConnectorBase_V1(ABC): - """ - return None, None - -+ def get_block_ids_with_load_errors(self) -> set[int]: -+ """ -+ Get the set of block IDs that failed to load. -+ Returns: -+ Optional[set[int]]: A set of block IDs that encountered load errors. -+ Returns None if no errors occurred during load. -+ """ -+ return set() -+ - # ============================== - # Scheduler-side methods - # ============================== -diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py -index 5f92d69bd..4e1f45e7a 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py -+++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py -@@ -134,6 +134,12 @@ class MultiConnector(KVConnectorBase_V1): - - return finished_sending or None, finished_recving or None - -+ def get_block_ids_with_load_errors(self) -> set[int]: -+ agg_block_ids: set[int] = set() -+ for c in self._connectors: -+ agg_block_ids |= c.get_block_ids_with_load_errors() -+ return agg_block_ids -+ - # ============================== - # Scheduler-side methods - # ============================== -diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py -index d21f94727..1800665c7 100644 ---- a/vllm/v1/core/block_pool.py -+++ b/vllm/v1/core/block_pool.py -@@ -124,7 +124,7 @@ class BlockPool: - kv_cache_group_id: The id of the KV cache group. - hash_fn: The hash function to use for block hashes. - """ -- if num_cached_blocks == num_full_blocks: -+ if num_cached_blocks >= num_full_blocks: - return - new_full_blocks = blocks[num_cached_blocks:num_full_blocks] - assert len(block_hashes) >= num_cached_blocks -diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py -index d34f39327..c94e421c0 100644 ---- a/vllm/v1/core/sched/output.py -+++ b/vllm/v1/core/sched/output.py -@@ -93,6 +93,7 @@ class CachedRequestData: - new_token_ids: list[list[int]] - new_block_ids: list[tuple[list[int], ...]] - num_computed_tokens: list[int] -+ num_output_tokens: list[int] - - @property - def num_reqs(self) -> int: -@@ -106,6 +107,7 @@ class CachedRequestData: - new_token_ids=[], - new_block_ids=[], - num_computed_tokens=[], -+ num_output_tokens=[], - ) - - -diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index cd80f92a1..2d4fd4d59 100644 ---- a/vllm/v1/core/sched/scheduler.py -+++ b/vllm/v1/core/sched/scheduler.py -@@ -119,6 +119,7 @@ class Scheduler(SchedulerInterface): - - # KV Connector: requests in process of async KV loading or recving - self.finished_recving_kv_req_ids: set[str] = set() -+ self.failed_recving_kv_req_ids: set[str] = set() - - # Encoder-related. - # Calculate encoder cache size if applicable -@@ -621,6 +622,7 @@ class Scheduler(SchedulerInterface): - new_token_ids: list[list[int]] = [] - new_block_ids: list[tuple[list[int], ...]] = [] - num_computed_tokens: list[int] = [] -+ num_output_tokens: list[int] = [] - - for req in itertools.chain(running_reqs, resumed_reqs): - req_id = req.request_id -@@ -638,6 +640,7 @@ class Scheduler(SchedulerInterface): - new_token_ids.append(token_ids) - new_block_ids.append(req_to_new_block_ids[req_id]) - num_computed_tokens.append(req.num_computed_tokens) -+ num_output_tokens.append(len(req.output_token_ids)) - # Because resumed_reqs is usually empty, it is more efficient to do - # in-place appending so that we don't need to allocate a new list. - resumed_from_preemption = [False] * len(running_reqs) -@@ -649,6 +652,7 @@ class Scheduler(SchedulerInterface): - new_token_ids=new_token_ids, - new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, -+ num_output_tokens=num_output_tokens, - ) - - def _try_schedule_encoder_inputs( -@@ -746,16 +750,29 @@ class Scheduler(SchedulerInterface): - num_scheduled_tokens = scheduler_output.num_scheduled_tokens - pooler_outputs = model_runner_output.pooler_output - num_nans_in_logits = model_runner_output.num_nans_in_logits -+ invalid_block_ids = model_runner_output.invalid_block_ids - - new_running: list[Request] = [] - outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) - spec_decoding_stats: Optional[SpecDecodingStats] = None - -+ failed_kv_load_req_ids = None -+ if invalid_block_ids: -+ # These blocks contain externally computed tokens that failed to -+ # load. Identify affected requests and adjust their computed token -+ # count to trigger recomputation of the invalid blocks. -+ failed_kv_load_req_ids = self._handle_invalid_blocks(invalid_block_ids) -+ - # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below - # loop can be a performance bottleneck. We should do our best to avoid - # expensive operations inside the loop. - for request in self.running: - req_id = request.request_id -+ # self.req_meta.stage == SequenceStage.PREFILL and self.req_meta.is_last_chunk -+ if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids: -+ # Skip requests that were recovered from KV load failure -+ new_running.append(request) -+ continue - num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) - if num_tokens_scheduled == 0: - # The request was not scheduled in this step. -@@ -1089,18 +1106,31 @@ class Scheduler(SchedulerInterface): - if request.request_id not in self.finished_recving_kv_req_ids: - return False - -- # Now that the blocks are ready, actually cache them. -- (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) -- num_computed_tokens = len(block_ids) * self.block_size -- # Handle the case where num request tokens less then one block. -- num_computed_tokens = min(num_computed_tokens, request.num_tokens) -- if num_computed_tokens == request.num_tokens: -- num_computed_tokens -= 1 -- # This will cache the blocks iff caching is enabled. -- self.kv_cache_manager.cache_blocks(request, num_computed_tokens) -- -- # Update the request state for scheduling. -- request.num_computed_tokens = num_computed_tokens -+ if request.request_id in self.failed_recving_kv_req_ids: -+ # Request had KV load failures; num_computed_tokens was already -+ # updated in _update_requests_with_invalid_blocks -+ if request.num_computed_tokens: -+ # Cache any valid computed tokens. -+ self.kv_cache_manager.cache_blocks(request, -+ request.num_computed_tokens) -+ else: -+ # No valid computed tokens, release allocated blocks. -+ # There may be a local cache hit on retry. -+ self.kv_cache_manager.free(request) -+ self.failed_recving_kv_req_ids.remove(request.request_id) -+ else: -+ # Now that the blocks are ready, actually cache them. -+ (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) -+ num_computed_tokens = len(block_ids) * self.block_size -+ # Handle the case where num request tokens less then one block. -+ num_computed_tokens = min(num_computed_tokens, request.num_tokens) -+ if num_computed_tokens == request.num_tokens: -+ num_computed_tokens -= 1 -+ # This will cache the blocks iff caching is enabled. -+ self.kv_cache_manager.cache_blocks(request, num_computed_tokens) -+ -+ # Update the request state for scheduling. -+ request.num_computed_tokens = num_computed_tokens - - # Return that we are ready. - self.finished_recving_kv_req_ids.remove(request.request_id) -@@ -1124,3 +1154,133 @@ class Scheduler(SchedulerInterface): - for req_id in (model_runner_output.finished_sending or ()): - logger.debug("Finished sending KV transfer for request %s", req_id) - self._free_blocks(self.requests[req_id]) -+ -+ -+ def _update_requests_with_invalid_blocks( -+ self, requests: Iterable[Request], -+ invalid_block_ids: set[int]) -> tuple[set[str], int]: -+ """ -+ Identify and update requests affected by invalid KV cache blocks. -+ This method scans the given requests, detects those with invalid blocks -+ and adjusts their `num_computed_tokens` to the longest valid prefix. -+ For observability, it also accumulates the total number of tokens that -+ will need to be recomputed across all affected requests. -+ Args: -+ requests: The set of requests to scan for invalid blocks. -+ invalid_block_ids: IDs of invalid blocks. -+ Returns: -+ tuple: -+ - affected_req_ids (set[str]): IDs of requests impacted by -+ invalid blocks. -+ - total_affected_tokens (int): Total number of tokens that must -+ be recomputed across all affected requests (for observability). -+ """ -+ affected_req_ids: set[str] = set() -+ total_affected_tokens = 0 -+ # If a block is invalid and shared by multiple requests in the batch, -+ # these requests must be rescheduled, but only the first will recompute -+ # it. This set tracks blocks already marked for recomputation. -+ marked_invalid_block_ids: set[int] = set() -+ for request in requests: -+ is_affected = False -+ marked_invalid_block = False -+ req_id = request.request_id -+ # TODO (davidb): add support for hybrid memory allocator -+ (req_block_ids, ) = self.kv_cache_manager.get_block_ids(req_id) -+ # We iterate only over blocks that may contain externally computed -+ # tokens -+ if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: -+ # Async loading. If num_computed_tokens is set it implies we -+ # already processed some block failures for it in a prior step -+ req_num_computed_tokens = ( -+ request.num_computed_tokens if req_id -+ in self.failed_recving_kv_req_ids else len(req_block_ids) * -+ self.block_size) -+ else: -+ # Sync loading. num_computed_tokens includes new tokens -+ req_num_computed_tokens = request.num_cached_tokens -+ -+ req_num_computed_blocks = (req_num_computed_tokens + -+ self.block_size - 1) // self.block_size -+ for idx, block_id in zip(range(req_num_computed_blocks), -+ req_block_ids): -+ -+ if block_id not in invalid_block_ids: -+ continue -+ -+ is_affected = True -+ -+ if block_id in marked_invalid_block_ids: -+ # This invalid block is shared with a previous request -+ # and was already marked for recomputation. -+ # This means this request can still consider this block -+ # as computed when rescheduled. -+ # Currently this only applies to sync loading; Async -+ # loading does not yet support block sharing -+ continue -+ -+ marked_invalid_block_ids.add(block_id) -+ -+ if marked_invalid_block: -+ # This request has already marked an invalid block for -+ # recomputation and updated its num_computed_tokens. -+ continue -+ -+ marked_invalid_block = True -+ # Truncate the computed tokens at the first failed block -+ request.num_computed_tokens = idx * self.block_size -+ total_affected_tokens += (req_num_computed_tokens - -+ request.num_computed_tokens) -+ -+ if is_affected: -+ if not marked_invalid_block: -+ # All invalid blocks of this request are shared with -+ # previous requests and will be recomputed by them. -+ # Revert to considering only cached tokens as computed. -+ # Currently this only applies to sync loading; Async -+ # loading does not yet support block sharing -+ total_affected_tokens += (request.num_computed_tokens - -+ request.num_cached_tokens) -+ request.num_computed_tokens = request.num_cached_tokens -+ -+ affected_req_ids.add(request.request_id) -+ -+ return (affected_req_ids, total_affected_tokens) -+ -+ -+ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: -+ total_requests_to_reschedule = 0 -+ total_tokens_to_reschedule = 0 -+ -+ # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) --- -+ async_load_reqs = ( -+ req for req in self.waiting -+ if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS) -+ async_affected_req_ids, num_tokens_to_reschedule = ( -+ self._update_requests_with_invalid_blocks(async_load_reqs, -+ invalid_block_ids)) -+ -+ total_requests_to_reschedule += len(async_affected_req_ids) -+ total_tokens_to_reschedule += num_tokens_to_reschedule -+ -+ # Mark requests with async KV load failures; they will be rescheduled -+ # once loading completes -+ self.failed_recving_kv_req_ids |= async_affected_req_ids -+ -+ # --- Handle sync KV loads (running requests) --- -+ sync_affected_req_ids, num_tokens_to_reschedule = ( -+ self._update_requests_with_invalid_blocks(self.running, -+ invalid_block_ids)) -+ -+ total_requests_to_reschedule += len(sync_affected_req_ids) -+ total_tokens_to_reschedule += num_tokens_to_reschedule -+ -+ if total_requests_to_reschedule: -+ logger.warning( -+ "Recovered from KV load failure: " -+ "%d request(s) rescheduled (%d tokens affected).", -+ total_requests_to_reschedule, total_tokens_to_reschedule) -+ -+ # Return the IDs of affected running requests to skip in -+ # update_from_output. -+ return sync_affected_req_ids -diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py -index 5b4718038..28bd4618a 100644 ---- a/vllm/v1/core/single_type_kv_cache_manager.py -+++ b/vllm/v1/core/single_type_kv_cache_manager.py -@@ -142,6 +142,9 @@ class SingleTypeKVCacheManager(ABC): - num_cached_blocks = self.num_cached_block[request.request_id] - num_full_blocks = num_tokens // self.block_size - -+ if num_cached_blocks >= num_full_blocks: -+ return -+ - self.block_pool.cache_full_blocks( - request=request, - blocks=self.req_to_blocks[request.request_id], -diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py -index b06b7cc80..61cd7110f 100644 ---- a/vllm/v1/executor/multiproc_executor.py -+++ b/vllm/v1/executor/multiproc_executor.py -@@ -26,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel) - from vllm.distributed.device_communicators.shm_broadcast import (Handle, - MessageQueue) -+from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator - from vllm.executor.multiproc_worker_utils import ( - _add_prefix, set_multiprocessing_worker_envs) - from vllm.logger import init_logger -@@ -111,10 +112,14 @@ class MultiprocExecutor(Executor): - if self.max_concurrent_batches > 1: - # Note: must use only 1 IO thread to keep dequeue sequence - # from the response queue -+ # _async_aggregate_workers_output also assumes a single IO thread - self.io_thread_pool = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mp_exec_io") - - self.output_rank = self._get_output_rank() -+ self.has_connector = self.vllm_config.kv_transfer_config is not None -+ self.kv_output_aggregator = KVOutputAggregator( -+ self.parallel_config.world_size) - - def start_worker_monitor(self): - workers = self.workers -@@ -155,13 +160,30 @@ class MultiprocExecutor(Executor): - self, - scheduler_output, - ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: -- (output, ) = self.collective_rpc( -+ non_block = self.max_concurrent_batches > 1 -+ -+ if not self.has_connector or self.vllm_config.model_config.use_mla: -+ # get output only from a single worker (output_rank) -+ (output, ) = self.collective_rpc( -+ "execute_model", -+ args=(scheduler_output, ), -+ unique_reply_rank=self.output_rank, -+ non_block=non_block, -+ timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) -+ return output -+ -+ # get output from all workers -+ outputs = self.collective_rpc( - "execute_model", - args=(scheduler_output, ), -- unique_reply_rank=self.output_rank, -- non_block=self.max_concurrent_batches > 1, -+ non_block=non_block, - timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) -- return output -+ -+ # aggregate all workers output to a single output -+ if non_block: -+ return self.kv_output_aggregator.async_aggregate( -+ outputs, self.output_rank) -+ return self.kv_output_aggregator.aggregate(outputs, self.output_rank) - - def collective_rpc(self, - method: Union[str, Callable], -diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py -index c8388baed..16af8dbce 100644 ---- a/vllm/v1/outputs.py -+++ b/vllm/v1/outputs.py -@@ -1,7 +1,7 @@ - # SPDX-License-Identifier: Apache-2.0 - # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - --from dataclasses import dataclass -+from dataclasses import dataclass, field - from typing import NamedTuple, Optional - - import torch -@@ -109,6 +109,10 @@ class ModelRunnerOutput: - finished_recving: Optional[set[str]] = None - finished_dumping: Optional[dict[str, list[str]]] = None - -+ # IDs of externally computed KV blocks that failed to load. -+ # Requests referencing these blocks should be rescheduled to recompute them. -+ invalid_block_ids: set[int] = field(default_factory=set) -+ - # req_id -> num_nans_in_logits - num_nans_in_logits: Optional[dict[str, int]] = None - -diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py -index 1a79d72be..8819d7629 100644 ---- a/vllm/v1/worker/gpu_input_batch.py -+++ b/vllm/v1/worker/gpu_input_batch.py -@@ -96,6 +96,9 @@ class InputBatch: - pin_memory=False, - ) - self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() -+ self.is_token_ids = torch.zeros( -+ (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False -+ ) - self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) - self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) - self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) -@@ -286,8 +289,14 @@ class InputBatch: - req_index, :num_prompt_tokens] = request.prompt_token_ids - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) -+ if request.prompt_token_ids is not None: -+ self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids -+ self.is_token_ids[req_index, :num_prompt_tokens] = True -+ else: -+ self.is_token_ids[req_index, :num_prompt_tokens] = False - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids -+ self.is_token_ids[req_index, start_idx:end_idx] = True - # Number of token ids in token_ids_cpu. - # NOTE(woosuk): This may include spec decode tokens. - self.num_tokens[req_index] = request.num_tokens -@@ -473,6 +482,8 @@ class InputBatch: - self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] - self.token_ids_cpu[i2, ...] = tmp - -+ self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...] -+ - swap_dict_values(self.generators, i1, i2) - swap_dict_values(self.bad_words_token_ids, i1, i2) - -@@ -542,6 +553,9 @@ class InputBatch: - num_tokens = self.num_tokens[last_req_index] - self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ - last_req_index, :num_tokens] -+ self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[ -+ last_req_index, :num_tokens -+ ] - self.num_tokens[empty_index] = num_tokens - self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ - last_req_index] -diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index 53ee8cfcd..c3df1d5d2 100644 ---- a/vllm/v1/worker/gpu_model_runner.py -+++ b/vllm/v1/worker/gpu_model_runner.py -@@ -473,6 +473,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - num_computed_tokens = req_data.num_computed_tokens[i] - new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_data.resumed_from_preemption[i] -+ num_output_tokens = req_data.num_output_tokens[i] - - # Update the cached states. - req_state.num_computed_tokens = num_computed_tokens -@@ -492,6 +493,21 @@ class GPUModelRunner(LoRAModelRunnerMixin): - elif num_new_tokens > 0: - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) -+ elif num_output_tokens < len(req_state.output_token_ids): -+ # Some output tokens were discarded due to a sync-KV-load -+ # failure. Align the cached state. -+ del req_state.output_token_ids[num_output_tokens:] -+ -+ req_index = self.input_batch.req_id_to_index.get(req_id) -+ if req_index is not None: -+ old_end_idx = self.input_batch.num_tokens_no_spec[ -+ req_index] -+ end_idx = self.input_batch.num_prompt_tokens[ -+ req_index] + num_output_tokens -+ self.input_batch.num_tokens[req_index] = end_idx -+ self.input_batch.num_tokens_no_spec[req_index] = end_idx -+ self.input_batch.is_token_ids[req_index, -+ end_idx:old_end_idx] = False - - # Update the block IDs. - if not resumed_from_preemption: -@@ -1381,6 +1397,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - finished_dumping = self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) -+ invalid_block_ids = self.get_block_ids_with_load_errors() - - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output -@@ -1564,6 +1581,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - finished_recving=finished_recving, - finished_dumping=finished_dumping, - num_nans_in_logits=num_nans_in_logits, -+ invalid_block_ids = invalid_block_ids - ) - - def propose_draft_token_ids( -@@ -1694,13 +1712,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): - self.maybe_setup_kv_connector(scheduler_output) - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) -+ invalid_block_ids = self.get_block_ids_with_load_errors() -+ get_kv_transfer_group().clear_connector_metadata() - -- if not finished_sending and not finished_recving: -+ if not finished_sending and not finished_recving and not invalid_block_ids: - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.finished_sending = finished_sending - output.finished_recving = finished_recving -+ output.invalid_block_ids = invalid_block_ids - return output - - @staticmethod -@@ -1733,6 +1754,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): - scheduler_output.finished_req_ids) - return None, None - -+ def get_block_ids_with_load_errors(self) -> Optional[set[int]]: -+ if has_kv_transfer_group(): -+ return get_kv_transfer_group().get_block_ids_with_load_errors() -+ return None -+ - def propose_ngram_draft_token_ids( - self, - sampled_token_ids: list[list[int]], -diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py -index 9e7e44d06..1b816b25b 100644 ---- a/vllm/v1/worker/gpu_worker.py -+++ b/vllm/v1/worker/gpu_worker.py -@@ -1,6 +1,7 @@ - # SPDX-License-Identifier: Apache-2.0 - # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - """A GPU worker class.""" -+import copy - import gc - import os - from typing import TYPE_CHECKING, Optional -@@ -15,7 +16,8 @@ from vllm.device_allocator.cumem import CuMemAllocator - from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) --from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized -+from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, -+ has_kv_transfer_group) - from vllm.distributed.parallel_state import get_pp_group, get_tp_group - from vllm.logger import init_logger - from vllm.lora.request import LoRARequest -@@ -24,7 +26,7 @@ from vllm.platforms import current_platform - from vllm.sequence import IntermediateTensors - from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling - from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec --from vllm.v1.outputs import ModelRunnerOutput -+from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput - from vllm.v1.utils import report_usage_stats - from vllm.v1.worker.gpu_model_runner import GPUModelRunner - from vllm.v1.worker.worker_base import WorkerBase -@@ -313,9 +315,22 @@ class Worker(WorkerBase): - assert isinstance(output, IntermediateTensors) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) -- return None -+ if not has_kv_transfer_group(): -+ return None -+ -+ # In case of PP with kv transfer, we need to pass through the -+ # finished_sending and finished_recving buffers. -+ new_output = EMPTY_MODEL_RUNNER_OUTPUT -+ if output.finished_sending or output.finished_recving or output.finished_dumping or output.invalid_block_ids: -+ new_output = copy.copy(new_output) -+ new_output.finished_sending = output.finished_sending -+ new_output.finished_recving = output.finished_recving -+ new_output.finished_dumping = output.finished_dumping -+ new_output.invalid_block_ids = output.invalid_block_ids -+ output = new_output -+ - assert isinstance(output, ModelRunnerOutput) -- return output if self.is_driver_worker else None -+ return output - - def profile(self, is_start: bool = True): - if self.profiler is None: --- -2.34.1 - diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-pc.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-pc.patch deleted file mode 100644 index bf0b7e19a..000000000 --- a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-pc.patch +++ /dev/null @@ -1,122 +0,0 @@ -From 26fdd2026cc3d1ed7da894907ae244a155a16566 Mon Sep 17 00:00:00 2001 -From: harrisonyhq -Date: Tue, 4 Nov 2025 19:36:36 -0800 -Subject: [PATCH 1/3] [Patch0] UCM PC adapt patch - ---- - .../kv_transfer/kv_connector/v1/multi_connector.py | 7 ++++++- - vllm/v1/core/sched/scheduler.py | 11 +++++++++++ - vllm/v1/outputs.py | 1 + - vllm/v1/request.py | 2 ++ - vllm/v1/worker/gpu_model_runner.py | 7 ++++--- - 5 files changed, 24 insertions(+), 4 deletions(-) - -diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py -index be3c23399..5f92d69bd 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py -+++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py -@@ -99,8 +99,13 @@ class MultiConnector(KVConnectorBase_V1): - c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) - - def wait_for_save(self): -+ success_dumped_blocks = None - for c in self._connectors: -- c.wait_for_save() -+ uc_dump_blocks = c.wait_for_save() -+ if uc_dump_blocks: -+ success_dumped_blocks = uc_dump_blocks -+ -+ return success_dumped_blocks if success_dumped_blocks else None - - def get_finished( - self, finished_req_ids: set[str] -diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index fe552db74..cd80f92a1 100644 ---- a/vllm/v1/core/sched/scheduler.py -+++ b/vllm/v1/core/sched/scheduler.py -@@ -34,6 +34,7 @@ from vllm.v1.outputs import ModelRunnerOutput - from vllm.v1.request import Request, RequestStatus - from vllm.v1.spec_decode.metrics import SpecDecodingStats - from vllm.v1.structured_output import StructuredOutputManager -+from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import MultiConnector - - logger = init_logger(__name__) - -@@ -791,6 +792,16 @@ class Scheduler(SchedulerInterface): - new_logprobs = None - new_token_ids = generated_token_ids - kv_transfer_params = None -+ if model_runner_output.finished_dumping is not None: -+ request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) -+ is_prefill = request.num_output_tokens == 0 -+ if is_prefill: -+ if isinstance(self.connector, MultiConnector): -+ for c in self.connector._connectors: -+ if hasattr(c, 'connector') and hasattr(c.connector, 'commit'): -+ c.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) -+ else: -+ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) - - # Append generated tokens and check for stop. Note that if - # a request is still being prefilled, we expect the model runner -diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py -index f78623f57..c8388baed 100644 ---- a/vllm/v1/outputs.py -+++ b/vllm/v1/outputs.py -@@ -107,6 +107,7 @@ class ModelRunnerOutput: - # [req_ids] - finished_sending: Optional[set[str]] = None - finished_recving: Optional[set[str]] = None -+ finished_dumping: Optional[dict[str, list[str]]] = None - - # req_id -> num_nans_in_logits - num_nans_in_logits: Optional[dict[str, int]] = None -diff --git a/vllm/v1/request.py b/vllm/v1/request.py -index 9b96f4599..e70d1695b 100644 ---- a/vllm/v1/request.py -+++ b/vllm/v1/request.py -@@ -103,6 +103,8 @@ class Request: - # The number of tokens with prefix cache hits. - self.num_cached_tokens = -1 - -+ self.succeed_dumped_blocks: list[str] = [] -+ - # The number of NaNs in logits. A value greater than 0 - # indicates that the output is corrupted - self.num_nans_in_logits = 0 -diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index 5a26e88db..53ee8cfcd 100644 ---- a/vllm/v1/worker/gpu_model_runner.py -+++ b/vllm/v1/worker/gpu_model_runner.py -@@ -1378,7 +1378,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - inputs_embeds=inputs_embeds, - ) - -- self.maybe_wait_for_kv_save() -+ finished_dumping = self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) - -@@ -1562,6 +1562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - pooler_output=[], - finished_sending=finished_sending, - finished_recving=finished_recving, -+ finished_dumping=finished_dumping, - num_nans_in_logits=num_nans_in_logits, - ) - -@@ -1719,9 +1720,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): - kv_connector.start_load_kv(get_forward_context()) - - @staticmethod -- def maybe_wait_for_kv_save() -> None: -+ def maybe_wait_for_kv_save(): - if has_kv_transfer_group(): -- get_kv_transfer_group().wait_for_save() -+ return get_kv_transfer_group().wait_for_save() - - @staticmethod - def get_finished_kv_transfers( --- -2.34.1 - diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-rerope-adapt.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-rerope.patch similarity index 76% rename from ucm/integration/vllm/patch/0.9.2/vllm-rerope-adapt.patch rename to ucm/integration/vllm/patch/0.9.2/vllm-adapt-rerope.patch index b123b0d41..df38a29d9 100644 --- a/ucm/integration/vllm/patch/0.9.2/vllm-rerope-adapt.patch +++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-rerope.patch @@ -1,107 +1,201 @@ -From 364e665984b77c3a027a2c3f06456397ead47b1d Mon Sep 17 00:00:00 2001 +From 62c7e41ce88d1a1e7bd380cd1483526b4e909096 Mon Sep 17 00:00:00 2001 From: wangxin <1848802892@qq.com> -Date: Thu, 25 Dec 2025 23:34:33 -0800 -Subject: [PATCH] feature for triton rerope in vLLM +Date: Sun, 4 Jan 2026 04:46:10 -0800 +Subject: [PATCH] [feat] feature for triton rerope --- - vllm/attention/layer.py | 20 +- + vllm/attention/layer.py | 100 +- .../ops/triton_unified_attention_rerope.py | 863 ++++++++++++++++++ vllm/envs.py | 13 + - vllm/model_executor/models/qwen2.py | 26 +- + vllm/model_executor/models/qwen2.py | 30 +- vllm/model_executor/models/qwen3.py | 30 +- vllm/model_executor/models/qwen3_moe.py | 30 +- - vllm/v1/attention/backends/triton_attn.py | 98 +- + vllm/v1/attention/backends/triton_attn.py | 132 ++- vllm/v1/attention/backends/utils.py | 2 + - vllm/v1/kv_cache_interface.py | 9 +- - vllm/v1/worker/gpu_model_runner.py | 18 + - 10 files changed, 1078 insertions(+), 31 deletions(-) + vllm/v1/kv_cache_interface.py | 19 +- + vllm/v1/worker/gpu_model_runner.py | 20 +- + 10 files changed, 1165 insertions(+), 74 deletions(-) create mode 100644 vllm/attention/ops/triton_unified_attention_rerope.py diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py -index f0ad68b16..b73f9172f 100644 +index f0ad68b16..39dc4bf1d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py -@@ -187,6 +187,8 @@ class Attention(nn.Module): - self, +@@ -188,6 +188,8 @@ class Attention(nn.Module): query: torch.Tensor, key: torch.Tensor, -+ query2: torch.Tensor, -+ key2: torch.Tensor, value: torch.Tensor, ++ query2: Optional[torch.Tensor] = None, ++ key2: Optional[torch.Tensor] = None, # For some alternate attention backends like MLA the attention output # shape does not match the query shape, so we optionally let the model + # definition specify the output tensor shape. @@ -224,6 +226,10 @@ class Attention(nn.Module): output = output.view(-1, self.num_heads, self.head_size) if key is not None: key = key.view(-1, self.num_kv_heads, self.head_size) -+ if query2 is not None: ++ if envs.VLLM_USE_REROPE and query2 is not None: + query2 = query2.view(-1, self.num_heads, self.head_size) -+ if key2 is not None: ++ if envs.VLLM_USE_REROPE and key2 is not None: + key2 = key2.view(-1, self.num_kv_heads, self.head_size) if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) if self.use_direct_call: -@@ -235,13 +241,15 @@ class Attention(nn.Module): - self.impl.forward(self, - query, - key, -+ query2, -+ key2, - value, - self_kv_cache, - attn_metadata, - output=output) +@@ -232,16 +238,31 @@ class Attention(nn.Module): + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] +- self.impl.forward(self, +- query, +- key, +- value, +- self_kv_cache, +- attn_metadata, +- output=output) ++ if envs.VLLM_USE_REROPE: ++ self.impl.forward(self, ++ query, ++ key, ++ value, ++ self_kv_cache, ++ attn_metadata, ++ query2=query2, ++ key2=key2, ++ output=output) ++ else: ++ self.impl.forward(self, ++ query, ++ key, ++ value, ++ self_kv_cache, ++ attn_metadata, ++ output=output) else: - torch.ops.vllm.unified_attention_with_output( +- torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name) -+ query, key, query2, key2, value, output, self.layer_name) ++ if envs.VLLM_USE_REROPE: ++ torch.ops.vllm.unified_attention_with_output( ++ query, key, value, output, self.layer_name, query2=query2, key2=key2) ++ else: ++ torch.ops.vllm.unified_attention_with_output( ++ query, key, value, output, self.layer_name) return output.view(-1, hidden_size) else: if self.use_direct_call: -@@ -250,11 +258,11 @@ class Attention(nn.Module): +@@ -250,11 +271,19 @@ class Attention(nn.Module): if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] - return self.impl.forward(self, query, key, value, -+ return self.impl.forward(self, query, key, query2, key2, value, - self_kv_cache, attn_metadata) +- self_kv_cache, attn_metadata) ++ if envs.VLLM_USE_REROPE: ++ return self.impl.forward(self, query, key, value, ++ self_kv_cache, attn_metadata, query2=query2, key2=key2) ++ else: ++ return self.impl.forward(self, query, key, value, ++ self_kv_cache, attn_metadata) else: - return torch.ops.vllm.unified_attention( +- return torch.ops.vllm.unified_attention( - query, key, value, self.layer_name) -+ query, key, query2, key2, value, self.layer_name) ++ if envs.VLLM_USE_REROPE: ++ return torch.ops.vllm.unified_attention( ++ query, key, value, self.layer_name, query2=query2, key2=key2) ++ else: ++ return torch.ops.vllm.unified_attention( ++ query, key, value, self.layer_name) def calc_kv_scales(self, query, key, value): self._q_scale.copy_(torch.abs(query).max() / self.q_range) -@@ -437,6 +445,8 @@ direct_register_custom_op( - def unified_attention_with_output( - query: torch.Tensor, +@@ -400,6 +429,8 @@ def unified_attention( key: torch.Tensor, -+ query2: torch.Tensor, -+ key2: torch.Tensor, value: torch.Tensor, - output: torch.Tensor, layer_name: str, -@@ -452,6 +462,8 @@ def unified_attention_with_output( - self.impl.forward(self, - query, - key, -+ query2, -+ key2, - value, - kv_cache, - attn_metadata, -@@ -464,6 +476,8 @@ def unified_attention_with_output( - def unified_attention_with_output_fake( - query: torch.Tensor, ++ query2: Optional[torch.Tensor] = None, ++ key2: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + +@@ -409,8 +440,12 @@ def unified_attention( + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] +- output = self.impl.forward(self, query, key, value, kv_cache, +- attn_metadata) ++ if envs.VLLM_USE_REROPE: ++ output = self.impl.forward(self, query, key, value, kv_cache, ++ attn_metadata, query2=query2, key2=key2) ++ else: ++ output = self.impl.forward(self, query, key, value, kv_cache, ++ attn_metadata) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output +@@ -421,6 +456,8 @@ def unified_attention_fake( key: torch.Tensor, -+ query2: torch.Tensor, -+ key2: torch.Tensor, + value: torch.Tensor, + layer_name: str, ++ query2: Optional[torch.Tensor] = None, ++ key2: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return torch.empty_like(query).contiguous() + +@@ -440,6 +477,8 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, layer_name: str, ++ query2: Optional[torch.Tensor] = None, ++ key2: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> None: + wait_for_kv_layer_from_connector(layer_name) +@@ -449,15 +488,26 @@ def unified_attention_with_output( + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] +- self.impl.forward(self, +- query, +- key, +- value, +- kv_cache, +- attn_metadata, +- output=output, +- output_scale=output_scale) +- ++ if envs.VLLM_USE_REROPE: ++ self.impl.forward(self, ++ query, ++ key, ++ value, ++ kv_cache, ++ attn_metadata, ++ query2=query2, ++ key2=key2, ++ output=output, ++ output_scale=output_scale) ++ else: ++ self.impl.forward(self, ++ query, ++ key, ++ value, ++ kv_cache, ++ attn_metadata, ++ output=output, ++ output_scale=output_scale) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + + +@@ -467,6 +517,8 @@ def unified_attention_with_output_fake( + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, ++ query2: Optional[torch.Tensor] = None, ++ key2: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> None: + return diff --git a/vllm/attention/ops/triton_unified_attention_rerope.py b/vllm/attention/ops/triton_unified_attention_rerope.py new file mode 100644 -index 000000000..4182abe21 +index 000000000..3028d2902 --- /dev/null +++ b/vllm/attention/ops/triton_unified_attention_rerope.py @@ -0,0 +1,863 @@ @@ -286,7 +380,7 @@ index 000000000..4182abe21 + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1) -+ ++ + k2_offset = (physical_block_idx * stride_k_cache2_0 + + kv_head_idx * stride_k_cache2_2 + + offs_d[:, None] * stride_k_cache2_3 + @@ -556,7 +650,7 @@ index 000000000..4182abe21 + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1) -+ ++ + k2_offset = (physical_block_idx * stride_k_cache2_0 + + kv_head_idx * stride_k_cache2_2 + + offs_d[:, None] * stride_k_cache2_3 + @@ -566,7 +660,7 @@ index 000000000..4182abe21 + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) -+ ++ + K2_load = tl.load(key_cache2_ptr + k2_offset, + mask=dim_mask[:, None], + other=0.0) @@ -616,7 +710,7 @@ index 000000000..4182abe21 + + query_pos_rerope = context_len + query_pos[:, None] + 1 + key_pos_rerope = seq_offset[None, :] -+ ++ + valid_query_mask = query_pos[:, None] < cur_batch_query_len + pos_diff = tl.abs(query_pos_rerope - key_pos_rerope) + rerope_mask = pos_diff < REROPE_WINDOW @@ -969,7 +1063,7 @@ index 000000000..4182abe21 + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) diff --git a/vllm/envs.py b/vllm/envs.py -index 0cc6792d7..767138d37 100644 +index 0cc6792d7..1b049c2c5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -83,6 +83,9 @@ if TYPE_CHECKING: @@ -986,7 +1080,7 @@ index 0cc6792d7..767138d37 100644 "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), -+ # add REROPE ++ # add REROPE + "VLLM_USE_REROPE": + lambda: str(os.getenv("VLLM_USE_REROPE", "0")).lower() in {"1", "true", "yes", "on"}, + @@ -1000,7 +1094,7 @@ index 0cc6792d7..767138d37 100644 # Acts as a parent switch to enable the rest of the other operations. "VLLM_ROCM_USE_AITER": diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py -index 7ef9d248d..c74ec434b 100644 +index 7ef9d248d..2d75195eb 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -57,6 +57,10 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, @@ -1014,39 +1108,41 @@ index 7ef9d248d..c74ec434b 100644 class Qwen2MLP(nn.Module): -@@ -178,10 +182,28 @@ class Qwen2Attention(nn.Module): - positions: torch.Tensor, - hidden_states: torch.Tensor, +@@ -180,8 +184,30 @@ class Qwen2Attention(nn.Module): ) -> torch.Tensor: -+ attn_metadata = get_forward_context().attn_metadata -+ REROPE_WINDOW = envs.REROPE_WINDOW -+ TRAINING_LENGTH = envs.TRAINING_LENGTH -+ qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) + -+ if attn_metadata and next(iter(attn_metadata.values())).use_rerope: -+ q *= ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)).clip(1).to(q.dtype) -+ q2 = q.clone() -+ k2 = k.clone() -+ k0 = k.clone() ++ if envs.VLLM_USE_REROPE: ++ attn_metadata = get_forward_context().attn_metadata ++ REROPE_WINDOW = envs.REROPE_WINDOW ++ TRAINING_LENGTH = envs.TRAINING_LENGTH ++ if attn_metadata and next(iter(attn_metadata.values())).use_rerope: ++ q *= ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)).clip(1).to(q.dtype) ++ q2 = q.clone() ++ k2 = k.clone() ++ k0 = k.clone() ++ ++ q, k = self.rotary_emb(positions, q, k) ++ q2, _ = self.rotary_emb(positions * 0 + REROPE_WINDOW, q2, k2) ++ del k2 ++ else: ++ k0 = k ++ q, k = self.rotary_emb(positions, q, k) ++ q2 = q + -+ q, k = self.rotary_emb(positions, q, k) -+ q2, _ = self.rotary_emb(positions * 0 + REROPE_WINDOW, q2, k2) -+ del k2 ++ attn_output = self.attn(q, k, v, query2=q2, key2=k0) + else: -+ k0 = k.clone() + q, k = self.rotary_emb(positions, q, k) -+ q2 = q.clone() -+ -+ attn_output = self.attn(q, k, q2, k0, v) ++ attn_output = self.attn(q, k, v) ++ output, _ = self.o_proj(attn_output) return output diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py -index de99a76f2..07f77d1ef 100644 +index de99a76f2..03904a054 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -50,6 +50,10 @@ from .qwen2 import Qwen2MLP as Qwen3MLP @@ -1060,48 +1156,41 @@ index de99a76f2..07f77d1ef 100644 logger = init_logger(__name__) -@@ -131,6 +135,10 @@ class Qwen3Attention(nn.Module): - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: -+ attn_metadata = get_forward_context().attn_metadata -+ REROPE_WINDOW = envs.REROPE_WINDOW -+ TRAINING_LENGTH = envs.TRAINING_LENGTH -+ - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # Add qk-norm -@@ -142,8 +150,26 @@ class Qwen3Attention(nn.Module): +@@ -142,8 +146,30 @@ class Qwen3Attention(nn.Module): self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) -+ -+ if attn_metadata and next(iter(attn_metadata.values())).use_rerope: -+ q *= ( -+ ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)) -+ .clip(1) -+ .to(q.dtype) -+ ) -+ q2 = q.clone() -+ k2 = k.clone() -+ k0 = k.clone() + -+ q, k = self.rotary_emb(positions, q, k) -+ q2, _ = self.rotary_emb(positions * 0 + REROPE_WINDOW, q2, k2) -+ del k2 ++ if envs.VLLM_USE_REROPE: ++ attn_metadata = get_forward_context().attn_metadata ++ REROPE_WINDOW = envs.REROPE_WINDOW ++ TRAINING_LENGTH = envs.TRAINING_LENGTH ++ if attn_metadata and next(iter(attn_metadata.values())).use_rerope: ++ q *= ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)).clip(1).to(q.dtype) ++ q2 = q.clone() ++ k2 = k.clone() ++ k0 = k.clone() ++ ++ q, k = self.rotary_emb(positions, q, k) ++ q2, _ = self.rotary_emb(positions * 0 + REROPE_WINDOW, q2, k2) ++ del k2 ++ else: ++ k0 = k ++ q, k = self.rotary_emb(positions, q, k) ++ q2 = q ++ ++ attn_output = self.attn(q, k, v, query2=q2, key2=k0) + else: -+ k0 = k.clone() + q, k = self.rotary_emb(positions, q, k) -+ q2 = q.clone() -+ -+ attn_output = self.attn(q, k, q2, k0, v) ++ attn_output = self.attn(q, k, v) ++ output, _ = self.o_proj(attn_output) return output diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py -index ff182aadf..6e5132d10 100644 +index ff182aadf..f7a787447 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -56,6 +56,10 @@ from .utils import (AutoWeightsLoader, extract_layer_index, @@ -1115,48 +1204,41 @@ index ff182aadf..6e5132d10 100644 logger = init_logger(__name__) -@@ -220,6 +224,10 @@ class Qwen3MoeAttention(nn.Module): - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: -+ attn_metadata = get_forward_context().attn_metadata -+ REROPE_WINDOW = envs.REROPE_WINDOW -+ TRAINING_LENGTH = envs.TRAINING_LENGTH -+ - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # Add qk-norm -@@ -232,8 +240,26 @@ class Qwen3MoeAttention(nn.Module): +@@ -232,8 +236,30 @@ class Qwen3MoeAttention(nn.Module): self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) + -+ if attn_metadata and next(iter(attn_metadata.values())).use_rerope: -+ q *= ( -+ ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)) -+ .clip(1) -+ .to(q.dtype) -+ ) -+ q2 = q.clone() -+ k2 = k.clone() -+ k0 = k.clone() ++ if envs.VLLM_USE_REROPE: ++ attn_metadata = get_forward_context().attn_metadata ++ REROPE_WINDOW = envs.REROPE_WINDOW ++ TRAINING_LENGTH = envs.TRAINING_LENGTH ++ if attn_metadata and next(iter(attn_metadata.values())).use_rerope: ++ q *= ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)).clip(1).to(q.dtype) ++ q2 = q.clone() ++ k2 = k.clone() ++ k0 = k.clone() ++ ++ q, k = self.rotary_emb(positions, q, k) ++ q2, _ = self.rotary_emb(positions * 0 + REROPE_WINDOW, q2, k2) ++ del k2 ++ else: ++ k0 = k ++ q, k = self.rotary_emb(positions, q, k) ++ q2 = q + -+ q, k = self.rotary_emb(positions, q, k) -+ q2, _ = self.rotary_emb(positions * 0 + REROPE_WINDOW, q2, k2) -+ del k2 ++ attn_output = self.attn(q, k, v, query2=q2, key2=k0) + else: -+ k0 = k.clone() + q, k = self.rotary_emb(positions, q, k) -+ q2 = q.clone() -+ -+ attn_output = self.attn(q, k, q2, k0, v) ++ attn_output = self.attn(q, k, v) ++ output, _ = self.o_proj(attn_output) return output diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py -index cdaff2f6a..5b8202896 100644 +index cdaff2f6a..38a6aa509 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -23,6 +23,8 @@ from vllm.v1.attention.backends.utils import ( @@ -1203,15 +1285,15 @@ index cdaff2f6a..5b8202896 100644 return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod -@@ -296,6 +305,8 @@ class TritonAttentionImpl(AttentionImpl): - layer: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, -+ query2: torch.Tensor, -+ key2: torch.Tensor, +@@ -299,6 +308,8 @@ class TritonAttentionImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, ++ query2: Optional[torch.Tensor] = None, ++ key2: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: @@ -342,7 +353,10 @@ class TritonAttentionImpl(AttentionImpl): key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -1224,11 +1306,11 @@ index cdaff2f6a..5b8202896 100644 if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. -@@ -369,9 +383,22 @@ class TritonAttentionImpl(AttentionImpl): - layer._k_scale, +@@ -370,8 +384,22 @@ class TritonAttentionImpl(AttentionImpl): layer._v_scale, ) -+ if envs.VLLM_USE_REROPE: + ++ if envs.VLLM_USE_REROPE and key2 is not None: + torch.ops._C_cache_ops.reshape_and_cache_flash( + key2, + value, @@ -1239,31 +1321,71 @@ index cdaff2f6a..5b8202896 100644 + layer._k_scale, + layer._v_scale, + ) - ++ if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype) -+ if key_cache2 is not None: ++ if envs.VLLM_USE_REROPE and key_cache2 is not None: + key_cache2 = key_cache2.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape assert layer._q_scale == 1.0, \ -@@ -384,6 +411,11 @@ class TritonAttentionImpl(AttentionImpl): +@@ -384,6 +412,12 @@ class TritonAttentionImpl(AttentionImpl): (num_tokens, num_heads * head_size)).contiguous(), layer._q_scale) query = query.reshape((num_tokens, num_heads, head_size)) -+ query2, _ = ops.scaled_fp8_quant( -+ query2.reshape( -+ (num_tokens, num_heads * head_size)).contiguous(), -+ layer._q_scale) -+ query2 = query2.reshape((num_tokens, num_heads, head_size)) ++ if envs.VLLM_USE_REROPE and query2 is not None: ++ query2, _ = ops.scaled_fp8_quant( ++ query2.reshape( ++ (num_tokens, num_heads * head_size)).contiguous(), ++ layer._q_scale) ++ query2 = query2.reshape((num_tokens, num_heads, head_size)) use_local_attn = \ (self.use_irope and attn_metadata.local_attn_metadata is not None) -@@ -425,25 +457,49 @@ class TritonAttentionImpl(AttentionImpl): +@@ -403,47 +437,71 @@ class TritonAttentionImpl(AttentionImpl): + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table ++ + if use_prefill_decode_attn: + # Compute attention and update output up to `num_actual_tokens`. + chunked_prefill_paged_decode(query=query[:num_actual_tokens], +- key=key[:num_actual_tokens], +- value=value[:num_actual_tokens], +- output=output[:num_actual_tokens], +- kv_cache_dtype=self.kv_cache_dtype, +- key_cache=key_cache, +- value_cache=value_cache, +- block_table=block_table, +- query_start_loc=cu_seqlens_q, +- seq_lens=seqused_k, +- max_seq_len=max_seqlen_k, +- max_query_len=max_seqlen_q, +- k_scale=layer._k_scale, +- v_scale=layer._v_scale, +- alibi_slopes=self.alibi_slopes, +- sliding_window=self.sliding_window[0], +- sm_scale=self.scale) +- ++ key=key[:num_actual_tokens], ++ value=value[:num_actual_tokens], ++ output=output[:num_actual_tokens], ++ kv_cache_dtype=self.kv_cache_dtype, ++ key_cache=key_cache, ++ value_cache=value_cache, ++ block_table=block_table, ++ query_start_loc=cu_seqlens_q, ++ seq_lens=seqused_k, ++ max_seq_len=max_seqlen_k, ++ max_query_len=max_seqlen_q, ++ k_scale=layer._k_scale, ++ v_scale=layer._v_scale, ++ alibi_slopes=self.alibi_slopes, ++ sliding_window=self.sliding_window[0], ++ sm_scale=self.scale) else: descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) -- + - unified_attention( - q=query[:num_actual_tokens], - k=key_cache, @@ -1283,11 +1405,12 @@ index cdaff2f6a..5b8202896 100644 - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) -+ -+ if not attn_metadata.use_rerope: -+ unified_attention( ++ if attn_metadata.use_rerope: ++ unified_attention_rerope( + q=query[:num_actual_tokens], + k=key_cache, ++ q2=query2[:num_actual_tokens], ++ k2=key_cache2, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, @@ -1296,6 +1419,7 @@ index cdaff2f6a..5b8202896 100644 + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, ++ rerope_window=envs.REROPE_WINDOW, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, @@ -1305,11 +1429,9 @@ index cdaff2f6a..5b8202896 100644 + v_descale=layer._v_scale.expand(descale_shape), + ) + else: -+ unified_attention_rerope( ++ unified_attention( + q=query[:num_actual_tokens], + k=key_cache, -+ q2=query2[:num_actual_tokens], -+ k2=key_cache2, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, @@ -1318,7 +1440,6 @@ index cdaff2f6a..5b8202896 100644 + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, -+ rerope_window=envs.REROPE_WINDOW, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, @@ -1343,7 +1464,7 @@ index b0ebb00d9..190a3f4ec 100644 M = TypeVar("M") diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py -index 43456a987..a0c3b87d7 100644 +index 43456a987..20edd1f86 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -13,6 +13,8 @@ from vllm.config import VllmConfig @@ -1369,8 +1490,32 @@ index 43456a987..a0c3b87d7 100644 return coef * self.block_size * self.num_kv_heads * self.head_size \ * get_dtype_size(self.dtype) +@@ -88,10 +95,10 @@ class AttentionSpec(KVCacheSpec): + class FullAttentionSpec(AttentionSpec): + sliding_window: Optional[int] = None + """ +- When hybrid allocator is disabled and the model contains both full +- attention layers and sliding window attention layers, sliding +- window attention are regarded as full attention in KV cache manager +- (blocks are allocated for all tokens), while computed as sliding window ++ When hybrid allocator is disabled and the model contains both full ++ attention layers and sliding window attention layers, sliding ++ window attention are regarded as full attention in KV cache manager ++ (blocks are allocated for all tokens), while computed as sliding window + attention in model runner. + In this case, we use FullAttentionSpec and record the sliding window size. + Default to None for not using sliding window attention. +@@ -108,7 +115,7 @@ class FullAttentionSpec(AttentionSpec): + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ +- Merge a list of FullAttentionSpec objects into a single ++ Merge a list of FullAttentionSpec objects into a single + FullAttentionSpec object. + """ + merged_spec = super().merge(specs) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index 5a26e88db..12452777d 100644 +index 5a26e88db..f61a38550 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -72,6 +72,8 @@ from ..sample.logits_processor import LogitsProcessorManager @@ -1419,6 +1564,15 @@ index 5a26e88db..12452777d 100644 ) attn_metadata: dict[str, Any] = {} +@@ -1943,7 +1961,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. + This is to help balance expert-selection + - during profile_run +- - during DP rank dummy run ++ - during DP rank dummy run + """ + dp_size = self.vllm_config.parallel_config.data_parallel_size + randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 -- 2.34.1 diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch index eb9848756..705872d78 100644 --- a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch +++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch @@ -1,34 +1,37 @@ -From 0431022b90649f7115b89b61aaf2a0f83e896d5a Mon Sep 17 00:00:00 2001 +From 8cb493f9ece884cbc2ba71e367bed2b4116ae1b3 Mon Sep 17 00:00:00 2001 From: wenxinwang -Date: Mon, 10 Nov 2025 20:35:47 +0800 -Subject: [PATCH] adapt to deepseek patch +Date: Tue, 23 Dec 2025 19:44:21 -0800 +Subject: [PATCH] kvcomp qwen deepseek --- - vllm/attention/layer.py | 49 ++++++++++++- - .../kv_transfer/kv_connector/utils.py | 5 ++ - .../v1/shared_storage_connector.py | 7 +- - vllm/v1/attention/backends/mla/common.py | 10 ++- - vllm/v1/core/kv_cache_manager.py | 7 +- - vllm/v1/core/sched/output.py | 3 + - vllm/v1/core/sched/scheduler.py | 37 +++++++--- - vllm/v1/worker/block_table.py | 13 ++++ - vllm/v1/worker/gpu_model_runner.py | 71 +++++++++++++++---- - vllm/v1/worker/gpu_worker.py | 2 + - 10 files changed, 171 insertions(+), 33 deletions(-) + vllm/attention/layer.py | 63 ++++++++++++++++- + vllm/model_executor/models/llama.py | 21 +++++- + vllm/model_executor/models/qwen2.py | 23 ++++++- + vllm/v1/attention/backends/flash_attn.py | 7 ++ + vllm/v1/attention/backends/mla/common.py | 15 +++- + vllm/v1/attention/backends/mla/flashmla.py | 18 ++++- + vllm/v1/core/kv_cache_manager.py | 7 +- + vllm/v1/core/kv_cache_utils.py | 13 ++++ + vllm/v1/core/sched/output.py | 3 + + vllm/v1/core/sched/scheduler.py | 30 +++++++- + vllm/v1/worker/block_table.py | 13 ++++ + vllm/v1/worker/gpu_model_runner.py | 80 +++++++++++++++++++--- + vllm/v1/worker/gpu_worker.py | 2 + + 13 files changed, 275 insertions(+), 20 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py -index f0ad68b16..728ab99fd 100644 +index f0ad68b16..ba93960de 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py -@@ -2,7 +2,6 @@ - # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - """Attention layer.""" - from typing import Any, Dict, List, Optional -- - import torch - import torch.nn as nn +@@ -8,6 +8,7 @@ import torch.nn as nn import torch.nn.functional as F -@@ -22,6 +21,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod + + import vllm.envs as envs ++import os + from vllm.attention import AttentionType + from vllm.attention.selector import backend_name_to_enum, get_attn_backend + from vllm.config import CacheConfig, get_current_vllm_config +@@ -22,6 +23,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.utils import validate_kv_sharing_target @@ -36,11 +39,11 @@ index f0ad68b16..728ab99fd 100644 class Attention(nn.Module): -@@ -409,9 +409,10 @@ def unified_attention( +@@ -409,9 +411,10 @@ def unified_attention( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] -+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) ++ query, key, value, _ = maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) - @@ -48,26 +51,34 @@ index f0ad68b16..728ab99fd 100644 maybe_save_kv_layer_to_connector(layer_name, kv_cache) return output -@@ -449,6 +450,8 @@ def unified_attention_with_output( +@@ -449,6 +452,15 @@ def unified_attention_with_output( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] + if not self.use_mla: -+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) ++ if attn_metadata is not None: ++ if os.getenv("VLLM_HASH_ATTENTION") == "1": ++ kv_cache, k_hash = kv_cache ++ else: ++ k_hash = None ++ query, _, _, _ = maybe_execute_sparse_attention_begin( ++ query, key, value, layer_name, forward_context, output, k_hash=k_hash ++ ) self.impl.forward(self, query, key, -@@ -457,7 +460,8 @@ def unified_attention_with_output( +@@ -457,6 +469,10 @@ def unified_attention_with_output( attn_metadata, output=output, output_scale=output_scale) -- + if not self.use_mla: -+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) ++ maybe_execute_sparse_attention_finished( ++ query, key, value, output, layer_name, forward_context ++ ) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) -@@ -479,3 +483,42 @@ direct_register_custom_op( +@@ -479,3 +495,48 @@ direct_register_custom_op( fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) @@ -78,18 +89,24 @@ index f0ad68b16..728ab99fd 100644 + value: torch.Tensor, + layer_name: str, + forward_context: ForwardContext, ++ output: Optional[torch.Tensor] = None, + phase: Optional[str] = None, ++ k_hash: Optional[torch.Tensor] = None, ++ decode_ql_nope: Optional[torch.Tensor] = None, ++ decode_q_pe: Optional[torch.Tensor] = None, +): + if not has_ucm_sparse(): -+ return ++ return query, key, value, output + + ucm_sparse = get_ucm_sparse() + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: -+ return ++ return query, key, value, output + -+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context, phase) ++ return ucm_sparse.attention_begin( ++ query, key, value, layer_name, forward_context, output, phase, k_hash, decode_ql_nope, decode_q_pe ++ ) + +def maybe_execute_sparse_attention_finished( + query: torch.Tensor, @@ -110,49 +127,143 @@ index f0ad68b16..728ab99fd 100644 + return + + ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context, phase) -diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py -index b63bf5965..155597c51 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/utils.py -+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py -@@ -3,6 +3,11 @@ - """ - KV cache helper for store. - """ -+from collections import defaultdict -+from collections.abc import Sequence -+from concurrent.futures import CancelledError, Future -+from typing import Optional, cast -+ - import torch - - from collections import defaultdict -diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py -index 3c574d065..223106def 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py -+++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py -@@ -2,7 +2,7 @@ - # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import hashlib - import os --from dataclasses import dataclass -+from dataclasses import dataclass, field - from typing import TYPE_CHECKING - - import safetensors -@@ -53,10 +53,7 @@ class ReqMeta: - - @dataclass - class SharedStorageConnectorMetadata(KVConnectorMetadata): -- requests: list[ReqMeta] +diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py +index 5d5080479..39cb2f4fb 100644 +--- a/vllm/model_executor/models/llama.py ++++ b/vllm/model_executor/models/llama.py +@@ -54,7 +54,12 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) - -- def __init__(self): -- self.requests = [] -+ requests: list[ReqMeta] = field(default_factory=list) - - def add_request( - self, ++from ucm.sparse.state import ( ++ maybe_execute_sparse_ffn_begin, ++ maybe_execute_sparse_ffn_finished, ++ maybe_execute_sparse_layer_begin, ++ maybe_execute_sparse_layer_finished, ++ ) + + class LlamaMLP(nn.Module): + +@@ -305,10 +310,16 @@ class LlamaDecoderLayer(nn.Module): + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + ++ hidden_states, residual = maybe_execute_sparse_ffn_begin( ++ hidden_states, residual ++ ) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) ++ hidden_states, residual = maybe_execute_sparse_ffn_finished( ++ hidden_states, residual ++ ) + return hidden_states, residual + + +@@ -387,9 +398,17 @@ class LlamaModel(nn.Module): + aux_hidden_states = [] + for idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): ++ positions, hidden_states, residual = maybe_execute_sparse_layer_begin( ++ positions, hidden_states, residual ++ ) + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) + hidden_states, residual = layer(positions, hidden_states, residual) ++ positions, hidden_states, residual = ( ++ maybe_execute_sparse_layer_finished( ++ positions, hidden_states, residual ++ ) ++ ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ +diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py +index 7ef9d248d..e35ab2fdc 100644 +--- a/vllm/model_executor/models/qwen2.py ++++ b/vllm/model_executor/models/qwen2.py +@@ -56,6 +56,12 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) ++from ucm.sparse.state import ( ++ maybe_execute_sparse_ffn_begin, ++ maybe_execute_sparse_ffn_finished, ++ maybe_execute_sparse_layer_begin, ++ maybe_execute_sparse_layer_finished, ++ ) + + + class Qwen2MLP(nn.Module): +@@ -255,11 +261,16 @@ class Qwen2DecoderLayer(nn.Module): + positions=positions, + hidden_states=hidden_states, + ) +- ++ residual, hidden_states = maybe_execute_sparse_ffn_begin( ++ residual, hidden_states ++ ) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) ++ residual, hidden_states = maybe_execute_sparse_ffn_finished( ++ residual, hidden_states ++ ) + return hidden_states, residual + + +@@ -352,11 +363,21 @@ class Qwen2Model(nn.Module): + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: ++ positions, hidden_states, residual = maybe_execute_sparse_layer_begin( ++ positions, ++ hidden_states, ++ residual, ++ ) + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) ++ positions, hidden_states, residual = ( ++ maybe_execute_sparse_layer_finished( ++ positions, hidden_states, residual ++ ) ++ ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, +diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py +index fbc13c06c..2b2244949 100755 +--- a/vllm/v1/attention/backends/flash_attn.py ++++ b/vllm/v1/attention/backends/flash_attn.py +@@ -16,6 +16,8 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states + from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version, + is_flash_attn_varlen_func_available) ++from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse ++import os + + if is_flash_attn_varlen_func_available(): + from vllm.attention.utils.fa_utils import (flash_attn_varlen_func, +@@ -221,6 +223,11 @@ class FlashAttentionMetadataBuilder( + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + ++ if has_ucm_sparse(): ++ ucm_sparse = get_ucm_sparse() ++ if os.getenv("VLLM_HASH_ATTENTION") == "1": ++ decode_mask, topk_seq_lens = ucm_sparse.build_decode_attention_meta(query_start_loc, seq_lens, block_table_tensor) ++ + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], + non_blocking=True) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py -index f2aaf59a4..b56f62b39 100644 +index f2aaf59a4..439bb9b14 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -200,6 +200,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, @@ -163,15 +274,16 @@ index f2aaf59a4..b56f62b39 100644 from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, -@@ -211,6 +212,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +@@ -211,6 +212,8 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable +from vllm.attention.layer import (maybe_execute_sparse_attention_begin, maybe_execute_sparse_attention_finished) ++import os try: from vllm.vllm_flash_attn import flash_attn_varlen_func -@@ -908,7 +910,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): +@@ -908,7 +911,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -180,11 +292,22 @@ index f2aaf59a4..b56f62b39 100644 assert output is not None, "Output tensor must be provided." if output_scale is not None: -@@ -957,10 +959,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): +@@ -945,6 +948,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + ++ if os.getenv("VLLM_HASH_ATTENTION") == "1": ++ kv_cache, k_hash = kv_cache ++ else: ++ k_hash = None + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( +@@ -957,10 +964,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ) if has_prefill: -+ maybe_execute_sparse_attention_begin(prefill_q, prefill_k_c_normed, prefill_k_pe, layer.layer_name, forward_context, "prefill") ++ prefill_q, _, _, _ = maybe_execute_sparse_attention_begin(prefill_q, k_c_normed, k_pe, layer.layer_name, forward_context, output=output, phase="prefill", k_hash=k_hash) output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata) @@ -193,17 +316,80 @@ index f2aaf59a4..b56f62b39 100644 if has_decode: assert attn_metadata.decode is not None decode_q_nope, decode_q_pe = decode_q.split( -@@ -971,8 +974,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): +@@ -971,8 +979,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) -- -+ maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, layer.layer_name, forward_context, "decode") ++ _, _, _, _ = maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), k_c_normed, k_pe, layer.layer_name, forward_context, output=output, phase="decode", k_hash=k_hash, decode_ql_nope=decode_ql_nope, decode_q_pe=decode_q_pe) + output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) +- + maybe_execute_sparse_attention_finished(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, output[:num_decode_tokens], layer.layer_name, forward_context, "decode") - return output_padded +diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py +index be26e0060..4d74e9d5b 100644 +--- a/vllm/v1/attention/backends/mla/flashmla.py ++++ b/vllm/v1/attention/backends/mla/flashmla.py +@@ -5,7 +5,7 @@ from dataclasses import dataclass + from typing import Any, ClassVar, Optional + + import torch +- ++import os + from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) + from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, +@@ -19,6 +19,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonMetadataBuilder) + from vllm.v1.kv_cache_interface import AttentionSpec + from vllm.v1.worker.block_table import BlockTable ++from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse + + logger = init_logger(__name__) + +@@ -46,6 +47,10 @@ class FlashMLABackend(MLACommonBackend): + class FlashMLADecodeMetadata(MLACommonDecodeMetadata): + tile_scheduler_metadata: torch.Tensor + num_splits: torch.Tensor ++ topk_seq_lens: torch.Tensor ++ topk_tile_scheduler_metadata: torch.Tensor ++ topk_num_splits: torch.Tensor ++ topk_block_table: torch.Tensor = None + + + @dataclass +@@ -74,6 +79,13 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + self.num_q_heads, + 1, # MQA for the decode path + ) ++ topk_seq_lens = None ++ topk_tile_scheduler_metadata = None ++ topk_num_splits = None ++ if has_ucm_sparse(): ++ ucm_sparse = get_ucm_sparse() ++ if os.getenv("VLLM_HASH_ATTENTION") == "1": ++ topk_seq_lens, topk_tile_scheduler_metadata, topk_num_splits = ucm_sparse.build_decode_hash(seq_lens) + + if self.runner.full_cuda_graph: + # First time around (CUDAGraph capture), allocate the static buffer +@@ -98,12 +110,16 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + num_splits_view.copy_(num_splits) + self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s + num_splits = num_splits_view ++ topk_tile_scheduler_metadata, topk_num_splits = ucm_sparse.maybe_init_cudagraph_buffers_for_topk(n, tile_scheduler_metadata) + + return FlashMLADecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, ++ topk_seq_lens=topk_seq_lens, ++ topk_tile_scheduler_metadata=topk_tile_scheduler_metadata, ++ topk_num_splits=topk_num_splits, + ) + + diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 6937455e7..bf9aec864 100644 --- a/vllm/v1/core/kv_cache_manager.py @@ -243,11 +429,35 @@ index 6937455e7..bf9aec864 100644 if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks +diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py +index 2fbcb569e..40c199563 100644 +--- a/vllm/v1/core/kv_cache_utils.py ++++ b/vllm/v1/core/kv_cache_utils.py +@@ -693,6 +693,19 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, + num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec), + available_memory, page_size) + ++ if os.getenv("VLLM_HASH_ATTENTION") == "1": ++ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE ++ ++ if vllm_config.cache_config.cache_dtype == 'auto': ++ dtype = vllm_config.model_config.dtype ++ else: ++ dtype = STR_DTYPE_TO_TORCH_DTYPE[vllm_config.cache_config.cache_dtype] ++ khash_scale = dtype.itemsize * 8 ++ new_num_blocks = num_blocks * khash_scale // (khash_scale + 1) ++ logger.info("[HASH_ATTN] reduce num_blocks from %d to %d to allocate khash_cache", ++ num_blocks, new_num_blocks) ++ num_blocks = new_num_blocks ++ + per_layer_size = page_size * num_blocks + # All layers have the same KV cache spec, so we create one kv cache group + # for all layers. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py -index c94e421c0..fff0eeb42 100644 +index d34f39327..141d750b3 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py -@@ -157,3 +157,6 @@ class SchedulerOutput: +@@ -155,3 +155,6 @@ class SchedulerOutput: # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None @@ -255,19 +465,21 @@ index c94e421c0..fff0eeb42 100644 + # modified slots by sparse algorithm + req_sparsed_slots: dict[str, int] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index 2d4fd4d59..e99a51788 100644 +index fe552db74..0d8a67eba 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py -@@ -35,6 +35,8 @@ from vllm.v1.request import Request, RequestStatus +@@ -34,6 +34,10 @@ from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager - from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import MultiConnector ++from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import MultiConnector +from ucm.sparse.state import ensure_ucm_sparse_initialized, get_ucm_sparse, has_ucm_sparse +from ucm.sparse.base import UcmSparseBase, UcmSparseRole, INVALID_SLOT ++from ucm.utils import Config logger = init_logger(__name__) -@@ -80,12 +82,18 @@ class Scheduler(SchedulerInterface): +@@ -79,12 +83,20 @@ class Scheduler(SchedulerInterface): # will have a corresponding KVConnector with Role=WORKER. # KV Connector pushes/pull of remote KVs for P/D and offloading. self.connector = None @@ -279,14 +491,16 @@ index 2d4fd4d59..e99a51788 100644 self.connector = KVConnectorFactory.create_connector_v1( config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + # Initialize UCM Sparse if available -+ if "ucm_sparse_config" in vllm_config.kv_transfer_config.kv_connector_extra_config: ++ ucm_config = Config(self.vllm_config.kv_transfer_config) ++ ucm_sparse_config = ucm_config.get_config().get("ucm_sparse_config") ++ if ucm_sparse_config: + ensure_ucm_sparse_initialized(vllm_config, role=UcmSparseRole.SCHEDULER) + self.ucm_sparse = get_ucm_sparse() + logger.info("UCM Sparse initialized successfully: {}".format(self.ucm_sparse)) self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, -@@ -203,8 +211,13 @@ class Scheduler(SchedulerInterface): +@@ -201,8 +213,13 @@ class Scheduler(SchedulerInterface): # First, schedule the RUNNING requests. req_index = 0 @@ -300,7 +514,7 @@ index 2d4fd4d59..e99a51788 100644 num_new_tokens = (request.num_tokens_with_spec - request.num_computed_tokens) -@@ -252,7 +265,8 @@ class Scheduler(SchedulerInterface): +@@ -250,7 +267,8 @@ class Scheduler(SchedulerInterface): request, num_new_tokens, num_draft_tokens=num_draft_tokens, @@ -310,7 +524,7 @@ index 2d4fd4d59..e99a51788 100644 if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. -@@ -339,6 +353,10 @@ class Scheduler(SchedulerInterface): +@@ -337,6 +355,10 @@ class Scheduler(SchedulerInterface): break request = self.waiting.peek_request() @@ -321,7 +535,7 @@ index 2d4fd4d59..e99a51788 100644 # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: -@@ -448,6 +466,7 @@ class Scheduler(SchedulerInterface): +@@ -446,6 +468,7 @@ class Scheduler(SchedulerInterface): new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, delay_cache_blocks=load_kv_async, @@ -329,7 +543,7 @@ index 2d4fd4d59..e99a51788 100644 ) if new_blocks is None: # The request cannot be scheduled. -@@ -561,6 +580,7 @@ class Scheduler(SchedulerInterface): +@@ -559,6 +582,7 @@ class Scheduler(SchedulerInterface): scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, @@ -337,42 +551,7 @@ index 2d4fd4d59..e99a51788 100644 # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between -@@ -809,16 +829,12 @@ class Scheduler(SchedulerInterface): - new_logprobs = None - new_token_ids = generated_token_ids - kv_transfer_params = None -+ - if model_runner_output.finished_dumping is not None: - request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) - is_prefill = request.num_output_tokens == 0 - if is_prefill: -- if isinstance(self.connector, MultiConnector): -- for c in self.connector._connectors: -- if hasattr(c, 'connector') and hasattr(c.connector, 'commit'): -- c.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) -- else: -- self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) -+ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) - - # Append generated tokens and check for stop. Note that if - # a request is still being prefilled, we expect the model runner -@@ -870,7 +886,6 @@ class Scheduler(SchedulerInterface): - spec_token_ids[req_index]) - else: - request.spec_token_ids = spec_token_ids[req_index] -- - # Get prompt logprobs for this request. - prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None \ -@@ -897,6 +912,7 @@ class Scheduler(SchedulerInterface): - - if not stopped: - new_running.append(request) -+ - self.running = new_running - - # KV Connector: update state for finished KV Transfers. -@@ -955,6 +971,8 @@ class Scheduler(SchedulerInterface): +@@ -927,6 +951,8 @@ class Scheduler(SchedulerInterface): def add_request(self, request: Request) -> None: self.waiting.add_request(request) self.requests[request.request_id] = request @@ -381,7 +560,7 @@ index 2d4fd4d59..e99a51788 100644 if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) -@@ -1004,6 +1022,8 @@ class Scheduler(SchedulerInterface): +@@ -976,6 +1002,8 @@ class Scheduler(SchedulerInterface): def _free_request(self, request: Request) -> Optional[dict[str, Any]]: assert request.is_finished() @@ -390,14 +569,6 @@ index 2d4fd4d59..e99a51788 100644 delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) -@@ -1155,7 +1175,6 @@ class Scheduler(SchedulerInterface): - logger.debug("Finished sending KV transfer for request %s", req_id) - self._free_blocks(self.requests[req_id]) - -- - def _update_requests_with_invalid_blocks( - self, requests: Iterable[Request], - invalid_block_ids: set[int]) -> tuple[set[str], int]: diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 8f4e8d64c..f45e39f5c 100644 --- a/vllm/v1/worker/block_table.py @@ -430,10 +601,18 @@ index 8f4e8d64c..f45e39f5c 100644 for i, block_table in enumerate(self.block_tables): block_table.add_row(block_ids[i], row_idx) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index c3df1d5d2..dbf1ea7d7 100644 +index 5a26e88db..6a39240d2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py -@@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager +@@ -15,6 +15,7 @@ import torch.nn as nn + from tqdm import tqdm + + import vllm.envs as envs ++import os + from vllm.attention import AttentionType, get_attn_backend + from vllm.attention.backends.abstract import AttentionBackend + from vllm.attention.layer import Attention +@@ -72,6 +73,9 @@ from ..sample.logits_processor import LogitsProcessorManager from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) @@ -443,7 +622,7 @@ index c3df1d5d2..dbf1ea7d7 100644 if TYPE_CHECKING: import xgrammar as xgr import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 -@@ -365,6 +368,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -365,6 +369,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): """ # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: @@ -451,7 +630,7 @@ index c3df1d5d2..dbf1ea7d7 100644 self.requests.pop(req_id, None) self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. -@@ -468,12 +472,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -468,11 +473,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs @@ -461,13 +640,12 @@ index c3df1d5d2..dbf1ea7d7 100644 num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] - num_output_tokens = req_data.num_output_tokens[i] + is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT # Update the cached states. req_state.num_computed_tokens = num_computed_tokens -@@ -510,15 +516,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): - end_idx:old_end_idx] = False +@@ -494,15 +501,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): + new_token_ids[-num_new_tokens:]) # Update the block IDs. - if not resumed_from_preemption: @@ -488,7 +666,7 @@ index c3df1d5d2..dbf1ea7d7 100644 req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: -@@ -531,6 +537,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -515,6 +522,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) @@ -497,7 +675,7 @@ index c3df1d5d2..dbf1ea7d7 100644 self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu -@@ -639,6 +647,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -623,6 +632,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.uses_mrope: self._calc_mrope_positions(scheduler_output) @@ -517,7 +695,7 @@ index c3df1d5d2..dbf1ea7d7 100644 # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] -@@ -668,11 +689,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -652,11 +674,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): # block_size. block_table_indices = ( req_indices * block_table.max_num_blocks_per_req + @@ -531,7 +709,7 @@ index c3df1d5d2..dbf1ea7d7 100644 np.add( block_numbers * block_size, block_offsets, -@@ -682,9 +703,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -666,9 +688,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens @@ -546,7 +724,7 @@ index c3df1d5d2..dbf1ea7d7 100644 # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( -@@ -696,6 +719,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -680,6 +704,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): non_blocking=True) else: # Common case (1D positions) @@ -555,7 +733,7 @@ index c3df1d5d2..dbf1ea7d7 100644 self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) -@@ -1386,6 +1411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1370,6 +1396,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): skip_cuda_graphs=skip_cuda_graphs, ): self.maybe_setup_kv_connector(scheduler_output) @@ -563,36 +741,36 @@ index c3df1d5d2..dbf1ea7d7 100644 model_output = self.model( input_ids=input_ids, -@@ -1395,6 +1421,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1379,6 +1406,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) - finished_dumping = self.maybe_wait_for_kv_save() -+ self.maybe_execute_ucm_sparse_finished() + self.maybe_wait_for_kv_save() ++ logits_indices = self.maybe_execute_ucm_sparse_finished(logits_indices) + finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) - invalid_block_ids = self.get_block_ids_with_load_errors() -@@ -1741,10 +1769,29 @@ class GPUModelRunner(LoRAModelRunnerMixin): - kv_connector.start_load_kv(get_forward_context()) - @staticmethod -- def maybe_wait_for_kv_save(): -+ def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]: +@@ -1723,6 +1752,30 @@ class GPUModelRunner(LoRAModelRunnerMixin): if has_kv_transfer_group(): - return get_kv_transfer_group().wait_for_save() + get_kv_transfer_group().wait_for_save() + def maybe_execute_ucm_sparse_begin(self, scheduler_output: "SchedulerOutput", attn_metadata: CommonAttentionMetadata): + if not has_ucm_sparse(): + return ++ if has_kv_transfer_group(): ++ uc_connector = get_kv_transfer_group() ++ uc_setup_model = getattr(uc_connector, "setup_model", None) ++ if callable(uc_setup_model): ++ uc_setup_model(self.model) + ucm_sparse = get_ucm_sparse() + ucm_sparse.build_sparse_meta(scheduler_output, self.requests, self.input_batch, attn_metadata) + ucm_sparse.execute_begin(scheduler_output) + -+ def maybe_execute_ucm_sparse_finished(self): ++ def maybe_execute_ucm_sparse_finished(self, logits_indices): + if not has_ucm_sparse(): -+ return ++ return logits_indices + ucm_sparse = get_ucm_sparse() -+ ucm_sparse.execute_finished() ++ return ucm_sparse.execute_finished(logits_indices) + + def ucm_sparse_request_finished_in_worker(self, request_id: str | int): + if not has_ucm_sparse(): @@ -603,11 +781,23 @@ index c3df1d5d2..dbf1ea7d7 100644 @staticmethod def get_finished_kv_transfers( scheduler_output: "SchedulerOutput", +@@ -2570,6 +2623,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): + kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, + kv_cache_raw_tensors) + ++ if has_ucm_sparse(): ++ ucm_sparse = get_ucm_sparse() ++ if os.getenv("VLLM_HASH_ATTENTION") == "1": ++ ucm_sparse.initialize_kv_hash_cache_tensors(kv_caches, self.device) ++ + # Setup `kv_cache_config` and `kv_caches` for models + # with cross-layer KV sharing + if self.shared_kv_cache_layers: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py -index 1b816b25b..d9666d102 100644 +index 9e7e44d06..d49099346 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py -@@ -30,6 +30,7 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +@@ -28,6 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -615,7 +805,7 @@ index 1b816b25b..d9666d102 100644 logger = init_logger(__name__) -@@ -401,6 +402,7 @@ def init_worker_distributed_environment( +@@ -386,6 +387,7 @@ def init_worker_distributed_environment( parallel_config.pipeline_parallel_size) ensure_kv_transfer_initialized(vllm_config) diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch index 705872d78..3c1e3e76e 100644 --- a/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch +++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch @@ -1,10 +1,1589 @@ -From 8cb493f9ece884cbc2ba71e367bed2b4116ae1b3 Mon Sep 17 00:00:00 2001 +From 19f73a680a9a4dbfb57b28d00e3e7b9502186596 Mon Sep 17 00:00:00 2001 +From: wangxin <1848802892@qq.com> +Date: Sun, 4 Jan 2026 17:59:52 -0800 +Subject: [PATCH 1/2] [feat] feature for triton rerope + +--- + vllm/attention/layer.py | 100 +- + .../ops/triton_unified_attention_rerope.py | 863 ++++++++++++++++++ + vllm/envs.py | 13 + + vllm/model_executor/models/qwen2.py | 30 +- + vllm/model_executor/models/qwen3.py | 30 +- + vllm/model_executor/models/qwen3_moe.py | 30 +- + vllm/v1/attention/backends/triton_attn.py | 132 ++- + vllm/v1/attention/backends/utils.py | 2 + + vllm/v1/kv_cache_interface.py | 19 +- + vllm/v1/worker/gpu_model_runner.py | 20 +- + 10 files changed, 1165 insertions(+), 74 deletions(-) + create mode 100644 vllm/attention/ops/triton_unified_attention_rerope.py + +diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py +index f0ad68b16..39dc4bf1d 100644 +--- a/vllm/attention/layer.py ++++ b/vllm/attention/layer.py +@@ -188,6 +188,8 @@ class Attention(nn.Module): + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, ++ query2: Optional[torch.Tensor] = None, ++ key2: Optional[torch.Tensor] = None, + # For some alternate attention backends like MLA the attention output + # shape does not match the query shape, so we optionally let the model + # definition specify the output tensor shape. +@@ -224,6 +226,10 @@ class Attention(nn.Module): + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) ++ if envs.VLLM_USE_REROPE and query2 is not None: ++ query2 = query2.view(-1, self.num_heads, self.head_size) ++ if envs.VLLM_USE_REROPE and key2 is not None: ++ key2 = key2.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) + if self.use_direct_call: +@@ -232,16 +238,31 @@ class Attention(nn.Module): + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] +- self.impl.forward(self, +- query, +- key, +- value, +- self_kv_cache, +- attn_metadata, +- output=output) ++ if envs.VLLM_USE_REROPE: ++ self.impl.forward(self, ++ query, ++ key, ++ value, ++ self_kv_cache, ++ attn_metadata, ++ query2=query2, ++ key2=key2, ++ output=output) ++ else: ++ self.impl.forward(self, ++ query, ++ key, ++ value, ++ self_kv_cache, ++ attn_metadata, ++ output=output) + else: +- torch.ops.vllm.unified_attention_with_output( +- query, key, value, output, self.layer_name) ++ if envs.VLLM_USE_REROPE: ++ torch.ops.vllm.unified_attention_with_output( ++ query, key, value, output, self.layer_name, query2=query2, key2=key2) ++ else: ++ torch.ops.vllm.unified_attention_with_output( ++ query, key, value, output, self.layer_name) + return output.view(-1, hidden_size) + else: + if self.use_direct_call: +@@ -250,11 +271,19 @@ class Attention(nn.Module): + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] +- return self.impl.forward(self, query, key, value, +- self_kv_cache, attn_metadata) ++ if envs.VLLM_USE_REROPE: ++ return self.impl.forward(self, query, key, value, ++ self_kv_cache, attn_metadata, query2=query2, key2=key2) ++ else: ++ return self.impl.forward(self, query, key, value, ++ self_kv_cache, attn_metadata) + else: +- return torch.ops.vllm.unified_attention( +- query, key, value, self.layer_name) ++ if envs.VLLM_USE_REROPE: ++ return torch.ops.vllm.unified_attention( ++ query, key, value, self.layer_name, query2=query2, key2=key2) ++ else: ++ return torch.ops.vllm.unified_attention( ++ query, key, value, self.layer_name) + + def calc_kv_scales(self, query, key, value): + self._q_scale.copy_(torch.abs(query).max() / self.q_range) +@@ -400,6 +429,8 @@ def unified_attention( + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, ++ query2: Optional[torch.Tensor] = None, ++ key2: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + +@@ -409,8 +440,12 @@ def unified_attention( + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] +- output = self.impl.forward(self, query, key, value, kv_cache, +- attn_metadata) ++ if envs.VLLM_USE_REROPE: ++ output = self.impl.forward(self, query, key, value, kv_cache, ++ attn_metadata, query2=query2, key2=key2) ++ else: ++ output = self.impl.forward(self, query, key, value, kv_cache, ++ attn_metadata) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output +@@ -421,6 +456,8 @@ def unified_attention_fake( + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, ++ query2: Optional[torch.Tensor] = None, ++ key2: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return torch.empty_like(query).contiguous() + +@@ -440,6 +477,8 @@ def unified_attention_with_output( + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, ++ query2: Optional[torch.Tensor] = None, ++ key2: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> None: + wait_for_kv_layer_from_connector(layer_name) +@@ -449,15 +488,26 @@ def unified_attention_with_output( + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] +- self.impl.forward(self, +- query, +- key, +- value, +- kv_cache, +- attn_metadata, +- output=output, +- output_scale=output_scale) +- ++ if envs.VLLM_USE_REROPE: ++ self.impl.forward(self, ++ query, ++ key, ++ value, ++ kv_cache, ++ attn_metadata, ++ query2=query2, ++ key2=key2, ++ output=output, ++ output_scale=output_scale) ++ else: ++ self.impl.forward(self, ++ query, ++ key, ++ value, ++ kv_cache, ++ attn_metadata, ++ output=output, ++ output_scale=output_scale) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + + +@@ -467,6 +517,8 @@ def unified_attention_with_output_fake( + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, ++ query2: Optional[torch.Tensor] = None, ++ key2: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> None: + return +diff --git a/vllm/attention/ops/triton_unified_attention_rerope.py b/vllm/attention/ops/triton_unified_attention_rerope.py +new file mode 100644 +index 000000000..3028d2902 +--- /dev/null ++++ b/vllm/attention/ops/triton_unified_attention_rerope.py +@@ -0,0 +1,863 @@ ++# SPDX-License-Identifier: Apache-2.0 ++# SPDX-FileCopyrightText: Copyright contributors to the vLLM project ++ ++# Authors: ++# - Burkhard Ringlein ++# - Jan van Lunteren ++# - Chih-Chieh Yang ++# - Thomas Parnell ++ ++import torch ++import triton ++import triton.language as tl ++ ++from vllm.logger import init_logger ++ ++logger = init_logger(__name__) ++ ++ ++@triton.jit ++def cdiv_fn(x, y): ++ return (x + y - 1) // y ++ ++ ++@triton.jit ++def apply_softcap(S, x): ++ Sdiv = S / x ++ p1 = tl.exp(Sdiv) ++ p2 = tl.exp(-Sdiv) ++ return x * (p1 - p2) / (p1 + p2) ++ ++ ++@triton.jit ++def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, ++ BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr): ++ left: tl.int32 = 0 ++ right = num_seqs ++ while left < right: ++ mid = (left + right) // 2 ++ val = tl.load(query_start_len_ptr + mid) ++ mid_val = val // BLOCK_Q + mid if use_q_block_mode else val ++ ++ if mid_val <= target_idx: ++ left = mid + 1 ++ else: ++ right = mid ++ ++ return left - 1 ++ ++ ++@triton.jit ++def kernel_unified_attention_2d( ++ output_ptr, # [num_tokens, num_query_heads, head_size] ++ query_ptr, # [num_tokens, num_query_heads, head_size] ++ key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] ++ query2_ptr, # [num_tokens, num_query_heads, head_size] ++ key_cache2_ptr, # [num_blks, blk_size, num_kv_heads, head_size] ++ value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] ++ block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] ++ seq_lens_ptr, # [num_seqs] ++ alibi_slopes_ptr, # [num_query_heads] ++ scale, # float32 ++ k_scale, # float32 ++ v_scale, # float32 ++ softcap, # float32 ++ num_query_heads: tl.constexpr, # int ++ num_queries_per_kv: tl.constexpr, # int ++ block_table_stride: tl.int64, # int ++ query_stride_0: tl.int64, # int ++ query_stride_1: tl.int64, # int, should be equal to head_size ++ query2_stride_0: tl.int64, # int ++ query2_stride_1: tl.int64, # int, should be equal to head_size ++ output_stride_0: tl.int64, # int ++ output_stride_1: tl.int64, # int, should be equal to head_size ++ REROPE_WINDOW: tl.constexpr, # int ++ BLOCK_SIZE: tl.constexpr, # int ++ HEAD_SIZE: tl.constexpr, # int ++ HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 ++ USE_ALIBI_SLOPES: tl.constexpr, # bool ++ USE_SOFTCAP: tl.constexpr, # bool ++ SLIDING_WINDOW: tl.constexpr, # int ++ stride_k_cache_0: tl.int64, # int ++ stride_k_cache_1: tl.int64, # int ++ stride_k_cache_2: tl.int64, # int ++ stride_k_cache_3: tl.constexpr, # int ++ stride_k_cache2_0: tl.int64, # int ++ stride_k_cache2_1: tl.int64, # int ++ stride_k_cache2_2: tl.int64, # int ++ stride_k_cache2_3: tl.constexpr, # int ++ stride_v_cache_0: tl.int64, # int ++ stride_v_cache_1: tl.int64, # int ++ stride_v_cache_2: tl.int64, # int ++ stride_v_cache_3: tl.constexpr, # int ++ query_start_len_ptr, # [num_seqs+1] ++ BLOCK_Q: tl.constexpr, # int ++ num_seqs: tl.int32, ++ BLOCK_M: tl.constexpr, # int ++): ++ q_block_global_idx = tl.program_id(0) ++ kv_head_idx = tl.program_id(1) ++ ++ seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, ++ BLOCK_Q, True) ++ ++ q_block_start_idx = tl.load(query_start_len_ptr + ++ seq_idx) // BLOCK_Q + seq_idx ++ ++ q_block_local_idx = q_block_global_idx - q_block_start_idx ++ ++ cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) ++ cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) ++ ++ cur_batch_query_len = cur_batch_in_all_stop_index \ ++ - cur_batch_in_all_start_index ++ ++ if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: ++ return ++ ++ offs_m = tl.arange(0, BLOCK_M) ++ offs_d = tl.arange(0, HEAD_SIZE_PADDED) ++ query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv ++ ++ query_offset_0 = cur_batch_in_all_start_index + query_pos ++ query_offset_1 = kv_head_idx * num_queries_per_kv + \ ++ offs_m % num_queries_per_kv ++ query_offset = (query_offset_0[:, None] * query_stride_0 + ++ query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) ++ query2_offset = (query_offset_0[:, None] * query2_stride_0 + ++ query_offset_1[:, None] * query2_stride_1 + offs_d[None, :]) ++ ++ dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) ++ query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) ++ query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) ++ ++ # Q : (BLOCK_M, HEAD_SIZE_PADDED) ++ Q = tl.load( ++ query_ptr + query_offset, ++ mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], ++ other=0.0, ++ ) ++ Q2 = tl.load( ++ query2_ptr + query2_offset, ++ mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], ++ other=0.0, ++ ) ++ ++ block_table_offset = seq_idx * block_table_stride ++ ++ M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) ++ L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) ++ acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) ++ ++ # sequence len for this particular sequence ++ seq_len = tl.load(seq_lens_ptr + seq_idx) ++ ++ # context length for this particular sequences ++ context_len = seq_len - cur_batch_query_len ++ ++ # alibi slope for this head ++ if USE_ALIBI_SLOPES: ++ alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, ++ mask=query_mask_1, ++ other=0.0) ++ ++ num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) ++ ++ # iterate through tiles ++ for j in range(0, num_blocks): ++ ++ physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) ++ ++ offs_n = tl.arange(0, BLOCK_SIZE) ++ ++ v_offset = (physical_block_idx * stride_v_cache_0 + ++ kv_head_idx * stride_v_cache_2 + ++ offs_d[None, :] * stride_v_cache_3 + ++ offs_n[:, None] * stride_v_cache_1) ++ ++ k_offset = (physical_block_idx * stride_k_cache_0 + ++ kv_head_idx * stride_k_cache_2 + ++ offs_d[:, None] * stride_k_cache_3 + ++ offs_n[None, :] * stride_k_cache_1) ++ ++ k2_offset = (physical_block_idx * stride_k_cache2_0 + ++ kv_head_idx * stride_k_cache2_2 + ++ offs_d[:, None] * stride_k_cache2_3 + ++ offs_n[None, :] * stride_k_cache2_1) ++ ++ # K : (HEAD_SIZE, BLOCK_SIZE) ++ K_load = tl.load(key_cache_ptr + k_offset, ++ mask=dim_mask[:, None], ++ other=0.0) ++ ++ K2_load = tl.load(key_cache2_ptr + k2_offset, ++ mask=dim_mask[:, None], ++ other=0.0) ++ ++ if K_load.dtype.is_fp8(): ++ if Q.dtype.is_fp8(): ++ K = K_load ++ else: ++ K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) ++ else: ++ K = K_load ++ ++ if K2_load.dtype.is_fp8(): ++ if Q2.dtype.is_fp8(): ++ K2 = K2_load ++ else: ++ K2 = (K2_load.to(tl.float32) * tl.load(k_scale)).to(Q2.dtype) ++ else: ++ K2 = K2_load ++ ++ # V : (BLOCK_SIZE, HEAD_SIZE) ++ V_load = tl.load(value_cache_ptr + v_offset, ++ mask=dim_mask[None, :], ++ other=0.0) ++ ++ if V_load.dtype.is_fp8(): ++ if Q.dtype.is_fp8(): ++ V = V_load ++ else: ++ V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) ++ else: ++ V = V_load ++ ++ seq_offset = j * BLOCK_SIZE + offs_n ++ ++ seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 ++ ++ # S : (BLOCK_M, BLOCK_SIZE) ++ S1 = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) ++ ++ S1 += scale * tl.dot(Q, K) ++ ++ ++ # rerope mask ++ S2 = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) ++ S2 += scale * tl.dot(Q2, K2) ++ ++ query_pos_rerope = context_len + query_pos[:, None] + 1 ++ key_pos_rerope = seq_offset[None, :] ++ ++ valid_query_mask = query_pos[:, None] < cur_batch_query_len ++ pos_diff = tl.abs(query_pos_rerope - key_pos_rerope) ++ rerope_mask = pos_diff < REROPE_WINDOW ++ rerope_mask = rerope_mask & valid_query_mask ++ ++ if USE_SOFTCAP: ++ S1 = apply_softcap(S1, softcap) ++ S2 = apply_softcap(S2, softcap) ++ ++ S = tl.where(rerope_mask, S1, S2) ++ ++ ++ # 越界检验 & causal mask ++ S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, ++ S, float("-inf")) ++ ++ if SLIDING_WINDOW > 0: ++ S = tl.where((context_len + query_pos[:, None] - seq_offset) ++ < SLIDING_WINDOW, S, float("-inf")) ++ ++ if USE_ALIBI_SLOPES: ++ S += alibi_slope[:, None] * (seq_offset - context_len) ++ ++ # compute running maximum ++ # m_j : (BLOCK_M,) ++ m_j = tl.maximum(M, tl.max(S, axis=1)) ++ # For sliding window there's a chance the max is -inf due to masking of ++ # the entire row. In this case we need to set m_j 0 to avoid NaN ++ m_j = tl.where(m_j > float("-inf"), m_j, 0.0) ++ ++ # P : (BLOCK_M, BLOCK_SIZE) ++ P = tl.exp(S - m_j[:, None]) ++ ++ # l_j : (BLOCK_M,) ++ l_j = tl.sum(P, axis=1) ++ ++ # alpha : (BLOCK_M, ) ++ alpha = tl.exp(M - m_j) ++ ++ # acc : (BLOCK_M, HEAD_SIZE_PADDED) ++ acc = acc * alpha[:, None] ++ ++ # update constants ++ L = L * alpha + l_j ++ M = m_j ++ ++ # acc : (BLOCK_M, HEAD_SIZE_PADDED) ++ acc += tl.dot(P.to(V.dtype), V) ++ ++ # epilogue ++ acc = acc / L[:, None] ++ ++ output_offset = (query_offset_0[:, None] * output_stride_0 + ++ query_offset_1[:, None] * output_stride_1 + ++ offs_d[None, :]) ++ ++ tl.store( ++ output_ptr + output_offset, ++ acc, ++ mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], ++ ) ++ ++ ++@triton.jit ++def kernel_unified_attention_3d( ++ segm_output_ptr, ++ # [num_tokens, num_query_heads, num_segments, head_size] ++ segm_max_ptr, # [num_tokens, num_query_heads, num_segments] ++ segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] ++ query_ptr, # [num_tokens, num_query_heads, head_size] ++ key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] ++ query2_ptr, # [num_tokens, num_query_heads, head_size] ++ key_cache2_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] ++ value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] ++ block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] ++ seq_lens_ptr, # [num_seqs] ++ alibi_slopes_ptr, # [num_query_heads] ++ scale, # float32 ++ k_scale, # float32 ++ v_scale, # float32 ++ softcap, # float32 ++ num_query_heads: tl.constexpr, # int ++ num_queries_per_kv: tl.constexpr, # int ++ block_table_stride: tl.int64, # int ++ query_stride_0: tl.int64, # int ++ query_stride_1: tl.int64, # int, should be equal to head_size ++ query2_stride_0: tl.int64, # int ++ query2_stride_1: tl.int64, # int, should be equal to head_size ++ BLOCK_SIZE: tl.constexpr, # int ++ HEAD_SIZE: tl.constexpr, # int ++ HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 ++ USE_ALIBI_SLOPES: tl.constexpr, # bool ++ USE_SOFTCAP: tl.constexpr, # bool ++ SLIDING_WINDOW: tl.constexpr, # int ++ stride_k_cache_0: tl.int64, # int ++ stride_k_cache_1: tl.int64, # int ++ stride_k_cache_2: tl.int64, # int ++ stride_k_cache_3: tl.constexpr, # int ++ stride_k_cache2_0: tl.int64, # int ++ stride_k_cache2_1: tl.int64, # int ++ stride_k_cache2_2: tl.int64, # int ++ stride_k_cache2_3: tl.constexpr, # int ++ stride_v_cache_0: tl.int64, # int ++ stride_v_cache_1: tl.int64, # int ++ stride_v_cache_2: tl.int64, # int ++ stride_v_cache_3: tl.constexpr, # int ++ query_start_len_ptr, # [num_seqs+1] ++ REROPE_WINDOW: tl.constexpr, # int ++ BLOCK_Q: tl.constexpr, # int ++ num_seqs: tl.int32, ++ BLOCK_M: tl.constexpr, # int ++ NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ++): ++ q_block_global_idx = tl.program_id(0) ++ kv_head_idx = tl.program_id(1) ++ segm_idx = tl.program_id(2) ++ ++ seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, ++ BLOCK_Q, True) ++ ++ q_block_start_idx = tl.load(query_start_len_ptr + ++ seq_idx) // BLOCK_Q + seq_idx ++ ++ q_block_local_idx = q_block_global_idx - q_block_start_idx ++ ++ cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) ++ cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) ++ ++ cur_batch_query_len = cur_batch_in_all_stop_index \ ++ - cur_batch_in_all_start_index ++ ++ if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: ++ return ++ ++ # sequence len for this particular sequence ++ seq_len = tl.load(seq_lens_ptr + seq_idx) ++ ++ # number of segments for this particular sequence ++ num_segments = NUM_SEGMENTS_PER_SEQ ++ blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) ++ ++ if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: ++ return ++ ++ offs_m = tl.arange(0, BLOCK_M) ++ offs_d = tl.arange(0, HEAD_SIZE_PADDED) ++ ++ query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv ++ ++ query_offset_0 = cur_batch_in_all_start_index + query_pos ++ query_offset_1 = kv_head_idx * num_queries_per_kv + \ ++ offs_m % num_queries_per_kv ++ ++ query_offset = (query_offset_0[:, None] * query_stride_0 + ++ query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) ++ query2_offset = (query_offset_0[:, None] * query2_stride_0 + ++ query_offset_1[:, None] * query2_stride_1 + offs_d[None, :]) ++ ++ dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) ++ query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) ++ query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) ++ ++ # Q : (BLOCK_M, HEAD_SIZE_PADDED) ++ Q = tl.load( ++ query_ptr + query_offset, ++ mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], ++ other=0.0, ++ ) ++ Q2 = tl.load( ++ query2_ptr + query2_offset, ++ mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], ++ other=0.0, ++ ) ++ ++ block_table_offset = seq_idx * block_table_stride ++ ++ M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) ++ L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) ++ acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) ++ ++ # context length for this particular sequences ++ context_len = seq_len - cur_batch_query_len ++ ++ # alibi slope for this head ++ if USE_ALIBI_SLOPES: ++ alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, ++ mask=query_mask_1, ++ other=0.0) ++ ++ num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) ++ ++ # iterate through tiles within current segment ++ for j in range( ++ segm_idx * blocks_per_segment, ++ min((segm_idx + 1) * blocks_per_segment, num_blocks), ++ ): ++ physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) ++ ++ offs_n = tl.arange(0, BLOCK_SIZE) ++ ++ v_offset = (physical_block_idx * stride_v_cache_0 + ++ kv_head_idx * stride_v_cache_2 + ++ offs_d[None, :] * stride_v_cache_3 + ++ offs_n[:, None] * stride_v_cache_1) ++ ++ k_offset = (physical_block_idx * stride_k_cache_0 + ++ kv_head_idx * stride_k_cache_2 + ++ offs_d[:, None] * stride_k_cache_3 + ++ offs_n[None, :] * stride_k_cache_1) ++ ++ k2_offset = (physical_block_idx * stride_k_cache2_0 + ++ kv_head_idx * stride_k_cache2_2 + ++ offs_d[:, None] * stride_k_cache2_3 + ++ offs_n[None, :] * stride_k_cache2_1) ++ ++ # K : (HEAD_SIZE, BLOCK_SIZE) ++ K_load = tl.load(key_cache_ptr + k_offset, ++ mask=dim_mask[:, None], ++ other=0.0) ++ ++ K2_load = tl.load(key_cache2_ptr + k2_offset, ++ mask=dim_mask[:, None], ++ other=0.0) ++ ++ if K_load.dtype.is_fp8(): ++ if Q.dtype.is_fp8(): ++ K = K_load ++ else: ++ K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) ++ else: ++ K = K_load ++ ++ if K2_load.dtype.is_fp8(): ++ if Q2.dtype.is_fp8(): ++ K2 = K2_load ++ else: ++ K2= (K2_load.to(tl.float32) * tl.load(k_scale)).to(Q2.dtype) ++ else: ++ K2= K2_load ++ ++ # V : (BLOCK_SIZE, HEAD_SIZE) ++ V_load = tl.load(value_cache_ptr + v_offset, ++ mask=dim_mask[None, :], ++ other=0.0) ++ ++ if V_load.dtype.is_fp8(): ++ if Q.dtype.is_fp8(): ++ V = V_load ++ else: ++ V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) ++ else: ++ V = V_load ++ ++ seq_offset = j * BLOCK_SIZE + offs_n ++ ++ seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 ++ ++ # S : (BLOCK_M, BLOCK_SIZE) ++ S1 = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) ++ ++ S1 += scale * tl.dot(Q, K) ++ ++ ++ # rerope mask ++ S2 = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) ++ S2 += scale * tl.dot(Q2, K2) ++ ++ query_pos_rerope = context_len + query_pos[:, None] + 1 ++ key_pos_rerope = seq_offset[None, :] ++ ++ valid_query_mask = query_pos[:, None] < cur_batch_query_len ++ pos_diff = tl.abs(query_pos_rerope - key_pos_rerope) ++ rerope_mask = pos_diff < REROPE_WINDOW ++ rerope_mask = rerope_mask & valid_query_mask ++ ++ if USE_SOFTCAP: ++ S1 = apply_softcap(S1, softcap) ++ S2 = apply_softcap(S2, softcap) ++ ++ S = tl.where(rerope_mask, S1, S2) ++ ++ ++ # 越界检查 & causal mask ++ S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, ++ S, float("-inf")) ++ ++ if SLIDING_WINDOW > 0: ++ S = tl.where((context_len + query_pos[:, None] - seq_offset) ++ < SLIDING_WINDOW, S, float("-inf")) ++ ++ if USE_ALIBI_SLOPES: ++ S += alibi_slope[:, None] * (seq_offset - context_len) ++ ++ # compute running maximum ++ # m_j : (BLOCK_M,) ++ m_j = tl.maximum(M, tl.max(S, axis=1)) ++ # For sliding window there's a chance the max is -inf due to masking of ++ # the entire row. In this case we need to set m_j 0 to avoid NaN ++ m_j = tl.where(m_j > float("-inf"), m_j, 0.0) ++ ++ # P : (BLOCK_M, BLOCK_SIZE,) ++ P = tl.exp(S - m_j[:, None]) ++ ++ # l_j : (BLOCK_M,) ++ l_j = tl.sum(P, axis=1) ++ ++ # alpha : (BLOCK_M, ) ++ alpha = tl.exp(M - m_j) ++ ++ # acc : (BLOCK_M, HEAD_SIZE_PADDED) ++ acc = acc * alpha[:, None] ++ ++ # update constants ++ L = L * alpha + l_j ++ M = m_j ++ ++ # acc : (BLOCK_M, HEAD_SIZE_PADDED) ++ acc += tl.dot(P.to(V.dtype), V) ++ ++ segm_output_offset = ( ++ query_offset_0[:, None].to(tl.int64) * ++ (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + ++ query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + ++ segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) ++ tl.store( ++ segm_output_ptr + segm_output_offset, ++ acc, ++ mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], ++ ) ++ segm_offset = (query_offset_0.to(tl.int64) * ++ (num_query_heads * NUM_SEGMENTS_PER_SEQ) + ++ query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx) ++ tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) ++ tl.store(segm_expsum_ptr + segm_offset, ++ L, ++ mask=query_mask_0 & query_mask_1) ++ ++ ++@triton.jit ++def reduce_segments( ++ output_ptr, # [num_tokens, num_query_heads, head_size] ++ segm_output_ptr, ++ #[num_tokens, num_query_heads, max_num_segments, head_size] ++ segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] ++ segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] ++ seq_lens_ptr, # [num_seqs] ++ num_seqs, # int ++ num_query_heads: tl.constexpr, # int ++ output_stride_0: tl.int64, # int ++ output_stride_1: tl.int64, # int, should be equal to head_size ++ block_table_stride: tl.int64, # int ++ BLOCK_SIZE: tl.constexpr, # int ++ HEAD_SIZE: tl.constexpr, # int, must be power of 2 ++ HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 ++ query_start_len_ptr, # [num_seqs+1] ++ BLOCK_Q: tl.constexpr, # int ++ NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ++): ++ query_token_idx = tl.program_id(0) ++ query_head_idx = tl.program_id(1) ++ ++ seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs, ++ BLOCK_Q, False) ++ ++ # sequence len for this particular sequence ++ seq_len = tl.load(seq_lens_ptr + seq_idx) ++ ++ # number of segments for this particular sequence ++ num_segments = NUM_SEGMENTS_PER_SEQ ++ blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) ++ ++ # create masks for subsequent loads ++ act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) ++ segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( ++ [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) ++ dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, ++ 0).to(tl.int1) ++ ++ # load segment maxima ++ segm_offset = (query_token_idx.to(tl.int64) * ++ (num_query_heads * NUM_SEGMENTS_PER_SEQ) + ++ query_head_idx * NUM_SEGMENTS_PER_SEQ + ++ tl.arange(0, NUM_SEGMENTS_PER_SEQ)) ++ segm_max = tl.load(segm_max_ptr + segm_offset, ++ mask=segm_mask, ++ other=float("-inf")) ++ overall_max = tl.max(segm_max) ++ ++ # load and rescale segment exp sums ++ segm_expsum = tl.load(segm_expsum_ptr + segm_offset, ++ mask=segm_mask, ++ other=0.0) ++ segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) ++ overall_expsum = tl.sum(segm_expsum) ++ ++ # load, rescale, and add segment attention outputs ++ segm_output_offset = ( ++ query_token_idx.to(tl.int64) * ++ (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + ++ query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + ++ tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + ++ tl.arange(0, HEAD_SIZE_PADDED)[None, :]) ++ segm_output = tl.load( ++ segm_output_ptr + segm_output_offset, ++ mask=segm_mask[:, None] & dim_mask[None, :], ++ other=0.0, ++ ) ++ segm_output *= tl.exp(segm_max - overall_max)[:, None] ++ acc_sum = tl.sum(segm_output, axis=0) ++ # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 ++ acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) ++ ++ # write result ++ output_offset = (query_token_idx * output_stride_0 + ++ query_head_idx * output_stride_1 + ++ tl.arange(0, HEAD_SIZE_PADDED)) ++ tl.store(output_ptr + output_offset, acc, mask=dim_mask) ++ ++ ++def unified_attention_rerope( ++ q, ++ k, ++ q2, ++ k2, ++ v, ++ out, ++ cu_seqlens_q, ++ max_seqlen_q, ++ seqused_k, ++ max_seqlen_k, ++ softmax_scale, ++ causal, ++ rerope_window, ++ window_size, ++ block_table, ++ softcap, ++ q_descale, ++ k_descale, ++ v_descale, ++ alibi_slopes=None, ++): ++ assert causal, "Only causal attention is supported" ++ assert q_descale is None, "Q scales not supported" ++ ++ block_size = v.shape[1] ++ assert q.element_size() >= 2 or block_size >= 32, \ ++ "Block size must be at least 32 for fp8" ++ ++ use_alibi_slopes = alibi_slopes is not None ++ ++ block_size = v.shape[1] ++ num_seqs = len(seqused_k) ++ num_query_heads = q.shape[1] ++ num_kv_heads = k.shape[2] ++ num_queries_per_kv = num_query_heads // num_kv_heads ++ head_size = q.shape[2] ++ ++ BLOCK_M = 16 ++ BLOCK_Q = BLOCK_M // num_queries_per_kv ++ ++ # Ideally we would launch with kernel with: ++ # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. ++ # However, it is slow to realize the query_lens on cpu. ++ # Instead we use upper-bound: ++ # \sum_i[ceil(query_len[i] / BLOCK_Q)] ++ # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] ++ # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs ++ # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs ++ # = floor(q.shape[0] / BLOCK_Q) + num_seqs ++ total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs ++ ++ # if batch contains a prefill ++ if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: ++ with torch.cuda.nvtx.range("atten_2D"): ++ kernel_unified_attention_2d[( ++ total_num_q_blocks, ++ num_kv_heads, ++ )]( ++ output_ptr=out, ++ query_ptr=q, ++ key_cache_ptr=k, ++ query2_ptr=q2, ++ key_cache2_ptr=k2, ++ value_cache_ptr=v, ++ block_tables_ptr=block_table, ++ seq_lens_ptr=seqused_k, ++ alibi_slopes_ptr=alibi_slopes, ++ scale=softmax_scale, ++ k_scale=k_descale, ++ v_scale=v_descale, ++ softcap=softcap, ++ num_query_heads=num_query_heads, ++ num_queries_per_kv=num_queries_per_kv, ++ block_table_stride=block_table.stride(0), ++ query_stride_0=q.stride(0), ++ query_stride_1=q.stride(1),\ ++ query2_stride_0=q2.stride(0), ++ query2_stride_1=q2.stride(1), ++ output_stride_0=out.stride(0), ++ output_stride_1=out.stride(1), ++ REROPE_WINDOW=rerope_window, ++ BLOCK_SIZE=block_size, ++ HEAD_SIZE=head_size, ++ HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), ++ USE_ALIBI_SLOPES=use_alibi_slopes, ++ USE_SOFTCAP=(softcap > 0), ++ SLIDING_WINDOW=(1 + window_size[0]), ++ stride_k_cache_0=k.stride(0), ++ stride_k_cache_1=k.stride(1), ++ stride_k_cache_2=k.stride(2), ++ stride_k_cache_3=k.stride(3), ++ stride_k_cache2_0=k.stride(0), ++ stride_k_cache2_1=k.stride(1), ++ stride_k_cache2_2=k.stride(2), ++ stride_k_cache2_3=k.stride(3), ++ stride_v_cache_0=v.stride(0), ++ stride_v_cache_1=v.stride(1), ++ stride_v_cache_2=v.stride(2), ++ stride_v_cache_3=v.stride(3), ++ query_start_len_ptr=cu_seqlens_q, ++ BLOCK_Q=BLOCK_Q, ++ num_seqs=num_seqs, ++ BLOCK_M=BLOCK_M, ++ ) ++ torch.cuda.synchronize() ++ else: ++ # for initial version, NUM_SEGMENTS = 16 is chosen as a default ++ # value that showed good performance in tests ++ NUM_SEGMENTS = 16 ++ ++ segm_output = torch.empty( ++ q.shape[0], ++ num_query_heads, ++ NUM_SEGMENTS, ++ triton.next_power_of_2(head_size), ++ dtype=torch.float32, ++ device=q.device, ++ ) ++ segm_max = torch.empty( ++ q.shape[0], ++ num_query_heads, ++ NUM_SEGMENTS, ++ dtype=torch.float32, ++ device=q.device, ++ ) ++ segm_expsum = torch.empty( ++ q.shape[0], ++ num_query_heads, ++ NUM_SEGMENTS, ++ dtype=torch.float32, ++ device=q.device, ++ ) ++ ++ kernel_unified_attention_3d[( ++ total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( ++ segm_output_ptr=segm_output, ++ segm_max_ptr=segm_max, ++ segm_expsum_ptr=segm_expsum, ++ query_ptr=q, ++ key_cache_ptr=k, ++ query2_ptr=q2, ++ key_cache2_ptr=k2, ++ value_cache_ptr=v, ++ block_tables_ptr=block_table, ++ seq_lens_ptr=seqused_k, ++ alibi_slopes_ptr=alibi_slopes, ++ scale=softmax_scale, ++ k_scale=k_descale, ++ v_scale=v_descale, ++ softcap=softcap, ++ num_query_heads=num_query_heads, ++ num_queries_per_kv=num_queries_per_kv, ++ block_table_stride=block_table.stride(0), ++ query_stride_0=q.stride(0), ++ query_stride_1=q.stride(1), ++ query2_stride_0=q2.stride(0), ++ query2_stride_1=q2.stride(1), ++ BLOCK_SIZE=block_size, ++ HEAD_SIZE=head_size, ++ HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), ++ USE_ALIBI_SLOPES=use_alibi_slopes, ++ USE_SOFTCAP=(softcap > 0), ++ SLIDING_WINDOW=(1 + window_size[0]), ++ stride_k_cache_0=k.stride(0), ++ stride_k_cache_1=k.stride(1), ++ stride_k_cache_2=k.stride(2), ++ stride_k_cache_3=k.stride(3), ++ stride_k_cache2_0=k2.stride(0), ++ stride_k_cache2_1=k2.stride(1), ++ stride_k_cache2_2=k2.stride(2), ++ stride_k_cache2_3=k2.stride(3), ++ stride_v_cache_0=v.stride(0), ++ stride_v_cache_1=v.stride(1), ++ stride_v_cache_2=v.stride(2), ++ stride_v_cache_3=v.stride(3), ++ query_start_len_ptr=cu_seqlens_q, ++ REROPE_WINDOW=rerope_window, ++ BLOCK_Q=BLOCK_Q, ++ num_seqs=num_seqs, ++ BLOCK_M=BLOCK_M, ++ NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, ++ ) ++ ++ reduce_segments[(q.shape[0], num_query_heads)]( ++ output_ptr=out, ++ segm_output_ptr=segm_output, ++ segm_max_ptr=segm_max, ++ segm_expsum_ptr=segm_expsum, ++ seq_lens_ptr=seqused_k, ++ num_seqs=num_seqs, ++ num_query_heads=num_query_heads, ++ output_stride_0=out.stride(0), ++ output_stride_1=out.stride(1), ++ block_table_stride=block_table.stride(0), ++ BLOCK_SIZE=block_size, ++ HEAD_SIZE=head_size, ++ HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), ++ query_start_len_ptr=cu_seqlens_q, ++ BLOCK_Q=BLOCK_Q, ++ NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, ++ ) +diff --git a/vllm/envs.py b/vllm/envs.py +index 0cc6792d7..1b049c2c5 100644 +--- a/vllm/envs.py ++++ b/vllm/envs.py +@@ -83,6 +83,9 @@ if TYPE_CHECKING: + VLLM_SKIP_P2P_CHECK: bool = False + VLLM_DISABLED_KERNELS: list[str] = [] + VLLM_USE_V1: bool = True ++ VLLM_USE_REROPE: bool = False ++ REROPE_WINDOW: int = 32768 ++ TRAINING_LENGTH: int = 32768 + VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False + VLLM_ROCM_USE_AITER_LINEAR: bool = True +@@ -637,6 +640,16 @@ environment_variables: dict[str, Callable[[], Any]] = { + "VLLM_USE_V1": + lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), + ++ # add REROPE ++ "VLLM_USE_REROPE": ++ lambda: str(os.getenv("VLLM_USE_REROPE", "0")).lower() in {"1", "true", "yes", "on"}, ++ ++ # add REROPE setting ++ "REROPE_WINDOW": ++ lambda: int(os.getenv("REROPE_WINDOW", "32768")), ++ "TRAINING_LENGTH": ++ lambda: int(os.getenv("TRAINING_LENGTH", "32768")), ++ + # Disable aiter ops unless specifically enabled. + # Acts as a parent switch to enable the rest of the other operations. + "VLLM_ROCM_USE_AITER": +diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py +index 7ef9d248d..2d75195eb 100644 +--- a/vllm/model_executor/models/qwen2.py ++++ b/vllm/model_executor/models/qwen2.py +@@ -57,6 +57,10 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + ++import math ++from vllm import envs ++from vllm.forward_context import get_forward_context ++ + + class Qwen2MLP(nn.Module): + +@@ -180,8 +184,30 @@ class Qwen2Attention(nn.Module): + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) +- q, k = self.rotary_emb(positions, q, k) +- attn_output = self.attn(q, k, v) ++ ++ if envs.VLLM_USE_REROPE: ++ attn_metadata = get_forward_context().attn_metadata ++ REROPE_WINDOW = envs.REROPE_WINDOW ++ TRAINING_LENGTH = envs.TRAINING_LENGTH ++ if attn_metadata and next(iter(attn_metadata.values())).use_rerope: ++ q *= ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)).clip(1).to(q.dtype) ++ q2 = q.clone() ++ k2 = k.clone() ++ k0 = k.clone() ++ ++ q, k = self.rotary_emb(positions, q, k) ++ q2, _ = self.rotary_emb(positions * 0 + REROPE_WINDOW, q2, k2) ++ del k2 ++ else: ++ k0 = k ++ q, k = self.rotary_emb(positions, q, k) ++ q2 = q ++ ++ attn_output = self.attn(q, k, v, query2=q2, key2=k0) ++ else: ++ q, k = self.rotary_emb(positions, q, k) ++ attn_output = self.attn(q, k, v) ++ + output, _ = self.o_proj(attn_output) + return output + +diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py +index de99a76f2..03904a054 100644 +--- a/vllm/model_executor/models/qwen3.py ++++ b/vllm/model_executor/models/qwen3.py +@@ -50,6 +50,10 @@ from .qwen2 import Qwen2MLP as Qwen3MLP + from .qwen2 import Qwen2Model + from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix + ++import math ++from vllm import envs ++from vllm.forward_context import get_forward_context ++ + logger = init_logger(__name__) + + +@@ -142,8 +146,30 @@ class Qwen3Attention(nn.Module): + self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) +- q, k = self.rotary_emb(positions, q, k) +- attn_output = self.attn(q, k, v) ++ ++ if envs.VLLM_USE_REROPE: ++ attn_metadata = get_forward_context().attn_metadata ++ REROPE_WINDOW = envs.REROPE_WINDOW ++ TRAINING_LENGTH = envs.TRAINING_LENGTH ++ if attn_metadata and next(iter(attn_metadata.values())).use_rerope: ++ q *= ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)).clip(1).to(q.dtype) ++ q2 = q.clone() ++ k2 = k.clone() ++ k0 = k.clone() ++ ++ q, k = self.rotary_emb(positions, q, k) ++ q2, _ = self.rotary_emb(positions * 0 + REROPE_WINDOW, q2, k2) ++ del k2 ++ else: ++ k0 = k ++ q, k = self.rotary_emb(positions, q, k) ++ q2 = q ++ ++ attn_output = self.attn(q, k, v, query2=q2, key2=k0) ++ else: ++ q, k = self.rotary_emb(positions, q, k) ++ attn_output = self.attn(q, k, v) ++ + output, _ = self.o_proj(attn_output) + return output + +diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py +index ff182aadf..f7a787447 100644 +--- a/vllm/model_executor/models/qwen3_moe.py ++++ b/vllm/model_executor/models/qwen3_moe.py +@@ -56,6 +56,10 @@ from .utils import (AutoWeightsLoader, extract_layer_index, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + ++import math ++from vllm import envs ++from vllm.forward_context import get_forward_context ++ + logger = init_logger(__name__) + + +@@ -232,8 +236,30 @@ class Qwen3MoeAttention(nn.Module): + self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) +- q, k = self.rotary_emb(positions, q, k) +- attn_output = self.attn(q, k, v) ++ ++ if envs.VLLM_USE_REROPE: ++ attn_metadata = get_forward_context().attn_metadata ++ REROPE_WINDOW = envs.REROPE_WINDOW ++ TRAINING_LENGTH = envs.TRAINING_LENGTH ++ if attn_metadata and next(iter(attn_metadata.values())).use_rerope: ++ q *= ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)).clip(1).to(q.dtype) ++ q2 = q.clone() ++ k2 = k.clone() ++ k0 = k.clone() ++ ++ q, k = self.rotary_emb(positions, q, k) ++ q2, _ = self.rotary_emb(positions * 0 + REROPE_WINDOW, q2, k2) ++ del k2 ++ else: ++ k0 = k ++ q, k = self.rotary_emb(positions, q, k) ++ q2 = q ++ ++ attn_output = self.attn(q, k, v, query2=q2, key2=k0) ++ else: ++ q, k = self.rotary_emb(positions, q, k) ++ attn_output = self.attn(q, k, v) ++ + output, _ = self.o_proj(attn_output) + return output + +diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py +index cdaff2f6a..9d2490ebc 100644 +--- a/vllm/v1/attention/backends/triton_attn.py ++++ b/vllm/v1/attention/backends/triton_attn.py +@@ -23,6 +23,8 @@ from vllm.v1.attention.backends.utils import ( + from vllm.v1.kv_cache_interface import AttentionSpec + from vllm.v1.worker.block_table import BlockTable + ++from vllm.attention.ops.triton_unified_attention_rerope import unified_attention_rerope ++ + if TYPE_CHECKING: + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +@@ -47,6 +49,8 @@ class TritonAttentionMetadata: + block_table: torch.Tensor + slot_mapping: torch.Tensor + ++ use_rerope: bool ++ + # For cascade attention. + use_cascade: bool + common_prefix_len: int +@@ -100,6 +104,8 @@ class TritonAttentionMetadataBuilder( + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + ++ use_rerope = common_attn_metadata.use_rerope ++ + max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens +@@ -177,6 +183,7 @@ class TritonAttentionMetadataBuilder( + suffix_kv_lens=suffix_kv_lens, + local_attn_metadata=local_attn_metadata, + prefix_scheduler_metadata=prefix_scheduler_metadata, ++ use_rerope = use_rerope + ) + return attn_metadata + +@@ -226,6 +233,8 @@ class TritonAttentionBackend(AttentionBackend): + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") ++ if envs.VLLM_USE_REROPE: ++ return (3, num_blocks, block_size, num_kv_heads, head_size) + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod +@@ -299,6 +308,8 @@ class TritonAttentionImpl(AttentionImpl): + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, ++ query2: Optional[torch.Tensor] = None, ++ key2: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: +@@ -342,7 +353,10 @@ class TritonAttentionImpl(AttentionImpl): + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + else: +- key_cache, value_cache = kv_cache.unbind(0) ++ if envs.VLLM_USE_REROPE: ++ key_cache, value_cache, key_cache2 = kv_cache.unbind(0) ++ else: ++ key_cache, value_cache = kv_cache.unbind(0) + + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. +@@ -370,8 +384,22 @@ class TritonAttentionImpl(AttentionImpl): + layer._v_scale, + ) + ++ if envs.VLLM_USE_REROPE and key2 is not None: ++ torch.ops._C_cache_ops.reshape_and_cache_flash( ++ key2, ++ value, ++ key_cache2, ++ value_cache, ++ attn_metadata.slot_mapping, ++ self.kv_cache_dtype, ++ layer._k_scale, ++ layer._v_scale, ++ ) ++ + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) ++ if envs.VLLM_USE_REROPE and key_cache2 is not None: ++ key_cache2 = key_cache2.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + num_tokens, num_heads, head_size = query.shape + assert layer._q_scale == 1.0, \ +@@ -384,6 +412,12 @@ class TritonAttentionImpl(AttentionImpl): + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) ++ if envs.VLLM_USE_REROPE and query2 is not None: ++ query2, _ = ops.scaled_fp8_quant( ++ query2.reshape( ++ (num_tokens, num_heads * head_size)).contiguous(), ++ layer._q_scale) ++ query2 = query2.reshape((num_tokens, num_heads, head_size)) + + use_local_attn = \ + (self.use_irope and attn_metadata.local_attn_metadata is not None) +@@ -403,47 +437,71 @@ class TritonAttentionImpl(AttentionImpl): + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + ++ + if use_prefill_decode_attn: + # Compute attention and update output up to `num_actual_tokens`. + chunked_prefill_paged_decode(query=query[:num_actual_tokens], +- key=key[:num_actual_tokens], +- value=value[:num_actual_tokens], +- output=output[:num_actual_tokens], +- kv_cache_dtype=self.kv_cache_dtype, +- key_cache=key_cache, +- value_cache=value_cache, +- block_table=block_table, +- query_start_loc=cu_seqlens_q, +- seq_lens=seqused_k, +- max_seq_len=max_seqlen_k, +- max_query_len=max_seqlen_q, +- k_scale=layer._k_scale, +- v_scale=layer._v_scale, +- alibi_slopes=self.alibi_slopes, +- sliding_window=self.sliding_window[0], +- sm_scale=self.scale) +- ++ key=key[:num_actual_tokens], ++ value=value[:num_actual_tokens], ++ output=output[:num_actual_tokens], ++ kv_cache_dtype=self.kv_cache_dtype, ++ key_cache=key_cache, ++ value_cache=value_cache, ++ block_table=block_table, ++ query_start_loc=cu_seqlens_q, ++ seq_lens=seqused_k, ++ max_seq_len=max_seqlen_k, ++ max_query_len=max_seqlen_q, ++ k_scale=layer._k_scale, ++ v_scale=layer._v_scale, ++ alibi_slopes=self.alibi_slopes, ++ sliding_window=self.sliding_window[0], ++ sm_scale=self.scale) + else: + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + +- unified_attention( +- q=query[:num_actual_tokens], +- k=key_cache, +- v=value_cache, +- out=output[:num_actual_tokens], +- cu_seqlens_q=cu_seqlens_q, +- max_seqlen_q=max_seqlen_q, +- seqused_k=seqused_k, +- max_seqlen_k=max_seqlen_k, +- softmax_scale=self.scale, +- causal=True, +- alibi_slopes=self.alibi_slopes, +- window_size=self.sliding_window, +- block_table=block_table, +- softcap=self.logits_soft_cap, +- q_descale=None, # Not supported +- k_descale=layer._k_scale.expand(descale_shape), +- v_descale=layer._v_scale.expand(descale_shape), +- ) ++ if attn_metadata.use_rerope: ++ unified_attention_rerope( ++ q=query[:num_actual_tokens], ++ k=key_cache, ++ q2=query2[:num_actual_tokens], ++ k2=key_cache2, ++ v=value_cache, ++ out=output[:num_actual_tokens], ++ cu_seqlens_q=cu_seqlens_q, ++ max_seqlen_q=max_seqlen_q, ++ seqused_k=seqused_k, ++ max_seqlen_k=max_seqlen_k, ++ softmax_scale=self.scale, ++ causal=True, ++ rerope_window=envs.REROPE_WINDOW, ++ alibi_slopes=self.alibi_slopes, ++ window_size=self.sliding_window, ++ block_table=block_table, ++ softcap=self.logits_soft_cap, ++ q_descale=None, # Not supported ++ k_descale=layer._k_scale.expand(descale_shape), ++ v_descale=layer._v_scale.expand(descale_shape), ++ ) ++ else: ++ unified_attention( ++ q=query[:num_actual_tokens], ++ k=key_cache, ++ v=value_cache, ++ out=output[:num_actual_tokens], ++ cu_seqlens_q=cu_seqlens_q, ++ max_seqlen_q=max_seqlen_q, ++ seqused_k=seqused_k, ++ max_seqlen_k=max_seqlen_k, ++ softmax_scale=self.scale, ++ causal=True, ++ alibi_slopes=self.alibi_slopes, ++ window_size=self.sliding_window, ++ block_table=block_table, ++ softcap=self.logits_soft_cap, ++ q_descale=None, # Not supported ++ k_descale=layer._k_scale.expand(descale_shape), ++ v_descale=layer._v_scale.expand(descale_shape), ++ ) + + return output +diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py +index b0ebb00d9..190a3f4ec 100644 +--- a/vllm/v1/attention/backends/utils.py ++++ b/vllm/v1/attention/backends/utils.py +@@ -43,6 +43,8 @@ class CommonAttentionMetadata: + max_query_len: int + """Longest query in batch""" + ++ use_rerope: bool ++ + + M = TypeVar("M") + +diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py +index 43456a987..20edd1f86 100644 +--- a/vllm/v1/kv_cache_interface.py ++++ b/vllm/v1/kv_cache_interface.py +@@ -13,6 +13,8 @@ from vllm.config import VllmConfig + from vllm.logger import init_logger + from vllm.utils import cdiv, get_dtype_size + ++from vllm import envs ++ + logger = init_logger(__name__) + + +@@ -79,7 +81,12 @@ class AttentionSpec(KVCacheSpec): + @property + def page_size_bytes(self) -> int: + # For MLA we only store a single latent vector +- coef = 1 if self.use_mla else 2 ++ if self.use_mla: ++ coef = 1 ++ elif envs.VLLM_USE_REROPE: ++ coef = 3 ++ else: ++ coef = 2 + return coef * self.block_size * self.num_kv_heads * self.head_size \ + * get_dtype_size(self.dtype) + +@@ -88,10 +95,10 @@ class AttentionSpec(KVCacheSpec): + class FullAttentionSpec(AttentionSpec): + sliding_window: Optional[int] = None + """ +- When hybrid allocator is disabled and the model contains both full +- attention layers and sliding window attention layers, sliding +- window attention are regarded as full attention in KV cache manager +- (blocks are allocated for all tokens), while computed as sliding window ++ When hybrid allocator is disabled and the model contains both full ++ attention layers and sliding window attention layers, sliding ++ window attention are regarded as full attention in KV cache manager ++ (blocks are allocated for all tokens), while computed as sliding window + attention in model runner. + In this case, we use FullAttentionSpec and record the sliding window size. + Default to None for not using sliding window attention. +@@ -108,7 +115,7 @@ class FullAttentionSpec(AttentionSpec): + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ +- Merge a list of FullAttentionSpec objects into a single ++ Merge a list of FullAttentionSpec objects into a single + FullAttentionSpec object. + """ + merged_spec = super().merge(specs) +diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py +index 5a26e88db..f61a38550 100644 +--- a/vllm/v1/worker/gpu_model_runner.py ++++ b/vllm/v1/worker/gpu_model_runner.py +@@ -72,6 +72,8 @@ from ..sample.logits_processor import LogitsProcessorManager + from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, + sanity_check_mm_encoder_outputs, scatter_mm_placeholders) + ++from vllm import envs ++ + if TYPE_CHECKING: + import xgrammar as xgr + import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 +@@ -317,6 +319,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): + # from the KV cache of `shared_kv_cache_layers[layer_name]`. + self.shared_kv_cache_layers: dict[str, str] = {} + ++ # use_rerope: current batch rerope state ++ # use_rerope_map: save every request rerope state ++ self.use_rerope = False ++ self.use_rerope_map: dict[str, bool] = {} ++ + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: + """ + Update the order of requests in the batch based on the attention +@@ -602,6 +609,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = max(tokens) + ++ # Setting use_rerope ++ if envs.VLLM_USE_REROPE: ++ use_rerope_this_batch = False ++ for req in scheduler_output.scheduled_new_reqs: ++ self.use_rerope_map[req.req_id] = len(req.prompt_token_ids) > envs.REROPE_WINDOW ++ for req_id in req_ids: ++ use_rerope_this_batch |= self.use_rerope_map[req_id] ++ self.use_rerope = use_rerope_this_batch ++ ++ + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], +@@ -705,6 +722,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, ++ use_rerope=self.use_rerope + ) + + attn_metadata: dict[str, Any] = {} +@@ -1943,7 +1961,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. + This is to help balance expert-selection + - during profile_run +- - during DP rank dummy run ++ - during DP rank dummy run + """ + dp_size = self.vllm_config.parallel_config.data_parallel_size + randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 +-- +2.34.1 + + +From 7fbc495957cdb76c43ce774735138e3a19c8273b Mon Sep 17 00:00:00 2001 From: wenxinwang Date: Tue, 23 Dec 2025 19:44:21 -0800 -Subject: [PATCH] kvcomp qwen deepseek +Subject: [PATCH 2/2] sparse + cache blend --- - vllm/attention/layer.py | 63 ++++++++++++++++- + vllm/attention/layer.py | 64 +++++++++++++++++- vllm/model_executor/models/llama.py | 21 +++++- vllm/model_executor/models/qwen2.py | 23 ++++++- vllm/v1/attention/backends/flash_attn.py | 7 ++ @@ -15,12 +1594,12 @@ Subject: [PATCH] kvcomp qwen deepseek vllm/v1/core/sched/output.py | 3 + vllm/v1/core/sched/scheduler.py | 30 +++++++- vllm/v1/worker/block_table.py | 13 ++++ - vllm/v1/worker/gpu_model_runner.py | 80 +++++++++++++++++++--- + vllm/v1/worker/gpu_model_runner.py | 79 +++++++++++++++++++--- vllm/v1/worker/gpu_worker.py | 2 + 13 files changed, 275 insertions(+), 20 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py -index f0ad68b16..ba93960de 100644 +index 39dc4bf1d..ff7a26e7f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -8,6 +8,7 @@ import torch.nn as nn @@ -39,19 +1618,23 @@ index f0ad68b16..ba93960de 100644 class Attention(nn.Module): -@@ -409,9 +411,10 @@ def unified_attention( +@@ -440,13 +442,14 @@ def unified_attention( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] + query, key, value, _ = maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) - output = self.impl.forward(self, query, key, value, kv_cache, - attn_metadata) + if envs.VLLM_USE_REROPE: + output = self.impl.forward(self, query, key, value, kv_cache, + attn_metadata, query2=query2, key2=key2) + else: + output = self.impl.forward(self, query, key, value, kv_cache, + attn_metadata) - + maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) maybe_save_kv_layer_to_connector(layer_name, kv_cache) return output -@@ -449,6 +452,15 @@ def unified_attention_with_output( +@@ -488,6 +491,15 @@ def unified_attention_with_output( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] @@ -64,21 +1647,22 @@ index f0ad68b16..ba93960de 100644 + query, _, _, _ = maybe_execute_sparse_attention_begin( + query, key, value, layer_name, forward_context, output, k_hash=k_hash + ) - self.impl.forward(self, - query, - key, -@@ -457,6 +469,10 @@ def unified_attention_with_output( - attn_metadata, - output=output, - output_scale=output_scale) + if envs.VLLM_USE_REROPE: + self.impl.forward(self, + query, +@@ -508,6 +520,11 @@ def unified_attention_with_output( + attn_metadata, + output=output, + output_scale=output_scale) + if not self.use_mla: + maybe_execute_sparse_attention_finished( + query, key, value, output, layer_name, forward_context + ) - ++ maybe_save_kv_layer_to_connector(layer_name, kv_cache) -@@ -479,3 +495,48 @@ direct_register_custom_op( + +@@ -531,3 +548,48 @@ direct_register_custom_op( fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) @@ -181,7 +1765,7 @@ index 5d5080479..39cb2f4fb 100644 if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py -index 7ef9d248d..e35ab2fdc 100644 +index 2d75195eb..512df9345 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -56,6 +56,12 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, @@ -195,9 +1779,9 @@ index 7ef9d248d..e35ab2fdc 100644 + maybe_execute_sparse_layer_finished, + ) - - class Qwen2MLP(nn.Module): -@@ -255,11 +261,16 @@ class Qwen2DecoderLayer(nn.Module): + import math + from vllm import envs +@@ -281,11 +287,16 @@ class Qwen2DecoderLayer(nn.Module): positions=positions, hidden_states=hidden_states, ) @@ -215,7 +1799,7 @@ index 7ef9d248d..e35ab2fdc 100644 return hidden_states, residual -@@ -352,11 +363,21 @@ class Qwen2Model(nn.Module): +@@ -378,11 +389,21 @@ class Qwen2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for layer in self.layers[self.start_layer:self.end_layer]: @@ -601,7 +2185,7 @@ index 8f4e8d64c..f45e39f5c 100644 for i, block_table in enumerate(self.block_tables): block_table.add_row(block_ids[i], row_idx) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index 5a26e88db..6a39240d2 100644 +index f61a38550..82decc8d7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -15,6 +15,7 @@ import torch.nn as nn @@ -612,17 +2196,16 @@ index 5a26e88db..6a39240d2 100644 from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention -@@ -72,6 +73,9 @@ from ..sample.logits_processor import LogitsProcessorManager - from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, +@@ -73,6 +74,8 @@ from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) + from vllm import envs +from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse +from ucm.sparse.base import UcmSparseMetadata, INVALID_SLOT -+ + if TYPE_CHECKING: import xgrammar as xgr - import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 -@@ -365,6 +369,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -372,6 +375,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): """ # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: @@ -630,7 +2213,7 @@ index 5a26e88db..6a39240d2 100644 self.requests.pop(req_id, None) self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. -@@ -468,11 +473,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -475,11 +479,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs @@ -644,7 +2227,7 @@ index 5a26e88db..6a39240d2 100644 # Update the cached states. req_state.num_computed_tokens = num_computed_tokens -@@ -494,15 +501,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -501,15 +507,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): new_token_ids[-num_new_tokens:]) # Update the block IDs. @@ -666,7 +2249,7 @@ index 5a26e88db..6a39240d2 100644 req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: -@@ -515,6 +522,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -522,6 +528,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) @@ -675,7 +2258,7 @@ index 5a26e88db..6a39240d2 100644 self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu -@@ -623,6 +632,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -640,6 +648,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.uses_mrope: self._calc_mrope_positions(scheduler_output) @@ -695,7 +2278,7 @@ index 5a26e88db..6a39240d2 100644 # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] -@@ -652,11 +674,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -669,11 +690,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): # block_size. block_table_indices = ( req_indices * block_table.max_num_blocks_per_req + @@ -709,7 +2292,7 @@ index 5a26e88db..6a39240d2 100644 np.add( block_numbers * block_size, block_offsets, -@@ -666,9 +688,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -683,9 +704,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens @@ -724,7 +2307,7 @@ index 5a26e88db..6a39240d2 100644 # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( -@@ -680,6 +704,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -697,6 +720,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): non_blocking=True) else: # Common case (1D positions) @@ -733,7 +2316,7 @@ index 5a26e88db..6a39240d2 100644 self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) -@@ -1370,6 +1396,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1388,6 +1413,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): skip_cuda_graphs=skip_cuda_graphs, ): self.maybe_setup_kv_connector(scheduler_output) @@ -741,7 +2324,7 @@ index 5a26e88db..6a39240d2 100644 model_output = self.model( input_ids=input_ids, -@@ -1379,6 +1406,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1397,6 +1423,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) self.maybe_wait_for_kv_save() @@ -750,7 +2333,7 @@ index 5a26e88db..6a39240d2 100644 finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) -@@ -1723,6 +1752,30 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1741,6 +1769,30 @@ class GPUModelRunner(LoRAModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().wait_for_save() @@ -781,7 +2364,7 @@ index 5a26e88db..6a39240d2 100644 @staticmethod def get_finished_kv_transfers( scheduler_output: "SchedulerOutput", -@@ -2570,6 +2623,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -2588,6 +2640,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors) diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_rerope_patch.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_rerope_patch.py index e7bc75ed7..36711f249 100644 --- a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_rerope_patch.py +++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_rerope_patch.py @@ -416,9 +416,6 @@ def _patch_qwen2_model() -> None: from ucm.sparse.rerope.rerope_utils import default_config - REROPE_WINDOW = default_config.rerope_window - TRAINING_LENGTH = default_config.training_length - def Qwen2Attention_forward( self, positions: torch.Tensor, @@ -430,6 +427,8 @@ def Qwen2Attention_forward( q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) ###################### rerope patch ############### + REROPE_WINDOW = default_config.rerope_window + TRAINING_LENGTH = default_config.training_length if attn_metadata and next(iter(attn_metadata.values())).use_rerope: q *= ( ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)) @@ -448,9 +447,10 @@ def Qwen2Attention_forward( k0 = k.clone() q, k = self.rotary_emb(positions, q, k) q2 = q.clone() + + attn_output = self.attn(q, k, v, query2=q2, key2=k0) ###################### rerope patch ############### - attn_output = self.attn(q, k, q2, k0, v) output, _ = self.o_proj(attn_output) return output @@ -472,9 +472,6 @@ def _patch_qwen3_model() -> None: from ucm.sparse.rerope.rerope_utils import default_config - REROPE_WINDOW = default_config.rerope_window - TRAINING_LENGTH = default_config.training_length - def Qwen3Attention_forward( self, positions: torch.Tensor, @@ -497,6 +494,8 @@ def Qwen3Attention_forward( k = k_by_head.view(k.shape) ###################### rerope patch ############### + REROPE_WINDOW = default_config.rerope_window + TRAINING_LENGTH = default_config.training_length if attn_metadata and next(iter(attn_metadata.values())).use_rerope: q *= ( ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)) @@ -514,9 +513,10 @@ def Qwen3Attention_forward( k0 = k.clone() q, k = self.rotary_emb(positions, q, k) q2 = q.clone() + + attn_output = self.attn(q, k, v, query2=q2, key2=k0) ###################### rerope patch ############### - attn_output = self.attn(q, k, q2, k0, v) output, _ = self.o_proj(attn_output) return output @@ -538,9 +538,6 @@ def _patch_qwen3moe_model() -> None: from ucm.sparse.rerope.rerope_utils import default_config - REROPE_WINDOW = default_config.rerope_window - TRAINING_LENGTH = default_config.training_length - def Qwen3MoeAttention_forward( self, positions: torch.Tensor, @@ -563,6 +560,8 @@ def Qwen3MoeAttention_forward( k = k_by_head.view(k.shape) ###################### rerope patch ############### + REROPE_WINDOW = default_config.rerope_window + TRAINING_LENGTH = default_config.training_length if attn_metadata and next(iter(attn_metadata.values())).use_rerope: q *= ( ((positions + 1)[:, None].log() / math.log(TRAINING_LENGTH)) @@ -580,9 +579,10 @@ def Qwen3MoeAttention_forward( k0 = k.clone() q, k = self.rotary_emb(positions, q, k) q2 = q.clone() + + attn_output = self.attn(q, k, v, query2=q2, key2=k0) ###################### rerope patch ############### - attn_output = self.attn(q, k, q2, k0, v) output, _ = self.o_proj(attn_output) return output @@ -609,9 +609,9 @@ def attn_forward( self, query: torch.Tensor, key: torch.Tensor, - query2: Optional[torch.Tensor], - key2: Optional[torch.Tensor], value: torch.Tensor, + query2: Optional[torch.Tensor] = None, + key2: Optional[torch.Tensor] = None, # For some alternate attention backends like MLA the attention output # shape does not match the query shape, so we optionally let the model # definition specify the output tensor shape. @@ -666,16 +666,22 @@ def attn_forward( self, query, key, - query2, - key2, value, self_kv_cache, attn_metadata, + query2=query2, + key2=key2, output=output, ) else: torch.ops.vllm.unified_attention_with_output( - query, key, query2, key2, value, output, self.layer_name + query, + key, + value, + output, + self.layer_name, + query2=query2, + key2=key2, ) return output.view(-1, hidden_size) else: @@ -689,15 +695,15 @@ def attn_forward( self, query, key, - query2, - key2, value, self_kv_cache, attn_metadata, + query2=query2, + key2=key2, ) else: return torch.ops.vllm.unified_attention( - query, key, query2, key2, value, self.layer_name + query, key, value, self.layer_name, query2=query2, key2=key2 ) ###################### rerope patch ############### @@ -721,10 +727,10 @@ def __getattr__(self, name): def unified_attention_impl( query: torch.Tensor, key: torch.Tensor, - query2: Optional[torch.Tensor], - key2: Optional[torch.Tensor], value: torch.Tensor, layer_name: str, + query2: Optional[torch.Tensor] = None, + key2: Optional[torch.Tensor] = None, ) -> torch.Tensor: wait_for_kv_layer_from_connector(layer_name) @@ -736,7 +742,14 @@ def unified_attention_impl( kv_cache = self.kv_cache[forward_context.virtual_engine] output = self.impl.forward( - self, query, key, query2, key2, value, kv_cache, attn_metadata + self, + query, + key, + value, + kv_cache, + attn_metadata, + query2=query2, + key2=key2, ) maybe_save_kv_layer_to_connector(layer_name, kv_cache) @@ -745,11 +758,11 @@ def unified_attention_impl( def unified_attention_with_output_impl( query: torch.Tensor, key: torch.Tensor, - query2: Optional[torch.Tensor], - key2: Optional[torch.Tensor], value: torch.Tensor, output: torch.Tensor, layer_name: str, + query2: Optional[torch.Tensor] = None, + key2: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, ) -> None: wait_for_kv_layer_from_connector(layer_name) @@ -763,11 +776,11 @@ def unified_attention_with_output_impl( self, query, key, - query2, - key2, value, kv_cache, attn_metadata, + query2=query2, + key2=key2, output=output, output_scale=output_scale, ) @@ -987,11 +1000,11 @@ def TritonAttentionImpl_forwad( layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, - query2: Optional[torch.Tensor], - key2: Optional[torch.Tensor], value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TritonAttentionMetadata, + query2: Optional[torch.Tensor] = None, + key2: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -1048,22 +1061,24 @@ def TritonAttentionImpl_forwad( layer._v_scale, ) ###################### rerope patch ############### - torch.ops._C_cache_ops.reshape_and_cache_flash( - key2, - value, - key_cache2, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + if key2 is not None: + torch.ops._C_cache_ops.reshape_and_cache_flash( + key2, + value, + key_cache2, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) ###################### rerope patch ############### if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype) ###################### rerope patch ############### - key_cache2 = key_cache2.view(self.fp8_dtype) + if key_cache2 is not None: + key_cache2 = key_cache2.view(self.fp8_dtype) ###################### rerope patch ############### value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape @@ -1079,13 +1094,14 @@ def TritonAttentionImpl_forwad( ) query = query.reshape((num_tokens, num_heads, head_size)) ###################### rerope patch ############### - query2, _ = ops.scaled_fp8_quant( - query2.reshape( - (num_tokens, num_heads * head_size) - ).contiguous(), - layer._q_scale, - ) - query2 = query2.reshape((num_tokens, num_heads, head_size)) + if query2 is not None: + query2, _ = ops.scaled_fp8_quant( + query2.reshape( + (num_tokens, num_heads * head_size) + ).contiguous(), + layer._q_scale, + ) + query2 = query2.reshape((num_tokens, num_heads, head_size)) ###################### rerope patch ############### use_local_attn = (