Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions vllm/attention/ops/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def aiter_mla_decode_fwd(
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
logit_cap: float = 0.0,
num_kv_splits: int | None = None,
num_kv_splits_indptr: torch.Tensor | None = None,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
Expand All @@ -45,6 +47,8 @@ def aiter_mla_decode_fwd(
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap,
num_kv_splits=num_kv_splits,
num_kv_splits_indptr=num_kv_splits_indptr,
)


Expand All @@ -59,6 +63,8 @@ def mla_decode_fwd_impl(
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
num_kv_splits: int | None = None,
num_kv_splits_indptr: torch.Tensor | None = None,
) -> None:
from aiter.mla import mla_decode_fwd

Expand All @@ -73,6 +79,8 @@ def mla_decode_fwd_impl(
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
num_kv_splits=num_kv_splits,
num_kv_splits_indptr=num_kv_splits_indptr,
)


Expand All @@ -87,6 +95,8 @@ def mla_decode_fwd_fake(
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
num_kv_splits: int | None = None,
num_kv_splits_indptr: torch.Tensor | None = None,
) -> None:
pass

Expand Down
25 changes: 25 additions & 0 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
paged_kv_last_page_len: torch.Tensor | None = None
# The query indptr, shape : [num_decode + 1]
qo_indptr: torch.Tensor | None = None
# The num_kv_splits indptr, shape : [num_decode + 1]
num_kv_splits_indptr: torch.Tensor | None = None
# Number of kv splits
num_kv_splits: int | None = 16


class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
Expand All @@ -80,6 +84,7 @@ def __init__(
)

self.compilation_config = vllm_config.compilation_config
self.num_kv_splits = 16
max_num_pages_per_req = cdiv(
vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size
)
Expand Down Expand Up @@ -109,6 +114,13 @@ def __init__(
self.qo_indptr = torch.arange(
0, max_num_reqs + 1, dtype=torch.int32, device=device
)
self.num_kv_splits_indptr = torch.arange(
0,
(max_num_reqs + 1) * self.num_kv_splits,
self.num_kv_splits,
dtype=torch.int32,
device=device,
)

def _build_decode(
self,
Expand Down Expand Up @@ -185,12 +197,21 @@ def _build_decode(
self.paged_kv_last_page_len[num_reqs:].fill_(1)
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]

num_kv_splits_indptr = self.num_kv_splits_indptr[: num_reqs + 1]

qo_indptr = self.qo_indptr[: 1 + num_reqs]

else:
qo_indptr = torch.arange(
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
)
num_kv_splits_indptr = torch.arange(
0,
(num_reqs + 1) * self.num_kv_splits,
step=self.num_kv_splits,
dtype=torch.int32,
device=device,
)

attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
Expand All @@ -199,6 +220,8 @@ def _build_decode(
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
num_kv_splits_indptr=num_kv_splits_indptr,
num_kv_splits=self.num_kv_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)

Expand Down Expand Up @@ -298,6 +321,8 @@ def _forward_decode(
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len,
num_kv_splits=attn_metadata.decode.num_kv_splits,
num_kv_splits_indptr=attn_metadata.decode.num_kv_splits_indptr,
)

return o, None