diff --git a/tensorrt_llm/_torch/attention_backend/sparse/rocket.py b/tensorrt_llm/_torch/attention_backend/sparse/rocket.py index e352dc81094..b06f5bb9ae0 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/rocket.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/rocket.py @@ -125,30 +125,104 @@ def __init__( self.kernel_size = sparse_attention_config.kernel_size self.page_size = sparse_attention_config.page_size - def sparse_attention_predict( - self, q: torch.Tensor, k: torch.Tensor, - metadata: RocketTrtllmAttentionMetadata + def sparse_attn_predict( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + metadata: TrtllmAttentionMetadata, + **kwargs, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """ - Predict sparse KV indices and sparse attention indices for the input sequence. + Predict sparse attention indices. + For RocketKV: + - Generation phase: predict RocketKV sparse attention indices + + Returns: + - sparse_attn_indices: [total_selected_indices, num_kv_heads] + - sparse_attn_offsets: [batch_size + 1] with cumulative indices count + """ + if k is None: + q, k, _ = q.split([ + self.num_heads * self.head_dim, self.num_kv_heads * + self.head_dim, self.num_kv_heads * self.head_dim + ], + dim=-1) + + num_contexts = metadata.num_contexts + num_generations = metadata.num_generations + seq_lens = metadata.seq_lens + seq_lens_kv = metadata.seq_lens_kv if metadata.seq_lens_kv is not None else seq_lens + past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq + + sparse_attn_indices = [] + sparse_attn_offsets = [0] + + q_offset = 0 + k_offset = 0 + + for i in range(num_contexts + num_generations): + seq_len = seq_lens[i].item() + seq_len_kv = seq_lens_kv[i].item() + + if seq_len <= 0 or seq_len_kv <= 0: + assert False, "Invalid sequence length" + + single_q = q[q_offset:q_offset + seq_len] + single_k = k[k_offset:k_offset + seq_len_kv] + + single_q = single_q.view(1, seq_len, self.num_heads, + self.head_dim).transpose(1, 2) + single_k = single_k.view(1, seq_len_kv, self.num_kv_heads, + self.head_dim) + + past_seen_token = past_seen_tokens[i] + # Generation phase: RocketKV sparse attention indices + if i >= num_contexts: + _sparse_attn_indices = self._rocketkv_selection( + single_q, single_k, past_seen_token, metadata, i) + if _sparse_attn_indices is not None: + sparse_attn_indices.append( + _sparse_attn_indices.squeeze(0)) # [topk, num_kv_heads] + sparse_attn_offsets.append(sparse_attn_offsets[-1] + + _sparse_attn_indices.size(1)) + else: + sparse_attn_offsets.append(sparse_attn_offsets[-1]) + + q_offset += seq_len + k_offset += seq_len_kv + + if len(sparse_attn_indices) == 0: + sparse_attn_indices, sparse_attn_offsets = None, None + else: + sparse_attn_indices = torch.cat(sparse_attn_indices, + dim=0).to(torch.int32) + sparse_attn_offsets = torch.tensor(sparse_attn_offsets, + dtype=torch.int32).to(q.device) + return sparse_attn_indices, sparse_attn_offsets + + def sparse_kv_predict( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + metadata: TrtllmAttentionMetadata, + **kwargs, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Predict sparse kv indices. For RocketKV: - Context phase: predict SnapKV sparse kv indices - - Generation phase: predict RocketKV sparse attention indices Returns: - Tuple of (flattened_indices, batch_offsets) - flattened_indices: [total_selected_indices, num_kv_heads] - batch_offsets: [batch_size + 1] with cumulative indices count """ - q, k, _ = q.split([ - self.num_heads * self.head_dim, self.num_kv_heads * self.head_dim, - self.num_kv_heads * self.head_dim - ], - dim=-1) - - if k is None or metadata is None: - return None, None + if k is None: + q, k, _ = q.split([ + self.num_heads * self.head_dim, self.num_kv_heads * + self.head_dim, self.num_kv_heads * self.head_dim + ], + dim=-1) num_contexts = metadata.num_contexts num_generations = metadata.num_generations @@ -157,9 +231,7 @@ def sparse_attention_predict( past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq sparse_kv_indices = [] - sparse_attn_indices = [] sparse_kv_offsets = [0] - sparse_attn_offsets = [0] q_offset = 0 k_offset = 0 @@ -191,17 +263,6 @@ def sparse_attention_predict( _sparse_kv_indices.size(1)) else: sparse_kv_offsets.append(sparse_kv_offsets[-1]) - else: - # Generation phase: RocketKV sparse attention indices - _sparse_attn_indices = self._rocketkv_selection( - single_q, single_k, past_seen_token, metadata, i) - if _sparse_attn_indices is not None: - sparse_attn_indices.append( - _sparse_attn_indices.squeeze(0)) # [topk, num_kv_heads] - sparse_attn_offsets.append(sparse_attn_offsets[-1] + - _sparse_attn_indices.size(1)) - else: - sparse_attn_offsets.append(sparse_attn_offsets[-1]) q_offset += seq_len k_offset += seq_len_kv @@ -211,17 +272,10 @@ def sparse_attention_predict( else: sparse_kv_indices = torch.cat(sparse_kv_indices, dim=0).to(torch.int32) + sparse_kv_indices = sparse_kv_indices.transpose(0, 1).contiguous() sparse_kv_offsets = torch.tensor(sparse_kv_offsets, dtype=torch.int32).to(q.device) - if len(sparse_attn_indices) == 0: - sparse_attn_indices, sparse_attn_offsets = None, None - else: - sparse_attn_indices = torch.cat(sparse_attn_indices, - dim=0).to(torch.int32) - sparse_attn_offsets = torch.tensor(sparse_attn_offsets, - dtype=torch.int32).to(q.device) - - return sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets + return sparse_kv_indices, sparse_kv_offsets def _get_snapkv_indices(self, q: Tensor, k: Tensor, past_seen_token: int, metadata: RocketTrtllmAttentionMetadata, @@ -802,7 +856,9 @@ def get_kt_block_offsets(self, request_ids: List[int]) -> torch.Tensor: block_ids = self.paged_kt_block_ids[request_ids[i]] block_num = len(block_ids) kt_block_offsets[i, 0:block_num].copy_( - self.base_kt_block_offsets[block_ids]) + self.base_kt_block_offsets[torch.tensor(block_ids, + dtype=torch.int32, + device="cpu")]) return kt_block_offsets def prepare_resources(self, scheduled_batch): diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 30055b82df7..56bb148c6bd 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1342,11 +1342,10 @@ def forward( sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = None, None, None, None if self.sparse_attention_config is not None: - sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = self.sparse_attention_predict( + sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict( + q, k, metadata) + sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict( q, k, metadata) - if sparse_kv_indices is not None: - sparse_kv_indices = sparse_kv_indices.transpose(0, - 1).contiguous() if sparse_attn_indices is not None: sparse_attn_indices, sparse_attn_offsets = convert_token_to_page_sparse_indices( sparse_attn_indices, sparse_attn_offsets, metadata) @@ -1618,10 +1617,26 @@ def merge_attention_for_mla( self.mla_params.v_head_dim, ) - def sparse_attention_predict( - self, q: torch.Tensor, k: torch.Tensor, - metadata: TrtllmAttentionMetadata + def sparse_attn_predict( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + metadata: TrtllmAttentionMetadata, + **kwargs, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Predict sparse attn indices. It's implemented in the derived class. + """ + raise NotImplementedError + + def sparse_kv_predict( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + metadata: TrtllmAttentionMetadata, + **kwargs, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """ - Predict sparse kv indices and sparse attn indices for the input sequence. It's implemented in the derived class. + Predict sparse kv indices. It's implemented in the derived class. """ + raise NotImplementedError