Skip to content

Commit be5bcbc

Browse files
[feat] cherry-pick KVComp in NPU -- HBM version into the 0.2.0-release branch (#619)
# Purpose What this PR does / why we need it? KVComp in NPU -- HBM version in 0.2.0-release Co-authored-by: Lei Deng <[email protected]>
1 parent db04137 commit be5bcbc

File tree

8 files changed

+533
-78
lines changed

8 files changed

+533
-78
lines changed

examples/offline_inference_kvcomphbm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
7777
},
7878
}
7979
],
80-
"ucm_sparse_config": {"GSA": {}},
80+
"ucm_sparse_config": {"KvCompOnDevice": {}},
8181
},
8282
)
8383

examples/ucm_config_example.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ load_only_first_rank: false
3131
# Or for GSA:
3232
# GSA: {}
3333
# Or for KvCompOnDevice:
34-
# GSA:
35-
# "kvcompOnDevice_config_path": "workspace/unified-cache-management/ucm/sparse/kvcomp/configs/kvcomp_qwen3_32B_config.json"
34+
# KvCompOnDevice: {}
3635

3736

3837
# Whether to use layerwise loading/saving (optional, default: True for UCMConnector)

ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch

Lines changed: 230 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Subject: [PATCH] modify ascend patch for register_kv_cache
1111
4 files changed, 204 insertions(+), 16 deletions(-)
1212

