Skip to content

Commit f6f5b4b

Browse files
authored
Merge pull request #626 from ROCm/deepseek_085_mla_persistent
change mla itf into persistent version
2 parents 2380171 + ef23553 commit f6f5b4b

File tree

2 files changed

+98
-41
lines changed

2 files changed

+98
-41
lines changed

vllm/attention/ops/rocm_aiter_mla.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,17 @@ def aiter_mla_decode_fwd(
3939
logit_cap=0.0,
4040
num_kv_splits=None,
4141
num_kv_splits_indptr=None,
42-
batch_split_table=None,
43-
split_table=None,
44-
splits=None,
45-
q_rope=None,
46-
k_rope=None,
42+
work_indptr=None,
43+
work_info_set=None,
44+
reduce_indptr=None,
45+
reduce_final_map=None,
46+
reduce_partial_map=None,
47+
48+
# batch_split_table=None,
49+
# split_table=None,
50+
# splits=None,
51+
# q_rope=None,
52+
# k_rope=None,
4753
):
4854
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
4955
q,
@@ -59,11 +65,16 @@ def aiter_mla_decode_fwd(
5965
logit_cap,
6066
num_kv_splits,
6167
num_kv_splits_indptr,
62-
batch_split_table,
63-
split_table,
64-
splits,
65-
q_rope,
66-
k_rope,
68+
work_indptr,
69+
work_info_set,
70+
reduce_indptr,
71+
reduce_final_map,
72+
reduce_partial_map,
73+
# batch_split_table,
74+
# split_table,
75+
# splits,
76+
# q_rope,
77+
# k_rope,
6778
)
6879

6980

@@ -81,11 +92,17 @@ def mla_decode_fwd_impl(
8192
logit_cap: Optional[float] = 0.0,
8293
num_kv_splits: Optional[int] = 1,
8394
num_kv_splits_indptr: Optional[torch.Tensor] = None,
84-
batch_split_table: Optional[torch.Tensor] = None,
85-
split_table: Optional[torch.Tensor] = None,
86-
splits: Optional[torch.Tensor] = None,
87-
q_rope: Optional[torch.Tensor] = None,
88-
k_rope: Optional[torch.Tensor] = None,
95+
work_indptr: Optional[torch.Tensor] = None,
96+
work_info_set: Optional[torch.Tensor] = None,
97+
reduce_indptr: Optional[torch.Tensor] = None,
98+
reduce_final_map: Optional[torch.Tensor] = None,
99+
reduce_partial_map: Optional[torch.Tensor] = None,
100+
101+
# batch_split_table: Optional[torch.Tensor] = None,
102+
# split_table: Optional[torch.Tensor] = None,
103+
# splits: Optional[torch.Tensor] = None,
104+
# q_rope: Optional[torch.Tensor] = None,
105+
# k_rope: Optional[torch.Tensor] = None,
89106
) -> None:
90107
from aiter.mla import mla_decode_fwd_dispatch
91108

@@ -101,9 +118,14 @@ def mla_decode_fwd_impl(
101118
logit_cap=logit_cap,
102119
num_kv_splits=num_kv_splits,
103120
num_kv_splits_indptr=num_kv_splits_indptr,
104-
batch_split_table=batch_split_table,
105-
split_table=split_table,
106-
cu_num=splits,
121+
work_indptr=work_indptr,
122+
work_info_set=work_info_set,
123+
reduce_indptr=reduce_indptr,
124+
reduce_final_map=reduce_final_map,
125+
reduce_partial_map=reduce_partial_map,
126+
# batch_split_table=batch_split_table,
127+
# split_table=split_table,
128+
# cu_num=splits,
107129
)
108130

109131

@@ -121,11 +143,17 @@ def mla_decode_fwd_fake(
121143
logit_cap: Optional[float] = 0.0,
122144
num_kv_splits: Optional[int] = 1,
123145
num_kv_splits_indptr: Optional[torch.Tensor] = None,
124-
batch_split_table: Optional[torch.Tensor] = None,
125-
split_table: Optional[torch.Tensor] = None,
126-
splits: Optional[torch.Tensor] = None,
127-
q_rope: Optional[torch.Tensor] = None,
128-
k_rope: Optional[torch.Tensor] = None,
146+
work_indptr: Optional[torch.Tensor] = None,
147+
work_info_set: Optional[torch.Tensor] = None,
148+
reduce_indptr: Optional[torch.Tensor] = None,
149+
reduce_final_map: Optional[torch.Tensor] = None,
150+
reduce_partial_map: Optional[torch.Tensor] = None,
151+
152+
# batch_split_table: Optional[torch.Tensor] = None,
153+
# split_table: Optional[torch.Tensor] = None,
154+
# splits: Optional[torch.Tensor] = None,
155+
# q_rope: Optional[torch.Tensor] = None,
156+
# k_rope: Optional[torch.Tensor] = None,
129157
) -> None:
130158
pass
131159

vllm/v1/attention/backends/mla/rocm_aiter_mla.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
6060
batch_split_table: Optional[torch.Tensor] = None
6161
split_table: Optional[torch.Tensor] = None
6262
splits: Optional[torch.Tensor] = None
63+
work_indptr: Optional[torch.Tensor] = None
64+
work_info_set: Optional[torch.Tensor] = None
65+
reduce_indptr: Optional[torch.Tensor] = None
66+
reduce_final_map: Optional[torch.Tensor] = None
67+
reduce_partial_map: Optional[torch.Tensor] = None
6368

