|
4 | 4 | #define ACC_TYPE4 float4 |
5 | 5 | #define DATA_TYPE float |
6 | 6 | #define DATA_TYPE4 float4 |
| 7 | +#define MASK_DATA_TYPE half |
7 | 8 | #define CONVERT_ACC4(x) (x) |
8 | 9 | #define CONVERT_DATA4(x) (x) |
9 | 10 |
|
@@ -148,7 +149,7 @@ __kernel void flash_attn_f32( |
148 | 149 | if (k_row1 >= n_kv) score1 = -INFINITY; |
149 | 150 |
|
150 | 151 | if (mask_base != NULL) { |
151 | | - const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1); |
| 152 | + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1); |
152 | 153 | if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0]; |
153 | 154 | if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1]; |
154 | 155 | } |
@@ -281,7 +282,7 @@ __kernel void flash_attn_f32_q1( |
281 | 282 | } |
282 | 283 | ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; |
283 | 284 | if (mask_base != NULL) { |
284 | | - const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base); |
| 285 | + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base); |
285 | 286 | score += slope * (ACC_TYPE)mask_ptr[k_idx]; |
286 | 287 | } |
287 | 288 | if (logit_softcap > 0.0f) { |
@@ -317,7 +318,7 @@ __kernel void flash_attn_f32_q1( |
317 | 318 | } |
318 | 319 | ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; |
319 | 320 | if (mask_base != NULL) { |
320 | | - const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base); |
| 321 | + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base); |
321 | 322 | score += slope * (ACC_TYPE)mask_ptr[k_idx]; |
322 | 323 | } |
323 | 324 | if (logit_softcap > 0.0f) { |
|
0 commit comments