Skip to content

Commit 7fafa2a

Browse files
committed
fix the shape bug of hamming output
1 parent c765625 commit 7fafa2a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ucm/sparse/kvcomp/kvcomp_hbm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
224224
[self.max_batch_size], dtype=torch.int32, device=self.device
225225
)
226226
self.hamming_output = torch.zeros(
227-
[self.max_batch_size, self.hash_topk_tokens // self.block_size],
227+
[self.max_batch_size, self.num_key_heads, self.hash_topk_tokens // self.block_size],
228228
dtype=torch.int32,
229229
device=self.device,
230230
)
@@ -495,9 +495,9 @@ def attention_begin(
495495
block_table_decode,
496496
self.hamming_output[: len(decode_req_ids)],
497497
)
498-
topk = self.hamming_output.shape[1]
498+
topk = self.hamming_output.shape[-1]
499499
attn_metadata.block_table[decode_req_ids, :topk] = (
500-
self.hamming_output[: len(decode_req_ids)]
500+
self.hamming_output[: len(decode_req_ids), 0, :]
501501
)
502502
attn_metadata.block_table[decode_req_ids, topk:] = 0
503503

0 commit comments

Comments
 (0)