Skip to content

Commit b94e207

Browse files
authored
Merge branch 'ROCm:main' into main
2 parents 3f8fe9b + ca0de0e commit b94e207

File tree

17 files changed

+271
-744
lines changed

17 files changed

+271
-744
lines changed

.github/workflows/aiter-test.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,6 @@ jobs:
9292
git submodule update --init --recursive --depth 1 --jobs 4
9393
fi
9494
95-
- name: Clean up Rocm processes
96-
run: |
97-
./.github/scripts/clean_up_rocm.sh
98-
9995
- name: Run the container
10096
run: |
10197
set -ex
@@ -158,11 +154,6 @@ jobs:
158154
if: always()
159155
run: |
160156
docker rm -f aiter_test || true
161-
162-
- name: Clean up Rocm processes
163-
if: always()
164-
run: |
165-
./.github/scripts/clean_up_rocm.sh
166157
167158
multi-gpu:
168159
name: Multi-GPU Tests (8 GPU)

.github/workflows/sglang_downstream.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ jobs:
9797
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
9898
-v "${GITHUB_WORKSPACE:-$PWD}/sglang:/sglang-checkout" \
9999
--ipc=host --group-add video \
100+
--network=host \
100101
--shm-size 32g \
101102
--cap-add=SYS_PTRACE \
102103
-e HF_TOKEN="${HF_TOKEN:-}" \

.github/workflows/triton-test.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@ jobs:
5151
fetch-depth: 1
5252
submodules: 'recursive'
5353

54-
- name: Clean up Rocm processes
55-
run: |
56-
./.github/scripts/clean_up_rocm.sh
57-
5854
- name: Run the container
5955
run: |
6056
set -ex
@@ -148,8 +144,3 @@ jobs:
148144
if: always()
149145
run: |
150146
docker rm -f triton_test || true
151-
152-
- name: Clean up Rocm processes
153-
if: always()
154-
run: |
155-
./.github/scripts/clean_up_rocm.sh

