@@ -89,6 +89,82 @@ def prepare(self):
8989 non_blocking = True )
9090
9191
92+ @torch .compile (dynamic = True )
93+ def convert_token_to_page_sparse_indices (
94+ sparse_attn_indices : torch .Tensor , sparse_attn_offsets : torch .Tensor ,
95+ metadata : 'TrtllmAttentionMetadata'
96+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
97+ """
98+ Convert token-based sparse attention indices to page-based sparse attention indices.
99+
100+ Args:
101+ sparse_attn_indices: Token-based indices with shape [num_tokens, num_kv_heads]
102+ sparse_attn_offsets: Offsets with shape [batch_size+1] indicating token boundaries for each batch
103+ metadata: Attention metadata containing tokens_per_block (page_size)
104+
105+ Returns:
106+ Tuple of (page_indices, page_offsets):
107+ - page_indices: Page-based indices with shape [num_pages, num_kv_heads]
108+ - page_offsets: Updated offsets with shape [batch_size+1] indicating page boundaries for each batch
109+
110+ Example:
111+ If sparse_attn_indices first dimension is [1, 30, 67] and page_size=32,
112+ the result will be [0, 2] (token 1 -> page 0, token 30 -> page 0, token 67 -> page 2)
113+ """
114+ page_size = metadata .tokens_per_block
115+ batch_size = sparse_attn_offsets .size (0 ) - 1
116+ num_kv_heads = sparse_attn_indices .size (1 )
117+
118+ # Convert token indices to page indices
119+ page_indices = sparse_attn_indices // page_size
120+
121+ # Process each batch and each kv_head separately to remove duplicates
122+ new_page_indices_list = []
123+ new_offsets = torch .zeros_like (sparse_attn_offsets )
124+
125+ current_offset = 0
126+ for batch_idx in range (batch_size ):
127+ start_idx = sparse_attn_offsets [batch_idx ]
128+ end_idx = sparse_attn_offsets [batch_idx + 1 ]
129+
130+ if start_idx >= end_idx :
131+ # Empty batch
132+ new_offsets [batch_idx + 1 ] = current_offset
133+ continue
134+
135+ batch_page_indices = page_indices [
136+ start_idx :end_idx ] # [num_tokens_in_batch, num_kv_heads]
137+
138+ # For each kv_head, remove duplicates while preserving order
139+ batch_unique_pages = []
140+ for head_idx in range (num_kv_heads ):
141+ head_pages = batch_page_indices [:, head_idx ]
142+ unique_pages = torch .unique (head_pages , sorted = False )
143+ batch_unique_pages .append (unique_pages )
144+
145+ # Find the maximum length among all heads for this batch
146+ max_len = max (pages .size (0 ) for pages in batch_unique_pages )
147+
148+ if max_len > 0 :
149+ batch_result = torch .full ((max_len , num_kv_heads ),
150+ fill_value = - 1 ,
151+ dtype = page_indices .dtype ,
152+ device = page_indices .device )
153+
154+ for head_idx in range (num_kv_heads ):
155+ unique_pages = batch_unique_pages [head_idx ]
156+ batch_result [:unique_pages .size (0 ), head_idx ] = unique_pages
157+
158+ new_page_indices_list .append (batch_result )
159+ current_offset += max_len
160+
161+ new_offsets [batch_idx + 1 ] = current_offset
162+
163+ new_page_indices = torch .cat (new_page_indices_list , dim = 0 )
164+
165+ return new_page_indices , new_offsets
166+
167+
92168class RocketTrtllmAttention (TrtllmAttention ):
93169 Metadata = RocketTrtllmAttentionMetadata
94170
@@ -198,6 +274,10 @@ def sparse_attn_predict(
198274 dim = 0 ).to (torch .int32 )
199275 sparse_attn_offsets = torch .tensor (sparse_attn_offsets ,
200276 dtype = torch .int32 ).to (q .device )
277+ sparse_attn_indices , sparse_attn_offsets = convert_token_to_page_sparse_indices (
278+ sparse_attn_indices , sparse_attn_offsets , metadata )
279+ sparse_attn_indices = sparse_attn_indices .transpose (0 ,
280+ 1 ).contiguous ()
201281 return sparse_attn_indices , sparse_attn_offsets
202282
203283 def sparse_kv_predict (
0 commit comments