@@ -793,6 +793,8 @@ void launch_fattn(
793793 GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
794794 " the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
795795
796+ GGML_ASSERT (K->ne [1 ] % FATTN_KQ_STRIDE == 0 && " Incorrect KV cache padding." );
797+
796798 ggml_cuda_pool & pool = ctx.pool ();
797799 cudaStream_t main_stream = ctx.stream ();
798800 const int id = ggml_cuda_get_device ();
@@ -876,7 +878,7 @@ void launch_fattn(
876878 // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
877879 // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
878880 // multiple sequences of possibly different lengths.
879- if (mask && K-> ne [ 1 ] % FATTN_KQ_STRIDE == 0 && (Q->ne [1 ] >= 1024 || Q->ne [3 ] > 1 )) {
881+ if (mask && (Q->ne [1 ] >= 1024 || Q->ne [3 ] > 1 )) {
880882 const int s31 = mask->nb [1 ] / sizeof (half2);
881883 const int s33 = mask->nb [3 ] / sizeof (half2);
882884
@@ -915,7 +917,8 @@ void launch_fattn(
915917
916918 dst_tmp_meta.alloc (blocks_num.x *ncols * (2 *2 + DV) * sizeof (float ));
917919 } else {
918- const int ntiles_KQ = (K->ne [1 ] + KQ_row_granularity - 1 ) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
920+ GGML_ASSERT (K->ne [1 ] % KQ_row_granularity == 0 );
921+ const int ntiles_KQ = K->ne [1 ] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
919922
920923 // parallel_blocks must not be larger than what the tensor size allows:
921924 parallel_blocks = std::min (parallel_blocks, ntiles_KQ);
@@ -944,7 +947,7 @@ void launch_fattn(
944947
945948 blocks_num.x = ntiles_x;
946949 blocks_num.y = parallel_blocks;
947- blocks_num.z = ( Q->ne [2 ]/ncols2) *Q->ne [3 ];
950+ blocks_num.z = Q->ne [2 ]*Q->ne [3 ];
948951
949952 if (parallel_blocks > 1 ) {
950953 dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
0 commit comments