@@ -326,6 +326,7 @@ def test_mla(
326326 page_size ,
327327 varlen ,
328328 decode_qlen ,
329+ max_split_per_batch ,
329330):
330331 ret = {}
331332
@@ -412,6 +413,7 @@ def test_mla(
412413 kvtype ,
413414 is_sparse = True ,
414415 fast_mode = True ,
416+ num_kv_splits = max_split_per_batch ,
415417 )
416418
417419 # aiter implementation
@@ -451,6 +453,7 @@ def test_mla(
451453 max_seqlen_qo = int (max_seqlen_qo ),
452454 uni_seqlen_qo = decode_qlen ,
453455 fast_mode = True ,
456+ max_split_per_batch = max_split_per_batch ,
454457 topk = 2048 ,
455458 dtype_q = dtype ,
456459 dtype_kv = kvtype ,
@@ -501,6 +504,7 @@ def test_sparse_mla_bf16():
501504 kv_last_page_lens ,
502505 1 ,
503506 sm_scale ,
507+ num_kv_splits = max_split_per_batch ,
504508 work_meta_data = work_meta_data ,
505509 work_indptr = work_indptr ,
506510 work_info_set = work_info_set ,
@@ -561,6 +565,7 @@ def test_absorb_decode_fp8():
561565 kv_last_page_lens ,
562566 1 ,
563567 sm_scale ,
568+ num_kv_splits = max_split_per_batch ,
564569 q_scale = q_scale ,
565570 kv_scale = kv_scale ,
566571 work_meta_data = work_meta_data ,
@@ -721,6 +726,15 @@ def test_absorb_decode_fp8():
721726 help = """Number of heads.
722727 e.g.: -n 16,1""" ,
723728)
729+ parser .add_argument (
730+ "-ms" ,
731+ "--max_split_per_batch" ,
732+ type = int ,
733+ nargs = "*" ,
734+ default = [16 ],
735+ help = """kv seqlens max split num for per batch.
736+ e.g.: -ms 32""" ,
737+ )
724738parser .add_argument (
725739 "--varlen" ,
726740 action = "store_true" ,
@@ -738,8 +752,8 @@ def test_absorb_decode_fp8():
738752
739753for nhead , decode_qlen in list_nhead :
740754 df = []
741- for dtype , kvtype , ctx_len , batch_size in itertools .product (
742- list_dtype , l_kv_dtype , args .ctxLen , args .batchSize
755+ for dtype , kvtype , ctx_len , batch_size , max_split_per_batch in itertools .product (
756+ list_dtype , l_kv_dtype , args .ctxLen , args .batchSize , args . max_split_per_batch
743757 ):
744758 ret = test_mla (
745759 ctx_len ,
@@ -754,6 +768,7 @@ def test_absorb_decode_fp8():
754768 args .block_size ,
755769 varlen = args .varlen ,
756770 decode_qlen = decode_qlen ,
771+ max_split_per_batch = max_split_per_batch ,
757772 )
758773 df .append (ret )
759774 df = pd .DataFrame (df )
0 commit comments