Skip to content

Commit 9462a98

Browse files
committed
Reapply "CUDA: fix misaligned synchronization in FA (ggml-org#13469)"
This reverts commit 4721a56.
1 parent c80f76a commit 9462a98

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
736736
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
737737
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
738738
}
739+
} else if (np > 1) {
740+
// Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
741+
// Therefore, all other warps also need to execute a __syncthreads().
742+
// Otherwise the points at which warps synchronize with each other would become misaligned.
743+
__syncthreads();
739744
}
740745

741746
if (np > 1) {

0 commit comments

Comments
 (0)