Skip to content

Commit 97669e4

Browse files
authored
opencl: add attn sinks support for FA kernels (ggml-org#15706)
1 parent 2f85368 commit 97669e4

File tree

4 files changed

+102
-16
lines changed

4 files changed

+102
-16
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2776,10 +2776,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
27762776
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
27772777
case GGML_OP_FLASH_ATTN_EXT:
27782778
{
2779-
if (op->src[4]) {
2780-
return false;
2781-
}
2782-
27832779
const ggml_tensor * q = op->src[0];
27842780
const ggml_tensor * k = op->src[1];
27852781
const ggml_tensor * v = op->src[2];
@@ -5765,13 +5761,17 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
57655761
static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
57665762
const ggml_tensor * v = dst->src[2];
57675763
const ggml_tensor * mask = dst->src[3];
5764+
const ggml_tensor * sinks = dst->src[4];
57685765
GGML_ASSERT(q->extra);
57695766
GGML_ASSERT(k->extra);
57705767
GGML_ASSERT(v->extra);
57715768
GGML_ASSERT(dst->extra);
57725769
if (mask) {
57735770
GGML_ASSERT(mask->extra);
57745771
}
5772+
if (sinks) {
5773+
GGML_ASSERT(sinks->extra);
5774+
}
57755775

57765776
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
57775777

@@ -5813,13 +5813,16 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
58135813
ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
58145814
ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
58155815
ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
5816+
ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL;
58165817

58175818
cl_ulong offset_q = extra_q->offset + q->view_offs;
58185819
cl_ulong offset_k = extra_k->offset + k->view_offs;
58195820
cl_ulong offset_v = extra_v->offset + v->view_offs;
58205821
cl_ulong offset_o = extra_o->offset + dst->view_offs;
58215822
cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
58225823
cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
5824+
cl_mem sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL;
5825+
cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0;
58235826

58245827
const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
58255828
const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
@@ -5874,6 +5877,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
58745877
CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
58755878
CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
58765879
CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));
5880+
CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem), &sinks_buffer));
5881+
CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks));
58775882