1313
diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py
14-
index 7d7f488f..ea982244 100644
14+
index 7d7f488..ea98224 100644
1515
--- a/vllm_ascend/attention/attention_v1.py
1616
+++ b/vllm_ascend/attention/attention_v1.py
1717
@@ -24,6 +24,9 @@ import torch_npu
@@ -129,7 +129,7 @@ index 7d7f488f..ea982244 100644
129129
def unified_attention_with_output_fake(
130130
query: torch.Tensor,
131131
diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py
132-
index f50fe56e..ae8f50bf 100644
132+
index f50fe56..ae8f50b 100644
133133
--- a/vllm_ascend/attention/mla_v1.py
134134
+++ b/vllm_ascend/attention/mla_v1.py
135135
@@ -13,10 +13,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size
@@ -396,7 +396,7 @@ index eabcdbcc..2762fbc7 100644
396396
+ ucm_sparse.request_finished_in_worker(request_id)
397397
\ No newline at end of file
398398
diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py
399-
index df03d508..5d5d9b5a 100644
399+
index df03d50..5d5d9b5 100644
400400
--- a/vllm_ascend/worker/worker_v1.py
401401
+++ b/vllm_ascend/worker/worker_v1.py
402402
@@ -17,6 +17,7 @@
@@ -468,5 +468,231 @@ index df03d508..5d5d9b5a 100644
468468
def _init_profiler(self):
469469
# Torch profiler. Enabled and configured through env vars:
470470
--
471-
2.50.1.windows.1
471+
2.43.0
472+
473+
474+
From 9b685ed319bae31f5e56d596499fbd7c7f60c4b0 Mon Sep 17 00:00:00 2001
475+
From: ldeng <[email protected]>
476+
Date: Mon, 29 Dec 2025 17:59:28 +0800
477+
Subject: [PATCH 2/4] update attention_v1 for kvcomp in NPU
478+
479+
---
480+
vllm_ascend/attention/attention_v1.py | 59 +++++++++++++++++++--------
481+
1 file changed, 43 insertions(+), 16 deletions(-)
482+
483+
diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py
484+
index ea98224..b924d8e 100644
485+
--- a/vllm_ascend/attention/attention_v1.py
486+
+++ b/vllm_ascend/attention/attention_v1.py
487+
@@ -37,7 +37,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
488+
nd_to_nz_2d, nd_to_nz_spec)
489+
490+
from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
491+
-
492+
+import os
493+
494+
class AscendAttentionBackend(AttentionBackend):
495+
accept_output_buffer: bool = True
496+
@@ -132,8 +132,9 @@ class AscendMetadata:
497+
# the computed tokens + new tokens None if it is a decoding.
498+
query_start_loc: torch.Tensor
499+
query_lens: torch.Tensor
500+
+ query_lens_device: torch.Tensor # (ldeng) added for KVComp
501+
seq_lens: torch.Tensor
502+
-
503+
+ seq_lens_device: torch.Tensor # (ldeng) added for KVComp
504+
# max value of number of tokens across dp group
505+
max_num_tokens_across_dp: int = 0
506+
507+
@@ -182,15 +183,22 @@ class AscendAttentionMetadataBuilder:
508+
block_table[:num_reqs])
509+
510+
query_lens = self.runner.query_lens
511+
+ query_lens_device = query_lens.pin_memory().to(self.runner.device, non_blocking=True)
512+
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
513+
+ seq_lens_device = seq_lens.pin_memory().to(self.runner.device, non_blocking=True)
514+
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
515+
self.runner.device, non_blocking=True)
516+
attn_mask = self.runner.attn_mask
517+
attn_state = self.runner.attn_state
518+
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
519+
- query_start_loc = query_start_loc_cpu.to(self.runner.device,
520+
+ query_start_loc = query_start_loc_cpu.pin_memory().to(self.runner.device,
521+
non_blocking=True)
522+
523+
+ if has_ucm_sparse():
524+
+ ucm_sparse = get_ucm_sparse()
525+
+ if os.getenv("VLLM_HASH_ATTENTION", "0") == "1":
526+
+ ucm_sparse.build_decode_attention_meta_npu(query_lens, seq_lens, block_table)
527+
+
528+
if is_310p():
529+
if attn_state == AscendAttentionState.PrefillNoCache:
530+
mask_nz = nd_to_nz_2d(attn_mask)
531+
@@ -206,7 +214,9 @@ class AscendAttentionMetadataBuilder:
532+
block_tables=block_table,
533+
query_start_loc=query_start_loc,
534+
query_lens=query_lens,
535+
+ query_lens_device=query_lens_device,
536+
seq_lens=seq_lens,
537+
+ seq_lens_device=seq_lens_device,
538+
max_query_len=max_query_len,
539+
slot_mapping=slot_mapping,
540+
attn_mask=attn_mask,
541+
@@ -279,8 +289,17 @@ class AscendAttentionBackendImpl(AttentionImpl):
542+
shape = [batch_size * seq_len, num_heads, head_size]
543+
"""
544+
num_tokens = query.shape[0]
545+
- use_kv_cache_int8 = kv_cache.numel(
546+
- ) > 0 and kv_cache[0].dtype == torch.int8
547+
+
548+
+ # In NPU, forward could be called directly, not by unified_ascend_attention_with_output
549+
+ actual_cache = kv_cache[0] if isinstance(kv_cache, tuple) else kv_cache
550+
+ if actual_cache is not None:
551+
+ use_kv_cache_int8 = actual_cache.numel() > 0 and actual_cache.dtype == torch.int8
552+
+ else:
553+
+ use_kv_cache_int8 = False
554+
+ kv_cache = actual_cache
555+
+
556+
+ #use_kv_cache_int8 = kv_cache.numel(
557+
+ #) > 0 and kv_cache[0].dtype == torch.int8
558+
if output is None:
559+
output = torch.empty(num_tokens,
560+
self.num_heads,
561+
@@ -449,14 +468,20 @@ def unified_ascend_attention_with_output(
562+
output: torch.Tensor,
563+
layer_name: str,
564+
) -> None:
565+
- wait_for_kv_layer_from_connector(layer_name)
566+
+ # wait_for_kv_layer_from_connector(layer_name)
567+
568+
forward_context: ForwardContext = get_forward_context()
569+
attn_metadata = forward_context.attn_metadata
570+
self = forward_context.no_compile_layers[layer_name]
571+
kv_cache = self.kv_cache[forward_context.virtual_engine]
572+
- if not self.use_mla:
573+
- query, _, _, _ = maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
574+
+
575+
+ # In NPU, during dummy_run, kv_cache could be a empty tensor, so we need to check the length of kv_cache
576+
+ if os.getenv("VLLM_HASH_ATTENTION", "0") == "1" and len(kv_cache) > 0:
577+
+ kv_cache, k_hash = kv_cache
578+
+ else:
579+
+ k_hash = None
580+
+ if attn_metadata is not None:
581+
+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context, output, k_hash=k_hash)
582+
self.impl.forward(self,
583+
query,
584+
key,
585+
@@ -465,9 +490,10 @@ def unified_ascend_attention_with_output(
586+
attn_metadata,
587+
output,
588+
trace_flag=False)
589+
- if not self.use_mla:
590+
+
591+
+ if attn_metadata is not None:
592+
maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
593+
- maybe_save_kv_layer_to_connector(layer_name, kv_cache)
594+
+ # maybe_save_kv_layer_to_connector(layer_name, kv_cache)
595+
return
596+
597+
def wait_for_kv_layer_from_connector(layer_name: str):
598+
@@ -506,19 +532,20 @@ def maybe_execute_sparse_attention_begin(
599+
forward_context: ForwardContext,
600+
output: Optional[torch.Tensor] = None,
601+
phase: Optional[str] = None,
602+
+ k_hash: Optional[torch.Tensor] = None,
603+
+ decode_ql_nope: Optional[torch.Tensor] = None,
604+
+ decode_q_pe: Optional[torch.Tensor] = None,
605+
):
606+
if not has_ucm_sparse():
607+
- return query, key, value, output
608+
+ return
609+
610+
ucm_sparse = get_ucm_sparse()
611+
612+
attn_metadata = forward_context.attn_metadata
613+
if attn_metadata is None:
614+
- return query, key, value, output
615+
+ return
616+
617+
- return ucm_sparse.attention_begin(
618+
- query, key, value, layer_name, forward_context, output, phase
619+
- )
620+
+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context, output, phase, k_hash, decode_ql_nope, decode_q_pe)
621+
622+
def maybe_execute_sparse_attention_finished(
623+
query: torch.Tensor,
624+
@@ -555,4 +582,4 @@ direct_register_custom_op(
625+
mutates_args=["output"],
626+
fake_impl=unified_attention_with_output_fake,
627+
dispatch_key="PrivateUse1",
628+
-)
629+
+)
630+
\ No newline at end of file
631+
--
632+
2.43.0
633+
634+
635+
From d82f08b6981ca4d01432309b659bea57b9d18576 Mon Sep 17 00:00:00 2001
636+
From: ldeng <[email protected]>
637+
Date: Mon, 29 Dec 2025 18:04:02 +0800
638+
Subject: [PATCH 3/4] call initialize_kv_hash_cache_tensors_npu to allocate
639+
hashk cache in NPUModelRunner when VLLM_HASH_ATTENTION is enabled for KVComp
640+
641+
---
642+
vllm_ascend/worker/model_runner_v1.py | 5 +++++
643+
1 file changed, 5 insertions(+)
644+
645+
diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py
646+
index 782b9a3..766316e 100644
647+
--- a/vllm_ascend/worker/model_runner_v1.py
648+
+++ b/vllm_ascend/worker/model_runner_v1.py
649+
@@ -1993,6 +1993,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
650+
# KV cache specs.
651+
raise ValueError("Unknown KV cache spec type.")
652+
653+
+ if has_ucm_sparse():
654+
+ ucm_sparse = get_ucm_sparse()
655+
+ if os.getenv("VLLM_HASH_ATTENTION", "0") == "1":
656+
+ ucm_sparse.initialize_kv_hash_cache_tensors_npu(kv_caches, self.device)
657+
+
658+
bind_kv_cache(
659+
kv_caches,
660+
self.vllm_config.compilation_config.static_forward_context,
661+
--
662+
2.43.0
663+
664+
665+
From 3d28d477a4887e7d5c909a66b448086994751567 Mon Sep 17 00:00:00 2001
666+
From: ldeng <[email protected]>
667+
Date: Mon, 29 Dec 2025 18:04:50 +0800
668+
Subject: [PATCH 4/4] uncomment connector
669+
670+
---
671+
vllm_ascend/attention/attention_v1.py | 4 ++--
672+
1 file changed, 2 insertions(+), 2 deletions(-)
673+
674+
diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py
675+
index b924d8e..ece7d17 100644
676+
--- a/vllm_ascend/attention/attention_v1.py
677+
+++ b/vllm_ascend/attention/attention_v1.py
678+
@@ -468,7 +468,7 @@ def unified_ascend_attention_with_output(
679+
output: torch.Tensor,
680+
layer_name: str,
681+
) -> None:
682+
- # wait_for_kv_layer_from_connector(layer_name)
683+
+ wait_for_kv_layer_from_connector(layer_name)
684+
685+
forward_context: ForwardContext = get_forward_context()
686+
attn_metadata = forward_context.attn_metadata
687+
@@ -493,7 +493,7 @@ def unified_ascend_attention_with_output(
688+
689+
if attn_metadata is not None:
690+
maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
691+
- # maybe_save_kv_layer_to_connector(layer_name, kv_cache)
692+
+ maybe_save_kv_layer_to_connector(layer_name, kv_cache)
693+
return
694+
695+
def wait_for_kv_layer_from_connector(layer_name: str):
696+
--
697+
2.43.0
472698

ucm/integration/vllm/ucm_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _generate_storage_backends(
229229
return backends
230230

231231
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
232-
if os.getenv("VLLM_HASH_ATTENTION") == "1":
232+
if os.getenv("VLLM_HASH_ATTENTION", "0") == "1":
233233
for layer_name, value in kv_caches.items():
234234
kv_cache, k_hash = value
235235
self.kv_caches[layer_name] = kv_cache

ucm/sparse/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def create_sparse_method(
4848
UcmSparseFactory.register_sparse_method("ESA", "ucm.sparse.esa.esa", "ESA")
4949
UcmSparseFactory.register_sparse_method("KvComp", "ucm.sparse.kvcomp.kvcomp", "KvComp")
5050
UcmSparseFactory.register_sparse_method(
51-
"GSA", "ucm.sparse.kvcomp.kvcomp_hbm", "KvCompOnDevice"
51+
"KvCompOnDevice", "ucm.sparse.kvcomp.kvcomp_hbm", "KvCompOnDevice"
5252
)
5353
# UcmSparseFactory.register_sparse_method("GSA", "ucm.sparse.gsa.gsa", "GSA")
5454
UcmSparseFactory.register_sparse_method(

ucm/sparse/kvcomp/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,9 @@ else()
4747
message(STATUS "Skipping numactl build...")
4848
endif()
4949

50-
add_subdirectory(hash_retrieval)
51-
add_subdirectory(ham_dist)
50+
string(TOLOWER "$ENV{PLATFORM}" PLATFORM_ENV)
51+
if(PLATFORM_ENV STREQUAL "cuda")
52+
message(STATUS "Building kvcomp for CUDA...")
53+
add_subdirectory(hash_retrieval)
54+
add_subdirectory(ham_dist)
55+
endif()

ucm/sparse/kvcomp/hamming_topk.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22

3-
from ucm.sparse.kvcomp.ham_dist import hamming
3+
if hasattr(torch, "cuda") and torch.cuda.is_available():
4+
from ucm.sparse.kvcomp.ham_dist import hamming
45

56

67
@torch.compile()

0 commit comments

Comments
 (0)