Skip to content

Commit d93f843

Browse files
authored
opencl: fix FA for f32 (ggml-org#16584)
1 parent f9fb33f commit d93f843

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

ggml/src/ggml-opencl/kernels/flash_attn_f32.cl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#define ACC_TYPE4 float4
55
#define DATA_TYPE float
66
#define DATA_TYPE4 float4
7+
#define MASK_DATA_TYPE half
78
#define CONVERT_ACC4(x) (x)
89
#define CONVERT_DATA4(x) (x)
910

@@ -148,7 +149,7 @@ __kernel void flash_attn_f32(
148149
if (k_row1 >= n_kv) score1 = -INFINITY;
149150

150151
if (mask_base != NULL) {
151-
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
152+
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
152153
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
153154
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
154155
}
@@ -281,7 +282,7 @@ __kernel void flash_attn_f32_q1(
281282
}
282283
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
283284
if (mask_base != NULL) {
284-
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
285+
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
285286
score += slope * (ACC_TYPE)mask_ptr[k_idx];
286287
}
287288
if (logit_softcap > 0.0f) {
@@ -317,7 +318,7 @@ __kernel void flash_attn_f32_q1(
317318
}
318319
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
319320
if (mask_base != NULL) {
320-
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
321+
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
321322
score += slope * (ACC_TYPE)mask_ptr[k_idx];
322323
}
323324
if (logit_softcap > 0.0f) {

0 commit comments

Comments
 (0)