@@ -400,10 +400,10 @@ struct ggml_backend_opencl_context {
400400 cl_program program_mul_mm_f32_f32_l4_lm;
401401 cl_program program_mul_mm_f16_f32_l4_lm;
402402
403- cl_kernel kernel_add, kernel_add_row;
404- cl_kernel kernel_mul, kernel_mul_row;
405- cl_kernel kernel_div, kernel_div_row;
406- cl_kernel kernel_sub, kernel_sub_row;
403+ cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16 ;
404+ cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16 ;
405+ cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16 ;
406+ cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16 ;
407407 cl_kernel kernel_scale;
408408 cl_kernel kernel_silu, kernel_silu_4;
409409 cl_kernel kernel_gelu, kernel_gelu_4;
@@ -674,8 +674,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
674674 backend_ctx->program_add =
675675 build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
676676
677- CL_CHECK ((backend_ctx->kernel_add = clCreateKernel (backend_ctx->program_add , " kernel_add" , &err), err));
678- CL_CHECK ((backend_ctx->kernel_add_row = clCreateKernel (backend_ctx->program_add , " kernel_add_row" , &err), err));
677+ CL_CHECK ((backend_ctx->kernel_add = clCreateKernel (backend_ctx->program_add , " kernel_add" , &err), err));
678+ CL_CHECK ((backend_ctx->kernel_add_row = clCreateKernel (backend_ctx->program_add , " kernel_add_row" , &err), err));
679+ CL_CHECK ((backend_ctx->kernel_add_f16 = clCreateKernel (backend_ctx->program_add , " kernel_add_f16" , &err), err));
680+ CL_CHECK ((backend_ctx->kernel_add_row_f16 = clCreateKernel (backend_ctx->program_add , " kernel_add_row_f16" , &err), err));
679681 GGML_LOG_CONT (" ." );
680682 }
681683
@@ -1089,8 +1091,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
10891091 backend_ctx->program_mul =
10901092 build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
10911093
1092- CL_CHECK ((backend_ctx->kernel_mul = clCreateKernel (backend_ctx->program_mul , " kernel_mul" , &err), err));
1093- CL_CHECK ((backend_ctx->kernel_mul_row = clCreateKernel (backend_ctx->program_mul , " kernel_mul_row" , &err), err));
1094+ CL_CHECK ((backend_ctx->kernel_mul = clCreateKernel (backend_ctx->program_mul , " kernel_mul" , &err), err));
1095+ CL_CHECK ((backend_ctx->kernel_mul_row = clCreateKernel (backend_ctx->program_mul , " kernel_mul_row" , &err), err));
1096+ CL_CHECK ((backend_ctx->kernel_mul_f16 = clCreateKernel (backend_ctx->program_mul , " kernel_mul_f16" , &err), err));
1097+ CL_CHECK ((backend_ctx->kernel_mul_row_f16 = clCreateKernel (backend_ctx->program_mul , " kernel_mul_row_f16" , &err), err));
10941098 GGML_LOG_CONT (" ." );
10951099 }
10961100
@@ -1288,11 +1292,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
12881292#else
12891293 const std::string kernel_src = read_file (" div.cl" );
12901294#endif
1295+ std::string compile_opts = std::string (" -cl-std=" ) + opencl_c_std +
1296+ " -cl-mad-enable -cl-finite-math-only " ;
1297+
12911298 backend_ctx->program_div =
12921299 build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
12931300
1294- CL_CHECK ((backend_ctx->kernel_div = clCreateKernel (backend_ctx->program_div , " kernel_div" , &err), err));
1295- CL_CHECK ((backend_ctx->kernel_div_row = clCreateKernel (backend_ctx->program_div , " kernel_div_row" , &err), err));
1301+ CL_CHECK ((backend_ctx->kernel_div = clCreateKernel (backend_ctx->program_div , " kernel_div" , &err), err));
1302+ CL_CHECK ((backend_ctx->kernel_div_row = clCreateKernel (backend_ctx->program_div , " kernel_div_row" , &err), err));
1303+ CL_CHECK ((backend_ctx->kernel_div_f16 = clCreateKernel (backend_ctx->program_div , " kernel_div_f16" , &err), err));
1304+ CL_CHECK ((backend_ctx->kernel_div_row_f16 = clCreateKernel (backend_ctx->program_div , " kernel_div_row_f16" , &err), err));
12961305 GGML_LOG_CONT (" ." );
12971306 }
12981307
@@ -1308,8 +1317,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
13081317 backend_ctx->program_sub =
13091318 build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
13101319
1311- CL_CHECK ((backend_ctx->kernel_sub = clCreateKernel (backend_ctx->program_sub , " kernel_sub" , &err), err));
1312- CL_CHECK ((backend_ctx->kernel_sub_row = clCreateKernel (backend_ctx->program_sub , " kernel_sub_row" , &err), err));
1320+ CL_CHECK ((backend_ctx->kernel_sub = clCreateKernel (backend_ctx->program_sub , " kernel_sub" , &err), err));
1321+ CL_CHECK ((backend_ctx->kernel_sub_row = clCreateKernel (backend_ctx->program_sub , " kernel_sub_row" , &err), err));
1322+ CL_CHECK ((backend_ctx->kernel_sub_f16 = clCreateKernel (backend_ctx->program_sub , " kernel_sub_f16" , &err), err));
1323+ CL_CHECK ((backend_ctx->kernel_sub_row_f16 = clCreateKernel (backend_ctx->program_sub , " kernel_sub_row_f16" , &err), err));
13131324 GGML_LOG_CONT (" ." );
13141325 }
13151326
@@ -2447,12 +2458,15 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
24472458 default :
24482459 return false ;
24492460 }
2450- case GGML_OP_ADD:
24512461 case GGML_OP_SCALE:
2462+ return op->src [0 ]->type == GGML_TYPE_F32 && ggml_is_contiguous (op->src [0 ]);
2463+ case GGML_OP_ADD:
24522464 case GGML_OP_MUL:
24532465 case GGML_OP_DIV:
24542466 case GGML_OP_SUB:
2455- return op->src [0 ]->type == GGML_TYPE_F32;
2467+ return (op->src [0 ]->type == op->src [1 ]->type ) &&
2468+ (op->src [0 ]->type == op->type ) &&
2469+ (op->src [0 ]->type == GGML_TYPE_F32 || op->src [0 ]->type == GGML_TYPE_F16);
24562470 case GGML_OP_UNARY:
24572471 switch (ggml_get_unary_op (op)) {
24582472 case GGML_UNARY_OP_GELU:
@@ -3680,35 +3694,39 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
36803694 GGML_ASSERT (dst);
36813695 GGML_ASSERT (dst->extra );
36823696
3683- const int ne00 = src0 ? src0->ne [0 ] : 0 ;
3684- const int ne01 = src0 ? src0->ne [1 ] : 0 ;
3685- const int ne02 = src0 ? src0->ne [2 ] : 0 ;
3686- const int ne03 = src0 ? src0->ne [3 ] : 0 ;
3697+ GGML_ASSERT (src0->type == src1->type );
3698+ GGML_ASSERT (src0->type == dst->type );
3699+ GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
36873700
3688- const cl_ulong nb00 = src0 ? src0-> nb [0 ] : 0 ;
3689- const cl_ulong nb01 = src0 ? src0-> nb [1 ] : 0 ;
3690- const cl_ulong nb02 = src0 ? src0-> nb [2 ] : 0 ;
3691- const cl_ulong nb03 = src0 ? src0-> nb [3 ] : 0 ;
3701+ const int ne00 = src0-> ne [0 ];
3702+ const int ne01 = src0-> ne [1 ];
3703+ const int ne02 = src0-> ne [2 ];
3704+ const int ne03 = src0-> ne [3 ];
36923705
3693- const int ne10 = src1 ? src1-> ne [0 ] : 0 ;
3694- const int ne11 = src1 ? src1-> ne [1 ] : 0 ;
3695- const int ne12 = src1 ? src1-> ne [2 ] : 0 ;
3696- const int ne13 = src1 ? src1-> ne [3 ] : 0 ; UNUSED (ne13) ;
3706+ const cl_ulong nb00 = src0-> nb [0 ];
3707+ const cl_ulong nb01 = src0-> nb [1 ];
3708+ const cl_ulong nb02 = src0-> nb [2 ];
3709+ const cl_ulong nb03 = src0-> nb [3 ];
36973710
3698- const cl_ulong nb10 = src1 ? src1-> nb [0 ] : 0 ;
3699- const cl_ulong nb11 = src1 ? src1-> nb [1 ] : 0 ;
3700- const cl_ulong nb12 = src1 ? src1-> nb [2 ] : 0 ;
3701- const cl_ulong nb13 = src1 ? src1-> nb [3 ] : 0 ; UNUSED (nb13 );
3711+ const int ne10 = src1-> ne [0 ];
3712+ const int ne11 = src1-> ne [1 ];
3713+ const int ne12 = src1-> ne [2 ];
3714+ const int ne13 = src1-> ne [3 ]; UNUSED (ne13 );
37023715
3703- const int ne0 = dst ? dst-> ne [0 ] : 0 ;
3704- const int ne1 = dst ? dst-> ne [1 ] : 0 ;
3705- const int ne2 = dst ? dst-> ne [2 ] : 0 ;
3706- const int ne3 = dst ? dst-> ne [3 ] : 0 ;
3716+ const cl_ulong nb10 = src1-> nb [0 ];
3717+ const cl_ulong nb11 = src1-> nb [1 ];
3718+ const cl_ulong nb12 = src1-> nb [2 ];
3719+ const cl_ulong nb13 = src1-> nb [3 ]; UNUSED (nb13) ;
37073720
3708- const cl_ulong nb0 = dst ? dst->nb [0 ] : 0 ;
3709- const cl_ulong nb1 = dst ? dst->nb [1 ] : 0 ;
3710- const cl_ulong nb2 = dst ? dst->nb [2 ] : 0 ;
3711- const cl_ulong nb3 = dst ? dst->nb [3 ] : 0 ;
3721+ const int ne0 = dst->ne [0 ];
3722+ const int ne1 = dst->ne [1 ];
3723+ const int ne2 = dst->ne [2 ];
3724+ const int ne3 = dst->ne [3 ];
3725+
3726+ const cl_ulong nb0 = dst->nb [0 ];
3727+ const cl_ulong nb1 = dst->nb [1 ];
3728+ const cl_ulong nb2 = dst->nb [2 ];
3729+ const cl_ulong nb3 = dst->nb [3 ];
37123730
37133731 ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
37143732
@@ -3731,7 +3749,12 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
37313749
37323750 bcast_row = true ;
37333751 int ne = ne00 / 4 ;
3734- kernel = backend_ctx->kernel_add_row ;
3752+
3753+ if (src0->type == GGML_TYPE_F32) {
3754+ kernel = backend_ctx->kernel_add_row ;
3755+ } else {
3756+ kernel = backend_ctx->kernel_add_row_f16 ;
3757+ }
37353758
37363759 CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
37373760 CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
@@ -3741,7 +3764,11 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
37413764 CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
37423765 CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
37433766 } else {
3744- kernel = backend_ctx->kernel_add ;
3767+ if (src0->type == GGML_TYPE_F32) {
3768+ kernel = backend_ctx->kernel_add ;
3769+ } else {
3770+ kernel = backend_ctx->kernel_add_f16 ;
3771+ }
37453772
37463773 CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
37473774 CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
@@ -3803,35 +3830,39 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
38033830 GGML_ASSERT (dst);
38043831 GGML_ASSERT (dst->extra );
38053832
3806- const int ne00 = src0 ? src0->ne [0 ] : 0 ;
3807- const int ne01 = src0 ? src0->ne [1 ] : 0 ;
3808- const int ne02 = src0 ? src0->ne [2 ] : 0 ;
3809- const int ne03 = src0 ? src0->ne [3 ] : 0 ;
3833+ GGML_ASSERT (src0->type == src1->type );
3834+ GGML_ASSERT (src0->type == dst->type );
3835+ GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
38103836
3811- const cl_ulong nb00 = src0 ? src0-> nb [0 ] : 0 ;
3812- const cl_ulong nb01 = src0 ? src0-> nb [1 ] : 0 ;
3813- const cl_ulong nb02 = src0 ? src0-> nb [2 ] : 0 ;
3814- const cl_ulong nb03 = src0 ? src0-> nb [3 ] : 0 ;
3837+ const int ne00 = src0-> ne [0 ];
3838+ const int ne01 = src0-> ne [1 ];
3839+ const int ne02 = src0-> ne [2 ];
3840+ const int ne03 = src0-> ne [3 ];
38153841
3816- const int ne10 = src1 ? src1-> ne [0 ] : 0 ;
3817- const int ne11 = src1 ? src1-> ne [1 ] : 0 ;
3818- const int ne12 = src1 ? src1-> ne [2 ] : 0 ;
3819- const int ne13 = src1 ? src1-> ne [3 ] : 0 ; UNUSED (ne13) ;
3842+ const cl_ulong nb00 = src0-> nb [0 ];
3843+ const cl_ulong nb01 = src0-> nb [1 ];
3844+ const cl_ulong nb02 = src0-> nb [2 ];
3845+ const cl_ulong nb03 = src0-> nb [3 ];
38203846
3821- const cl_ulong nb10 = src1 ? src1->nb [0 ] : 0 ;
3822- const cl_ulong nb11 = src1 ? src1->nb [1 ] : 0 ;
3823- const cl_ulong nb12 = src1 ? src1->nb [2 ] : 0 ;
3824- const cl_ulong nb13 = src1 ? src1->nb [3 ] : 0 ; UNUSED (nb13);
3847+ const int ne10 = src1->ne [0 ];
3848+ const int ne11 = src1->ne [1 ];
3849+ const int ne12 = src1->ne [2 ];
3850+ const int ne13 = src1->ne [3 ]; UNUSED (ne13);
3851+
3852+ const cl_ulong nb10 = src1->nb [0 ];
3853+ const cl_ulong nb11 = src1->nb [1 ];
3854+ const cl_ulong nb12 = src1->nb [2 ];
3855+ const cl_ulong nb13 = src1->nb [3 ]; UNUSED (nb13);
38253856
3826- const int ne0 = dst ? dst ->ne [0 ] : 0 ;
3827- const int ne1 = dst ? dst ->ne [1 ] : 0 ;
3828- const int ne2 = dst ? dst ->ne [2 ] : 0 ;
3829- const int ne3 = dst ? dst ->ne [3 ] : 0 ;
3857+ const int ne0 = dst->ne [0 ];
3858+ const int ne1 = dst->ne [1 ];
3859+ const int ne2 = dst->ne [2 ];
3860+ const int ne3 = dst->ne [3 ];
38303861
3831- const cl_ulong nb0 = dst ? dst ->nb [0 ] : 0 ;
3832- const cl_ulong nb1 = dst ? dst ->nb [1 ] : 0 ;
3833- const cl_ulong nb2 = dst ? dst ->nb [2 ] : 0 ;
3834- const cl_ulong nb3 = dst ? dst ->nb [3 ] : 0 ;
3862+ const cl_ulong nb0 = dst->nb [0 ];
3863+ const cl_ulong nb1 = dst->nb [1 ];
3864+ const cl_ulong nb2 = dst->nb [2 ];
3865+ const cl_ulong nb3 = dst->nb [3 ];
38353866
38363867 ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
38373868
@@ -3854,7 +3885,12 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
38543885
38553886 bcast_row = true ;
38563887 int ne = ne00 / 4 ;
3857- kernel = backend_ctx->kernel_mul_row ;
3888+
3889+ if (src0->type == GGML_TYPE_F32) {
3890+ kernel = backend_ctx->kernel_mul_row ;
3891+ } else {
3892+ kernel = backend_ctx->kernel_mul_row_f16 ;
3893+ }
38583894
38593895 CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
38603896 CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
@@ -3864,7 +3900,11 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
38643900 CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
38653901 CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
38663902 } else {
3867- kernel = backend_ctx->kernel_mul ;
3903+ if (src0->type == GGML_TYPE_F32) {
3904+ kernel = backend_ctx->kernel_mul ;
3905+ } else {
3906+ kernel = backend_ctx->kernel_mul_f16 ;
3907+ }
38683908
38693909 CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
38703910 CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
@@ -3926,6 +3966,10 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
39263966 GGML_ASSERT (dst);
39273967 GGML_ASSERT (dst->extra );
39283968
3969+ GGML_ASSERT (src0->type == src1->type );
3970+ GGML_ASSERT (src0->type == dst->type );
3971+ GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3972+
39293973 const int ne00 = src0->ne [0 ];
39303974 const int ne01 = src0->ne [1 ];
39313975 const int ne02 = src0->ne [2 ];
@@ -3974,7 +4018,12 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
39744018
39754019 bcast_row = true ;
39764020 int ne = ne00 / 4 ;
3977- kernel = backend_ctx->kernel_div_row ;
4021+
4022+ if (src0->type == GGML_TYPE_F32) {
4023+ kernel = backend_ctx->kernel_div_row ;
4024+ } else {
4025+ kernel = backend_ctx->kernel_div_row_f16 ;
4026+ }
39784027
39794028 CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
39804029 CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
@@ -3984,7 +4033,11 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
39844033 CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
39854034 CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
39864035 } else {
3987- kernel = backend_ctx->kernel_div ;
4036+ if (src0->type == GGML_TYPE_F32) {
4037+ kernel = backend_ctx->kernel_div ;
4038+ } else {
4039+ kernel = backend_ctx->kernel_div_f16 ;
4040+ }
39884041
39894042 CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
39904043 CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
@@ -4034,6 +4087,10 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
40344087 GGML_ASSERT (dst);
40354088 GGML_ASSERT (dst->extra );
40364089
4090+ GGML_ASSERT (src0->type == src1->type );
4091+ GGML_ASSERT (src0->type == dst->type );
4092+ GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
4093+
40374094 const int ne00 = src0->ne [0 ];
40384095 const int ne01 = src0->ne [1 ];
40394096 const int ne02 = src0->ne [2 ];
@@ -4082,7 +4139,12 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
40824139
40834140 bcast_row = true ;
40844141 int ne = ne00 / 4 ;
4085- kernel = backend_ctx->kernel_sub_row ;
4142+
4143+ if (src0->type == GGML_TYPE_F32) {
4144+ kernel = backend_ctx->kernel_sub_row ;
4145+ } else {
4146+ kernel = backend_ctx->kernel_sub_row_f16 ;
4147+ }
40864148
40874149 CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
40884150 CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
@@ -4092,7 +4154,11 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
40924154 CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
40934155 CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
40944156 } else {
4095- kernel = backend_ctx->kernel_sub ;
4157+ if (src0->type == GGML_TYPE_F32) {
4158+ kernel = backend_ctx->kernel_sub ;
4159+ } else {
4160+ kernel = backend_ctx->kernel_sub_f16 ;
4161+ }
40964162
40974163 CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
40984164 CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
0 commit comments