@@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)(
2525 const float m1,
2626 const uint32_t n_head_log2,
2727 const float logit_softcap,
28- const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
28+ const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
2929 const int32_t nb01, const int32_t nb02, const int32_t nb03,
3030 const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
3131 const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -621,7 +621,8 @@ static __global__ void flash_attn_mask_to_KV_max(
621621template <int D, int ncols1, int ncols2> // D == head size
622622__launch_bounds__ (D, 1 )
623623static __global__ void flash_attn_stream_k_fixup(
624- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
624+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
625+ const int nbatch_fa) {
625626 constexpr int ncols = ncols1*ncols2;
626627
627628 const int bidx0 = blockIdx .x ;
@@ -632,8 +633,8 @@ static __global__ void flash_attn_stream_k_fixup(
632633
633634 const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim .x *(2 *2 *ncols);
634635
635- const int iter_k = ne11 / FATTN_KQ_STRIDE ;
636- const int iter_j = (ne01 + (ncols1 - 1 )) / ncols1;
636+ const int iter_k = ( ne11 + (nbatch_fa - 1 )) / nbatch_fa ;
637+ const int iter_j = (ne01 + (ncols1 - 1 )) / ncols1;
637638
638639 const int kbc0 = (bidx0 + 0 )*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
639640 const int kbc0_stop = (bidx0 + 1 )*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
@@ -765,7 +766,7 @@ static __global__ void flash_attn_combine_results(
765766template <int DV, int ncols1, int ncols2>
766767void launch_fattn (
767768 ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
768- const int KQ_row_granularity , const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
769+ const int nbatch_fa , const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
769770) {
770771 constexpr int ncols = ncols1 * ncols2;
771772
@@ -790,8 +791,6 @@ void launch_fattn(
790791 GGML_ASSERT (!V || V->nb [0 ] == ggml_element_size (V));
791792
792793 GGML_ASSERT (!mask || mask->type == GGML_TYPE_F16);
793- GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
794- " the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
795794
796795 ggml_cuda_pool & pool = ctx.pool ();
797796 cudaStream_t main_stream = ctx.stream ();
@@ -915,7 +914,7 @@ void launch_fattn(
915914
916915 dst_tmp_meta.alloc (blocks_num.x *ncols * (2 *2 + DV) * sizeof (float ));
917916 } else {
918- const int ntiles_KQ = (K->ne [1 ] + KQ_row_granularity - 1 ) / KQ_row_granularity ; // Max. number of parallel blocks limited by tensor size.
917+ const int ntiles_KQ = (K->ne [1 ] + nbatch_fa - 1 ) / nbatch_fa ; // Max. number of parallel blocks limited by tensor size.
919918
920919 // parallel_blocks must not be larger than what the tensor size allows:
921920 parallel_blocks = std::min (parallel_blocks, ntiles_KQ);
@@ -970,6 +969,9 @@ void launch_fattn(
970969 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
971970 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
972971
972+ // TODO other tensor dimensions after removal of WMMA kernel:
973+ const uint3 ne01 = init_fastdiv_values (Q->ne [1 ]);
974+
973975 GGML_ASSERT (block_dim.x % warp_size == 0 );
974976 fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>> (
975977 (const char *) Q->data ,
@@ -980,7 +982,7 @@ void launch_fattn(
980982 KV_max.ptr ,
981983 !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data , dst_tmp_meta.ptr ,
982984 scale, max_bias, m0, m1, n_head_log2, logit_softcap,
983- Q->ne [0 ], Q-> ne [ 1 ], Q->ne [2 ], Q->ne [3 ], Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
985+ Q->ne [0 ], ne01, Q->ne [2 ], Q->ne [3 ], Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
984986 K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ], nb11, nb12, nb13,
985987 nb21, nb22, nb23,
986988 mask ? mask->ne [1 ] : 0 , mask ? mask->ne [2 ] : 0 , mask ? mask->ne [3 ] : 0 ,
@@ -995,7 +997,7 @@ void launch_fattn(
995997
996998 flash_attn_stream_k_fixup<DV, ncols1, ncols2>
997999 <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
998- ((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], Q->ne [3 ], K->ne [1 ]);
1000+ ((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], Q->ne [3 ], K->ne [1 ], nbatch_fa );
9991001 }
10001002 } else if (parallel_blocks > 1 ) {
10011003 const dim3 block_dim_combine (DV, 1 , 1 );
0 commit comments