@@ -516,6 +516,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516516        nullptr ;
517517}
518518
519+ template <int  D, int  ncols, int  KQ_stride> //  D == head size
520+ #if  !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
521+ __launch_bounds__ (D, 1 )
522+ #endif  //  !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
523+ static  __global__  void  flash_attn_stream_k_fixup (
524+         float  * __restrict__  dst, const  float2  * __restrict__  dst_fixup, const  int  ne01, const  int  ne02, const  int  ne11) {
525+     const  float  * dst_fixup_data = ((const  float  *) dst_fixup) + gridDim .x *(2 *2 *ncols);
526+ 
527+     const  int  iter_k = ne11 / KQ_stride;
528+     const  int  iter_j = (ne01 + (ncols - 1 )) / ncols;
529+ 
530+     const  int  bidx0 = blockIdx .x ;
531+ 
532+     const  int  kbc0      = (bidx0 + 0 )*iter_k*iter_j*ne02 / gridDim .x ;
533+     const  int  kbc0_stop = (bidx0 + 1 )*iter_k*iter_j*ne02 / gridDim .x ;
534+ 
535+     const  bool  did_not_have_any_data   = kbc0 == kbc0_stop;
536+     const  bool  wrote_beginning_of_tile = kbc0 % iter_k == 0 ;
537+     const  bool  did_not_write_last      = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0 ;
538+     if  (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
539+         return ;
540+     }
541+ 
542+     const  int  channel = kbc0 / (iter_k*iter_j);
543+     const  int  jt      = (kbc0 - channel*iter_k*iter_j) / iter_k;
544+ 
545+     dst += jt*ncols*ne02*D + channel*D;
546+ 
547+     //  Load the partial result that needs a fixup:
548+     float  dst_val[ncols] = {0 .0f };
549+     float  max_val[ncols] = {0 .0f };
550+     float  rowsum[ncols]  = {0 .0f };
551+ #pragma  unroll
552+     for  (int  j = 0 ; j < ncols; ++j) {
553+         if  (jt*ncols + j >= ne01) {
554+             break ;
555+         }
556+         dst_val[j] = dst[j*ne02*D + threadIdx .x ];
557+ 
558+         const  float2  tmp = dst_fixup[bidx0*ncols + j];
559+         max_val[j] = tmp.x ;
560+         rowsum[j]  = tmp.y ;
561+     }
562+ 
563+     //  Iterate over previous blocks and compute the combined results.
564+     //  All CUDA blocks that get here must have a previous block that needs a fixup.
565+     int  bidx = bidx0 - 1 ;
566+     int  kbc_stop = kbc0;
567+     while (true ) {
568+         const  int  kbc = bidx*iter_k*iter_j*ne02 / gridDim .x ;
569+         if  (kbc == kbc_stop) { //  Did not have any data.
570+             bidx--;
571+             kbc_stop = kbc;
572+             continue ;
573+         }
574+ 
575+ #pragma  unroll
576+         for  (int  j = 0 ; j < ncols; ++j) {
577+             if  (jt*ncols + j >= ne01) {
578+                 break ;
579+             }
580+             const  float  dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx .x ];
581+ 
582+             const  float2  tmp = dst_fixup[(gridDim .x  + bidx)*ncols + j];
583+ 
584+             //  Scale the current and new value accumulators depending on the max. values.
585+             const  float  max_val_new = fmaxf (max_val[j], tmp.x );
586+ 
587+             const  float  diff_val = max_val[j] - max_val_new;
588+             const  float  diff_add = tmp.x       - max_val_new;
589+ 
590+             const  float  scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_val) : 0 .0f ;
591+             const  float  scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_add) : 0 .0f ;
592+ 
593+             dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
594+             rowsum[j]  = scale_val*rowsum[j]  + scale_add*tmp.y ;
595+ 
596+             max_val[j] = max_val_new;
597+         }
598+ 
599+         //  If this block started in a previous tile we are done and don't need to combine additional partial results.
600+         if  (kbc % iter_k == 0  || kbc/iter_k < kbc0/iter_k) {
601+             break ;
602+         }
603+         bidx--;
604+         kbc_stop = kbc;
605+     }
606+ 
607+     //  Write back final result:
608+ #pragma  unroll
609+     for  (int  j = 0 ; j < ncols; ++j) {
610+         if  (jt*ncols + j >= ne01) {
611+             return ;
612+         }
613+         dst[j*ne02*D + threadIdx .x ] = dst_val[j] / rowsum[j];
614+     }
615+ }
616+ 
519617template <int  D, int  parallel_blocks> //  D == head size
520618#if  !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
521619__launch_bounds__ (D, 1 )
@@ -581,10 +679,11 @@ static void on_no_fattn_vec_case(const int D) {
581679    }
582680}
583681
584- template  <int  D, int  parallel_blocks>
682+ //  parallel_blocks == 0 is stream-k decomposition
683+ template  <int  D, int  cols_per_block, int  parallel_blocks, int  KQ_stride>
585684void  launch_fattn (
586685    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t  fattn_kernel,
587-     const  int  nwarps, const  int  cols_per_block , const  bool  need_f16_K, const  bool  need_f16_V
686+     const  int  nwarps, const  size_t  nbytes_shared , const  bool  need_f16_K, const  bool  need_f16_V
588687) {
589688    const  ggml_tensor * Q = dst->src [0 ];
590689    const  ggml_tensor * K = dst->src [1 ];
@@ -603,20 +702,23 @@ void launch_fattn(
603702
604703    GGML_ASSERT (K->ne [1 ] % FATTN_KQ_STRIDE == 0  && " Incorrect KV cache padding." 
605704
705+     GGML_ASSERT (Q->ne [3 ] == 1 );
706+ 
606707    ggml_cuda_pool & pool = ctx.pool ();
607708    cudaStream_t main_stream = ctx.stream ();
709+     const  int  nsm = ggml_cuda_info ().devices [ggml_cuda_get_device ()].nsm ;
608710
609711    ggml_cuda_pool_alloc<half>   K_f16 (pool);
610712    ggml_cuda_pool_alloc<half>   V_f16 (pool);
611713    ggml_cuda_pool_alloc<float >  dst_tmp (pool);
612714    ggml_cuda_pool_alloc<float2 > dst_tmp_meta (pool);
613715
614-     char  * K_data = (char  *) K->data ;
716+     const   char  * K_data = (const   char  *) K->data ;
615717    size_t  nb11 = K->nb [1 ];
616718    size_t  nb12 = K->nb [2 ];
617719    size_t  nb13 = K->nb [3 ];
618720
619-     char  * V_data = (char  *) V->data ;
721+     const   char  * V_data = (const   char  *) V->data ;
620722    size_t  nb21 = V->nb [1 ];
621723    size_t  nb22 = V->nb [2 ];
622724    size_t  nb23 = V->nb [3 ];
@@ -649,39 +751,60 @@ void launch_fattn(
649751        nb23 = nb23*bs*sizeof (half)/ts;
650752    }
651753
652-     if  (parallel_blocks > 1 ) {
653-         dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
654-         dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
655-     }
754+     const  int  ntiles_x = ((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block);
755+     const  int  ntiles_total = ntiles_x*Q->ne [2 ]*Q->ne [3 ];
656756
657757    const  dim3  block_dim (WARP_SIZE, nwarps, 1 );
658-     const  dim3  blocks_num (parallel_blocks*((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block), Q->ne [2 ], Q->ne [3 ]);
659-     const  int   shmem = 0 ;
758+     dim3  blocks_num;
759+     if  (parallel_blocks == 0 ) {
760+         //  For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
761+         const  int  tiles_nwaves  = (ntiles_total - nsm - 1 ) / nsm;
762+         const  bool  tiles_inefficient = 3 *nsm < 2 *tiles_nwaves*ntiles_total;
763+         const  bool  short_context = K->ne [1 ] < 4096 ;
764+ 
765+         const  int  nblocks_stream_k = 2 *nsm;
766+ 
767+         blocks_num.x  = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
768+         blocks_num.y  = 1 ;
769+         blocks_num.z  = 1 ;
770+ 
771+         dst_tmp_meta.alloc (blocks_num.x *cols_per_block * (2 *2  + D) * sizeof (float ));
772+     } else  {
773+         blocks_num.x  = parallel_blocks*ntiles_x;
774+         blocks_num.y  = Q->ne [2 ];
775+         blocks_num.z  = Q->ne [3 ];
776+ 
777+         if  (parallel_blocks > 1 ) {
778+             dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
779+             dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
780+         }
781+     }
782+ 
660783
661784    float  scale         = 1 .0f ;
662785    float  max_bias      = 0 .0f ;
663786    float  logit_softcap = 0 .0f ;
664787
665-     memcpy (&scale,         (float  *) KQV->op_params  + 0 , sizeof (float ));
666-     memcpy (&max_bias,      (float  *) KQV->op_params  + 1 , sizeof (float ));
667-     memcpy (&logit_softcap, (float  *) KQV->op_params  + 2 , sizeof (float ));
788+     memcpy (&scale,         (const   float  *) KQV->op_params  + 0 , sizeof (float ));
789+     memcpy (&max_bias,      (const   float  *) KQV->op_params  + 1 , sizeof (float ));
790+     memcpy (&logit_softcap, (const   float  *) KQV->op_params  + 2 , sizeof (float ));
668791
669792    if  (logit_softcap != 0 .0f ) {
670793        scale /= logit_softcap;
671794    }
672795
673796    const  uint32_t  n_head      = Q->ne [2 ];
674-     const  uint32_t  n_head_log2 = 1u  << ( uint32_t )  floorf (log2f (( float )  n_head));
797+     const  uint32_t  n_head_log2 = 1u  << uint32_t ( floorf (log2f (float ( n_head)) ));
675798
676799    const  float  m0 = powf (2 .0f , -(max_bias       ) / n_head_log2);
677800    const  float  m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
678801
679-     fattn_kernel<<<blocks_num, block_dim, shmem , main_stream>>> (
802+     fattn_kernel<<<blocks_num, block_dim, nbytes_shared , main_stream>>> (
680803        (const  char  *) Q->data ,
681804        K_data,
682805        V_data,
683806        mask ? ((const  char  *) mask->data ) : nullptr ,
684-         (parallel_blocks) ==  1  ? (float  *) KQV->data  : dst_tmp. ptr , dst_tmp_meta.ptr ,
807+         (parallel_blocks) >  1  ? dst_tmp. ptr  :  (float  *) KQV->data , dst_tmp_meta.ptr ,
685808        scale, max_bias, m0, m1, n_head_log2, logit_softcap,
686809        Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
687810        K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
@@ -693,16 +816,22 @@ void launch_fattn(
693816    );
694817    CUDA_CHECK (cudaGetLastError ());
695818
696-     if  ((parallel_blocks) == 1 ) {
697-         return ;
698-     }
819+     if  constexpr  (parallel_blocks == 0 ) {
820+         if  (blocks_num.x  % ntiles_total != 0 ) { //  Fixup is only needed if the SMs work on fractional tiles.
821+             const  dim3  block_dim_combine (D, 1 , 1 );
822+             const  dim3  blocks_num_combine = blocks_num;
699823
700-     const  dim3  block_dim_combine (D, 1 , 1 );
701-     const  dim3  blocks_num_combine (Q->ne [1 ], blocks_num.y , blocks_num.z );
702-     const  int   shmem_combine = 0 ;
824+             flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
825+                 <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>> 
826+                 ((float  *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], K->ne [1 ]);
827+         }
828+     } else  if  constexpr  (parallel_blocks > 1 ) {
829+         const  dim3  block_dim_combine (D, 1 , 1 );
830+         const  dim3  blocks_num_combine (Q->ne [1 ], blocks_num.y , blocks_num.z );
703831
704-     flash_attn_combine_results<D, parallel_blocks>
705-         <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>> 
706-         (dst_tmp.ptr , dst_tmp_meta.ptr , (float  *) KQV->data );
832+         flash_attn_combine_results<D, parallel_blocks>
833+             <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>> 
834+             (dst_tmp.ptr , dst_tmp_meta.ptr , (float  *) KQV->data );
835+     }
707836    CUDA_CHECK (cudaGetLastError ());
708837}
0 commit comments