@@ -2520,8 +2520,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
25202520 case GGML_OP_CLAMP:
25212521 return op->src [0 ]->type == GGML_TYPE_F32;
25222522 case GGML_OP_SOFT_MAX:
2523- // TODO: support attention sinks [TAG_ATTN_SINKS]
2524- return op->src [2 ] == nullptr ;
25252523 case GGML_OP_NORM:
25262524 case GGML_OP_RMS_NORM:
25272525 return true ;
@@ -6594,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
65946592 GGML_ASSERT (src1->extra );
65956593 }
65966594
6595+ const ggml_tensor * src2 = dst->src [2 ];
6596+ if (src2) {
6597+ GGML_ASSERT (src2->extra );
6598+ }
6599+
65976600 ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
65986601
65996602 ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
66006603 ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
66016604
66026605 ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr ;
6606+ ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr ;
66036607
66046608 cl_ulong offset0 = extra0->offset + src0->view_offs ;
66056609 cl_ulong offsetd = extrad->offset + dst->view_offs ;
66066610
66076611 cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
6612+ cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
66086613
66096614 const int ne00 = src0->ne [0 ];
66106615 const int ne01 = src0->ne [1 ];
@@ -6672,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
66726677 CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
66736678 CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), extra1 ? &extra1->data_device : &extra0->data_device ));
66746679 CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
6675- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
6676- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
6677- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
6678- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &nb01));
6679- CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb02));
6680- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb03));
6681- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne12));
6682- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne13));
6683- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb11));
6684- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb12));
6685- CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb13));
6686- CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb1));
6687- CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb2));
6688- CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb3));
6689- CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (float ), &scale));
6690- CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (float ), &max_bias));
6691- CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (float ), &m0));
6692- CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (float ), &m1));
6693- CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &n_head_log2));
6680+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), extra2 ? &extra2->data_device : &extra0->data_device ));
6681+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offset2));
6682+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (cl_mem), &extrad->data_device ));
6683+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &offsetd));
6684+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne00));
6685+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb01));
6686+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb02));
6687+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb03));
6688+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &ne12));
6689+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne13));
6690+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb11));
6691+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb12));
6692+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb13));
6693+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb1));
6694+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb2));
6695+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb3));
6696+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (float ), &scale));
6697+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (float ), &max_bias));
6698+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (float ), &m0));
6699+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (float ), &m1));
6700+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &n_head_log2));
66946701
66956702 size_t global_work_size[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
66966703 size_t local_work_size[] = {(size_t )nth, 1 , 1 };
0 commit comments