Skip to content

Commit 5961a92

Browse files
authored
add max split num to sparse mla (ROCm#1617)
1 parent a0328ce commit 5961a92

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

op_tests/test_mla_sparse.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def test_mla(
326326
page_size,
327327
varlen,
328328
decode_qlen,
329+
max_split_per_batch,
329330
):
330331
ret = {}
331332

@@ -412,6 +413,7 @@ def test_mla(
412413
kvtype,
413414
is_sparse=True,
414415
fast_mode=True,
416+
num_kv_splits=max_split_per_batch,
415417
)
416418

417419
# aiter implementation
@@ -451,6 +453,7 @@ def test_mla(
451453
max_seqlen_qo=int(max_seqlen_qo),
452454
uni_seqlen_qo=decode_qlen,
453455
fast_mode=True,
456+
max_split_per_batch=max_split_per_batch,
454457
topk=2048,
455458
dtype_q=dtype,
456459
dtype_kv=kvtype,
@@ -501,6 +504,7 @@ def test_sparse_mla_bf16():
501504
kv_last_page_lens,
502505
1,
503506
sm_scale,
507+
num_kv_splits=max_split_per_batch,
504508
work_meta_data=work_meta_data,
505509
work_indptr=work_indptr,
506510
work_info_set=work_info_set,
@@ -561,6 +565,7 @@ def test_absorb_decode_fp8():
561565
kv_last_page_lens,
562566
1,
563567
sm_scale,
568+
num_kv_splits=max_split_per_batch,
564569
q_scale=q_scale,
565570
kv_scale=kv_scale,
566571
work_meta_data=work_meta_data,
@@ -721,6 +726,15 @@ def test_absorb_decode_fp8():
721726
help="""Number of heads.
722727
e.g.: -n 16,1""",
723728
)
729+
parser.add_argument(
730+
"-ms",
731+
"--max_split_per_batch",
732+
type=int,
733+
nargs="*",
734+
default=[16],
735+
help="""kv seqlens max split num for per batch.
736+
e.g.: -ms 32""",
737+
)
724738
parser.add_argument(
725739
"--varlen",
726740
action="store_true",
@@ -738,8 +752,8 @@ def test_absorb_decode_fp8():
738752

739753
for nhead, decode_qlen in list_nhead:
740754
df = []
741-
for dtype, kvtype, ctx_len, batch_size in itertools.product(
742-
list_dtype, l_kv_dtype, args.ctxLen, args.batchSize
755+
for dtype, kvtype, ctx_len, batch_size, max_split_per_batch in itertools.product(
756+
list_dtype, l_kv_dtype, args.ctxLen, args.batchSize, args.max_split_per_batch
743757
):
744758
ret = test_mla(
745759
ctx_len,
@@ -754,6 +768,7 @@ def test_absorb_decode_fp8():
754768
args.block_size,
755769
varlen=args.varlen,
756770
decode_qlen=decode_qlen,
771+
max_split_per_batch=max_split_per_batch,
757772
)
758773
df.append(ret)
759774
df = pd.DataFrame(df)

0 commit comments

Comments
 (0)