Skip to content

Commit ea76b36

Browse files
[feat] KvCompOnDevice: per-KV-head Top-K for Qwen (#589)
1 parent 1e1c3e6 commit ea76b36

File tree

4 files changed

+43
-18
lines changed

4 files changed

+43
-18
lines changed

ucm/sparse/kvcomp/ham_dist/paged_ham_dist_mla.cu

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,21 @@
5959
{ \
6060
__VA_ARGS__ \
6161
} \
62+
} else if ((val) == 2) { \
63+
constexpr int NumKVHead = 2; \
64+
{ \
65+
__VA_ARGS__ \
66+
} \
67+
} else if ((val) == 4) { \
68+
constexpr int NumKVHead = 4; \
69+
{ \
70+
__VA_ARGS__ \
71+
} \
72+
} else if ((val) == 8) { \
73+
constexpr int NumKVHead = 8; \
74+
{ \
75+
__VA_ARGS__ \
76+
} \
6277
} else { \
6378
LOG(FATAL) << "NumKVHead is not support"; \
6479
} \
@@ -295,7 +310,7 @@ torch::Tensor HammingScoreContiCUDA(torch::Tensor& key_codes,
295310
bool is_block_mode = block_table_opt.has_value();
296311

297312
int32_t bsz = query_code.size(0);
298-
int32_t num_kv_head = is_block_mode ? key_codes.size(1) : key_codes.size(2);
313+
int32_t num_kv_head = key_codes.size(2);
299314
int32_t num_chunk = key_codes.size(3);
300315

301316
int32_t num_head = query_code.size(2);
@@ -309,7 +324,7 @@ torch::Tensor HammingScoreContiCUDA(torch::Tensor& key_codes,
309324

310325
if(is_block_mode) {
311326
int32_t num_blocks = key_codes.size(0);
312-
int32_t block_size = key_codes.size(2);
327+
int32_t block_size = key_codes.size(1);
313328
const auto& block_table = block_table_opt.value(); // *block_table_opt;
314329
int32_t max_num_block_per_seq = block_table.size(1);
315330
TORCH_CHECK(bsz == block_table.size(0), "batch size mismatch between query_code and block_table");

ucm/sparse/kvcomp/hamming_topk.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ def cuda_hamming_topk(
2020
topk_token,
2121
sink_token,
2222
recent_token,
23+
is_mla,
2324
):
2425
q_hash = q_hash.view(torch.int32)
2526
k_hash = k_hash.view(torch.int32)
26-
assert k_hash.shape[1] == 1
27+
# assert k_hash.shape[1] == 1
2728
# assert k_hash.shape[-1] == 18 and q_hash.shape[-1] == 18
28-
block_size = k_hash.shape[2]
29+
block_size = k_hash.shape[1]
2930
assert topk_token % block_size == 0
3031
assert recent_token > 0 and topk_token > (sink_token + recent_token)
3132
max_seqlen = block_size * block_table.shape[1]
32-
3333
output = hamming.hamming_score(
3434
k_hash,
3535
q_hash,
@@ -40,17 +40,23 @@ def cuda_hamming_topk(
4040
recent_token,
4141
)
4242

43-
block_output = torch.min(
44-
output.view(output.shape[0], output.shape[-1] // block_size, block_size), dim=-1
45-
)[0]
43+
k_blocks = topk_token // block_size
44+
B, Hk, S = output.shape
45+
num_blocks = S // block_size
4646

47-
ind = torch.topk(block_output, k=(topk_token // block_size), dim=-1, largest=False)[
48-
1
49-
]
50-
ind = torch.sort(ind, dim=-1, descending=False)[0]
47+
# block_output: [B, Hk, num_blocks]
48+
block_output = output.view(B, Hk, num_blocks, block_size).amin(dim=-1)
5149

52-
new_block_table = torch.gather(block_table, dim=-1, index=ind)
53-
return new_block_table
50+
if is_mla:
51+
block_score = block_output[:, 0, :]
52+
ind = torch.topk(block_score, k=k_blocks, dim=-1, largest=False).indices
53+
ind = ind.sort(dim=-1).values
54+
return torch.gather(block_table, dim=-1, index=ind)
55+
56+
block_score = block_output.amin(dim=1) # [B, num_blocks]
57+
ind = torch.topk(block_score, k=k_blocks, dim=-1, largest=False).indices
58+
ind = ind.sort(dim=-1).values
59+
return torch.gather(block_table, dim=-1, index=ind)
5460

5561

5662
def fake_hamming_topk(
@@ -66,7 +72,7 @@ def fake_hamming_topk(
6672
k_hash = k_hash.view(torch.int32)
6773
assert k_hash.shape[1] == 1
6874
assert k_hash.shape[-1] == 18 and q_hash.shape[-1] == 18
69-
block_size = k_hash.shape[2]
75+
block_size = k_hash.shape[1]
7076
assert topk_token % block_size == 0
7177
assert recent_token > 0 and topk_token > (sink_token + recent_token)
7278
max_seqlen = block_size * block_table.shape[1]

ucm/sparse/kvcomp/kvcomp_hbm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def kvcomp_config_path_for_model(vllm_config) -> str:
3434
rel = "ucm/sparse/kvcomp/configs/kvcomp_deepseek_r1_awq_config.json"
3535
elif "qwen3" in model and "32b" in model:
3636
rel = "ucm/sparse/kvcomp/configs/kvcomp_qwen3_32B_config.json"
37+
elif "deepseek" in model and "v2" in model:
38+
rel = "ucm/sparse/kvcomp/configs/kvcomp_deepseek_v2_lite_config.json"
3739
else:
3840
raise ValueError(f"[KvCompOnDevice] Unsupported model for kvcomp: {model}")
3941

@@ -270,12 +272,13 @@ def attention_begin(
270272
topk_token = self.hash_topk_tokens
271273
block_table = cuda_hamming_topk(
272274
q_hash.unsqueeze(1),
273-
k_hash.unsqueeze(1),
275+
k_hash.unsqueeze(2),
274276
attn_metadata.decode.block_table,
275277
attn_metadata.decode.seq_lens,
276278
topk_token=topk_token,
277279
sink_token=64,
278280
recent_token=512,
281+
is_mla=self.is_mla,
279282
)
280283
attn_metadata.decode.topk_block_table = block_table
281284

@@ -324,12 +327,13 @@ def attention_begin(
324327
)
325328
block_table_decode = cuda_hamming_topk(
326329
q_hash.unsqueeze(1),
327-
k_hash.unsqueeze(1),
330+
k_hash,
328331
block_table_decode,
329332
seq_len_decode,
330333
topk_token=topk_token,
331334
sink_token=64,
332335
recent_token=512,
336+
is_mla=self.is_mla,
333337
)
334338
# update topk_block_table
335339
topk = block_table_decode.shape[1]

ucm/sparse/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def ensure_ucm_sparse_initialized(
4141

4242
# Check if UCM sparse is enabled
4343
ucm_config = Config(vllm_config.kv_transfer_config)
44-
ucm_sparse_config = ucm_config.get_config().get("ucm_sparse_method")
44+
ucm_sparse_config = ucm_config.get_config().get("ucm_sparse_config")
4545
if not ucm_sparse_config:
4646
return
4747

0 commit comments

Comments
 (0)