File tree Expand file tree Collapse file tree 3 files changed +8
-10
lines changed
Expand file tree Collapse file tree 3 files changed +8
-10
lines changed Original file line number Diff line number Diff line change @@ -636,7 +636,7 @@ def get_mla_metadata_info_v1(
636636 tile_cnt = batch_size * max_qo_tiles_per_batch
637637
638638 if fast_mode :
639- max_work = tile_cnt + cu_num - 1
639+ max_work = ( batch_size + cu_num - 1 ) * max_qo_tiles_per_batch
640640 max_split_tiles = (
641641 min (batch_size + cu_num - 1 , (cu_num - 1 ) * 2 ) * max_qo_tiles_per_batch
642642 )
Original file line number Diff line number Diff line change @@ -244,8 +244,8 @@ def test_mla(
244244 batch_size ,
245245 max_seqlen_qo ,
246246 nhead ,
247- q . dtype ,
248- kv_buffer . dtype ,
247+ dtype ,
248+ kvtype ,
249249 is_sparse = False ,
250250 fast_mode = True if not non_persistent_mode else False ,
251251 num_kv_splits = max_split_per_batch ,
@@ -292,9 +292,7 @@ def test_mla(
292292 max_split_per_batch = max_split_per_batch ,
293293 intra_batch_mode = non_persistent_mode ,
294294 dtype_q = dtype ,
295- dtype_kv = (
296- kvtype if dtype == kvtype else dtype
297- ), # if q bf16 k fp8 should be same as bf16bf16 for dp mode
295+ dtype_kv = kvtype ,
298296 )
299297
300298 def test_absorb_decode_bf16 ():
Original file line number Diff line number Diff line change @@ -408,8 +408,8 @@ def test_mla(
408408 batch_size ,
409409 max_seqlen_qo ,
410410 nhead ,
411- q . dtype ,
412- kv_buffer . dtype ,
411+ dtype ,
412+ kvtype ,
413413 is_sparse = True ,
414414 fast_mode = True ,
415415 )
@@ -452,8 +452,8 @@ def test_mla(
452452 uni_seqlen_qo = decode_qlen ,
453453 fast_mode = True ,
454454 topk = 2048 ,
455- dtype_q = q . dtype ,
456- dtype_kv = kv_buffer . dtype ,
455+ dtype_q = dtype ,
456+ dtype_kv = kvtype ,
457457 )
458458
459459 # generate kv topk per token & convert indices into per token
You can’t perform that action at this time.
0 commit comments