File tree Expand file tree Collapse file tree 2 files changed +6
-0
lines changed Expand file tree Collapse file tree 2 files changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -198,6 +198,8 @@ static __global__ void flash_attn_vec_ext_f16(
198198
199199 // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
200200 // In such cases, skip the KV slice.
201+ // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
202+ #ifndef GGML_USE_HIP
201203 bool skip = true ;
202204#pragma unroll
203205 for (int j = 0 ; j < ncols; ++j) {
@@ -212,6 +214,7 @@ static __global__ void flash_attn_vec_ext_f16(
212214 if (__all_sync (0xFFFFFFFF , skip)) {
213215 continue ;
214216 }
217+ #endif // GGML_USE_HIP
215218 }
216219
217220 // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
Original file line number Diff line number Diff line change @@ -204,6 +204,8 @@ static __global__ void flash_attn_vec_ext_f32(
204204
205205 // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
206206 // In such cases, skip the KV slice.
207+ // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
208+ #ifndef GGML_USE_HIP
207209 bool skip = true ;
208210#pragma unroll
209211 for (int j = 0 ; j < ncols; ++j) {
@@ -217,6 +219,7 @@ static __global__ void flash_attn_vec_ext_f32(
217219 if (__all_sync (0xFFFFFFFF , skip)) {
218220 continue ;
219221 }
222+ #endif // GGML_USE_HIP
220223 }
221224
222225 float kqmax_new_arr[ncols];
You can’t perform that action at this time.
0 commit comments