Skip to content

Commit 1726e93

Browse files
LucasWilkinsonRobert Shawrshaw@neuralmagic.com
authored
[BugFix][DP/EP] Fix CUTLASS MLA hang under load (vllm-project#26026)
Signed-off-by: Lucas Wilkinson <[email protected]> Co-authored-by: Robert Shaw <[email protected]> Co-authored-by: [email protected] <[email protected]>
1 parent ee04c0c commit 1726e93

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
580580
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
581581
auto blk_coord = tile_scheduler.get_block_coord();
582582
auto problem_shape = params.problem_shape;
583-
auto local_split_kv = params.split_kv;
583+
auto local_split_kv = params.split_kv;
584584
if (params.mainloop.ptr_seq != nullptr) {
585585
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
586-
if (params.ptr_split_kv != nullptr) {
586+
if (params.ptr_split_kv != nullptr) {
587587
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
588588
}
589589
}
590-
if (local_split_kv <= get<3>(blk_coord))
591-
continue;
590+
if (local_split_kv <= get<3>(blk_coord))
591+
continue;
592592
load_page_table(
593593
blk_coord,
594594
problem_shape,
595595
params.mainloop,
596596
shared_storage.tensors,
597597
pipeline_page_table, pipeline_pt_producer_state,
598-
local_split_kv
598+
local_split_kv
599599
);
600600
}
601601
}
@@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
604604
CUTLASS_PRAGMA_NO_UNROLL
605605
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
606606
auto blk_coord = tile_scheduler.get_block_coord();
607-
auto problem_shape = params.problem_shape;
608-
auto local_split_kv = params.split_kv;
607+
auto problem_shape = params.problem_shape;
608+
auto local_split_kv = params.split_kv;
609609
if (params.mainloop.ptr_seq != nullptr) {
610610
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
611-
if (params.ptr_split_kv != nullptr) {
611+
if (params.ptr_split_kv != nullptr) {
612612
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
613613
}
614614
}
615-
if (local_split_kv <= get<3>(blk_coord))
615+
if (local_split_kv <= get<3>(blk_coord))
616616
continue;
617617
load_cpasync(
618618
blk_coord,
@@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
621621
params.mainloop_params,
622622
shared_storage.tensors,
623623
pipeline_load_qk, pipeline_load_qk_producer_state,
624-
local_split_kv,
624+
local_split_kv,
625625
/* must be shared pipe */
626626
pipeline_page_table, pipeline_pt_consumer_state
627627
);
@@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
633633
CUTLASS_PRAGMA_NO_UNROLL
634634
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
635635
auto blk_coord = tile_scheduler.get_block_coord();
636-
auto problem_shape = params.problem_shape;
637-
auto local_split_kv = params.split_kv;
636+
auto problem_shape = params.problem_shape;
637+
auto local_split_kv = params.split_kv;
638638
if (params.mainloop.ptr_seq != nullptr) {
639639
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
640-
if (params.ptr_split_kv != nullptr) {
641-
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
642-
}
640+
if (params.ptr_split_kv != nullptr) {
641+
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
642+
}
643643
}
644-
if (local_split_kv <= get<3>(blk_coord))
644+
if (local_split_kv <= get<3>(blk_coord))
645645
continue;
646646
load_tma</* paged= */ true>(
647647
blk_coord,
@@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
651651
shared_storage.tensors,
652652
pipeline_load_qk, pipeline_load_qk_producer_state,
653653
pipeline_load_qk, pipeline_load_qk_producer_state,
654-
local_split_kv
654+
local_split_kv
655655
);
656656
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
657657
}
@@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
660660
CUTLASS_PRAGMA_NO_UNROLL
661661
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
662662
auto blk_coord = tile_scheduler.get_block_coord();
663-
auto problem_shape = params.problem_shape;
664-
auto local_split_kv = params.split_kv;
663+
auto problem_shape = params.problem_shape;
664+
auto local_split_kv = params.split_kv;
665665
if (params.mainloop.ptr_seq != nullptr) {
666666
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
667-
if (params.ptr_split_kv != nullptr) {
667+
if (params.ptr_split_kv != nullptr) {
668668
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
669-
}
669+
}
670670
}
671-
if (local_split_kv <= get<3>(blk_coord))
671+
if (local_split_kv <= get<3>(blk_coord))
672672
continue;
673673
load_tma<false>(
674674
blk_coord,
@@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
678678
shared_storage.tensors,
679679
pipeline_load_qk, pipeline_load_qk_producer_state,
680680
pipeline_load_qk, pipeline_load_qk_producer_state,
681-
local_split_kv
681+
local_split_kv
682682
);
683683
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
684684
}
@@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
694694
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
695695
auto blk_coord = tile_scheduler.get_block_coord();
696696
auto problem_shape = params.problem_shape;
697-
auto local_split_kv = params.split_kv;
697+
auto local_split_kv = params.split_kv;
698698
if (params.mainloop.ptr_seq != nullptr) {
699699
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
700700
if (params.ptr_split_kv != nullptr) {
701701
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
702702
}
703703
}
704-
if (local_split_kv <= get<3>(blk_coord))
704+
if (local_split_kv <= get<3>(blk_coord))
705705
continue;
706706
mma(blk_coord,
707707
problem_shape,
@@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
711711
pipeline_mma_s, pipeline_mma_s_producer_state,
712712
pipeline_p_mma, pipeline_p_mma_consumer_state,
713713
pipeline_mma_o, pipeline_mma_o_producer_state,
714-
local_split_kv
714+
local_split_kv
715715
);
716716
}
717717
}
@@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
726726
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
727727
auto blk_coord = tile_scheduler.get_block_coord();
728728
auto problem_shape = params.problem_shape;
729-
auto split_kv = params.split_kv;
730-
auto local_split_kv = split_kv;
729+
auto split_kv = params.split_kv;
730+
auto local_split_kv = split_kv;
731731
if (params.mainloop.ptr_seq != nullptr) {
732732
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
733-
if (params.ptr_split_kv != nullptr) {
733+
if (params.ptr_split_kv != nullptr) {
734734
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
735735
}
736736
}
737-
if (local_split_kv <= get<3>(blk_coord))
737+
if (local_split_kv <= get<3>(blk_coord))
738738
continue;
739739
compute(
740740
blk_coord,
@@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
745745
pipeline_mma_s, pipeline_mma_s_consumer_state,
746746
pipeline_p_mma, pipeline_p_mma_producer_state,
747747
pipeline_mma_o, pipeline_mma_o_consumer_state,
748-
local_split_kv
748+
local_split_kv
749749
);
750750
}
751751

@@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
19001900
cutlass::arch::NamedBarrier(
19011901
(kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
19021902
kNamedBarrierEpilogue
1903-
).arrive();
1903+
).arrive_and_wait();
19041904

19051905
return;
19061906
}

0 commit comments

Comments
 (0)