Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 93 additions & 37 deletions tensorrt_llm/_torch/attention_backend/sparse/rocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,30 +125,104 @@ def __init__(
self.kernel_size = sparse_attention_config.kernel_size
self.page_size = sparse_attention_config.page_size

def sparse_attention_predict(
self, q: torch.Tensor, k: torch.Tensor,
metadata: RocketTrtllmAttentionMetadata
def sparse_attn_predict(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
metadata: TrtllmAttentionMetadata,
**kwargs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Predict sparse KV indices and sparse attention indices for the input sequence.
Predict sparse attention indices.
For RocketKV:
- Generation phase: predict RocketKV sparse attention indices

Returns:
- sparse_attn_indices: [total_selected_indices, num_kv_heads]
- sparse_attn_offsets: [batch_size + 1] with cumulative indices count
"""
if k is None:
q, k, _ = q.split([
self.num_heads * self.head_dim, self.num_kv_heads *
self.head_dim, self.num_kv_heads * self.head_dim
],
dim=-1)

num_contexts = metadata.num_contexts
num_generations = metadata.num_generations
seq_lens = metadata.seq_lens
seq_lens_kv = metadata.seq_lens_kv if metadata.seq_lens_kv is not None else seq_lens
past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq

sparse_attn_indices = []
sparse_attn_offsets = [0]

q_offset = 0
k_offset = 0

for i in range(num_contexts + num_generations):
seq_len = seq_lens[i].item()
seq_len_kv = seq_lens_kv[i].item()

if seq_len <= 0 or seq_len_kv <= 0:
assert False, "Invalid sequence length"

single_q = q[q_offset:q_offset + seq_len]
single_k = k[k_offset:k_offset + seq_len_kv]

single_q = single_q.view(1, seq_len, self.num_heads,
self.head_dim).transpose(1, 2)
single_k = single_k.view(1, seq_len_kv, self.num_kv_heads,
self.head_dim)

past_seen_token = past_seen_tokens[i]
# Generation phase: RocketKV sparse attention indices
if i >= num_contexts:
_sparse_attn_indices = self._rocketkv_selection(
single_q, single_k, past_seen_token, metadata, i)
if _sparse_attn_indices is not None:
sparse_attn_indices.append(
_sparse_attn_indices.squeeze(0)) # [topk, num_kv_heads]
sparse_attn_offsets.append(sparse_attn_offsets[-1] +
_sparse_attn_indices.size(1))
else:
sparse_attn_offsets.append(sparse_attn_offsets[-1])

q_offset += seq_len
k_offset += seq_len_kv

if len(sparse_attn_indices) == 0:
sparse_attn_indices, sparse_attn_offsets = None, None
else:
sparse_attn_indices = torch.cat(sparse_attn_indices,
dim=0).to(torch.int32)
sparse_attn_offsets = torch.tensor(sparse_attn_offsets,
dtype=torch.int32).to(q.device)
return sparse_attn_indices, sparse_attn_offsets

def sparse_kv_predict(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
metadata: TrtllmAttentionMetadata,
**kwargs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Predict sparse kv indices.

For RocketKV:
- Context phase: predict SnapKV sparse kv indices
- Generation phase: predict RocketKV sparse attention indices

Returns:
Tuple of (flattened_indices, batch_offsets)
- flattened_indices: [total_selected_indices, num_kv_heads]
- batch_offsets: [batch_size + 1] with cumulative indices count
"""
q, k, _ = q.split([
self.num_heads * self.head_dim, self.num_kv_heads * self.head_dim,
self.num_kv_heads * self.head_dim
],
dim=-1)

if k is None or metadata is None:
return None, None
if k is None:
q, k, _ = q.split([
self.num_heads * self.head_dim, self.num_kv_heads *
self.head_dim, self.num_kv_heads * self.head_dim
],
dim=-1)

num_contexts = metadata.num_contexts
num_generations = metadata.num_generations
Expand All @@ -157,9 +231,7 @@ def sparse_attention_predict(
past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq

sparse_kv_indices = []
sparse_attn_indices = []
sparse_kv_offsets = [0]
sparse_attn_offsets = [0]

q_offset = 0
k_offset = 0
Expand Down Expand Up @@ -191,17 +263,6 @@ def sparse_attention_predict(
_sparse_kv_indices.size(1))
else:
sparse_kv_offsets.append(sparse_kv_offsets[-1])
else:
# Generation phase: RocketKV sparse attention indices
_sparse_attn_indices = self._rocketkv_selection(
single_q, single_k, past_seen_token, metadata, i)
if _sparse_attn_indices is not None:
sparse_attn_indices.append(
_sparse_attn_indices.squeeze(0)) # [topk, num_kv_heads]
sparse_attn_offsets.append(sparse_attn_offsets[-1] +
_sparse_attn_indices.size(1))
else:
sparse_attn_offsets.append(sparse_attn_offsets[-1])

q_offset += seq_len
k_offset += seq_len_kv
Expand All @@ -211,17 +272,10 @@ def sparse_attention_predict(
else:
sparse_kv_indices = torch.cat(sparse_kv_indices,
dim=0).to(torch.int32)
sparse_kv_indices = sparse_kv_indices.transpose(0, 1).contiguous()
sparse_kv_offsets = torch.tensor(sparse_kv_offsets,
dtype=torch.int32).to(q.device)
if len(sparse_attn_indices) == 0:
sparse_attn_indices, sparse_attn_offsets = None, None
else:
sparse_attn_indices = torch.cat(sparse_attn_indices,
dim=0).to(torch.int32)
sparse_attn_offsets = torch.tensor(sparse_attn_offsets,
dtype=torch.int32).to(q.device)

return sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets
return sparse_kv_indices, sparse_kv_offsets

def _get_snapkv_indices(self, q: Tensor, k: Tensor, past_seen_token: int,
metadata: RocketTrtllmAttentionMetadata,
Expand Down Expand Up @@ -802,7 +856,9 @@ def get_kt_block_offsets(self, request_ids: List[int]) -> torch.Tensor:
block_ids = self.paged_kt_block_ids[request_ids[i]]
block_num = len(block_ids)
kt_block_offsets[i, 0:block_num].copy_(
self.base_kt_block_offsets[block_ids])
self.base_kt_block_offsets[torch.tensor(block_ids,
dtype=torch.int32,
device="cpu")])
return kt_block_offsets

def prepare_resources(self, scheduled_batch):
Expand Down
31 changes: 23 additions & 8 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,11 +1342,10 @@ def forward(

sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = None, None, None, None
if self.sparse_attention_config is not None:
sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = self.sparse_attention_predict(
sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict(
q, k, metadata)
sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(
q, k, metadata)
if sparse_kv_indices is not None:
sparse_kv_indices = sparse_kv_indices.transpose(0,
1).contiguous()
if sparse_attn_indices is not None:
sparse_attn_indices, sparse_attn_offsets = convert_token_to_page_sparse_indices(
sparse_attn_indices, sparse_attn_offsets, metadata)
Expand Down Expand Up @@ -1618,10 +1617,26 @@ def merge_attention_for_mla(
self.mla_params.v_head_dim,
)

def sparse_attention_predict(
self, q: torch.Tensor, k: torch.Tensor,
metadata: TrtllmAttentionMetadata
def sparse_attn_predict(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
metadata: TrtllmAttentionMetadata,
**kwargs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Predict sparse attn indices. It's implemented in the derived class.
"""
raise NotImplementedError

def sparse_kv_predict(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
metadata: TrtllmAttentionMetadata,
**kwargs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Predict sparse kv indices and sparse attn indices for the input sequence. It's implemented in the derived class.
Predict sparse kv indices. It's implemented in the derived class.
"""
raise NotImplementedError
Loading