@@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]:
5353
5454@dataclass
5555class AiterMLAMetadata (MLACommonMetadata ):
56- # The following 4 tensors are for current version of AITER MLA
56+ # The following 5 tensors are for current version of AITER MLA
5757 block_table_bound : Optional [torch .Tensor ] = None
5858 # The indptr of the paged kv cache, shape: [batch_size + 1]
5959 paged_kv_indptr : Optional [torch .Tensor ] = None
@@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata):
6363 # the paged kv cache, shape: [batch_size]
6464 paged_kv_last_page_lens : Optional [torch .Tensor ] = None
6565
66+ # This is just to make new AITER MLA API work
67+ # -- MTP support is not added yet.
68+ qo_indptr : Optional [torch .Tensor ] = None
69+
6670 @property
6771 def prefill_metadata (self ):
6872 prefill_metadata = super ().prefill_metadata
@@ -74,6 +78,7 @@ def prefill_metadata(self):
7478 prefill_metadata \
7579 .paged_kv_last_page_lens = self .paged_kv_last_page_lens
7680 prefill_metadata .block_table_bound = self .block_table_bound
81+ prefill_metadata .qo_indptr = self .qo_indptr
7782
7883 # update the cache
7984 self ._cached_prefill_metadata = self .__class__ (
@@ -93,6 +98,7 @@ def decode_metadata(self):
9398 decode_metadata \
9499 .paged_kv_last_page_lens = self .paged_kv_last_page_lens
95100 decode_metadata .block_table_bound = self .block_table_bound
101+ decode_metadata .qo_indptr = self .qo_indptr
96102
97103 # update the cache
98104 self ._cached_decode_metadata = self .__class__ (
@@ -136,6 +142,7 @@ def prepare(self):
136142 self .paged_kv_indptr : list [int ] = [0 ]
137143 self .paged_kv_last_page_lens : list [int ] = []
138144 self .total_blocks = 0
145+ self .qo_indptr : list [int ] = [0 ]
139146
140147 def _add_seq_group (self , inter_data , chunked_prefill_enabled : bool ,
141148 prefix_cache_hit : bool ):
@@ -210,6 +217,7 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
210217 self .paged_kv_indices .extend (block_table [:block_table_bound ])
211218 self .paged_kv_indptr .append (self .paged_kv_indptr [- 1 ] +
212219 block_table_bound )
220+ self .qo_indptr .append (self .qo_indptr [- 1 ] + 1 )
213221
214222 last_page_len = seq_len % self .block_size
215223 if last_page_len == 0 :
@@ -228,6 +236,8 @@ def build(self, seq_lens: list[int], query_lens: list[int],
228236 self .paged_kv_indptr .extend ([last_paged_kv_indptr ] *
229237 cuda_graph_pad_size )
230238 self .paged_kv_last_page_lens .extend ([0 ] * cuda_graph_pad_size )
239+ last_qo_indptr = self .qo_indptr [- 1 ]
240+ self .qo_indptr .extend ([last_qo_indptr ] * cuda_graph_pad_size )
231241
232242 # For current version of AITER MLA
233243 if len (self .paged_kv_indptr ) > 0 :
@@ -247,16 +257,22 @@ def build(self, seq_lens: list[int], query_lens: list[int],
247257 1 ,
248258 device = device ,
249259 dtype = torch .int )
260+
261+ qo_indptr = torch .tensor (self .qo_indptr ,
262+ device = device ,
263+ dtype = torch .int )
250264 else :
251265 paged_kv_indices_tensor = None
252266 paged_kv_indptr_tensor = None
253267 paged_kv_last_page_lens_tensor = None
254268 block_table_bound_tensor = None
269+ qo_indptr = None
255270
256271 metadata .paged_kv_indptr = paged_kv_indptr_tensor
257272 metadata .paged_kv_indices = paged_kv_indices_tensor
258273 metadata .paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
259274 metadata .block_table_bound = block_table_bound_tensor
275+ metadata .qo_indptr = qo_indptr
260276
261277 return metadata
262278
@@ -265,21 +281,25 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
265281
266282 @contextmanager
267283 def graph_capture (self , max_batch_size : int ):
268- kv_indices , kv_indptr , last_page_lens = get_aiter_mla_metadata (
269- max_batch_size = max_batch_size ,
270- block_size = self .runner .block_size ,
271- max_block_per_batch = self .runner .get_max_block_per_batch (),
272- device = self .runner .device )
284+ kv_indices , kv_indptr , last_page_lens , qo_indptr = \
285+ get_aiter_mla_metadata (
286+ max_batch_size = max_batch_size ,
287+ block_size = self .runner .block_size ,
288+ max_block_per_batch = \
289+ self .runner .get_max_block_per_batch (),
290+ device = self .runner .device )
273291 self ._paged_kv_indices_tensor = kv_indices
274292 self ._paged_kv_indptr_tensor = kv_indptr
275293 self ._paged_kv_last_page_lens_tensor = last_page_lens
294+ self ._qo_indptr_tensor = qo_indptr
276295
277296 with super ().graph_capture (max_batch_size ):
278297 yield
279298
280299 del self ._paged_kv_indices_tensor
281300 del self ._paged_kv_indptr_tensor
282301 del self ._paged_kv_last_page_lens_tensor
302+ del self ._qo_indptr_tensor
283303
284304 def graph_capture_get_metadata_for_batch (
285305 self ,
@@ -293,10 +313,12 @@ def graph_capture_get_metadata_for_batch(
293313 paged_kv_indices = self ._paged_kv_indices_tensor
294314 paged_kv_last_page_lens = self ._paged_kv_last_page_lens_tensor [:
295315 batch_size ]
316+ qo_indptr = self ._qo_indptr_tensor [:batch_size + 1 ]
296317
297318 metadata .paged_kv_indptr = paged_kv_indptr
298319 metadata .paged_kv_indices = paged_kv_indices
299320 metadata .paged_kv_last_page_lens = paged_kv_last_page_lens
321+ metadata .qo_indptr = qo_indptr
300322
301323 return metadata
302324
@@ -313,6 +335,7 @@ def get_graph_input_buffers(self,
313335 input_buffers [
314336 "paged_kv_last_page_lens" ] = attn_metadata .\
315337 decode_metadata .paged_kv_last_page_lens
338+ input_buffers ['qo_indptr' ] = attn_metadata .qo_indptr
316339
317340 return input_buffers
318341
@@ -332,6 +355,8 @@ def prepare_graph_input_buffers(self,
332355 input_buffers ["paged_kv_last_page_lens" ].copy_ (
333356 attn_metadata .decode_metadata .paged_kv_last_page_lens ,
334357 non_blocking = True )
358+ input_buffers ["qo_indptr" ].copy_ (
359+ attn_metadata .decode_metadata .qo_indptr , non_blocking = True )
335360
336361
337362class AiterMLAImpl (MLACommonImpl [AiterMLAMetadata ]):
@@ -372,11 +397,9 @@ def _flash_attn_varlen_diff_headdims(
372397 softmax_scale : float , return_softmax_lse : bool ,
373398 ** kwargs ) -> Union [tuple [torch .Tensor , ...], torch .Tensor ]:
374399 output = self .flash_attn_varlen_func (
375- q = q ,
376- k = k ,
377- v = v ,
378- softmax_scale = softmax_scale ,
379- return_lse = return_softmax_lse ,
400+ q ,
401+ k ,
402+ v ,
380403 ** kwargs ,
381404 )
382405
@@ -396,7 +419,7 @@ def _forward_decode(
396419 B = q_nope .shape [0 ]
397420
398421 q = torch .cat ([q_nope , q_pe ], dim = - 1 )
399- o = torch .zeros (B ,
422+ o = torch .empty (B ,
400423 self .num_heads ,
401424 self .kv_lora_rank ,
402425 dtype = q .dtype ,
@@ -405,6 +428,8 @@ def _forward_decode(
405428 kv_buffer = kv_c_and_k_pe_cache .unsqueeze (2 )
406429
407430 aiter_mla_decode_fwd (q , kv_buffer , o , self .scale ,
431+ attn_metadata .qo_indptr ,
432+ attn_metadata .max_query_len ,
408433 attn_metadata .paged_kv_indptr ,
409434 attn_metadata .paged_kv_indices ,
410435 attn_metadata .paged_kv_last_page_lens )
0 commit comments