@@ -793,8 +793,6 @@ void launch_fattn(
793793 GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
794794 " the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
795795
796- GGML_ASSERT (K->ne [1 ] % FATTN_KQ_STRIDE == 0 && " Incorrect KV cache padding." );
797-
798796 ggml_cuda_pool & pool = ctx.pool ();
799797 cudaStream_t main_stream = ctx.stream ();
800798 const int id = ggml_cuda_get_device ();
@@ -878,7 +876,7 @@ void launch_fattn(
878876 // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
879877 // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
880878 // multiple sequences of possibly different lengths.
881- if (mask && (Q->ne [1 ] >= 1024 || Q->ne [3 ] > 1 )) {
879+ if (mask && K-> ne [ 1 ] % FATTN_KQ_STRIDE == 0 && (Q->ne [1 ] >= 1024 || Q->ne [3 ] > 1 )) {
882880 const int s31 = mask->nb [1 ] / sizeof (half2);
883881 const int s33 = mask->nb [3 ] / sizeof (half2);
884882
@@ -916,8 +914,7 @@ void launch_fattn(
916914
917915 dst_tmp_meta.alloc (blocks_num.x *ncols * (2 *2 + DV) * sizeof (float ));
918916 } else {
919- GGML_ASSERT (K->ne [1 ] % KQ_row_granularity == 0 );
920- const int ntiles_KQ = K->ne [1 ] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
917+ const int ntiles_KQ = (K->ne [1 ] + KQ_row_granularity - 1 ) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
921918
922919 // parallel_blocks must not be larger than what the tensor size allows:
923920 parallel_blocks = std::min (parallel_blocks, ntiles_KQ);
@@ -946,7 +943,7 @@ void launch_fattn(
946943
947944 blocks_num.x = ntiles_x;
948945 blocks_num.y = parallel_blocks;
949- blocks_num.z = Q->ne [2 ]*Q->ne [3 ];
946+ blocks_num.z = ( Q->ne [2 ]/ncols2) *Q->ne [3 ];
950947
951948 if (parallel_blocks > 1 ) {
952949 dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
0 commit comments