@@ -299,6 +299,8 @@ struct ggml_backend_opencl_context {
299299 cl_program program_mul_mv_f16_f32;
300300 cl_program program_mul_mv_f32_f32;
301301 cl_program program_mul;
302+ cl_program program_div;
303+ cl_program program_sub;
302304 cl_program program_norm;
303305 cl_program program_relu;
304306 cl_program program_rms_norm;
@@ -315,6 +317,7 @@ struct ggml_backend_opencl_context {
315317 cl_kernel kernel_add, kernel_add_row;
316318 cl_kernel kernel_mul, kernel_mul_row;
317319 cl_kernel kernel_div, kernel_div_row;
320+ cl_kernel kernel_sub, kernel_sub_row;
318321 cl_kernel kernel_scale;
319322 cl_kernel kernel_silu, kernel_silu_4;
320323 cl_kernel kernel_gelu, kernel_gelu_4;
@@ -1016,11 +1019,28 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
10161019#else
10171020 const std::string kernel_src = read_file (" div.cl" );
10181021#endif
1019- backend_ctx->program_mul =
1022+ backend_ctx->program_div =
1023+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1024+
1025+ CL_CHECK ((backend_ctx->kernel_div = clCreateKernel (backend_ctx->program_div , " kernel_div" , &err), err));
1026+ CL_CHECK ((backend_ctx->kernel_div_row = clCreateKernel (backend_ctx->program_div , " kernel_div_row" , &err), err));
1027+ GGML_LOG_CONT (" ." );
1028+ }
1029+
1030+ // sub
1031+ {
1032+ #ifdef GGML_OPENCL_EMBED_KERNELS
1033+ const std::string kernel_src {
1034+ #include " sub.cl.h"
1035+ };
1036+ #else
1037+ const std::string kernel_src = read_file (" sub.cl" );
1038+ #endif
1039+ backend_ctx->program_sub =
10201040 build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
10211041
1022- CL_CHECK ((backend_ctx->kernel_div = clCreateKernel (backend_ctx->program_mul , " kernel_div " , &err), err));
1023- CL_CHECK ((backend_ctx->kernel_div_row = clCreateKernel (backend_ctx->program_mul , " kernel_div_row " , &err), err));
1042+ CL_CHECK ((backend_ctx->kernel_sub = clCreateKernel (backend_ctx->program_sub , " kernel_sub " , &err), err));
1043+ CL_CHECK ((backend_ctx->kernel_sub_row = clCreateKernel (backend_ctx->program_sub , " kernel_sub_row " , &err), err));
10241044 GGML_LOG_CONT (" ." );
10251045 }
10261046
@@ -1911,6 +1931,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
19111931 case GGML_OP_SCALE:
19121932 case GGML_OP_MUL:
19131933 case GGML_OP_DIV:
1934+ case GGML_OP_SUB:
19141935 return op->src [0 ]->type == GGML_TYPE_F32;
19151936 case GGML_OP_UNARY:
19161937 switch (ggml_get_unary_op (op)) {
@@ -3422,6 +3443,131 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
34223443 }
34233444}
34243445
3446+ static void ggml_cl_sub (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3447+ GGML_ASSERT (src0);
3448+ GGML_ASSERT (src0->extra );
3449+ GGML_ASSERT (src1);
3450+ GGML_ASSERT (src1->extra );
3451+ GGML_ASSERT (dst);
3452+ GGML_ASSERT (dst->extra );
3453+
3454+ const int ne00 = src0->ne [0 ];
3455+ const int ne01 = src0->ne [1 ];
3456+ const int ne02 = src0->ne [2 ];
3457+ const int ne03 = src0->ne [3 ];
3458+
3459+ const cl_ulong nb00 = src0->nb [0 ];
3460+ const cl_ulong nb01 = src0->nb [1 ];
3461+ const cl_ulong nb02 = src0->nb [2 ];
3462+ const cl_ulong nb03 = src0->nb [3 ];
3463+
3464+ const int ne10 = src1->ne [0 ];
3465+ const int ne11 = src1->ne [1 ];
3466+ const int ne12 = src1->ne [2 ];
3467+ const int ne13 = src1->ne [3 ];
3468+
3469+ const cl_ulong nb10 = src1->nb [0 ];
3470+ const cl_ulong nb11 = src1->nb [1 ];
3471+ const cl_ulong nb12 = src1->nb [2 ];
3472+ const cl_ulong nb13 = src1->nb [3 ];
3473+
3474+ const int ne0 = dst->ne [0 ];
3475+
3476+ const cl_ulong nb0 = dst->nb [0 ];
3477+ const cl_ulong nb1 = dst->nb [1 ];
3478+ const cl_ulong nb2 = dst->nb [2 ];
3479+ const cl_ulong nb3 = dst->nb [3 ];
3480+
3481+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
3482+ cl_command_queue queue = backend_ctx->queue ;
3483+
3484+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
3485+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
3486+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
3487+
3488+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
3489+ cl_ulong offset1 = extra1->offset + src1->view_offs ;
3490+ cl_ulong offsetd = extrad->offset + dst->view_offs ;
3491+
3492+ bool bcast_row = false ;
3493+ cl_kernel kernel;
3494+
3495+ if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
3496+ GGML_ASSERT (ggml_is_contiguous (src0));
3497+
3498+ // src1 is a row
3499+ GGML_ASSERT (ne11 == 1 );
3500+
3501+ bcast_row = true ;
3502+ int ne = ne00 / 4 ;
3503+ kernel = backend_ctx->kernel_sub_row ;
3504+
3505+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3506+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3507+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3508+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3509+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3510+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3511+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne));
3512+ } else {
3513+ kernel = backend_ctx->kernel_sub ;
3514+
3515+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
3516+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
3517+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
3518+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
3519+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
3520+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
3521+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (cl_ulong), &nb00));
3522+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &nb01));
3523+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb02));
3524+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb03));
3525+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne10));
3526+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne11));
3527+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &ne12));
3528+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne13));
3529+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb10));
3530+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb11));
3531+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb12));
3532+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb13));
3533+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (int ), &ne0));
3534+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb0));
3535+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb1));
3536+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb2));
3537+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (cl_ulong), &nb3));
3538+ }
3539+
3540+ if (bcast_row) {
3541+ int n = ggml_nelements (dst)/4 ;
3542+ size_t global_work_size[] = {(size_t )n, 1 , 1 };
3543+ size_t local_work_size[] = {64 , 1 , 1 };
3544+
3545+ #ifdef GGML_OPENCL_PROFILING
3546+ cl_event evt;
3547+ CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , &evt));
3548+
3549+ g_profiling_info.emplace_back ();
3550+ populateProfilingInfo (g_profiling_info.back (), evt, kernel, global_work_size, local_work_size, dst);
3551+ #else
3552+ CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , NULL ));
3553+ #endif
3554+ } else {
3555+ unsigned int nth = MIN (64 , ne0);
3556+ size_t global_work_size[] = {ne01*nth, (size_t )ne02, (size_t )ne03};
3557+ size_t local_work_size[] = {nth, 1 , 1 };
3558+
3559+ #ifdef GGML_OPENCL_PROFILING
3560+ cl_event evt;
3561+ CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , &evt));
3562+
3563+ g_profiling_info.emplace_back ();
3564+ populateProfilingInfo (g_profiling_info.back (), evt, kernel, global_work_size, local_work_size, dst);
3565+ #else
3566+ CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , NULL ));
3567+ #endif
3568+ }
3569+ }
3570+
34253571static void ggml_cl_gelu (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
34263572 GGML_ASSERT (src0);
34273573 GGML_ASSERT (src0->extra );
@@ -5331,6 +5477,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
53315477 }
53325478 func = ggml_cl_div;
53335479 break ;
5480+ case GGML_OP_SUB:
5481+ if (!any_on_device) {
5482+ return false ;
5483+ }
5484+ func = ggml_cl_sub;
5485+ break ;
53345486 case GGML_OP_UNARY:
53355487 switch (ggml_get_unary_op (tensor)) {
53365488 case GGML_UNARY_OP_GELU:
0 commit comments