Skip to content

Commit 59b51d2

Browse files
authored
fix rocketkv interface. (#47)
Signed-off-by: Fanrong Li <[email protected]>
1 parent 2963c18 commit 59b51d2

File tree

2 files changed

+80
-81
lines changed

2 files changed

+80
-81
lines changed

tensorrt_llm/_torch/attention_backend/sparse/rocket.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,82 @@ def prepare(self):
8989
non_blocking=True)
9090

9191

92+
@torch.compile(dynamic=True)
93+
def convert_token_to_page_sparse_indices(
94+
sparse_attn_indices: torch.Tensor, sparse_attn_offsets: torch.Tensor,
95+
metadata: 'TrtllmAttentionMetadata'
96+
) -> Tuple[torch.Tensor, torch.Tensor]:
97+
"""
98+
Convert token-based sparse attention indices to page-based sparse attention indices.
99+
100+
Args:
101+
sparse_attn_indices: Token-based indices with shape [num_tokens, num_kv_heads]
102+
sparse_attn_offsets: Offsets with shape [batch_size+1] indicating token boundaries for each batch
103+
metadata: Attention metadata containing tokens_per_block (page_size)
104+
105+
Returns:
106+
Tuple of (page_indices, page_offsets):
107+
- page_indices: Page-based indices with shape [num_pages, num_kv_heads]
108+
- page_offsets: Updated offsets with shape [batch_size+1] indicating page boundaries for each batch
109+
110+
Example:
111+
If sparse_attn_indices first dimension is [1, 30, 67] and page_size=32,
112+
the result will be [0, 2] (token 1 -> page 0, token 30 -> page 0, token 67 -> page 2)
113+
"""
114+
page_size = metadata.tokens_per_block
115+
batch_size = sparse_attn_offsets.size(0) - 1
116+
num_kv_heads = sparse_attn_indices.size(1)
117+
118+
# Convert token indices to page indices
119+
page_indices = sparse_attn_indices // page_size
120+
121+
# Process each batch and each kv_head separately to remove duplicates
122+
new_page_indices_list = []
123+
new_offsets = torch.zeros_like(sparse_attn_offsets)
124+
125+
current_offset = 0
126+
for batch_idx in range(batch_size):
127+
start_idx = sparse_attn_offsets[batch_idx]
128+
end_idx = sparse_attn_offsets[batch_idx + 1]
129+
130+
if start_idx >= end_idx:
131+
# Empty batch
132+
new_offsets[batch_idx + 1] = current_offset
133+
continue
134+
135+
batch_page_indices = page_indices[
136+
start_idx:end_idx] # [num_tokens_in_batch, num_kv_heads]
137+
138+
# For each kv_head, remove duplicates while preserving order
139+
batch_unique_pages = []
140+
for head_idx in range(num_kv_heads):
141+
head_pages = batch_page_indices[:, head_idx]
142+
unique_pages = torch.unique(head_pages, sorted=False)
143+
batch_unique_pages.append(unique_pages)
144+
145+
# Find the maximum length among all heads for this batch
146+
max_len = max(pages.size(0) for pages in batch_unique_pages)
147+
148+
if max_len > 0:
149+
batch_result = torch.full((max_len, num_kv_heads),
150+
fill_value=-1,
151+
dtype=page_indices.dtype,
152+
device=page_indices.device)
153+
154+
for head_idx in range(num_kv_heads):
155+
unique_pages = batch_unique_pages[head_idx]
156+
batch_result[:unique_pages.size(0), head_idx] = unique_pages
157+
158+
new_page_indices_list.append(batch_result)
159+
current_offset += max_len
160+
161+
new_offsets[batch_idx + 1] = current_offset
162+
163+
new_page_indices = torch.cat(new_page_indices_list, dim=0)
164+
165+
return new_page_indices, new_offsets
166+
167+
92168
class RocketTrtllmAttention(TrtllmAttention):
93169
Metadata = RocketTrtllmAttentionMetadata
94170

@@ -198,6 +274,10 @@ def sparse_attn_predict(
198274
dim=0).to(torch.int32)
199275
sparse_attn_offsets = torch.tensor(sparse_attn_offsets,
200276
dtype=torch.int32).to(q.device)
277+
sparse_attn_indices, sparse_attn_offsets = convert_token_to_page_sparse_indices(
278+
sparse_attn_indices, sparse_attn_offsets, metadata)
279+
sparse_attn_indices = sparse_attn_indices.transpose(0,
280+
1).contiguous()
201281
return sparse_attn_indices, sparse_attn_offsets
202282

