@@ -729,11 +729,12 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
729
729
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
730
730
OrderBarrierSoftmax& order_s) {
731
731
732
- int mask_tile_count = Mask (params.window_size_left , params.window_size_right ).get_unmasked_trip_count (blk_coord, TileShape{}, problem_shape);
733
-
734
- auto min_max = Mask (params.window_size_left , params.window_size_right ).get_n_block_min_max (blk_coord, TileShape{}, problem_shape);
732
+ Mask mask (params.window_size_left , params.window_size_right );
733
+ auto min_max = mask.get_n_block_min_max (blk_coord, TileShape{}, problem_shape);
735
734
int n_block_min = get<0 >(min_max);
736
- // int n_block_max = get<1>(min_max);
735
+ const int n_block_max = get<1 >(min_max);
736
+ const int n_block_start_unmask = mask.get_n_block_start_unmask (blk_coord, TileShape{}, problem_shape);
737
+ const int n_block_stop_unmask = mask.get_n_block_stop_unmask (blk_coord, TileShape{}, problem_shape);
737
738
738
739
ElementQK row_max = -INFINITY;
739
740
ElementQK row_sum = 0 ;
@@ -747,35 +748,73 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
747
748
748
749
pipeline_c.producer_acquire (pipeline_c_producer_state);
749
750
750
- CUTLASS_PRAGMA_NO_UNROLL
751
- for (; mask_tile_count > 0 ; mask_tile_count -= 1 ) {
752
- softmax_step<false /* need_apply_mask */ >(
753
- row_max, row_sum, stage,
754
- (mask_tile_count == 1 ) &&
755
- (Mask (params.window_size_left , params.window_size_right ).get_masked_trip_count (blk_coord, TileShape{}, problem_shape) == 0 ),
756
- blk_coord, cS, params, problem_shape,
757
- pipeline_s, pipeline_s_consumer_state,
758
- pipeline_c, pipeline_c_producer_state,
759
- order_s
760
- );
761
-
762
- cS.data () = cS.data () + E<1 >{} * get<1 >(ThreadShape{}) * get<1 >(TileShapeQK{});
763
- }
764
-
765
- // Masked iterations
766
- mask_tile_count = Mask (params.window_size_left , params.window_size_right ).get_masked_trip_count (blk_coord, TileShape{}, problem_shape);
767
-
768
- CUTLASS_PRAGMA_NO_UNROLL
769
- for (; mask_tile_count > 0 ; mask_tile_count -= 1 ) {
770
- softmax_step<true /* need_apply_mask */ >(
771
- row_max, row_sum, stage, mask_tile_count == 1 ,
772
- blk_coord, cS, params, problem_shape,
773
- pipeline_s, pipeline_s_consumer_state,
774
- pipeline_c, pipeline_c_producer_state,
775
- order_s
776
- );
751
+ // from observation, dispatch is better for the mask -> unmask -> mask pattern and when the number of tiles is small
752
+ if constexpr (std::is_base_of_v<cutlass::fmha::collective::LocalMask<true >, Mask>
753
+ || std::is_base_of_v<cutlass::fmha::collective::LocalMask<false >, Mask>) {
754
+ auto dispatch_bool = [](bool b, auto fn) {
755
+ if (b) {
756
+ fn (cute::true_type{});
757
+ }
758
+ else {
759
+ fn (cute::false_type{});
760
+ }
761
+ };
762
+
763
+ CUTLASS_PRAGMA_NO_UNROLL
764
+ for (; n_block_min < n_block_max; n_block_min += 1 ) {
765
+ // Apply mask only for tiles outside the attention window
766
+ // for local mask, we don't guarantee n_block_start_unmask <= n_block_stop_unmask <= n_block_max
767
+ bool need_apply_mask = warp_uniform (n_block_min < n_block_start_unmask || n_block_min >= n_block_stop_unmask);
768
+
769
+ dispatch_bool (need_apply_mask, [&](auto is_masked_tile) {
770
+ if constexpr (decltype (is_masked_tile)::value) {
771
+ softmax_step<true /* need_apply_mask */ >(
772
+ row_max, row_sum, stage, (n_block_min == n_block_max - 1 ),
773
+ blk_coord, cS, params, problem_shape,
774
+ pipeline_s, pipeline_s_consumer_state,
775
+ pipeline_c, pipeline_c_producer_state,
776
+ order_s
777
+ );
778
+ } else {
779
+ softmax_step<false /* need_apply_mask */ >(
780
+ row_max, row_sum, stage, (n_block_min == n_block_max - 1 ),
781
+ blk_coord, cS, params, problem_shape,
782
+ pipeline_s, pipeline_s_consumer_state,
783
+ pipeline_c, pipeline_c_producer_state,
784
+ order_s
785
+ );
786
+ }
787
+ });
788
+
789
+ cS.data () = cS.data () + E<1 >{} * get<1 >(ThreadShape{}) * get<1 >(TileShapeQK{});
790
+ }
791
+ } else {
792
+ CUTLASS_PRAGMA_NO_UNROLL
793
+ for (; n_block_min < n_block_stop_unmask; n_block_min += 1 ) {
794
+ softmax_step<false /* need_apply_mask */ >(
795
+ row_max, row_sum, stage,
796
+ (n_block_min == n_block_max - 1 ),
797
+ blk_coord, cS, params, problem_shape,
798
+ pipeline_s, pipeline_s_consumer_state,
799
+ pipeline_c, pipeline_c_producer_state,
800
+ order_s
801
+ );
802
+
803
+ cS.data () = cS.data () + E<1 >{} * get<1 >(ThreadShape{}) * get<1 >(TileShapeQK{});
804
+ }
777
805
778
- cS.data () = cS.data () + E<1 >{} * get<1 >(ThreadShape{}) * get<1 >(TileShapeQK{});
806
+ CUTLASS_PRAGMA_NO_UNROLL
807
+ for (; n_block_min < n_block_max; n_block_min += 1 ) {
808
+ softmax_step<true /* need_apply_mask */ >(
809
+ row_max, row_sum, stage, n_block_min == n_block_max - 1 ,
810
+ blk_coord, cS, params, problem_shape,
811
+ pipeline_s, pipeline_s_consumer_state,
812
+ pipeline_c, pipeline_c_producer_state,
813
+ order_s
814
+ );
815
+
816
+ cS.data () = cS.data () + E<1 >{} * get<1 >(ThreadShape{}) * get<1 >(TileShapeQK{});
817
+ }
779
818
}
780
819
781
820
pipeline_c.producer_commit (pipeline_c_producer_state);
0 commit comments