@@ -606,48 +606,47 @@ static __global__ void flash_attn_stream_k_fixup(
606606 *dst = dst_val / rowsum;
607607}
608608
609- template <int D, int parallel_blocks > // D == head size
609+ template <int D> // D == head size
610610#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
611611__launch_bounds__ (D, 1 )
612612#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
613613static __global__ void flash_attn_combine_results (
614614 const float * __restrict__ VKQ_parts,
615615 const float2 * __restrict__ VKQ_meta,
616- float * __restrict__ dst) {
617- VKQ_parts += parallel_blocks*D * gridDim .y *blockIdx .x ;
618- VKQ_meta += parallel_blocks * gridDim .y *blockIdx .x ;
619- dst += D * gridDim .y *blockIdx .x ;
616+ float * __restrict__ dst,
617+ const int parallel_blocks) {
618+ VKQ_parts += parallel_blocks*D * gridDim .z *blockIdx .x ;
619+ VKQ_meta += parallel_blocks * gridDim .z *blockIdx .x ;
620+ dst += D * gridDim .z *blockIdx .x ;
620621
621622 const int tid = threadIdx .x ;
622623 __builtin_assume (tid < D);
623624
624- __shared__ float2 meta[parallel_blocks ];
625+ extern __shared__ float2 meta[];
625626 if (tid < 2 *parallel_blocks) {
626- ((float *) meta)[threadIdx .x ] = ((const float *)VKQ_meta) [blockIdx .y *(2 *parallel_blocks) + tid];
627+ ((float *) meta)[threadIdx .x ] = ((const float *)VKQ_meta) [blockIdx .z *(2 *parallel_blocks) + tid];
627628 }
628629
629630 __syncthreads ();
630631
631632 float kqmax = meta[0 ].x ;
632- #pragma unroll
633633 for (int l = 1 ; l < parallel_blocks; ++l) {
634634 kqmax = max (kqmax, meta[l].x );
635635 }
636636
637637 float VKQ_numerator = 0 .0f ;
638638 float VKQ_denominator = 0 .0f ;
639- #pragma unroll
640639 for (int l = 0 ; l < parallel_blocks; ++l) {
641640 const float diff = meta[l].x - kqmax;
642641 const float KQ_max_scale = expf (diff);
643642 const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
644643 *((uint32_t *) &KQ_max_scale) &= ftz_mask;
645644
646- VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim .y *D + blockIdx .y *D + tid];
645+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim .z *D + blockIdx .z *D + tid];
647646 VKQ_denominator += KQ_max_scale * meta[l].y ;
648647 }
649648
650- dst[blockIdx .y *D + tid] = VKQ_numerator / VKQ_denominator;
649+ dst[blockIdx .z *D + tid] = VKQ_numerator / VKQ_denominator;
651650}
652651
653652static void on_no_fattn_vec_case (const int D) {
@@ -671,12 +670,10 @@ static void on_no_fattn_vec_case(const int D) {
671670 }
672671}
673672
674- // parallel_blocks == 0 is stream-k decomposition
675- template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
673+ template <int D, int ncols1, int ncols2, int KQ_stride>
676674void launch_fattn (
677- ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
678- const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
679- const int warp_size = WARP_SIZE
675+ ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
676+ const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
680677) {
681678 constexpr int ncols = ncols1 * ncols2;
682679
@@ -748,12 +745,14 @@ void launch_fattn(
748745 nb23 = nb23*bs*sizeof (half)/ts;
749746 }
750747
748+ int parallel_blocks = 1 ;
749+
751750 const int ntiles_x = ((Q->ne [1 ] + ncols1 - 1 ) / ncols1);
752751 const int ntiles_total = ntiles_x * (Q->ne [2 ] / ncols2) * Q->ne [3 ];
753752
754753 const dim3 block_dim (warp_size, nwarps, 1 );
755754 dim3 blocks_num;
756- if (parallel_blocks == 0 ) {
755+ if (stream_k ) {
757756 // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
758757 const int max_blocks = 2 *nsm;
759758 const int tiles_nwaves = (ntiles_total + max_blocks - 1 ) / max_blocks;
@@ -769,9 +768,43 @@ void launch_fattn(
769768
770769 dst_tmp_meta.alloc (blocks_num.x *ncols * (2 *2 + D) * sizeof (float ));
771770 } else {
772- blocks_num.x = parallel_blocks*ntiles_x;
773- blocks_num.y = Q->ne [2 ];
774- blocks_num.z = Q->ne [3 ];
771+ GGML_ASSERT (K->ne [1 ] % KQ_row_granularity == 0 );
772+ const int ntiles_KQ = K->ne [1 ] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
773+
774+ int max_blocks_per_sm = 1 ; // Max. number of active blocks limited by occupancy.
775+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z , nbytes_shared));
776+
777+ // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
778+ parallel_blocks = std::max ((nsm * max_blocks_per_sm) / ntiles_total, 1 );
779+
780+ // parallel_blocks must not be larger than what the tensor size allows:
781+ parallel_blocks = std::min (parallel_blocks, ntiles_KQ);
782+
783+ // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
784+ // Test whether parallel_blocks can be set to a higher value for better efficiency.
785+ const int blocks_per_wave = nsm * max_blocks_per_sm;
786+ int nwaves_best = 0 ;
787+ int efficiency_percent_best = 0 ;
788+ for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
789+ const int nblocks_total = ntiles_total * parallel_blocks_test;
790+ const int nwaves = (nblocks_total + blocks_per_wave - 1 ) / blocks_per_wave;
791+ const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
792+
793+ // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
794+ if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
795+ break ;
796+ }
797+
798+ if (efficiency_percent > efficiency_percent_best) {
799+ nwaves_best = nwaves;
800+ efficiency_percent_best = efficiency_percent;
801+ parallel_blocks = parallel_blocks_test;
802+ }
803+ }
804+
805+ blocks_num.x = ntiles_x;
806+ blocks_num.y = parallel_blocks;
807+ blocks_num.z = Q->ne [2 ]*Q->ne [3 ];
775808
776809 if (parallel_blocks > 1 ) {
777810 dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
@@ -803,7 +836,7 @@ void launch_fattn(
803836 K_data,
804837 V_data,
805838 mask ? ((const char *) mask->data ) : nullptr ,
806- ( parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data , dst_tmp_meta.ptr ,
839+ !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data , dst_tmp_meta.ptr ,
807840 scale, max_bias, m0, m1, n_head_log2, logit_softcap,
808841 Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
809842 K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
@@ -815,7 +848,7 @@ void launch_fattn(
815848 );
816849 CUDA_CHECK (cudaGetLastError ());
817850
818- if constexpr (parallel_blocks == 0 ) {
851+ if (stream_k ) {
819852 if (ntiles_total % blocks_num.x != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
820853 const dim3 block_dim_combine (D, 1 , 1 );
821854 const dim3 blocks_num_combine = {blocks_num.x , ncols1, ncols2};
@@ -824,13 +857,14 @@ void launch_fattn(
824857 <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
825858 ((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], K->ne [1 ]);
826859 }
827- } else if constexpr (parallel_blocks > 1 ) {
860+ } else if (parallel_blocks > 1 ) {
828861 const dim3 block_dim_combine (D, 1 , 1 );
829- const dim3 blocks_num_combine (Q->ne [1 ], blocks_num.y , blocks_num.z );
862+ const dim3 blocks_num_combine (Q->ne [1 ], 1 , blocks_num.z );
863+ const size_t nbytes_shared_combine = parallel_blocks*sizeof (float2 );
830864
831- flash_attn_combine_results<D, parallel_blocks >
832- <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
833- (dst_tmp.ptr , dst_tmp_meta.ptr , (float *) KQV->data );
865+ flash_attn_combine_results<D>
866+ <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine , main_stream>>>
867+ (dst_tmp.ptr , dst_tmp_meta.ptr , (float *) KQV->data , parallel_blocks );
834868 }
835869 CUDA_CHECK (cudaGetLastError ());
836870}
0 commit comments