203283
def sparse_kv_predict(

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -544,82 +544,6 @@ def is_nvfp4_output_kernel_available(
544544
)
545545

546546

547-
@torch.compile(dynamic=True)
548-
def convert_token_to_page_sparse_indices(
549-
sparse_attn_indices: torch.Tensor, sparse_attn_offsets: torch.Tensor,
550-
metadata: 'TrtllmAttentionMetadata'
551-
) -> Tuple[torch.Tensor, torch.Tensor]:
552-
"""
553-
Convert token-based sparse attention indices to page-based sparse attention indices.
554-
555-
Args:
556-
sparse_attn_indices: Token-based indices with shape [num_tokens, num_kv_heads]
557-
sparse_attn_offsets: Offsets with shape [batch_size+1] indicating token boundaries for each batch
558-
metadata: Attention metadata containing tokens_per_block (page_size)
559-
560-
Returns:
561-
Tuple of (page_indices, page_offsets):
562-
- page_indices: Page-based indices with shape [num_pages, num_kv_heads]
563-
- page_offsets: Updated offsets with shape [batch_size+1] indicating page boundaries for each batch
564-
565-
Example:
566-
If sparse_attn_indices first dimension is [1, 30, 67] and page_size=32,
567-
the result will be [0, 2] (token 1 -> page 0, token 30 -> page 0, token 67 -> page 2)
568-
"""
569-
page_size = metadata.tokens_per_block
570-
batch_size = sparse_attn_offsets.size(0) - 1
571-
num_kv_heads = sparse_attn_indices.size(1)
572-
573-
# Convert token indices to page indices
574-
page_indices = sparse_attn_indices // page_size
575-
576-
# Process each batch and each kv_head separately to remove duplicates
577-
new_page_indices_list = []
578-
new_offsets = torch.zeros_like(sparse_attn_offsets)
579-
580-
current_offset = 0
581-
for batch_idx in range(batch_size):
582-
start_idx = sparse_attn_offsets[batch_idx]
583-
end_idx = sparse_attn_offsets[batch_idx + 1]
584-
585-
if start_idx >= end_idx:
586-
# Empty batch
587-
new_offsets[batch_idx + 1] = current_offset
588-
continue
589-
590-
batch_page_indices = page_indices[
591-
start_idx:end_idx] # [num_tokens_in_batch, num_kv_heads]
592-
593-
# For each kv_head, remove duplicates while preserving order
594-
batch_unique_pages = []
595-
for head_idx in range(num_kv_heads):
596-
head_pages = batch_page_indices[:, head_idx]
597-
unique_pages = torch.unique(head_pages, sorted=False)
598-
batch_unique_pages.append(unique_pages)
599-
600-
# Find the maximum length among all heads for this batch
601-
max_len = max(pages.size(0) for pages in batch_unique_pages)
602-
603-
if max_len > 0:
604-
batch_result = torch.full((max_len, num_kv_heads),
605-
fill_value=-1,
606-
dtype=page_indices.dtype,
607-
device=page_indices.device)
608-
609-
for head_idx in range(num_kv_heads):
610-
unique_pages = batch_unique_pages[head_idx]
611-
batch_result[:unique_pages.size(0), head_idx] = unique_pages
612-
613-
new_page_indices_list.append(batch_result)
614-
current_offset += max_len
615-
616-
new_offsets[batch_idx + 1] = current_offset
617-
618-
new_page_indices = torch.cat(new_page_indices_list, dim=0)
619-
620-
return new_page_indices, new_offsets
621-
622-
623547
@dataclass(kw_only=True)
624548
class TrtllmAttentionMetadata(AttentionMetadata):
625549
workspace: Optional[torch.Tensor] = None
@@ -1346,11 +1270,6 @@ def forward(
13461270
q, k, metadata)
13471271
sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(
13481272
q, k, metadata)
1349-
if sparse_attn_indices is not None:
1350-
sparse_attn_indices, sparse_attn_offsets = convert_token_to_page_sparse_indices(
1351-
sparse_attn_indices, sparse_attn_offsets, metadata)
1352-
sparse_attn_indices = sparse_attn_indices.transpose(
1353-
0, 1).contiguous()
13541273

13551274
self.wrapper.plan(
13561275
layer_idx=self.get_local_layer_idx(metadata),

0 commit comments

Comments
 (0)