@@ -270,19 +270,20 @@ def attention_begin(
270270 )
271271 else : # NPU
272272 if not self .is_tensor_computed :
273- decode_req_ids = torch .nonzero (
274- self .decode_mask , as_tuple = False
275- ).flatten ()
276- decode_req_ids_npu = torch .nonzero (
277- self .decode_mask_npu , as_tuple = False
278- ).flatten ()
279- batch_size_for_hamming = self .decode_mask .sum ().item ()
280- self .query_lens_device = attn_metadata .query_lens_device [decode_req_ids_npu ]
281- self .topk_for_hamming = self .topk_for_hamming_full [:batch_size_for_hamming ]
282- self .chunk_sizes_for_hamming = self .chunk_sizes_for_hamming_full [:batch_size_for_hamming ]
283- self .seq_lens_for_hamming = attn_metadata .seq_lens_device [decode_req_ids_npu ]
284- self .max_seq_len_for_hamming = torch .max (attn_metadata .seq_lens [decode_req_ids ]).item ()
285- self .is_tensor_computed = True
273+ if self .decode_mask .any (): # with at least one decode request
274+ decode_req_ids = torch .nonzero (
275+ self .decode_mask , as_tuple = False
276+ ).flatten ()
277+ decode_req_ids_npu = torch .nonzero (
278+ self .decode_mask_npu , as_tuple = False
279+ ).flatten ()
280+ batch_size_for_hamming = self .decode_mask .sum ().item ()
281+ self .query_lens_device = attn_metadata .query_lens_device [decode_req_ids_npu ]
282+ self .topk_for_hamming = self .topk_for_hamming_full [:batch_size_for_hamming ]
283+ self .chunk_sizes_for_hamming = self .chunk_sizes_for_hamming_full [:batch_size_for_hamming ]
284+ self .seq_lens_for_hamming = attn_metadata .seq_lens_device [decode_req_ids_npu ]
285+ self .max_seq_len_for_hamming = torch .max (attn_metadata .seq_lens [decode_req_ids ]).item ()
286+ self .is_tensor_computed = True
286287
287288 k_hash_compute = self .hash_encoder .compute_hash (key )
288289 k_hash_compute = k_hash_compute .transpose (0 ,1 ).reshape (- 1 , k_hash_compute .shape [- 1 ]).contiguous ()
0 commit comments