aiter/mla.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ def mla_decode_fwd(
150150
kv_indices,
151151
kv_last_page_lens,
152152
max_seqlen_q,
153-
page_size=1,
154-
nhead_kv=1,
155153
sm_scale=None, # 1.0 / (qk_head_dim**0.5)
156154
logit_cap=0.0,
157155
num_kv_splits=None, # for experts only!!!
@@ -170,11 +168,7 @@ def mla_decode_fwd(
170168
):
171169
device = q.device
172170
assert logit_cap <= 0, f"{logit_cap=} is not support yet"
173-
if kv_buffer.dtype != torch.uint8:
174-
_, _, _, qk_head_dim = kv_buffer.shape
175-
else:
176-
_, _, qk_head_dim = q.shape
177-
171+
num_page, page_size, nhead_kv, qk_head_dim = kv_buffer.shape
178172
if sm_scale is None:
179173
sm_scale = 1.0 / (qk_head_dim**0.5)
180174

@@ -233,8 +227,6 @@ def mla_decode_fwd(
233227
None,
234228
None,
235229
max_seqlen_q,
236-
page_size,
237-
nhead_kv,
238230
sm_scale,
239231
logits,
240232
attn_lse,
@@ -327,8 +319,6 @@ def mla_decode_fwd(
327319
work_indptr,
328320
work_info_set,
329321
max_seqlen_q,
330-
page_size,
331-
nhead_kv,
332322
sm_scale,
333323
logits,
334324
attn_lse,

aiter/ops/attention.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -566,8 +566,6 @@ def mla_decode_stage1_asm_fwd(
566566
work_indptr: Optional[torch.Tensor],
567567
work_info_set: Optional[torch.Tensor],
568568
max_seqlen_q: int,
569-
page_size: int,
570-
nhead_kv: int,
571569
softmax_scale: float,
572570
# [batch_size, num_kv_splits, num_heads, v_head_dim]
573571
splitData: torch.Tensor,
@@ -856,7 +854,6 @@ def get_mla_metadata_info_v1(
856854
def get_mla_metadata_v1(
857855
seqlens_qo_indptr: torch.Tensor,
858856
seqlens_kv_indptr: torch.Tensor,
859-
kv_last_page_lens: torch.Tensor,
860857
num_heads_per_head_k: int,
861858
num_heads_k: int,
862859
is_causal: bool,
@@ -866,7 +863,6 @@ def get_mla_metadata_v1(
866863
reduce_indptr: torch.Tensor,
867864
reduce_final_map: torch.Tensor,
868865
reduce_partial_map: torch.Tensor,
869-
page_size: int = 1,
870866
kv_granularity: int = 16,
871867
max_seqlen_qo: int = -1,
872868
uni_seqlen_qo: int = -1,
@@ -880,14 +876,12 @@ def get_mla_metadata_v1(
880876
"""
881877
Inputs:
882878
cumulated seqlens of q/o: (batch_size + 1), dtype torch.int32.
883-
cumulated page indices of k/v: (batch_size + 1), dtype torch.int32.
884-
Length of last page of k/v: (batch_size), dtype torch.int32.
879+
cumulated seqlens of k/v: (batch_size + 1), dtype torch.int32.
885880
num_heads_per_head_k: Equals to num_heads_q // num_heads_k.
886881
num_heads_k: num_heads_k.
887882
is_causal: Whether causal mask is enabled.
888883
Options: Detailed settings for spliting. All of them are optional.
889-
page_size: default=1. The size of a page.
890-
kv_granularity: default=16. The granularity on kv page nums when cutting batch.
884+
kv_granularity: default=16. The granularity on kv sequence length when cutting batch.
891885
max_seqlen_qo: default=-1. Used to check lds usage and save time. value less than 1 means unknown.
892886
uni_seqlen_qo: default=-1. Sequence length of qo is uniform across batches. value less than 1 means the
893887
length is not fixed.
@@ -905,11 +899,11 @@ def get_mla_metadata_v1(
905899
[2.2] q_start: (#work), The global index in seq where q/o starts. Use global index here can
906900
reduce memory access count in kernel.
907901
[2.3] q_end: (#work), The global index in seq where q/o ends (not included).
908-
[2.4] kv_start: (#work), The global index in page where k/v starts.
909-
[2.5] kv_end: (#work), The global index in page where k/v ends (not included). Note that
902+
[2.4] kv_start: (#work), The global index in seq where k/v starts.
903+
[2.5] kv_end: (#work), The global index in seq where k/v ends (not included). Note that
910904
this value indicates the end of last qo sequence if there are
911905
multiple qo sequences included in the current work and causal mask
912-
is enabled when page_size is 1.
906+
is enabled.
913907
[2.6] kv_offset: (#work), Remaining length in seq from kv_end to the end of current batch.
914908
[2.7] pad (#work, 1), Pad to 8 DWs.
915909
[3] reduce_indptr: (sum(qo_seqlen_blk_count) + 1),

csrc/include/attention_asm_mla.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ void mla_decode_stage1_asm_fwd(
1515
std::optional<torch::Tensor>& work_indptr, // metadata
1616
std::optional<torch::Tensor>& work_info_set, // [batch_size+1]
1717
int max_seqlen_q,
18-
int page_size,
19-
int nhead_kv,
2018
float softmax_scale,
2119
// following are output
2220
torch::Tensor& splitData, //[batch_size, num_kv_splits, num_heads, v_head_dim]

csrc/include/mla.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
#pragma once
55

@@ -37,7 +37,6 @@ static_assert(kSizeMlaPartialTileInfoInDw == 2);
3737

3838
void get_mla_metadata_v1(const torch::Tensor& seqlens_qo_indptr, // [batch size + 1]
3939
const torch::Tensor& seqlens_kv_indptr, // [batch size + 1]
40-
const torch::Tensor& kv_last_page_lens, // [batch size]
4140
const int32_t num_heads_per_head_k,
4241
const int32_t num_heads_k,
4342
const bool is_causal,
@@ -47,14 +46,13 @@ void get_mla_metadata_v1(const torch::Tensor& seqlens_qo_indptr, // [batch size
4746
torch::Tensor& reduce_indptr,
4847
torch::Tensor& reduce_final_map,
4948
torch::Tensor& reduce_partial_map,
50-
const int32_t page_size,
5149
const int32_t kv_granularity,
5250
const int32_t max_seqlen_qo,
5351
const int32_t uni_seqlen_qo,
5452
const bool fast_mode,
5553
const int32_t topk,
5654
const int32_t max_split_per_batch,
57-
const bool intra_batch_mode,
55+
const bool intra_batch_mode,
5856
const std::optional<at::ScalarType> dtype_q,
5957
const std::optional<at::ScalarType> dtype_kv);
6058

csrc/include/rocm_ops.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ namespace py = pybind11;
5757
py::arg("work_indptr"), \
5858
py::arg("work_info_set"), \
5959
py::arg("max_seqlen_q"), \
60-
py::arg("page_size"), \
61-
py::arg("nhead_kv"), \
6260
py::arg("softmax_scale"), \
6361
py::arg("splitData"), \
6462
py::arg("splitLse"), \
@@ -1656,7 +1654,6 @@ namespace py = pybind11;
16561654
"get_mla_metadata_v1", \
16571655
py::arg("seqlens_qo_indptr"), \
16581656
py::arg("seqlens_kv_indptr"), \
1659-
py::arg("kv_last_page_lens"), \
16601657
py::arg("num_heads_per_head_k"), \
16611658
py::arg("num_heads_k"), \
16621659
py::arg("is_causal"), \
@@ -1666,7 +1663,6 @@ namespace py = pybind11;
16661663
py::arg("reduce_indptr"), \
16671664
py::arg("reduce_final_map"), \
16681665
py::arg("reduce_partial_map"), \
1669-
py::arg("page_size") = 1, \
16701666
py::arg("kv_granularity") = 16, \
16711667
py::arg("max_seqlen_qo") = -1, \
16721668
py::arg("uni_seqlen_qo") = -1, \

csrc/kernels/mla/metadata.cu

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
55
#include "metadata/v1_0_device.cuh"
@@ -40,7 +40,6 @@
4040
void get_mla_metadata_v1(
4141
const torch::Tensor& seqlens_qo_indptr, // [batch size + 1]
4242
const torch::Tensor& seqlens_kv_indptr, // [batch size + 1]
43-
const torch::Tensor& kv_last_page_lens, // [batch size]
4443
const int32_t num_heads_per_head_k,
4544
const int32_t num_heads_k,
4645
const bool is_causal,
@@ -50,7 +49,6 @@ void get_mla_metadata_v1(
5049
torch::Tensor& reduce_indptr,
5150
torch::Tensor& reduce_final_map,
5251
torch::Tensor& reduce_partial_map,
53-
const int32_t page_size,
5452
const int32_t kv_granularity,
5553
const int32_t max_seqlen_qo,
5654
const int32_t uni_seqlen_qo,
@@ -65,8 +63,6 @@ void get_mla_metadata_v1(
6563

6664
TORCH_CHECK((kv_granularity & (kv_granularity - 1)) == 0,
6765
__func__, ": kv_granularity Must be power of 2!");
68-
TORCH_CHECK((page_size & (page_size - 1)) == 0,
69-
__func__, ": page_size Must be power of 2!");
7066
TORCH_CHECK(seqlens_qo_indptr.stride(0) == 1,
7167
__func__, ": seqlens_qo_indptr should be continuous!");
7268
TORCH_CHECK(seqlens_qo_indptr.scalar_type() == at::ScalarType::Int,
@@ -75,10 +71,6 @@ void get_mla_metadata_v1(
7571
__func__, ": seqlens_kv_indptr should be continuous!");
7672
TORCH_CHECK(seqlens_kv_indptr.scalar_type() == at::ScalarType::Int,
7773
__func__, ": seqlens_kv_indptr's element type should be int!");
78-
TORCH_CHECK(kv_last_page_lens.stride(0) == 1,
79-
__func__, ": kv_last_page_lens should be continuous!");
80-
TORCH_CHECK(kv_last_page_lens.scalar_type() == at::ScalarType::Int,
81-
__func__, ": kv_last_page_lens's element type should be int!");
8274

8375
at::ScalarType q_dtype = dtype_q.has_value() ? dtype_q.value() : at::ScalarType::BFloat16;
8476
at::ScalarType kv_dtype = dtype_kv.has_value() ? dtype_kv.value() : at::ScalarType::BFloat16;
@@ -88,11 +80,9 @@ void get_mla_metadata_v1(
8880
get_mla_metadata_v1_2_device(
8981
seqlens_qo_indptr,
9082
seqlens_kv_indptr,
91-
kv_last_page_lens,
9283
num_heads_per_head_k,
9384
num_heads_k,
9485
is_causal,
95-
page_size,
9686
kv_granularity,
9787
max_seqlen_qo,
9888
uni_seqlen_qo,

0 commit comments

Comments
 (0)