Skip to content

Commit aaa3d07

Browse files
authored
opencl: support sink in soft_max (attn sinks) (#15152)
1 parent 50aa938 commit aaa3d07

File tree

5 files changed

+68
-29
lines changed

5 files changed

+68
-29
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2520,8 +2520,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
25202520
case GGML_OP_CLAMP:
25212521
return op->src[0]->type == GGML_TYPE_F32;
25222522
case GGML_OP_SOFT_MAX:
2523-
// TODO: support attention sinks [TAG_ATTN_SINKS]
2524-
return op->src[2] == nullptr;
25252523
case GGML_OP_NORM:
25262524
case GGML_OP_RMS_NORM:
25272525
return true;
@@ -6594,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
65946592
GGML_ASSERT(src1->extra);
65956593
}
65966594

6595+
const ggml_tensor * src2 = dst->src[2];
6596+
if (src2) {
6597+
GGML_ASSERT(src2->extra);
6598+
}
6599+
65976600
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
65986601

65996602
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
66006603
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
66016604

66026605
ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
6606+
ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
66036607

66046608
cl_ulong offset0 = extra0->offset + src0->view_offs;
66056609
cl_ulong offsetd = extrad->offset + dst->view_offs;
66066610

66076611
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
6612+
cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
66086613

66096614
const int ne00 = src0->ne[0];
66106615
const int ne01 = src0->ne[1];
@@ -6672,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
66726677
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
66736678
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device));
66746679
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
6675-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
6676-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
6677-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
6678-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
6679-
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
6680-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
6681-
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
6682-
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
6683-
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
6684-
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
6685-
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
6686-
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
6687-
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
6688-
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
6689-
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
6690-
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
6691-
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
6692-
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
6693-
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
6680+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device));
6681+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
6682+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
6683+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
6684+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
6685+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
6686+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
6687+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
6688+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
6689+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13));
6690+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
6691+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
6692+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
6693+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
6694+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
6695+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
6696+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &scale));
6697+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &max_bias));
6698+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &m0));
6699+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m1));
6700+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_head_log2));
66946701

66956702
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
66966703
size_t local_work_size[] = {(size_t)nth, 1, 1};

ggml/src/ggml-opencl/kernels/softmax_4_f16.cl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16(
2626
ulong offset0,
2727
global char * src1,
2828
ulong offset1,
29+
global char * src2,
30+
ulong offset2,
2931
global char * dst,
3032
ulong offsetd,
3133
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16(
4850
) {
4951
src0 = src0 + offset0;
5052
src1 = src1 + offset1;
53+
src2 = src2 + offset2;
5154
dst = dst + offsetd;
5255

5356
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16(
6063

6164
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
6265
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66+
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
6367
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
6468

6569
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16(
7579
}
7680

7781
// parallel max
78-
float4 lmax4 = -INFINITY;
82+
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
7983
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
8084
lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
8185
}
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16(
9296
}
9397
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
9498

95-
const float sum = sub_group_reduce_add(lsum);
99+
float sum = sub_group_reduce_add(lsum);
100+
101+
if (psrc2) {
102+
sum += exp(psrc2[i02] - max);
103+
}
96104

97105
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
98106
pdst4[i00] /= sum;

ggml/src/ggml-opencl/kernels/softmax_4_f32.cl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4(
2626
ulong offset0,
2727
global char * src1,
2828
ulong offset1,
29+
global char * src2,
30+
ulong offset2,
2931
global char * dst,
3032
ulong offsetd,
3133
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4(
4850
) {
4951
src0 = src0 + offset0;
5052
src1 = src1 + offset1;
53+
src2 = src2 + offset2;
5154
dst = dst + offsetd;
5255

5356
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4(
6063

6164
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
6265
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66+
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
6367
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
6468

6569
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4(
7579
}
7680

7781
// parallel max
78-
float4 lmax4 = -INFINITY;
82+
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
7983
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
8084
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
8185
}
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4(
9296
}
9397
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
9498

95-
const float sum = sub_group_reduce_add(lsum);
99+
float sum = sub_group_reduce_add(lsum);
100+
101+
if (psrc2) {
102+
sum += exp(psrc2[i02] - max);
103+
}
96104

97105
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
98106
pdst4[i00] /= sum;

ggml/src/ggml-opencl/kernels/softmax_f16.cl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_f16(
2626
ulong offset0,
2727
global char * src1,
2828
ulong offset1,
29+
global char * src2,
30+
ulong offset2,
2931
global char * dst,
3032
ulong offsetd,
3133
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_f16(
4850
) {
4951
src0 = src0 + offset0;
5052
src1 = src1 + offset1;
53+
src2 = src2 + offset2;
5154
dst = dst + offsetd;
5255

5356
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_f16(
6063

6164
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
6265
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66+
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
6367
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
6468

6569
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_f16(
7579
}
7680

7781
// parallel max
78-
float lmax = -INFINITY;
82+
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
7983
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
8084
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
8185
}
@@ -91,7 +95,11 @@ kernel void kernel_soft_max_f16(
9195
pdst[i00] = exp_psrc0;
9296
}
9397

94-
const float sum = sub_group_reduce_add(lsum);
98+
float sum = sub_group_reduce_add(lsum);
99+
100+
if (psrc2) {
101+
sum += exp(psrc2[i02] - max);
102+
}
95103

96104
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
97105
pdst[i00] /= sum;

ggml/src/ggml-opencl/kernels/softmax_f32.cl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ kernel void kernel_soft_max(
2626
ulong offset0,
2727
global char * src1,
2828
ulong offset1,
29+
global char * src2,
30+
ulong offset2,
2931
global char * dst,
3032
ulong offsetd,
3133
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max(
4850
) {
4951
src0 = src0 + offset0;
5052
src1 = src1 + offset1;
53+
src2 = src2 + offset2;
5154
dst = dst + offsetd;
5255

5356
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max(
6063

6164
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
6265
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66+
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
6367
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
6468

6569
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max(
7579
}
7680

7781
// parallel max
78-
float lmax = -INFINITY;
82+
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
7983
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
8084
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
8185
}
@@ -91,7 +95,11 @@ kernel void kernel_soft_max(
9195
pdst[i00] = exp_psrc0;
9296
}
9397

94-
const float sum = sub_group_reduce_add(lsum);
98+
float sum = sub_group_reduce_add(lsum);
99+
100+
if (psrc2) {
101+
sum += exp(psrc2[i02] - max);
102+
}
95103

96104
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
97105
pdst[i00] /= sum;

0 commit comments

Comments
 (0)