@@ -516,6 +516,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516516 nullptr ;
517517}
518518
519+ template <int D, int ncols, int KQ_stride> // D == head size
520+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
521+ __launch_bounds__ (D, 1 )
522+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
523+ static __global__ void flash_attn_stream_k_fixup (
524+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
525+ const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim .x *(2 *2 *ncols);
526+
527+ const int iter_k = ne11 / KQ_stride;
528+ const int iter_j = (ne01 + (ncols - 1 )) / ncols;
529+
530+ const int bidx0 = blockIdx .x ;
531+
532+ const int kbc0 = (bidx0 + 0 )*iter_k*iter_j*ne02 / gridDim .x ;
533+ const int kbc0_stop = (bidx0 + 1 )*iter_k*iter_j*ne02 / gridDim .x ;
534+
535+ const bool did_not_have_any_data = kbc0 == kbc0_stop;
536+ const bool wrote_beginning_of_tile = kbc0 % iter_k == 0 ;
537+ const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0 ;
538+ if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
539+ return ;
540+ }
541+
542+ const int channel = kbc0 / (iter_k*iter_j);
543+ const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
544+
545+ dst += jt*ncols*ne02*D + channel*D;
546+
547+ // Load the partial result that needs a fixup:
548+ float dst_val[ncols] = {0 .0f };
549+ float max_val[ncols] = {0 .0f };
550+ float rowsum[ncols] = {0 .0f };
551+ #pragma unroll
552+ for (int j = 0 ; j < ncols; ++j) {
553+ if (jt*ncols + j >= ne01) {
554+ break ;
555+ }
556+ dst_val[j] = dst[j*ne02*D + threadIdx .x ];
557+
558+ const float2 tmp = dst_fixup[bidx0*ncols + j];
559+ max_val[j] = tmp.x ;
560+ rowsum[j] = tmp.y ;
561+ }
562+
563+ // Iterate over previous blocks and compute the combined results.
564+ // All CUDA blocks that get here must have a previous block that needs a fixup.
565+ int bidx = bidx0 - 1 ;
566+ int kbc_stop = kbc0;
567+ while (true ) {
568+ const int kbc = bidx*iter_k*iter_j*ne02 / gridDim .x ;
569+ if (kbc == kbc_stop) { // Did not have any data.
570+ bidx--;
571+ kbc_stop = kbc;
572+ continue ;
573+ }
574+
575+ #pragma unroll
576+ for (int j = 0 ; j < ncols; ++j) {
577+ if (jt*ncols + j >= ne01) {
578+ break ;
579+ }
580+ const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx .x ];
581+
582+ const float2 tmp = dst_fixup[(gridDim .x + bidx)*ncols + j];
583+
584+ // Scale the current and new value accumulators depending on the max. values.
585+ const float max_val_new = fmaxf (max_val[j], tmp.x );
586+
587+ const float diff_val = max_val[j] - max_val_new;
588+ const float diff_add = tmp.x - max_val_new;
589+
590+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_val) : 0 .0f ;
591+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_add) : 0 .0f ;
592+
593+ dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
594+ rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y ;
595+
596+ max_val[j] = max_val_new;
597+ }
598+
599+ // If this block started in a previous tile we are done and don't need to combine additional partial results.
600+ if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
601+ break ;
602+ }
603+ bidx--;
604+ kbc_stop = kbc;
605+ }
606+
607+ // Write back final result:
608+ #pragma unroll
609+ for (int j = 0 ; j < ncols; ++j) {
610+ if (jt*ncols + j >= ne01) {
611+ return ;
612+ }
613+ dst[j*ne02*D + threadIdx .x ] = dst_val[j] / rowsum[j];
614+ }
615+ }
616+
519617template <int D, int parallel_blocks> // D == head size
520618#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
521619__launch_bounds__ (D, 1 )
@@ -581,10 +679,11 @@ static void on_no_fattn_vec_case(const int D) {
581679 }
582680}
583681
584- template <int D, int parallel_blocks>
682+ // parallel_blocks == 0 is stream-k decomposition
683+ template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
585684void launch_fattn (
586685 ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
587- const int nwarps, const int cols_per_block , const bool need_f16_K, const bool need_f16_V
686+ const int nwarps, const size_t nbytes_shared , const bool need_f16_K, const bool need_f16_V
588687) {
589688 const ggml_tensor * Q = dst->src [0 ];
590689 const ggml_tensor * K = dst->src [1 ];
@@ -603,20 +702,23 @@ void launch_fattn(
603702
604703 GGML_ASSERT (K->ne [1 ] % FATTN_KQ_STRIDE == 0 && " Incorrect KV cache padding." );
605704
705+ GGML_ASSERT (Q->ne [3 ] == 1 );
706+
606707 ggml_cuda_pool & pool = ctx.pool ();
607708 cudaStream_t main_stream = ctx.stream ();
709+ const int nsm = ggml_cuda_info ().devices [ggml_cuda_get_device ()].nsm ;
608710
609711 ggml_cuda_pool_alloc<half> K_f16 (pool);
610712 ggml_cuda_pool_alloc<half> V_f16 (pool);
611713 ggml_cuda_pool_alloc<float > dst_tmp (pool);
612714 ggml_cuda_pool_alloc<float2 > dst_tmp_meta (pool);
613715
614- char * K_data = (char *) K->data ;
716+ const char * K_data = (const char *) K->data ;
615717 size_t nb11 = K->nb [1 ];
616718 size_t nb12 = K->nb [2 ];
617719 size_t nb13 = K->nb [3 ];
618720
619- char * V_data = (char *) V->data ;
721+ const char * V_data = (const char *) V->data ;
620722 size_t nb21 = V->nb [1 ];
621723 size_t nb22 = V->nb [2 ];
622724 size_t nb23 = V->nb [3 ];
@@ -649,39 +751,60 @@ void launch_fattn(
649751 nb23 = nb23*bs*sizeof (half)/ts;
650752 }
651753
652- if (parallel_blocks > 1 ) {
653- dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
654- dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
655- }
754+ const int ntiles_x = ((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block);
755+ const int ntiles_total = ntiles_x*Q->ne [2 ]*Q->ne [3 ];
656756
657757 const dim3 block_dim (WARP_SIZE, nwarps, 1 );
658- const dim3 blocks_num (parallel_blocks*((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block), Q->ne [2 ], Q->ne [3 ]);
659- const int shmem = 0 ;
758+ dim3 blocks_num;
759+ if (parallel_blocks == 0 ) {
760+ // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
761+ const int tiles_nwaves = (ntiles_total - nsm - 1 ) / nsm;
762+ const bool tiles_inefficient = 3 *nsm < 2 *tiles_nwaves*ntiles_total;
763+ const bool short_context = K->ne [1 ] < 4096 ;
764+
765+ const int nblocks_stream_k = 2 *nsm;
766+
767+ blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
768+ blocks_num.y = 1 ;
769+ blocks_num.z = 1 ;
770+
771+ dst_tmp_meta.alloc (blocks_num.x *cols_per_block * (2 *2 + D) * sizeof (float ));
772+ } else {
773+ blocks_num.x = parallel_blocks*ntiles_x;
774+ blocks_num.y = Q->ne [2 ];
775+ blocks_num.z = Q->ne [3 ];
776+
777+ if (parallel_blocks > 1 ) {
778+ dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
779+ dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
780+ }
781+ }
782+
660783
661784 float scale = 1 .0f ;
662785 float max_bias = 0 .0f ;
663786 float logit_softcap = 0 .0f ;
664787
665- memcpy (&scale, (float *) KQV->op_params + 0 , sizeof (float ));
666- memcpy (&max_bias, (float *) KQV->op_params + 1 , sizeof (float ));
667- memcpy (&logit_softcap, (float *) KQV->op_params + 2 , sizeof (float ));
788+ memcpy (&scale, (const float *) KQV->op_params + 0 , sizeof (float ));
789+ memcpy (&max_bias, (const float *) KQV->op_params + 1 , sizeof (float ));
790+ memcpy (&logit_softcap, (const float *) KQV->op_params + 2 , sizeof (float ));
668791
669792 if (logit_softcap != 0 .0f ) {
670793 scale /= logit_softcap;
671794 }
672795
673796 const uint32_t n_head = Q->ne [2 ];
674- const uint32_t n_head_log2 = 1u << ( uint32_t ) floorf (log2f (( float ) n_head));
797+ const uint32_t n_head_log2 = 1u << uint32_t ( floorf (log2f (float ( n_head)) ));
675798
676799 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
677800 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
678801
679- fattn_kernel<<<blocks_num, block_dim, shmem , main_stream>>> (
802+ fattn_kernel<<<blocks_num, block_dim, nbytes_shared , main_stream>>> (
680803 (const char *) Q->data ,
681804 K_data,
682805 V_data,
683806 mask ? ((const char *) mask->data ) : nullptr ,
684- (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp. ptr , dst_tmp_meta.ptr ,
807+ (parallel_blocks) > 1 ? dst_tmp. ptr : (float *) KQV->data , dst_tmp_meta.ptr ,
685808 scale, max_bias, m0, m1, n_head_log2, logit_softcap,
686809 Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
687810 K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
@@ -693,16 +816,22 @@ void launch_fattn(
693816 );
694817 CUDA_CHECK (cudaGetLastError ());
695818
696- if ((parallel_blocks) == 1 ) {
697- return ;
698- }
819+ if constexpr (parallel_blocks == 0 ) {
820+ if (blocks_num.x % ntiles_total != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
821+ const dim3 block_dim_combine (D, 1 , 1 );
822+ const dim3 blocks_num_combine = blocks_num;
699823
700- const dim3 block_dim_combine (D, 1 , 1 );
701- const dim3 blocks_num_combine (Q->ne [1 ], blocks_num.y , blocks_num.z );
702- const int shmem_combine = 0 ;
824+ flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
825+ <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
826+ ((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], K->ne [1 ]);
827+ }
828+ } else if constexpr (parallel_blocks > 1 ) {
829+ const dim3 block_dim_combine (D, 1 , 1 );
830+ const dim3 blocks_num_combine (Q->ne [1 ], blocks_num.y , blocks_num.z );
703831
704- flash_attn_combine_results<D, parallel_blocks>
705- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
706- (dst_tmp.ptr , dst_tmp_meta.ptr , (float *) KQV->data );
832+ flash_attn_combine_results<D, parallel_blocks>
833+ <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
834+ (dst_tmp.ptr , dst_tmp_meta.ptr , (float *) KQV->data );
835+ }
707836 CUDA_CHECK (cudaGetLastError ());
708837}
0 commit comments