@@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516516        nullptr ;
517517}
518518
519- template <int  D, int  ncols1, int  ncols2,  int  KQ_stride > //  D == head size
519+ template <int  D, int  ncols1, int  ncols2> //  D == head size
520520__launch_bounds__ (D, 1 )
521521static __global__ void flash_attn_stream_k_fixup(
522522        float  * __restrict__  dst, const  float2  * __restrict__  dst_fixup, const  int  ne01, const  int  ne02, const  int  ne11) {
@@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) {
665665        fprintf (stderr, " Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n "  );
666666        GGML_ABORT (" fatal error"  );
667667    } else  {
668-         fprintf (stderr, " Unsupported KV type combination for head_size 256 .\n "  );
668+         fprintf (stderr, " Unsupported KV type combination for head_size %d .\n " , D );
669669        fprintf (stderr, " Only f16 is supported.\n "  );
670670        GGML_ABORT (" fatal error"  );
671671    }
672672}
673673
674- template  <int  D , int  ncols1, int  ncols2,  int  KQ_stride >
674+ template  <int  DV , int  ncols1, int  ncols2>
675675void  launch_fattn (
676676    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t  fattn_kernel, const  int  nwarps, const  size_t  nbytes_shared,
677677    const  int  KQ_row_granularity, const  bool  need_f16_K, const  bool  need_f16_V, const  bool  stream_k, const  int  warp_size = WARP_SIZE
@@ -691,7 +691,7 @@ void launch_fattn(
691691
692692    GGML_ASSERT (!mask || mask->type  == GGML_TYPE_F16);
693693    GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
694-                                  " the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"  );
694+         " the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"  );
695695
696696    GGML_ASSERT (K->ne [1 ] % FATTN_KQ_STRIDE == 0  && " Incorrect KV cache padding."  );
697697
@@ -754,10 +754,13 @@ void launch_fattn(
754754    const  int  ntiles_total = ntiles_x * (Q->ne [2 ] / ncols2) * Q->ne [3 ];
755755
756756    const  dim3  block_dim (warp_size, nwarps, 1 );
757+     int  max_blocks_per_sm = 1 ; //  Max. number of active blocks limited by occupancy.
758+     CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&max_blocks_per_sm, fattn_kernel, block_dim.x  * block_dim.y  * block_dim.z , nbytes_shared));
759+ 
757760    dim3  blocks_num;
758761    if  (stream_k) {
759762        //  For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
760-         const  int  max_blocks = 2 *nsm;
763+         const  int  max_blocks = max_blocks_per_sm *nsm;
761764        const  int  tiles_nwaves = (ntiles_total + max_blocks - 1 ) / max_blocks;
762765        const  int  tiles_efficiency_percent = 100  * ntiles_total / (max_blocks*tiles_nwaves);
763766
@@ -769,14 +772,11 @@ void launch_fattn(
769772        blocks_num.y  = 1 ;
770773        blocks_num.z  = 1 ;
771774
772-         dst_tmp_meta.alloc (blocks_num.x *ncols * (2 *2  + D ) * sizeof (float ));
775+         dst_tmp_meta.alloc (blocks_num.x *ncols * (2 *2  + DV ) * sizeof (float ));
773776    } else  {
774777        GGML_ASSERT (K->ne [1 ] % KQ_row_granularity == 0 );
775778        const  int  ntiles_KQ = K->ne [1 ] / KQ_row_granularity; //  Max. number of parallel blocks limited by tensor size.
776779
777-         int  max_blocks_per_sm = 1 ; //  Max. number of active blocks limited by occupancy.
778-         CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&max_blocks_per_sm, fattn_kernel, block_dim.x  * block_dim.y  * block_dim.z , nbytes_shared));
779- 
780780        //  parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
781781        parallel_blocks = std::max ((nsm * max_blocks_per_sm) / ntiles_total, 1 );
782782
@@ -853,19 +853,19 @@ void launch_fattn(
853853
854854    if  (stream_k) {
855855        if  (ntiles_total % blocks_num.x  != 0 ) { //  Fixup is only needed if the SMs work on fractional tiles.
856-             const  dim3  block_dim_combine (D , 1 , 1 );
856+             const  dim3  block_dim_combine (DV , 1 , 1 );
857857            const  dim3  blocks_num_combine = {blocks_num.x , ncols1, ncols2};
858858
859-             flash_attn_stream_k_fixup<D , ncols1, ncols2, KQ_stride >
859+             flash_attn_stream_k_fixup<DV , ncols1, ncols2>
860860                <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>> 
861861                ((float  *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], K->ne [1 ]);
862862        }
863863    } else  if  (parallel_blocks > 1 ) {
864-         const  dim3  block_dim_combine (D , 1 , 1 );
864+         const  dim3  block_dim_combine (DV , 1 , 1 );
865865        const  dim3  blocks_num_combine (Q->ne [1 ], 1 , blocks_num.z );
866866        const  size_t  nbytes_shared_combine = parallel_blocks*sizeof (float2 );
867867
868-         flash_attn_combine_results<D >
868+         flash_attn_combine_results<DV >
869869            <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>> 
870870            (dst_tmp.ptr , dst_tmp_meta.ptr , (float  *) KQV->data , parallel_blocks);
871871    }
0 commit comments