58785883
if (n_q == 1) {
58795884
const size_t wg_size = 64;

ggml/src/ggml-opencl/kernels/flash_attn_f16.cl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ __kernel void flash_attn_f16(
4949
const ulong mask_nb2,
5050
const ulong mask_nb3,
5151
const int mask_ne2,
52-
const int mask_ne3
52+
const int mask_ne3,
53+
const global void* sinks_void,
54+
const ulong sinks_offset
5355
) {
5456
const int tid = get_local_id(0);
5557
const int block_q_idx = get_group_id(0);
@@ -171,6 +173,20 @@ __kernel void flash_attn_f16(
171173
}
172174

173175
if (my_query_row < n_q) {
176+
if (sinks_void != NULL) {
177+
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
178+
const ACC_TYPE m_sink = sinks_ptr[head_idx];
179+
const ACC_TYPE m_final = max(m_i, m_sink);
180+
181+
const ACC_TYPE scale_o = exp(m_i - m_final);
182+
#pragma unroll
183+
for (int i = 0; i < DV_VEC; ++i) {
184+
o_acc[i] *= scale_o;
185+
}
186+
187+
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
188+
}
189+
174190
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
175191
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
176192
if (l_i > 0.0f) {
@@ -214,7 +230,9 @@ __kernel void flash_attn_f16_q1(
214230
const ulong mask_nb2,
215231
const ulong mask_nb3,
216232
const int mask_ne2,
217-
const int mask_ne3
233+
const int mask_ne3,
234+
const global void* sinks_void,
235+
const ulong sinks_offset
218236
) {
219237
const int tid = get_local_id(0);
220238
const int head_batch_idx = get_global_id(1);
@@ -247,7 +265,12 @@ __kernel void flash_attn_f16_q1(
247265

248266
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
249267

250-
ACC_TYPE m_i = -INFINITY;
268+
const global ACC_TYPE* sinks_ptr = NULL;
269+
if (sinks_void != NULL) {
270+
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
271+
}
272+
273+
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
251274
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
252275
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
253276
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
@@ -320,7 +343,11 @@ __kernel void flash_attn_f16_q1(
320343

321344
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
322345
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
323-
const ACC_TYPE l_final = local_l[0];
346+
ACC_TYPE l_final = local_l[0];
347+
348+
if (sinks_ptr != NULL) {
349+
l_final += exp(sinks_ptr[head_idx] - m_final);
350+
}
324351

325352
if (l_final > 0.0f) {
326353
const ACC_TYPE l_inv = 1.0f / l_final;

ggml/src/ggml-opencl/kernels/flash_attn_f32.cl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ __kernel void flash_attn_f32(
4949
const ulong mask_nb2,
5050
const ulong mask_nb3,
5151
const int mask_ne2,
52-
const int mask_ne3
52+
const int mask_ne3,
53+
const global void* sinks_void,
54+
const ulong sinks_offset
5355
) {
5456
const int tid = get_local_id(0);
5557
const int block_q_idx = get_group_id(0);
@@ -171,6 +173,20 @@ __kernel void flash_attn_f32(
171173
}
172174

173175
if (my_query_row < n_q) {
176+
if (sinks_void != NULL) {
177+
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
178+
const ACC_TYPE m_sink = sinks_ptr[head_idx];
179+
const ACC_TYPE m_final = max(m_i, m_sink);
180+
181+
const ACC_TYPE scale_o = exp(m_i - m_final);
182+
#pragma unroll
183+
for (int i = 0; i < DV_VEC; ++i) {
184+
o_acc[i] *= scale_o;
185+
}
186+
187+
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
188+
}
189+
174190
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
175191
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
176192
if (l_i > 0.0f) {
@@ -214,7 +230,9 @@ __kernel void flash_attn_f32_q1(
214230
const ulong mask_nb2,
215231
const ulong mask_nb3,
216232
const int mask_ne2,
217-
const int mask_ne3
233+
const int mask_ne3,
234+
const global void* sinks_void,
235+
const ulong sinks_offset
218236
) {
219237
const int tid = get_local_id(0);
220238
const int head_batch_idx = get_global_id(1);
@@ -247,7 +265,12 @@ __kernel void flash_attn_f32_q1(
247265

248266
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
249267

250-
ACC_TYPE m_i = -INFINITY;
268+
const global ACC_TYPE* sinks_ptr = NULL;
269+
if (sinks_void != NULL) {
270+
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
271+
}
272+
273+
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
251274
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
252275
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
253276
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
@@ -320,7 +343,11 @@ __kernel void flash_attn_f32_q1(
320343

321344
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
322345
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
323-
const ACC_TYPE l_final = local_l[0];
346+
ACC_TYPE l_final = local_l[0];
347+
348+
if (sinks_ptr != NULL) {
349+
l_final += exp(sinks_ptr[head_idx] - m_final);
350+
}
324351

325352
if (l_final > 0.0f) {
326353
const ACC_TYPE l_inv = 1.0f / l_final;

ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ __kernel void flash_attn_f32_f16(
5252
const ulong mask_nb2,
5353
const ulong mask_nb3,
5454
const int mask_ne2,
55-
const int mask_ne3
55+
const int mask_ne3,
56+
const global void* sinks_void,
57+
const ulong sinks_offset
5658
) {
5759
const int tid = get_local_id(0);
5860
const int block_q_idx = get_group_id(0);
@@ -174,6 +176,20 @@ __kernel void flash_attn_f32_f16(
174176
}
175177

176178
if (my_query_row < n_q) {
179+
if (sinks_void != NULL) {
180+
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
181+
const ACC_TYPE m_sink = sinks_ptr[head_idx];
182+
const ACC_TYPE m_final = max(m_i, m_sink);
183+
184+
const ACC_TYPE scale_o = exp(m_i - m_final);
185+
#pragma unroll
186+
for (int i = 0; i < DV_VEC; ++i) {
187+
o_acc[i] *= scale_o;
188+
}
189+
190+
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
191+
}
192+
177193
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
178194
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
179195
if (l_i > 0.0f) {
@@ -217,7 +233,9 @@ __kernel void flash_attn_f32_f16_q1(
217233
const ulong mask_nb2,
218234
const ulong mask_nb3,
219235
const int mask_ne2,
220-
const int mask_ne3
236+
const int mask_ne3,
237+
const global void* sinks_void,
238+
const ulong sinks_offset
221239
) {
222240
const int tid = get_local_id(0);
223241
const int head_batch_idx = get_global_id(1);
@@ -250,7 +268,12 @@ __kernel void flash_attn_f32_f16_q1(
250268

251269
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
252270

253-
ACC_TYPE m_i = -INFINITY;
271+
const global ACC_TYPE* sinks_ptr = NULL;
272+
if (sinks_void != NULL) {
273+
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
274+
}
275+
276+
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
254277
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
255278
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
256279
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
@@ -323,7 +346,11 @@ __kernel void flash_attn_f32_f16_q1(
323346

324347
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
325348
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
326-
const ACC_TYPE l_final = local_l[0];
349+
ACC_TYPE l_final = local_l[0];
350+
351+
if (sinks_ptr != NULL) {
352+
l_final += exp(sinks_ptr[head_idx] - m_final);
353+
}
327354

328355
if (l_final > 0.0f) {
329356
const ACC_TYPE l_inv = 1.0f / l_final;

0 commit comments

Comments
 (0)