@@ -60,6 +60,11 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
6060 batch_split_table : Optional [torch .Tensor ] = None
6161 split_table : Optional [torch .Tensor ] = None
6262 splits : Optional [torch .Tensor ] = None
63+ work_indptr : Optional [torch .Tensor ] = None
64+ work_info_set : Optional [torch .Tensor ] = None
65+ reduce_indptr : Optional [torch .Tensor ] = None
66+ reduce_final_map : Optional [torch .Tensor ] = None
67+ reduce_partial_map : Optional [torch .Tensor ] = None
6368
6469
6570class AiterMLAMetadata (MLACommonMetadata [AiterMLADecodeMetadata ]):
@@ -150,31 +155,48 @@ def _build_decode(self, input_positions: torch.Tensor,
150155 qo_indptr ,
151156 ) = self ._get_paged_kv_tensors (block_table , seq_lens )
152157
153- num_kv_splits_indptr = torch .empty (200 , dtype = torch .int32 , device = block_table .device )
154- batch_split_table = torch .empty (480 , dtype = torch .int32 , device = block_table .device )
155- split_table = torch .empty (480 , dtype = torch .int32 , device = block_table .device )
156- splits = torch .empty (1 , dtype = torch .int32 , device = block_table .device )
158+ # num_kv_splits_indptr = torch.empty(200, dtype=torch.int32, device=block_table.device)
159+ # batch_split_table = torch.empty(480, dtype=torch.int32, device=block_table.device)
160+ # split_table = torch.empty(480, dtype=torch.int32, device=block_table.device)
161+ # splits = torch.empty(1, dtype=torch.int32, device=block_table.device)
157162
158163 import aiter
159164 max_seqlen_qo = 1
165+ num_kv_splits_indptr = None
166+ # work_indptr = None
167+ # work_info_set = None
168+ # reduce_indptr = None
169+ # reduce_final_map = None
170+ # reduce_partial_map = None
171+
172+ work_indptr = torch .empty ([81 ], dtype = torch .int32 , device = "cuda" )
173+ work_info_set = torch .empty ([batch_size + 80 , 8 ], dtype = torch .int32 , device = "cuda" )
174+ reduce_indptr = torch .empty ([batch_size + 1 ], dtype = torch .int32 , device = "cuda" )
175+ reduce_final_map = torch .empty ([batch_size , 2 ], dtype = torch .int32 , device = "cuda" )
176+ reduce_partial_map = torch .empty ([batch_size ], dtype = torch .int32 , device = "cuda" )
160177
161178 if max_seqlen_qo == 1 or paged_kv_indptr [- 1 ] < 16 * 128 :
162- num_kv_splits_indptr = None
163179 batch_split_table = None
164180 split_table = None
165181 splits = None
166182 else :
167- aiter .get_mla_metadata_impl (paged_kv_indptr , num_kv_splits_indptr , batch_split_table , split_table , splits )
183+ # aiter.get_mla_metadata_impl(paged_kv_indptr, num_kv_splits_indptr, batch_split_table, split_table, splits)
168184 # if get gpu hang, please use cpu metadata as following:
169185 # num_kv_splits_indptr = torch.empty(200, dtype=torch.int32, device=block_table.device)
170186 # kv_seq_les = torch.empty(200, dtype=torch.int32, device=block_table.device)
171187 # aiter.mla.get_meta_param_balanced(paged_kv_indptr, num_kv_splits_indptr, batch_split_table, split_table, kv_seq_les, splits)
172-
173- # double check
174- #if num_kv_splits_indptr[0] == -1:
175- # num_kv_splits_indptr=None
176- # batch_split_table=None
177- # split_table=None
188+ aiter .get_mla_metadata_v1 (
189+ qo_indptr ,
190+ paged_kv_indptr ,
191+ 16 , # nhead // nhead_kv,
192+ 1 , # nhead_kv,
193+ True ,
194+ work_info_set ,
195+ work_indptr ,
196+ reduce_indptr ,
197+ reduce_final_map ,
198+ reduce_partial_map ,
199+ )
178200
179201 attn_metadata = AiterMLADecodeMetadata (
180202 input_positions = input_positions ,
@@ -184,9 +206,11 @@ def _build_decode(self, input_positions: torch.Tensor,
184206 paged_kv_indices = paged_kv_indices ,
185207 paged_kv_last_page_len = paged_last_page_len ,
186208 num_kv_splits_indptr = num_kv_splits_indptr ,
187- batch_split_table = batch_split_table ,
188- split_table = split_table ,
189- splits = splits ,
209+ work_indptr = work_indptr ,
210+ work_info_set = work_info_set ,
211+ reduce_indptr = reduce_indptr ,
212+ reduce_final_map = reduce_final_map ,
213+ reduce_partial_map = reduce_partial_map ,
190214 qo_indptr = qo_indptr )
191215
192216 return attn_metadata
@@ -279,9 +303,14 @@ def _forward_decode(
279303 max_seqlen_qo , self .scale ,
280304 True , 0.0 , 1 ,
281305 attn_metadata .decode .num_kv_splits_indptr ,
282- attn_metadata .decode .batch_split_table ,
283- attn_metadata .decode .split_table ,
284- attn_metadata .decode .splits ,
306+ attn_metadata .decode .work_indptr ,
307+ attn_metadata .decode .work_info_set ,
308+ attn_metadata .decode .reduce_indptr ,
309+ attn_metadata .decode .reduce_final_map ,
310+ attn_metadata .decode .reduce_partial_map ,
311+ # attn_metadata.decode.batch_split_table,
312+ # attn_metadata.decode.split_table,
313+ # attn_metadata.decode.splits,
285314 )
286315
287316 return self ._v_up_proj_and_o_proj (o )
0 commit comments