Skip to content

Commit 8e28f90

Browse files
committed
update the vllm-ascend patch for KVComp in NPU
1 parent 4361da0 commit 8e28f90

File tree

1 file changed

+233
-7
lines changed

1 file changed

+233
-7
lines changed

ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch

Lines changed: 233 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
From c92cb68fd1fa6215cd6d5b207b95c841ac20dbe1 Mon Sep 17 00:00:00 2001
1+
From 3ce6a6053ca9854d95828ab600e01b848a253a56 Mon Sep 17 00:00:00 2001
22
From: wenxinwang <[email protected]>
33
Date: Tue, 23 Dec 2025 19:21:33 -0800
4-
Subject: [PATCH] sparse patch for vllm-ascend
4+
Subject: [PATCH 1/4] sparse patch for vllm-ascend
55

66
---
77
vllm_ascend/attention/attention_v1.py | 80 ++++++++++++++++++++++
@@ -11,7 +11,7 @@ Subject: [PATCH] sparse patch for vllm-ascend
1111
4 files changed, 201 insertions(+), 16 deletions(-)
1212

1313
diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py
14-
index 7d7f488f..ea982244 100644
14+
index 7d7f488..ea98224 100644
1515
--- a/vllm_ascend/attention/attention_v1.py
1616
+++ b/vllm_ascend/attention/attention_v1.py
1717
@@ -24,6 +24,9 @@ import torch_npu
@@ -129,7 +129,7 @@ index 7d7f488f..ea982244 100644
129129
def unified_attention_with_output_fake(
130130
query: torch.Tensor,
131131
diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py
132-
index f50fe56e..ae8f50bf 100644
132+
index f50fe56..ae8f50b 100644
133133
--- a/vllm_ascend/attention/mla_v1.py
134134
+++ b/vllm_ascend/attention/mla_v1.py
135135
@@ -13,10 +13,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size
@@ -185,7 +185,7 @@ index f50fe56e..ae8f50bf 100644
185185

186186
return output_padded
187187
diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py
188-
index eabcdbcc..782b9a3b 100644
188+
index eabcdbc..782b9a3 100644
189189
--- a/vllm_ascend/worker/model_runner_v1.py
190190
+++ b/vllm_ascend/worker/model_runner_v1.py
191191
@@ -39,7 +39,10 @@ from vllm.config import CompilationLevel, VllmConfig
@@ -386,7 +386,7 @@ index eabcdbcc..782b9a3b 100644
386386
+ ucm_sparse.request_finished_in_worker(request_id)
387387
\ No newline at end of file
388388
diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py
389-
index df03d508..5d5d9b5a 100644
389+
index df03d50..5d5d9b5 100644
390390
--- a/vllm_ascend/worker/worker_v1.py
391391
+++ b/vllm_ascend/worker/worker_v1.py
392392
@@ -17,6 +17,7 @@
@@ -458,5 +458,231 @@ index df03d508..5d5d9b5a 100644
458458
def _init_profiler(self):
459459
# Torch profiler. Enabled and configured through env vars:
460460
--
461-
2.34.1
461+
2.43.0
462+
463+
464+
From 9b685ed319bae31f5e56d596499fbd7c7f60c4b0 Mon Sep 17 00:00:00 2001
465+
From: ldeng <[email protected]>
466+
Date: Mon, 29 Dec 2025 17:59:28 +0800
467+
Subject: [PATCH 2/4] update attention_v1 for kvcomp in NPU
468+
469+
---
470+
vllm_ascend/attention/attention_v1.py | 59 +++++++++++++++++++--------
471+
1 file changed, 43 insertions(+), 16 deletions(-)
472+
473+
diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py
474+
index ea98224..b924d8e 100644
475+
--- a/vllm_ascend/attention/attention_v1.py
476+
+++ b/vllm_ascend/attention/attention_v1.py
477+
@@ -37,7 +37,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
478+
nd_to_nz_2d, nd_to_nz_spec)
479+
480+
from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
481+
-
482+
+import os
483+
484+
class AscendAttentionBackend(AttentionBackend):
485+
accept_output_buffer: bool = True
486+
@@ -132,8 +132,9 @@ class AscendMetadata:
487+
# the computed tokens + new tokens None if it is a decoding.
488+
query_start_loc: torch.Tensor
489+
query_lens: torch.Tensor
490+
+ query_lens_device: torch.Tensor # (ldeng) added for KVComp
491+
seq_lens: torch.Tensor
492+
-
493+
+ seq_lens_device: torch.Tensor # (ldeng) added for KVComp
494+
# max value of number of tokens across dp group
495+
max_num_tokens_across_dp: int = 0
496+
497+
@@ -182,15 +183,22 @@ class AscendAttentionMetadataBuilder:
498+
block_table[:num_reqs])
499+
500+
query_lens = self.runner.query_lens
501+
+ query_lens_device = query_lens.pin_memory().to(self.runner.device, non_blocking=True)
502+
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
503+
+ seq_lens_device = seq_lens.pin_memory().to(self.runner.device, non_blocking=True)
504+
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
505+
self.runner.device, non_blocking=True)
506+
attn_mask = self.runner.attn_mask
507+
attn_state = self.runner.attn_state
508+
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
509+
- query_start_loc = query_start_loc_cpu.to(self.runner.device,
510+
+ query_start_loc = query_start_loc_cpu.pin_memory().to(self.runner.device,
511+
non_blocking=True)
512+
513+
+ if has_ucm_sparse():
514+
+ ucm_sparse = get_ucm_sparse()
515+
+ if os.getenv("VLLM_HASH_ATTENTION", "0") == "1":
516+
+ ucm_sparse.build_decode_attention_meta_npu(query_lens, seq_lens, block_table)
517+
+
518+
if is_310p():
519+
if attn_state == AscendAttentionState.PrefillNoCache:
520+
mask_nz = nd_to_nz_2d(attn_mask)
521+
@@ -206,7 +214,9 @@ class AscendAttentionMetadataBuilder:
522+
block_tables=block_table,
523+
query_start_loc=query_start_loc,
524+
query_lens=query_lens,
525+
+ query_lens_device=query_lens_device,
526+
seq_lens=seq_lens,
527+
+ seq_lens_device=seq_lens_device,
528+
max_query_len=max_query_len,
529+
slot_mapping=slot_mapping,
530+
attn_mask=attn_mask,
531+
@@ -279,8 +289,17 @@ class AscendAttentionBackendImpl(AttentionImpl):
532+
shape = [batch_size * seq_len, num_heads, head_size]
533+
"""
534+
num_tokens = query.shape[0]
535+
- use_kv_cache_int8 = kv_cache.numel(
536+
- ) > 0 and kv_cache[0].dtype == torch.int8
537+
+
538+
+ # In NPU, forward could be called directly, not by unified_ascend_attention_with_output
539+
+ actual_cache = kv_cache[0] if isinstance(kv_cache, tuple) else kv_cache
540+
+ if actual_cache is not None:
541+
+ use_kv_cache_int8 = actual_cache.numel() > 0 and actual_cache.dtype == torch.int8
542+
+ else:
543+
+ use_kv_cache_int8 = False
544+
+ kv_cache = actual_cache
545+
+
546+
+ #use_kv_cache_int8 = kv_cache.numel(
547+
+ #) > 0 and kv_cache[0].dtype == torch.int8
548+
if output is None:
549+
output = torch.empty(num_tokens,
550+
self.num_heads,
551+
@@ -449,14 +468,20 @@ def unified_ascend_attention_with_output(
552+
output: torch.Tensor,
553+
layer_name: str,
554+
) -> None:
555+
- wait_for_kv_layer_from_connector(layer_name)
556+
+ # wait_for_kv_layer_from_connector(layer_name)
557+
558+
forward_context: ForwardContext = get_forward_context()
559+
attn_metadata = forward_context.attn_metadata
560+
self = forward_context.no_compile_layers[layer_name]
561+
kv_cache = self.kv_cache[forward_context.virtual_engine]
562+
- if not self.use_mla:
563+
- query, _, _, _ = maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
564+
+
565+
+ # In NPU, during dummy_run, kv_cache could be a empty tensor, so we need to check the length of kv_cache
566+
+ if os.getenv("VLLM_HASH_ATTENTION", "0") == "1" and len(kv_cache) > 0:
567+
+ kv_cache, k_hash = kv_cache
568+
+ else:
569+
+ k_hash = None
570+
+ if attn_metadata is not None:
571+
+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context, output, k_hash=k_hash)
572+
self.impl.forward(self,
573+
query,
574+
key,
575+
@@ -465,9 +490,10 @@ def unified_ascend_attention_with_output(
576+
attn_metadata,
577+
output,
578+
trace_flag=False)
579+
- if not self.use_mla:
580+
+
581+
+ if attn_metadata is not None:
582+
maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
583+
- maybe_save_kv_layer_to_connector(layer_name, kv_cache)
584+
+ # maybe_save_kv_layer_to_connector(layer_name, kv_cache)
585+
return
586+
587+
def wait_for_kv_layer_from_connector(layer_name: str):
588+
@@ -506,19 +532,20 @@ def maybe_execute_sparse_attention_begin(
589+
forward_context: ForwardContext,
590+
output: Optional[torch.Tensor] = None,
591+
phase: Optional[str] = None,
592+
+ k_hash: Optional[torch.Tensor] = None,
593+
+ decode_ql_nope: Optional[torch.Tensor] = None,
594+
+ decode_q_pe: Optional[torch.Tensor] = None,
595+
):
596+
if not has_ucm_sparse():
597+
- return query, key, value, output
598+
+ return
599+
600+
ucm_sparse = get_ucm_sparse()
601+
602+
attn_metadata = forward_context.attn_metadata
603+
if attn_metadata is None:
604+
- return query, key, value, output
605+
+ return
606+
607+
- return ucm_sparse.attention_begin(
608+
- query, key, value, layer_name, forward_context, output, phase
609+
- )
610+
+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context, output, phase, k_hash, decode_ql_nope, decode_q_pe)
611+
612+
def maybe_execute_sparse_attention_finished(
613+
query: torch.Tensor,
614+
@@ -555,4 +582,4 @@ direct_register_custom_op(
615+
mutates_args=["output"],
616+
fake_impl=unified_attention_with_output_fake,
617+
dispatch_key="PrivateUse1",
618+
-)
619+
+)
620+
\ No newline at end of file
621+
--
622+
2.43.0
623+
624+
625+
From d82f08b6981ca4d01432309b659bea57b9d18576 Mon Sep 17 00:00:00 2001
626+
From: ldeng <[email protected]>
627+
Date: Mon, 29 Dec 2025 18:04:02 +0800
628+
Subject: [PATCH 3/4] call initialize_kv_hash_cache_tensors_npu to allocate
629+
hashk cache in NPUModelRunner when VLLM_HASH_ATTENTION is enabled for KVComp
630+
631+
---
632+
vllm_ascend/worker/model_runner_v1.py | 5 +++++
633+
1 file changed, 5 insertions(+)
634+
635+
diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py
636+
index 782b9a3..766316e 100644
637+
--- a/vllm_ascend/worker/model_runner_v1.py
638+
+++ b/vllm_ascend/worker/model_runner_v1.py
639+
@@ -1993,6 +1993,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
640+
# KV cache specs.
641+
raise ValueError("Unknown KV cache spec type.")
642+
643+
+ if has_ucm_sparse():
644+
+ ucm_sparse = get_ucm_sparse()
645+
+ if os.getenv("VLLM_HASH_ATTENTION", "0") == "1":
646+
+ ucm_sparse.initialize_kv_hash_cache_tensors_npu(kv_caches, self.device)
647+
+
648+
bind_kv_cache(
649+
kv_caches,
650+
self.vllm_config.compilation_config.static_forward_context,
651+
--
652+
2.43.0
653+
654+
655+
From 3d28d477a4887e7d5c909a66b448086994751567 Mon Sep 17 00:00:00 2001
656+
From: ldeng <[email protected]>
657+
Date: Mon, 29 Dec 2025 18:04:50 +0800
658+
Subject: [PATCH 4/4] uncomment connector
659+
660+
---
661+
vllm_ascend/attention/attention_v1.py | 4 ++--
662+
1 file changed, 2 insertions(+), 2 deletions(-)
663+
664+
diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py
665+
index b924d8e..ece7d17 100644
666+
--- a/vllm_ascend/attention/attention_v1.py
667+
+++ b/vllm_ascend/attention/attention_v1.py
668+
@@ -468,7 +468,7 @@ def unified_ascend_attention_with_output(
669+
output: torch.Tensor,
670+
layer_name: str,
671+
) -> None:
672+
- # wait_for_kv_layer_from_connector(layer_name)
673+
+ wait_for_kv_layer_from_connector(layer_name)
674+
675+
forward_context: ForwardContext = get_forward_context()
676+
attn_metadata = forward_context.attn_metadata
677+
@@ -493,7 +493,7 @@ def unified_ascend_attention_with_output(
678+
679+
if attn_metadata is not None:
680+
maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
681+
- # maybe_save_kv_layer_to_connector(layer_name, kv_cache)
682+
+ maybe_save_kv_layer_to_connector(layer_name, kv_cache)
683+
return
684+
685+
def wait_for_kv_layer_from_connector(layer_name: str):
686+
--
687+
2.43.0
462688

0 commit comments

Comments
 (0)