Skip to content

Commit 627f499

Browse files
committed
Enhance tensor computation logic by adding a check for decode requests before processing in KvCompOnDevice class
1 parent 60db08d commit 627f499

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

ucm/sparse/kvcomp/kvcomp_hbm.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -270,19 +270,20 @@ def attention_begin(
270270
)
271271
else: # NPU
272272
if not self.is_tensor_computed:
273-
decode_req_ids = torch.nonzero(
274-
self.decode_mask, as_tuple=False
275-
).flatten()
276-
decode_req_ids_npu = torch.nonzero(
277-
self.decode_mask_npu, as_tuple=False
278-
).flatten()
279-
batch_size_for_hamming = self.decode_mask.sum().item()
280-
self.query_lens_device = attn_metadata.query_lens_device[decode_req_ids_npu]
281-
self.topk_for_hamming = self.topk_for_hamming_full[:batch_size_for_hamming]
282-
self.chunk_sizes_for_hamming = self.chunk_sizes_for_hamming_full[:batch_size_for_hamming]
283-
self.seq_lens_for_hamming = attn_metadata.seq_lens_device[decode_req_ids_npu]
284-
self.max_seq_len_for_hamming = torch.max(attn_metadata.seq_lens[decode_req_ids]).item()
285-
self.is_tensor_computed = True
273+
if self.decode_mask.any(): # with at least one decode request
274+
decode_req_ids = torch.nonzero(
275+
self.decode_mask, as_tuple=False
276+
).flatten()
277+
decode_req_ids_npu = torch.nonzero(
278+
self.decode_mask_npu, as_tuple=False
279+
).flatten()
280+
batch_size_for_hamming = self.decode_mask.sum().item()
281+
self.query_lens_device = attn_metadata.query_lens_device[decode_req_ids_npu]
282+
self.topk_for_hamming = self.topk_for_hamming_full[:batch_size_for_hamming]
283+
self.chunk_sizes_for_hamming = self.chunk_sizes_for_hamming_full[:batch_size_for_hamming]
284+
self.seq_lens_for_hamming = attn_metadata.seq_lens_device[decode_req_ids_npu]
285+
self.max_seq_len_for_hamming = torch.max(attn_metadata.seq_lens[decode_req_ids]).item()
286+
self.is_tensor_computed = True
286287

287288
k_hash_compute = self.hash_encoder.compute_hash(key)
288289
k_hash_compute = k_hash_compute.transpose(0,1).reshape(-1, k_hash_compute.shape[-1]).contiguous()

0 commit comments

Comments
 (0)