@@ -2481,6 +2481,13 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
24812481 case GGML_OP_SCALE:
24822482 return op->src [0 ]->type == GGML_TYPE_F32 && ggml_is_contiguous (op->src [0 ]);
24832483 case GGML_OP_ADD:
2484+ if (op->type == GGML_TYPE_F16) {
2485+ const bool src0_ok = op->src [0 ]->type == GGML_TYPE_F16 || op->src [0 ]->type == GGML_TYPE_F32;
2486+ const bool src1_ok = op->src [1 ]->type == GGML_TYPE_F16 || op->src [1 ]->type == GGML_TYPE_F32;
2487+ if (src0_ok && src1_ok) {
2488+ return true ;
2489+ }
2490+ }
24842491 case GGML_OP_MUL:
24852492 case GGML_OP_DIV:
24862493 case GGML_OP_SUB:
@@ -3717,34 +3724,30 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
37173724 GGML_ASSERT (dst);
37183725 GGML_ASSERT (dst->extra );
37193726
3720- GGML_ASSERT (src0->type == src1->type );
3721- GGML_ASSERT (src0->type == dst->type );
3722- GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3723-
3724- const int ne00 = src0->ne [0 ];
3725- const int ne01 = src0->ne [1 ];
3726- const int ne02 = src0->ne [2 ];
3727- const int ne03 = src0->ne [3 ];
3727+ const int ne00 = src0->ne [0 ];
3728+ const int ne01 = src0->ne [1 ];
3729+ const int ne02 = src0->ne [2 ];
3730+ const int ne03 = src0->ne [3 ];
37283731
37293732 const cl_ulong nb00 = src0->nb [0 ];
37303733 const cl_ulong nb01 = src0->nb [1 ];
37313734 const cl_ulong nb02 = src0->nb [2 ];
37323735 const cl_ulong nb03 = src0->nb [3 ];
37333736
3734- const int ne10 = src1->ne [0 ];
3735- const int ne11 = src1->ne [1 ];
3736- const int ne12 = src1->ne [2 ];
3737- const int ne13 = src1->ne [3 ]; UNUSED (ne13) ;
3737+ const int ne10 = src1->ne [0 ];
3738+ const int ne11 = src1->ne [1 ];
3739+ const int ne12 = src1->ne [2 ];
3740+ const int ne13 = src1->ne [3 ];
37383741
37393742 const cl_ulong nb10 = src1->nb [0 ];
37403743 const cl_ulong nb11 = src1->nb [1 ];
37413744 const cl_ulong nb12 = src1->nb [2 ];
3742- const cl_ulong nb13 = src1->nb [3 ]; UNUSED (nb13);
3745+ const cl_ulong nb13 = src1->nb [3 ];
37433746
3744- const int ne0 = dst->ne [0 ];
3745- const int ne1 = dst->ne [1 ];
3746- const int ne2 = dst->ne [2 ];
3747- const int ne3 = dst->ne [3 ];
3747+ const int ne0 = dst->ne [0 ];
3748+ const int ne1 = dst->ne [1 ];
3749+ const int ne2 = dst->ne [2 ];
3750+ const int ne3 = dst->ne [3 ];
37483751
37493752 const cl_ulong nb0 = dst->nb [0 ];
37503753 const cl_ulong nb1 = dst->nb [1 ];
@@ -3761,68 +3764,114 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
37613764 cl_ulong offset1 = extra1->offset + src1->view_offs ;
37623765 cl_ulong offsetd = extrad->offset + dst->view_offs ;
37633766
3764- bool bcast_row = false ;
37653767 cl_kernel kernel;
37663768
3767- if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
3768- GGML_ASSERT (ggml_is_contiguous (src0));
3769+ const bool bcast_row = ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ;
37693770
3770- // src1 is a row
3771+ if (bcast_row) {
3772+ GGML_ASSERT (ggml_is_contiguous (src0));
37713773 GGML_ASSERT (ne11 == 1 );
3774+ }
37723775
3773- bcast_row = true ;
3774- int ne = ne00 / 4 ;
3775-
3776- if (src0->type == GGML_TYPE_F32) {
3776+ if (dst->type == GGML_TYPE_F32) {
3777+ GGML_ASSERT (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);
3778+ if (bcast_row) {
37773779 kernel = backend_ctx->kernel_add_row ;
3780+ const int ne = ne00 / 4 ;
3781+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3782+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3783+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3784+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3785+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3786+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3787+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
37783788 } else {
3779- kernel = backend_ctx->kernel_add_row_f16 ;
3780- }
3781-
3782- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3783- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3784- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3785- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3786- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3787- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3788- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
3789- } else {
3790- if (src0->type == GGML_TYPE_F32) {
37913789 kernel = backend_ctx->kernel_add ;
3790+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3791+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3792+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3793+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3794+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3795+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3796+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
3797+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
3798+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
3799+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne03));
3800+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb00));
3801+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb01));
3802+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb02));
3803+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb03));
3804+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne10));
3805+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne11));
3806+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne12));
3807+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne13));
3808+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb10));
3809+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb11));
3810+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb12));
3811+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb13));
3812+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &ne0));
3813+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &ne1));
3814+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &ne2));
3815+ CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &ne3));
3816+ CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (cl_ulong), &nb0));
3817+ CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (cl_ulong), &nb1));
3818+ CL_CHECK (clSetKernelArg (kernel, 28 , sizeof (cl_ulong), &nb2));
3819+ CL_CHECK (clSetKernelArg (kernel, 29 , sizeof (cl_ulong), &nb3));
3820+ }
3821+ } else if (dst->type == GGML_TYPE_F16) {
3822+ GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
3823+ GGML_ASSERT (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
3824+ const int type_src0 = (src0->type == GGML_TYPE_F32);
3825+ const int type_src1 = (src1->type == GGML_TYPE_F32);
3826+ if (bcast_row) {
3827+ kernel = backend_ctx->kernel_add_row_f16 ;
3828+ const int ne = ne00 / 4 ;
3829+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3830+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3831+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3832+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3833+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3834+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3835+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
3836+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &type_src0));
3837+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &type_src1));
37923838 } else {
37933839 kernel = backend_ctx->kernel_add_f16 ;
3840+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3841+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3842+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3843+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3844+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3845+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3846+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
3847+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
3848+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
3849+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne03));
3850+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb00));
3851+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb01));
3852+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb02));
3853+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb03));
3854+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne10));
3855+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne11));
3856+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne12));
3857+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne13));
3858+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb10));
3859+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb11));
3860+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb12));
3861+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb13));
3862+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &ne0));
3863+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &ne1));
3864+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &ne2));
3865+ CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &ne3));
3866+ CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (cl_ulong), &nb0));
3867+ CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (cl_ulong), &nb1));
3868+ CL_CHECK (clSetKernelArg (kernel, 28 , sizeof (cl_ulong), &nb2));
3869+ CL_CHECK (clSetKernelArg (kernel, 29 , sizeof (cl_ulong), &nb3));
3870+ CL_CHECK (clSetKernelArg (kernel, 30 , sizeof (int ), &type_src0));
3871+ CL_CHECK (clSetKernelArg (kernel, 31 , sizeof (int ), &type_src1));
37943872 }
3795-
3796- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3797- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3798- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3799- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3800- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3801- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3802- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
3803- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
3804- CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
3805- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne03));
3806- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb00));
3807- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb01));
3808- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb02));
3809- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb03));
3810- CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne10));
3811- CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne11));
3812- CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne12));
3813- CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne13));
3814- CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb10));
3815- CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb11));
3816- CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb12));
3817- CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb13));
3818- CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &ne0));
3819- CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &ne1));
3820- CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &ne2));
3821- CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &ne3));
3822- CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (cl_ulong), &nb0));
3823- CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (cl_ulong), &nb1));
3824- CL_CHECK (clSetKernelArg (kernel, 28 , sizeof (cl_ulong), &nb2));
3825- CL_CHECK (clSetKernelArg (kernel, 29 , sizeof (cl_ulong), &nb3));
3873+ } else {
3874+ GGML_ASSERT (false && " unsupported data types for add" );
38263875 }
38273876
38283877 if (bcast_row) {
@@ -3832,13 +3881,13 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
38323881
38333882 size_t * local_work_size_ptr = local_work_size;
38343883 if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups ) {
3835- local_work_size_ptr = nullptr ; // Let driver choose the work-group sizes.
3884+ local_work_size_ptr = nullptr ;
38363885 }
38373886
3838- backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size_ptr, dst);
3887+ backend_ctx->enqueue_ndrange_kernel (kernel, 1 , global_work_size, local_work_size_ptr, dst);
38393888 } else {
38403889 unsigned int nth = MIN (64 , ne0);
3841- size_t global_work_size[] = {ne01*nth, (size_t )ne02, (size_t )ne03};
3890+ size_t global_work_size[] = {( size_t ) ne01*nth, (size_t )ne02, (size_t )ne03};
38423891 size_t local_work_size[] = {nth, 1 , 1 };
38433892
38443893 backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
0 commit comments