@@ -2480,6 +2480,13 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
24802480 case GGML_OP_SCALE:
24812481 return op->src [0 ]->type == GGML_TYPE_F32 && ggml_is_contiguous (op->src [0 ]);
24822482 case GGML_OP_ADD:
2483+ if (op->type == GGML_TYPE_F16) {
2484+ const bool src0_ok = op->src [0 ]->type == GGML_TYPE_F16 || op->src [0 ]->type == GGML_TYPE_F32;
2485+ const bool src1_ok = op->src [1 ]->type == GGML_TYPE_F16 || op->src [1 ]->type == GGML_TYPE_F32;
2486+ if (src0_ok && src1_ok) {
2487+ return true ;
2488+ }
2489+ }
24832490 case GGML_OP_MUL:
24842491 case GGML_OP_DIV:
24852492 case GGML_OP_SUB:
@@ -3718,34 +3725,30 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
37183725 GGML_ASSERT (dst);
37193726 GGML_ASSERT (dst->extra );
37203727
3721- GGML_ASSERT (src0->type == src1->type );
3722- GGML_ASSERT (src0->type == dst->type );
3723- GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3724-
3725- const int ne00 = src0->ne [0 ];
3726- const int ne01 = src0->ne [1 ];
3727- const int ne02 = src0->ne [2 ];
3728- const int ne03 = src0->ne [3 ];
3728+ const int ne00 = src0->ne [0 ];
3729+ const int ne01 = src0->ne [1 ];
3730+ const int ne02 = src0->ne [2 ];
3731+ const int ne03 = src0->ne [3 ];
37293732
37303733 const cl_ulong nb00 = src0->nb [0 ];
37313734 const cl_ulong nb01 = src0->nb [1 ];
37323735 const cl_ulong nb02 = src0->nb [2 ];
37333736 const cl_ulong nb03 = src0->nb [3 ];
37343737
3735- const int ne10 = src1->ne [0 ];
3736- const int ne11 = src1->ne [1 ];
3737- const int ne12 = src1->ne [2 ];
3738- const int ne13 = src1->ne [3 ]; UNUSED (ne13) ;
3738+ const int ne10 = src1->ne [0 ];
3739+ const int ne11 = src1->ne [1 ];
3740+ const int ne12 = src1->ne [2 ];
3741+ const int ne13 = src1->ne [3 ];
37393742
37403743 const cl_ulong nb10 = src1->nb [0 ];
37413744 const cl_ulong nb11 = src1->nb [1 ];
37423745 const cl_ulong nb12 = src1->nb [2 ];
3743- const cl_ulong nb13 = src1->nb [3 ]; UNUSED (nb13);
3746+ const cl_ulong nb13 = src1->nb [3 ];
37443747
3745- const int ne0 = dst->ne [0 ];
3746- const int ne1 = dst->ne [1 ];
3747- const int ne2 = dst->ne [2 ];
3748- const int ne3 = dst->ne [3 ];
3748+ const int ne0 = dst->ne [0 ];
3749+ const int ne1 = dst->ne [1 ];
3750+ const int ne2 = dst->ne [2 ];
3751+ const int ne3 = dst->ne [3 ];
37493752
37503753 const cl_ulong nb0 = dst->nb [0 ];
37513754 const cl_ulong nb1 = dst->nb [1 ];
@@ -3762,68 +3765,114 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
37623765 cl_ulong offset1 = extra1->offset + src1->view_offs ;
37633766 cl_ulong offsetd = extrad->offset + dst->view_offs ;
37643767
3765- bool bcast_row = false ;
37663768 cl_kernel kernel;
37673769
3768- if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
3769- GGML_ASSERT (ggml_is_contiguous (src0));
3770+ const bool bcast_row = ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ;
37703771
3771- // src1 is a row
3772+ if (bcast_row) {
3773+ GGML_ASSERT (ggml_is_contiguous (src0));
37723774 GGML_ASSERT (ne11 == 1 );
3775+ }
37733776
3774- bcast_row = true ;
3775- int ne = ne00 / 4 ;
3776-
3777- if (src0->type == GGML_TYPE_F32) {
3777+ if (dst->type == GGML_TYPE_F32) {
3778+ GGML_ASSERT (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);
3779+ if (bcast_row) {
37783780 kernel = backend_ctx->kernel_add_row ;
3781+ const int ne = ne00 / 4 ;
3782+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3783+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3784+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3785+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3786+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3787+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3788+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
37793789 } else {
3780- kernel = backend_ctx->kernel_add_row_f16 ;
3781- }
3782-
3783- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3784- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3785- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3786- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3787- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3788- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3789- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
3790- } else {
3791- if (src0->type == GGML_TYPE_F32) {
37923790 kernel = backend_ctx->kernel_add ;
3791+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3792+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3793+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3794+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3795+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3796+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3797+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
3798+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
3799+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
3800+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne03));
3801+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb00));
3802+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb01));
3803+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb02));
3804+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb03));
3805+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne10));
3806+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne11));
3807+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne12));
3808+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne13));
3809+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb10));
3810+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb11));
3811+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb12));
3812+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb13));
3813+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &ne0));
3814+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &ne1));
3815+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &ne2));
3816+ CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &ne3));
3817+ CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (cl_ulong), &nb0));
3818+ CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (cl_ulong), &nb1));
3819+ CL_CHECK (clSetKernelArg (kernel, 28 , sizeof (cl_ulong), &nb2));
3820+ CL_CHECK (clSetKernelArg (kernel, 29 , sizeof (cl_ulong), &nb3));
3821+ }
3822+ } else if (dst->type == GGML_TYPE_F16) {
3823+ GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
3824+ GGML_ASSERT (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
3825+ const int type_src0 = (src0->type == GGML_TYPE_F32);
3826+ const int type_src1 = (src1->type == GGML_TYPE_F32);
3827+ if (bcast_row) {
3828+ kernel = backend_ctx->kernel_add_row_f16 ;
3829+ const int ne = ne00 / 4 ;
3830+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3831+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3832+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3833+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3834+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3835+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3836+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
3837+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &type_src0));
3838+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &type_src1));
37933839 } else {
37943840 kernel = backend_ctx->kernel_add_f16 ;
3841+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3842+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3843+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3844+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3845+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3846+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3847+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
3848+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
3849+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
3850+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne03));
3851+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb00));
3852+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb01));
3853+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb02));
3854+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb03));
3855+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne10));
3856+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne11));
3857+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne12));
3858+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne13));
3859+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb10));
3860+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb11));
3861+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb12));
3862+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb13));
3863+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &ne0));
3864+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &ne1));
3865+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &ne2));
3866+ CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &ne3));
3867+ CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (cl_ulong), &nb0));
3868+ CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (cl_ulong), &nb1));
3869+ CL_CHECK (clSetKernelArg (kernel, 28 , sizeof (cl_ulong), &nb2));
3870+ CL_CHECK (clSetKernelArg (kernel, 29 , sizeof (cl_ulong), &nb3));
3871+ CL_CHECK (clSetKernelArg (kernel, 30 , sizeof (int ), &type_src0));
3872+ CL_CHECK (clSetKernelArg (kernel, 31 , sizeof (int ), &type_src1));
37953873 }
3796-
3797- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3798- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3799- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3800- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3801- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3802- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3803- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
3804- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
3805- CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
3806- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne03));
3807- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb00));
3808- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb01));
3809- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb02));
3810- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb03));
3811- CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne10));
3812- CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne11));
3813- CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne12));
3814- CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne13));
3815- CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb10));
3816- CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb11));
3817- CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb12));
3818- CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb13));
3819- CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &ne0));
3820- CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &ne1));
3821- CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &ne2));
3822- CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &ne3));
3823- CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (cl_ulong), &nb0));
3824- CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (cl_ulong), &nb1));
3825- CL_CHECK (clSetKernelArg (kernel, 28 , sizeof (cl_ulong), &nb2));
3826- CL_CHECK (clSetKernelArg (kernel, 29 , sizeof (cl_ulong), &nb3));
3874+ } else {
3875+ GGML_ASSERT (false && " unsupported data types for add" );
38273876 }
38283877
38293878 if (bcast_row) {
@@ -3833,13 +3882,13 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
38333882
38343883 size_t * local_work_size_ptr = local_work_size;
38353884 if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups ) {
3836- local_work_size_ptr = nullptr ; // Let driver choose the work-group sizes.
3885+ local_work_size_ptr = nullptr ;
38373886 }
38383887
3839- backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size_ptr, dst);
3888+ backend_ctx->enqueue_ndrange_kernel (kernel, 1 , global_work_size, local_work_size_ptr, dst);
38403889 } else {
38413890 unsigned int nth = MIN (64 , ne0);
3842- size_t global_work_size[] = {ne01*nth, (size_t )ne02, (size_t )ne03};
3891+ size_t global_work_size[] = {( size_t ) ne01*nth, (size_t )ne02, (size_t )ne03};
38433892 size_t local_work_size[] = {nth, 1 , 1 };
38443893
38453894 backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
0 commit comments