@@ -2520,8 +2520,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2520
2520
case GGML_OP_CLAMP:
2521
2521
return op->src [0 ]->type == GGML_TYPE_F32;
2522
2522
case GGML_OP_SOFT_MAX:
2523
- // TODO: support attention sinks [TAG_ATTN_SINKS]
2524
- return op->src [2 ] == nullptr ;
2525
2523
case GGML_OP_NORM:
2526
2524
case GGML_OP_RMS_NORM:
2527
2525
return true ;
@@ -6594,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
6594
6592
GGML_ASSERT (src1->extra );
6595
6593
}
6596
6594
6595
+ const ggml_tensor * src2 = dst->src [2 ];
6596
+ if (src2) {
6597
+ GGML_ASSERT (src2->extra );
6598
+ }
6599
+
6597
6600
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
6598
6601
6599
6602
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
6600
6603
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
6601
6604
6602
6605
ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr ;
6606
+ ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr ;
6603
6607
6604
6608
cl_ulong offset0 = extra0->offset + src0->view_offs ;
6605
6609
cl_ulong offsetd = extrad->offset + dst->view_offs ;
6606
6610
6607
6611
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
6612
+ cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
6608
6613
6609
6614
const int ne00 = src0->ne [0 ];
6610
6615
const int ne01 = src0->ne [1 ];
@@ -6672,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
6672
6677
CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
6673
6678
CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), extra1 ? &extra1->data_device : &extra0->data_device ));
6674
6679
CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
6675
- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
6676
- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
6677
- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
6678
- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &nb01));
6679
- CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb02));
6680
- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb03));
6681
- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne12));
6682
- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne13));
6683
- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb11));
6684
- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb12));
6685
- CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb13));
6686
- CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb1));
6687
- CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb2));
6688
- CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb3));
6689
- CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (float ), &scale));
6690
- CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (float ), &max_bias));
6691
- CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (float ), &m0));
6692
- CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (float ), &m1));
6693
- CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &n_head_log2));
6680
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), extra2 ? &extra2->data_device : &extra0->data_device ));
6681
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offset2));
6682
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (cl_mem), &extrad->data_device ));
6683
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &offsetd));
6684
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne00));
6685
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb01));
6686
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb02));
6687
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb03));
6688
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &ne12));
6689
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne13));
6690
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb11));
6691
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb12));
6692
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb13));
6693
+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb1));
6694
+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb2));
6695
+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb3));
6696
+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (float ), &scale));
6697
+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (float ), &max_bias));
6698
+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (float ), &m0));
6699
+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (float ), &m1));
6700
+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &n_head_log2));
6694
6701
6695
6702
size_t global_work_size[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
6696
6703
size_t local_work_size[] = {(size_t )nth, 1 , 1 };
0 commit comments