@@ -149,6 +149,8 @@ static __global__ void flash_attn_vec_ext_f16(
149149 VKQ += V_k*KQ2[k0/2 ];
150150 }
151151 }
152+
153+ __syncthreads ();
152154 }
153155
154156 if (tid >= D) {
@@ -547,7 +549,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
547549 dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
548550 }
549551
550- constexpr int nwarps = ((D) + WARP_SIZE - 1 ) / WARP_SIZE;
552+ constexpr int nwarps = (D + WARP_SIZE - 1 ) / WARP_SIZE;
551553 constexpr dim3 block_dim (WARP_SIZE, nwarps, 1 );
552554 const dim3 blocks_num (parallel_blocks*Q->ne [1 ], Q->ne [2 ], Q->ne [3 ]);
553555 const int shmem = 0 ;
@@ -561,7 +563,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
561563 (const char *) K->data ,
562564 (const char *) V->data ,
563565 mask ? ((const char *) mask->data ) : nullptr ,
564- ( parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr , dst_tmp_meta.ptr ,
566+ parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr , dst_tmp_meta.ptr ,
565567 scale,
566568 Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
567569 K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
@@ -572,7 +574,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
572574 );
573575 CUDA_CHECK (cudaGetLastError ());
574576
575- if (( parallel_blocks) == 1 ) {
577+ if (parallel_blocks == 1 ) {
576578 return ;
577579 }
578580
0 commit comments