Skip to content

Commit f6fad2f

Browse files
authored
Fix issues (ROCm#1606)
1 parent 3cd58c2 commit f6fad2f

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

aiter/ops/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def get_mla_metadata_info_v1(
636636
tile_cnt = batch_size * max_qo_tiles_per_batch
637637

638638
if fast_mode:
639-
max_work = tile_cnt + cu_num - 1
639+
max_work = (batch_size + cu_num - 1) * max_qo_tiles_per_batch
640640
max_split_tiles = (
641641
min(batch_size + cu_num - 1, (cu_num - 1) * 2) * max_qo_tiles_per_batch
642642
)

op_tests/test_mla_persistent.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,8 @@ def test_mla(
244244
batch_size,
245245
max_seqlen_qo,
246246
nhead,
247-
q.dtype,
248-
kv_buffer.dtype,
247+
dtype,
248+
kvtype,
249249
is_sparse=False,
250250
fast_mode=True if not non_persistent_mode else False,
251251
num_kv_splits=max_split_per_batch,
@@ -292,9 +292,7 @@ def test_mla(
292292
max_split_per_batch=max_split_per_batch,
293293
intra_batch_mode=non_persistent_mode,
294294
dtype_q=dtype,
295-
dtype_kv=(
296-
kvtype if dtype == kvtype else dtype
297-
), # if q bf16 k fp8 should be same as bf16bf16 for dp mode
295+
dtype_kv=kvtype,
298296
)
299297

300298
def test_absorb_decode_bf16():

op_tests/test_mla_sparse.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,8 @@ def test_mla(
408408
batch_size,
409409
max_seqlen_qo,
410410
nhead,
411-
q.dtype,
412-
kv_buffer.dtype,
411+
dtype,
412+
kvtype,
413413
is_sparse=True,
414414
fast_mode=True,
415415
)
@@ -452,8 +452,8 @@ def test_mla(
452452
uni_seqlen_qo=decode_qlen,
453453
fast_mode=True,
454454
topk=2048,
455-
dtype_q=q.dtype,
456-
dtype_kv=kv_buffer.dtype,
455+
dtype_q=dtype,
456+
dtype_kv=kvtype,
457457
)
458458

459459
# generate kv topk per token & convert indices into per token

0 commit comments

Comments
 (0)