6469

6570
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
@@ -150,31 +155,48 @@ def _build_decode(self, input_positions: torch.Tensor,
150155
qo_indptr,
151156
) = self._get_paged_kv_tensors(block_table, seq_lens)
152157

153-
num_kv_splits_indptr = torch.empty(200, dtype=torch.int32, device=block_table.device)
154-
batch_split_table = torch.empty(480, dtype=torch.int32, device=block_table.device)
155-
split_table = torch.empty(480, dtype=torch.int32, device=block_table.device)
156-
splits = torch.empty(1, dtype=torch.int32, device=block_table.device)
158+
# num_kv_splits_indptr = torch.empty(200, dtype=torch.int32, device=block_table.device)
159+
# batch_split_table = torch.empty(480, dtype=torch.int32, device=block_table.device)
160+
# split_table = torch.empty(480, dtype=torch.int32, device=block_table.device)
161+
# splits = torch.empty(1, dtype=torch.int32, device=block_table.device)
157162

158163
import aiter
159164
max_seqlen_qo = 1
165+
num_kv_splits_indptr = None
166+
# work_indptr = None
167+
# work_info_set = None
168+
# reduce_indptr = None
169+
# reduce_final_map = None
170+
# reduce_partial_map = None
171+
172+
work_indptr = torch.empty([81], dtype=torch.int32, device="cuda")
173+
work_info_set = torch.empty([batch_size + 80, 8], dtype=torch.int32, device="cuda")
174+
reduce_indptr = torch.empty([batch_size + 1], dtype=torch.int32, device="cuda")
175+
reduce_final_map = torch.empty([batch_size, 2], dtype=torch.int32, device="cuda")
176+
reduce_partial_map = torch.empty([batch_size], dtype=torch.int32, device="cuda")
160177

161178
if max_seqlen_qo == 1 or paged_kv_indptr[-1] < 16 * 128:
162-
num_kv_splits_indptr = None
163179
batch_split_table = None
164180
split_table = None
165181
splits = None
166182
else:
167-
aiter.get_mla_metadata_impl(paged_kv_indptr, num_kv_splits_indptr, batch_split_table, split_table, splits)
183+
# aiter.get_mla_metadata_impl(paged_kv_indptr, num_kv_splits_indptr, batch_split_table, split_table, splits)
168184
# if get gpu hang, please use cpu metadata as following:
169185
# num_kv_splits_indptr = torch.empty(200, dtype=torch.int32, device=block_table.device)
170186
# kv_seq_les = torch.empty(200, dtype=torch.int32, device=block_table.device)
171187
# aiter.mla.get_meta_param_balanced(paged_kv_indptr, num_kv_splits_indptr, batch_split_table, split_table, kv_seq_les, splits)
172-
173-
# double check
174-
#if num_kv_splits_indptr[0] == -1:
175-
# num_kv_splits_indptr=None
176-
# batch_split_table=None
177-
# split_table=None
188+
aiter.get_mla_metadata_v1(
189+
qo_indptr,
190+
paged_kv_indptr,
191+
16, # nhead // nhead_kv,
192+
1, # nhead_kv,
193+
True,
194+
work_info_set,
195+
work_indptr,
196+
reduce_indptr,
197+
reduce_final_map,
198+
reduce_partial_map,
199+
)
178200

179201
attn_metadata = AiterMLADecodeMetadata(
180202
input_positions=input_positions,
@@ -184,9 +206,11 @@ def _build_decode(self, input_positions: torch.Tensor,
184206
paged_kv_indices=paged_kv_indices,
185207
paged_kv_last_page_len=paged_last_page_len,
186208
num_kv_splits_indptr=num_kv_splits_indptr,
187-
batch_split_table=batch_split_table,
188-
split_table=split_table,
189-
splits=splits,
209+
work_indptr=work_indptr,
210+
work_info_set=work_info_set,
211+
reduce_indptr=reduce_indptr,
212+
reduce_final_map=reduce_final_map,
213+
reduce_partial_map=reduce_partial_map,
190214
qo_indptr=qo_indptr)
191215

192216
return attn_metadata
@@ -279,9 +303,14 @@ def _forward_decode(
279303
max_seqlen_qo, self.scale,
280304
True, 0.0, 1,
281305
attn_metadata.decode.num_kv_splits_indptr,
282-
attn_metadata.decode.batch_split_table,
283-
attn_metadata.decode.split_table,
284-
attn_metadata.decode.splits,
306+
attn_metadata.decode.work_indptr,
307+
attn_metadata.decode.work_info_set,
308+
attn_metadata.decode.reduce_indptr,
309+
attn_metadata.decode.reduce_final_map,
310+
attn_metadata.decode.reduce_partial_map,
311+
# attn_metadata.decode.batch_split_table,
312+
# attn_metadata.decode.split_table,
313+
# attn_metadata.decode.splits,
285314
)
286315

287316
return self._v_up_proj_and_o_proj(o)

0 commit comments

Comments
 (0)