Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 51 additions & 23 deletions vllm/attention/ops/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,17 @@ def aiter_mla_decode_fwd(
logit_cap=0.0,
num_kv_splits=None,
num_kv_splits_indptr=None,
batch_split_table=None,
split_table=None,
splits=None,
q_rope=None,
k_rope=None,
work_indptr=None,
work_info_set=None,
reduce_indptr=None,
reduce_final_map=None,
reduce_partial_map=None,

# batch_split_table=None,
# split_table=None,
# splits=None,
# q_rope=None,
# k_rope=None,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
Expand All @@ -59,11 +65,16 @@ def aiter_mla_decode_fwd(
logit_cap,
num_kv_splits,
num_kv_splits_indptr,
batch_split_table,
split_table,
splits,
q_rope,
k_rope,
work_indptr,
work_info_set,
reduce_indptr,
reduce_final_map,
reduce_partial_map,
# batch_split_table,
# split_table,
# splits,
# q_rope,
# k_rope,
)


Expand All @@ -81,11 +92,17 @@ def mla_decode_fwd_impl(
logit_cap: Optional[float] = 0.0,
num_kv_splits: Optional[int] = 1,
num_kv_splits_indptr: Optional[torch.Tensor] = None,
batch_split_table: Optional[torch.Tensor] = None,
split_table: Optional[torch.Tensor] = None,
splits: Optional[torch.Tensor] = None,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
work_indptr: Optional[torch.Tensor] = None,
work_info_set: Optional[torch.Tensor] = None,
reduce_indptr: Optional[torch.Tensor] = None,
reduce_final_map: Optional[torch.Tensor] = None,
reduce_partial_map: Optional[torch.Tensor] = None,

# batch_split_table: Optional[torch.Tensor] = None,
# split_table: Optional[torch.Tensor] = None,
# splits: Optional[torch.Tensor] = None,
# q_rope: Optional[torch.Tensor] = None,
# k_rope: Optional[torch.Tensor] = None,
) -> None:
from aiter.mla import mla_decode_fwd_dispatch

Expand All @@ -101,9 +118,14 @@ def mla_decode_fwd_impl(
logit_cap=logit_cap,
num_kv_splits=num_kv_splits,
num_kv_splits_indptr=num_kv_splits_indptr,
batch_split_table=batch_split_table,
split_table=split_table,
cu_num=splits,
work_indptr=work_indptr,
work_info_set=work_info_set,
reduce_indptr=reduce_indptr,
reduce_final_map=reduce_final_map,
reduce_partial_map=reduce_partial_map,
# batch_split_table=batch_split_table,
# split_table=split_table,
# cu_num=splits,
)


Expand All @@ -121,11 +143,17 @@ def mla_decode_fwd_fake(
logit_cap: Optional[float] = 0.0,
num_kv_splits: Optional[int] = 1,
num_kv_splits_indptr: Optional[torch.Tensor] = None,
batch_split_table: Optional[torch.Tensor] = None,
split_table: Optional[torch.Tensor] = None,
splits: Optional[torch.Tensor] = None,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
work_indptr: Optional[torch.Tensor] = None,
work_info_set: Optional[torch.Tensor] = None,
reduce_indptr: Optional[torch.Tensor] = None,
reduce_final_map: Optional[torch.Tensor] = None,
reduce_partial_map: Optional[torch.Tensor] = None,

# batch_split_table: Optional[torch.Tensor] = None,
# split_table: Optional[torch.Tensor] = None,
# splits: Optional[torch.Tensor] = None,
# q_rope: Optional[torch.Tensor] = None,
# k_rope: Optional[torch.Tensor] = None,
) -> None:
pass

Expand Down
65 changes: 47 additions & 18 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
batch_split_table: Optional[torch.Tensor] = None
split_table: Optional[torch.Tensor] = None
splits: Optional[torch.Tensor] = None
work_indptr: Optional[torch.Tensor] = None
work_info_set: Optional[torch.Tensor] = None
reduce_indptr: Optional[torch.Tensor] = None
reduce_final_map: Optional[torch.Tensor] = None
reduce_partial_map: Optional[torch.Tensor] = None


class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
Expand Down Expand Up @@ -150,31 +155,48 @@ def _build_decode(self, input_positions: torch.Tensor,
qo_indptr,
) = self._get_paged_kv_tensors(block_table, seq_lens)

num_kv_splits_indptr = torch.empty(200, dtype=torch.int32, device=block_table.device)
batch_split_table = torch.empty(480, dtype=torch.int32, device=block_table.device)
split_table = torch.empty(480, dtype=torch.int32, device=block_table.device)
splits = torch.empty(1, dtype=torch.int32, device=block_table.device)
# num_kv_splits_indptr = torch.empty(200, dtype=torch.int32, device=block_table.device)
# batch_split_table = torch.empty(480, dtype=torch.int32, device=block_table.device)
# split_table = torch.empty(480, dtype=torch.int32, device=block_table.device)
# splits = torch.empty(1, dtype=torch.int32, device=block_table.device)

import aiter
max_seqlen_qo = 1
num_kv_splits_indptr = None
# work_indptr = None
# work_info_set = None
# reduce_indptr = None
# reduce_final_map = None
# reduce_partial_map = None

work_indptr = torch.empty([81], dtype=torch.int32, device="cuda")
work_info_set = torch.empty([batch_size + 80, 8], dtype=torch.int32, device="cuda")
reduce_indptr = torch.empty([batch_size + 1], dtype=torch.int32, device="cuda")
reduce_final_map = torch.empty([batch_size, 2], dtype=torch.int32, device="cuda")
reduce_partial_map = torch.empty([batch_size], dtype=torch.int32, device="cuda")

if max_seqlen_qo == 1 or paged_kv_indptr[-1] < 16 * 128:
num_kv_splits_indptr = None
batch_split_table = None
split_table = None
splits = None
else:
aiter.get_mla_metadata_impl(paged_kv_indptr, num_kv_splits_indptr, batch_split_table, split_table, splits)
# aiter.get_mla_metadata_impl(paged_kv_indptr, num_kv_splits_indptr, batch_split_table, split_table, splits)
# if get gpu hang, please use cpu metadata as following:
# num_kv_splits_indptr = torch.empty(200, dtype=torch.int32, device=block_table.device)
# kv_seq_les = torch.empty(200, dtype=torch.int32, device=block_table.device)
# aiter.mla.get_meta_param_balanced(paged_kv_indptr, num_kv_splits_indptr, batch_split_table, split_table, kv_seq_les, splits)

# double check
#if num_kv_splits_indptr[0] == -1:
# num_kv_splits_indptr=None
# batch_split_table=None
# split_table=None
aiter.get_mla_metadata_v1(
qo_indptr,
paged_kv_indptr,
16, # nhead // nhead_kv,
1, # nhead_kv,
True,
work_info_set,
work_indptr,
reduce_indptr,
reduce_final_map,
reduce_partial_map,
)

attn_metadata = AiterMLADecodeMetadata(
input_positions=input_positions,
Expand All @@ -184,9 +206,11 @@ def _build_decode(self, input_positions: torch.Tensor,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_last_page_len,
num_kv_splits_indptr=num_kv_splits_indptr,
batch_split_table=batch_split_table,
split_table=split_table,
splits=splits,
work_indptr=work_indptr,
work_info_set=work_info_set,
reduce_indptr=reduce_indptr,
reduce_final_map=reduce_final_map,
reduce_partial_map=reduce_partial_map,
qo_indptr=qo_indptr)

return attn_metadata
Expand Down Expand Up @@ -279,9 +303,14 @@ def _forward_decode(
max_seqlen_qo, self.scale,
True, 0.0, 1,
attn_metadata.decode.num_kv_splits_indptr,
attn_metadata.decode.batch_split_table,
attn_metadata.decode.split_table,
attn_metadata.decode.splits,
attn_metadata.decode.work_indptr,
attn_metadata.decode.work_info_set,
attn_metadata.decode.reduce_indptr,
attn_metadata.decode.reduce_final_map,
attn_metadata.decode.reduce_partial_map,
# attn_metadata.decode.batch_split_table,
# attn_metadata.decode.split_table,
# attn_metadata.decode.splits,
)

return self._v_up_proj_and_o_proj(o)
Loading