@@ -653,6 +653,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
653653 nullptr ;
654654}
655655
656+ template <int D, int ncols, int KQ_stride> // D == head size
657+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
658+ __launch_bounds__ (D, 1 )
659+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
660+ static __global__ void flash_attn_stream_k_fixup (
661+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
662+ const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim .x *(2 *2 *ncols);
663+
664+ const int iter_k = ne11 / KQ_stride;
665+ const int iter_j = (ne01 + (ncols - 1 )) / ncols;
666+
667+ const int bidx0 = blockIdx .x ;
668+
669+ const int kbc0 = (bidx0 + 0 )*iter_k*iter_j*ne02 / gridDim .x ;
670+ const int kbc0_stop = (bidx0 + 1 )*iter_k*iter_j*ne02 / gridDim .x ;
671+
672+ const bool did_not_have_any_data = kbc0 == kbc0_stop;
673+ const bool wrote_beginning_of_tile = kbc0 % iter_k == 0 ;
674+ const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0 ;
675+ if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
676+ return ;
677+ }
678+
679+ const int channel = kbc0 / (iter_k*iter_j);
680+ const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
681+
682+ dst += jt*ncols*ne02*D + channel*D;
683+
684+ // Load the partial result that needs a fixup:
685+ float dst_val[ncols] = {0 .0f };
686+ float max_val[ncols] = {0 .0f };
687+ float rowsum[ncols] = {0 .0f };
688+ #pragma unroll
689+ for (int j = 0 ; j < ncols; ++j) {
690+ if (jt*ncols + j >= ne01) {
691+ break ;
692+ }
693+ dst_val[j] = dst[j*ne02*D + threadIdx .x ];
694+
695+ const float2 tmp = dst_fixup[bidx0*ncols + j];
696+ max_val[j] = tmp.x ;
697+ rowsum[j] = tmp.y ;
698+ }
699+
700+ // Iterate over previous blocks and compute the combined results.
701+ // All CUDA blocks that get here must have a previous block that needs a fixup.
702+ int bidx = bidx0 - 1 ;
703+ int kbc_stop = kbc0;
704+ while (true ) {
705+ const int kbc = bidx*iter_k*iter_j*ne02 / gridDim .x ;
706+ if (kbc == kbc_stop) { // Did not have any data.
707+ bidx--;
708+ kbc_stop = kbc;
709+ continue ;
710+ }
711+
712+ #pragma unroll
713+ for (int j = 0 ; j < ncols; ++j) {
714+ if (jt*ncols + j >= ne01) {
715+ break ;
716+ }
717+ const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx .x ];
718+
719+ const float2 tmp = dst_fixup[(gridDim .x + bidx)*ncols + j];
720+
721+ // Scale the current and new value accumulators depending on the max. values.
722+ const float max_val_new = fmaxf (max_val[j], tmp.x );
723+
724+ const float diff_val = max_val[j] - max_val_new;
725+ const float diff_add = tmp.x - max_val_new;
726+
727+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_val) : 0 .0f ;
728+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_add) : 0 .0f ;
729+
730+ dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
731+ rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y ;
732+
733+ max_val[j] = max_val_new;
734+ }
735+
736+ // If this block started in a previous tile we are done and don't need to combine additional partial results.
737+ if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
738+ break ;
739+ }
740+ bidx--;
741+ kbc_stop = kbc;
742+ }
743+
744+ // Write back final result:
745+ #pragma unroll
746+ for (int j = 0 ; j < ncols; ++j) {
747+ if (jt*ncols + j >= ne01) {
748+ return ;
749+ }
750+ dst[j*ne02*D + threadIdx .x ] = dst_val[j] / rowsum[j];
751+ }
752+ }
753+
656754template <int D, int parallel_blocks> // D == head size
657755#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
658756__launch_bounds__ (D, 1 )
@@ -722,10 +820,11 @@ static void on_no_fattn_vec_case(const int D) {
722820 }
723821}
724822
725- template <int D, int parallel_blocks>
823+ // parallel_blocks == 0 is stream-k decomposition
824+ template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
726825void launch_fattn (
727826 ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
728- const int nwarps, const int cols_per_block , const bool need_f16_K, const bool need_f16_V
827+ const int nwarps, const size_t nbytes_shared , const bool need_f16_K, const bool need_f16_V
729828) {
730829 const ggml_tensor * Q = dst->src [0 ];
731830 const ggml_tensor * K = dst->src [1 ];
@@ -744,20 +843,23 @@ void launch_fattn(
744843
745844 GGML_ASSERT (K->ne [1 ] % FATTN_KQ_STRIDE == 0 && " Incorrect KV cache padding." );
746845
846+ GGML_ASSERT (Q->ne [3 ] == 1 );
847+
747848 ggml_cuda_pool & pool = ctx.pool ();
748849 cudaStream_t main_stream = ctx.stream ();
850+ const int nsm = ggml_cuda_info ().devices [ggml_cuda_get_device ()].nsm ;
749851
750852 ggml_cuda_pool_alloc<half> K_f16 (pool);
751853 ggml_cuda_pool_alloc<half> V_f16 (pool);
752854 ggml_cuda_pool_alloc<float > dst_tmp (pool);
753855 ggml_cuda_pool_alloc<float2 > dst_tmp_meta (pool);
754856
755- char * K_data = (char *) K->data ;
857+ const char * K_data = (const char *) K->data ;
756858 size_t nb11 = K->nb [1 ];
757859 size_t nb12 = K->nb [2 ];
758860 size_t nb13 = K->nb [3 ];
759861
760- char * V_data = (char *) V->data ;
862+ const char * V_data = (const char *) V->data ;
761863 size_t nb21 = V->nb [1 ];
762864 size_t nb22 = V->nb [2 ];
763865 size_t nb23 = V->nb [3 ];
@@ -790,39 +892,60 @@ void launch_fattn(
790892 nb23 = nb23*bs*sizeof (half)/ts;
791893 }
792894
793- if (parallel_blocks > 1 ) {
794- dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
795- dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
796- }
895+ const int ntiles_x = ((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block);
896+ const int ntiles_total = ntiles_x*Q->ne [2 ]*Q->ne [3 ];
797897
798898 const dim3 block_dim (WARP_SIZE, nwarps, 1 );
799- const dim3 blocks_num (parallel_blocks*((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block), Q->ne [2 ], Q->ne [3 ]);
800- const int shmem = 0 ;
899+ dim3 blocks_num;
900+ if (parallel_blocks == 0 ) {
901+ // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
902+ const int tiles_nwaves = (ntiles_total - nsm - 1 ) / nsm;
903+ const bool tiles_inefficient = 3 *nsm < 2 *tiles_nwaves*ntiles_total;
904+ const bool short_context = K->ne [1 ] < 4096 ;
905+
906+ const int nblocks_stream_k = 2 *nsm;
907+
908+ blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
909+ blocks_num.y = 1 ;
910+ blocks_num.z = 1 ;
911+
912+ dst_tmp_meta.alloc (blocks_num.x *cols_per_block * (2 *2 + D) * sizeof (float ));
913+ } else {
914+ blocks_num.x = parallel_blocks*ntiles_x;
915+ blocks_num.y = Q->ne [2 ];
916+ blocks_num.z = Q->ne [3 ];
917+
918+ if (parallel_blocks > 1 ) {
919+ dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
920+ dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
921+ }
922+ }
923+
801924
802925 float scale = 1 .0f ;
803926 float max_bias = 0 .0f ;
804927 float logit_softcap = 0 .0f ;
805928
806- memcpy (&scale, (float *) KQV->op_params + 0 , sizeof (float ));
807- memcpy (&max_bias, (float *) KQV->op_params + 1 , sizeof (float ));
808- memcpy (&logit_softcap, (float *) KQV->op_params + 2 , sizeof (float ));
929+ memcpy (&scale, (const float *) KQV->op_params + 0 , sizeof (float ));
930+ memcpy (&max_bias, (const float *) KQV->op_params + 1 , sizeof (float ));
931+ memcpy (&logit_softcap, (const float *) KQV->op_params + 2 , sizeof (float ));
809932
810933 if (logit_softcap != 0 .0f ) {
811934 scale /= logit_softcap;
812935 }
813936
814937 const uint32_t n_head = Q->ne [2 ];
815- const uint32_t n_head_log2 = 1u << ( uint32_t ) floorf (log2f (( float ) n_head));
938+ const uint32_t n_head_log2 = 1u << uint32_t ( floorf (log2f (float ( n_head)) ));
816939
817940 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
818941 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
819942
820- fattn_kernel<<<blocks_num, block_dim, shmem , main_stream>>> (
943+ fattn_kernel<<<blocks_num, block_dim, nbytes_shared , main_stream>>> (
821944 (const char *) Q->data ,
822945 K_data,
823946 V_data,
824947 mask ? ((const char *) mask->data ) : nullptr ,
825- (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp. ptr , dst_tmp_meta.ptr ,
948+ (parallel_blocks) > 1 ? dst_tmp. ptr : (float *) KQV->data , dst_tmp_meta.ptr ,
826949 scale, max_bias, m0, m1, n_head_log2, logit_softcap,
827950 Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
828951 K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
@@ -834,16 +957,22 @@ void launch_fattn(
834957 );
835958 CUDA_CHECK (cudaGetLastError ());
836959
837- if ((parallel_blocks) == 1 ) {
838- return ;
839- }
960+ if constexpr (parallel_blocks == 0 ) {
961+ if (blocks_num.x % ntiles_total != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
962+ const dim3 block_dim_combine (D, 1 , 1 );
963+ const dim3 blocks_num_combine = blocks_num;
840964
841- const dim3 block_dim_combine (D, 1 , 1 );
842- const dim3 blocks_num_combine (Q->ne [1 ], blocks_num.y , blocks_num.z );
843- const int shmem_combine = 0 ;
965+ flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
966+ <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
967+ ((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], K->ne [1 ]);
968+ }
969+ } else if constexpr (parallel_blocks > 1 ) {
970+ const dim3 block_dim_combine (D, 1 , 1 );
971+ const dim3 blocks_num_combine (Q->ne [1 ], blocks_num.y , blocks_num.z );
844972
845- flash_attn_combine_results<D, parallel_blocks>
846- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
847- (dst_tmp.ptr , dst_tmp_meta.ptr , (float *) KQV->data );
973+ flash_attn_combine_results<D, parallel_blocks>
974+ <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
975+ (dst_tmp.ptr , dst_tmp_meta.ptr , (float *) KQV->data );
976+ }
848977 CUDA_CHECK (cudaGetLastError ());
849978}
0 commit comments