@@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
580
580
for (; tile_scheduler.is_valid (); ++tile_scheduler) {
581
581
auto blk_coord = tile_scheduler.get_block_coord ();
582
582
auto problem_shape = params.problem_shape ;
583
- auto local_split_kv = params.split_kv ;
583
+ auto local_split_kv = params.split_kv ;
584
584
if (params.mainloop .ptr_seq != nullptr ) {
585
585
get<1 >(problem_shape) = params.mainloop .ptr_seq [get<2 >(blk_coord)];
586
- if (params.ptr_split_kv != nullptr ) {
586
+ if (params.ptr_split_kv != nullptr ) {
587
587
local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
588
588
}
589
589
}
590
- if (local_split_kv <= get<3 >(blk_coord))
591
- continue ;
590
+ if (local_split_kv <= get<3 >(blk_coord))
591
+ continue ;
592
592
load_page_table (
593
593
blk_coord,
594
594
problem_shape,
595
595
params.mainloop ,
596
596
shared_storage.tensors ,
597
597
pipeline_page_table, pipeline_pt_producer_state,
598
- local_split_kv
598
+ local_split_kv
599
599
);
600
600
}
601
601
}
@@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
604
604
CUTLASS_PRAGMA_NO_UNROLL
605
605
for (; tile_scheduler.is_valid (); ++tile_scheduler) {
606
606
auto blk_coord = tile_scheduler.get_block_coord ();
607
- auto problem_shape = params.problem_shape ;
608
- auto local_split_kv = params.split_kv ;
607
+ auto problem_shape = params.problem_shape ;
608
+ auto local_split_kv = params.split_kv ;
609
609
if (params.mainloop .ptr_seq != nullptr ) {
610
610
get<1 >(problem_shape) = params.mainloop .ptr_seq [get<2 >(blk_coord)];
611
- if (params.ptr_split_kv != nullptr ) {
611
+ if (params.ptr_split_kv != nullptr ) {
612
612
local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
613
613
}
614
614
}
615
- if (local_split_kv <= get<3 >(blk_coord))
615
+ if (local_split_kv <= get<3 >(blk_coord))
616
616
continue ;
617
617
load_cpasync (
618
618
blk_coord,
@@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
621
621
params.mainloop_params ,
622
622
shared_storage.tensors ,
623
623
pipeline_load_qk, pipeline_load_qk_producer_state,
624
- local_split_kv,
624
+ local_split_kv,
625
625
/* must be shared pipe */
626
626
pipeline_page_table, pipeline_pt_consumer_state
627
627
);
@@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
633
633
CUTLASS_PRAGMA_NO_UNROLL
634
634
for (; tile_scheduler.is_valid (); ++tile_scheduler) {
635
635
auto blk_coord = tile_scheduler.get_block_coord ();
636
- auto problem_shape = params.problem_shape ;
637
- auto local_split_kv = params.split_kv ;
636
+ auto problem_shape = params.problem_shape ;
637
+ auto local_split_kv = params.split_kv ;
638
638
if (params.mainloop .ptr_seq != nullptr ) {
639
639
get<1 >(problem_shape) = params.mainloop .ptr_seq [get<2 >(blk_coord)];
640
- if (params.ptr_split_kv != nullptr ) {
641
- local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
642
- }
640
+ if (params.ptr_split_kv != nullptr ) {
641
+ local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
642
+ }
643
643
}
644
- if (local_split_kv <= get<3 >(blk_coord))
644
+ if (local_split_kv <= get<3 >(blk_coord))
645
645
continue ;
646
646
load_tma</* paged= */ true >(
647
647
blk_coord,
@@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
651
651
shared_storage.tensors ,
652
652
pipeline_load_qk, pipeline_load_qk_producer_state,
653
653
pipeline_load_qk, pipeline_load_qk_producer_state,
654
- local_split_kv
654
+ local_split_kv
655
655
);
656
656
cutlass::arch::NamedBarrier ((kNumComputeWarps + kNumLoadWarps ) * NumThreadsPerWarp, kNamedBarrierEpilogue ).arrive_and_wait ();
657
657
}
@@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
660
660
CUTLASS_PRAGMA_NO_UNROLL
661
661
for (; tile_scheduler.is_valid (); ++tile_scheduler) {
662
662
auto blk_coord = tile_scheduler.get_block_coord ();
663
- auto problem_shape = params.problem_shape ;
664
- auto local_split_kv = params.split_kv ;
663
+ auto problem_shape = params.problem_shape ;
664
+ auto local_split_kv = params.split_kv ;
665
665
if (params.mainloop .ptr_seq != nullptr ) {
666
666
get<1 >(problem_shape) = params.mainloop .ptr_seq [get<2 >(blk_coord)];
667
- if (params.ptr_split_kv != nullptr ) {
667
+ if (params.ptr_split_kv != nullptr ) {
668
668
local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
669
- }
669
+ }
670
670
}
671
- if (local_split_kv <= get<3 >(blk_coord))
671
+ if (local_split_kv <= get<3 >(blk_coord))
672
672
continue ;
673
673
load_tma<false >(
674
674
blk_coord,
@@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
678
678
shared_storage.tensors ,
679
679
pipeline_load_qk, pipeline_load_qk_producer_state,
680
680
pipeline_load_qk, pipeline_load_qk_producer_state,
681
- local_split_kv
681
+ local_split_kv
682
682
);
683
683
cutlass::arch::NamedBarrier ((kNumComputeWarps + kNumLoadWarps ) * NumThreadsPerWarp, kNamedBarrierEpilogue ).arrive_and_wait ();
684
684
}
@@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
694
694
for (; tile_scheduler.is_valid (); ++tile_scheduler) {
695
695
auto blk_coord = tile_scheduler.get_block_coord ();
696
696
auto problem_shape = params.problem_shape ;
697
- auto local_split_kv = params.split_kv ;
697
+ auto local_split_kv = params.split_kv ;
698
698
if (params.mainloop .ptr_seq != nullptr ) {
699
699
get<1 >(problem_shape) = params.mainloop .ptr_seq [get<2 >(blk_coord)];
700
700
if (params.ptr_split_kv != nullptr ) {
701
701
local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
702
702
}
703
703
}
704
- if (local_split_kv <= get<3 >(blk_coord))
704
+ if (local_split_kv <= get<3 >(blk_coord))
705
705
continue ;
706
706
mma (blk_coord,
707
707
problem_shape,
@@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
711
711
pipeline_mma_s, pipeline_mma_s_producer_state,
712
712
pipeline_p_mma, pipeline_p_mma_consumer_state,
713
713
pipeline_mma_o, pipeline_mma_o_producer_state,
714
- local_split_kv
714
+ local_split_kv
715
715
);
716
716
}
717
717
}
@@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
726
726
for (; tile_scheduler.is_valid (); ++tile_scheduler) {
727
727
auto blk_coord = tile_scheduler.get_block_coord ();
728
728
auto problem_shape = params.problem_shape ;
729
- auto split_kv = params.split_kv ;
730
- auto local_split_kv = split_kv;
729
+ auto split_kv = params.split_kv ;
730
+ auto local_split_kv = split_kv;
731
731
if (params.mainloop .ptr_seq != nullptr ) {
732
732
get<1 >(problem_shape) = params.mainloop .ptr_seq [get<2 >(blk_coord)];
733
- if (params.ptr_split_kv != nullptr ) {
733
+ if (params.ptr_split_kv != nullptr ) {
734
734
local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
735
735
}
736
736
}
737
- if (local_split_kv <= get<3 >(blk_coord))
737
+ if (local_split_kv <= get<3 >(blk_coord))
738
738
continue ;
739
739
compute (
740
740
blk_coord,
@@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
745
745
pipeline_mma_s, pipeline_mma_s_consumer_state,
746
746
pipeline_p_mma, pipeline_p_mma_producer_state,
747
747
pipeline_mma_o, pipeline_mma_o_consumer_state,
748
- local_split_kv
748
+ local_split_kv
749
749
);
750
750
}
751
751
@@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
1900
1900
cutlass::arch::NamedBarrier (
1901
1901
(kNumComputeWarps + kNumLoadWarps ) * NumThreadsPerWarp,
1902
1902
kNamedBarrierEpilogue
1903
- ).arrive ();
1903
+ ).arrive_and_wait ();
1904
1904
1905
1905
return ;
1906
1906
}
0 commit comments