@@ -2889,10 +2889,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
28892889 case GGML_OP_REPEAT:
28902890 return op->src [0 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
28912891 case GGML_OP_PAD:
2892- return op->src [0 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
2893- op->src [0 ]->ne [3 ] == 1 && op->ne [3 ] == 1 &&
2894- (ggml_get_op_params_i32 (op, 0 ) == 0 ) && (ggml_get_op_params_i32 (op, 2 ) == 0 ) &&
2895- (ggml_get_op_params_i32 (op, 4 ) == 0 ) && (ggml_get_op_params_i32 (op, 6 ) == 0 );
2892+ return op->src [0 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
28962893 case GGML_OP_UPSCALE:
28972894 return op->src [0 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
28982895 case GGML_OP_CONV_2D:
@@ -4222,15 +4219,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
42224219 GGML_ASSERT (dst);
42234220 GGML_ASSERT (dst->extra );
42244221
4225- const int ne00 = src0 ? src0->ne [0 ] : 0 ;
4226- const cl_ulong nb01 = src0 ? src0->nb [1 ] : 0 ;
4227- const cl_ulong nb02 = src0 ? src0->nb [2 ] : 0 ;
4228- const int ne10 = src1 ? src1->ne [0 ] : 0 ;
4229- const cl_ulong nb10 = src1 ? src1->nb [0 ] : 0 ;
4230- const int ne11 = src1 ? src1->ne [1 ] : 0 ;
4231- const cl_ulong nb11 = src1 ? src1->nb [1 ] : 0 ;
4232- const cl_ulong nb1 = dst ? dst->nb [1 ] : 0 ;
4233- const cl_ulong nb2 = dst ? dst->nb [2 ] : 0 ;
4222+ const int ne00 = src0->ne [0 ];
4223+ const cl_ulong nb01 = src0->nb [1 ];
4224+ const cl_ulong nb02 = src0->nb [2 ];
4225+ const cl_ulong nb03 = src0->nb [3 ];
4226+ const int ne10 = src1->ne [0 ];
4227+ const cl_ulong nb10 = src1->nb [0 ];
4228+ const int ne11 = src1->ne [1 ];
4229+ const int ne12 = src1->ne [2 ];
4230+ const cl_ulong nb11 = src1->nb [1 ];
4231+ const cl_ulong nb12 = src1->nb [2 ];
4232+ const cl_ulong nb1 = dst->nb [1 ];
4233+ const cl_ulong nb2 = dst->nb [2 ];
4234+ const cl_ulong nb3 = dst->nb [3 ];
42344235
42354236 ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
42364237
@@ -4267,14 +4268,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
42674268 CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
42684269 CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &nb01));
42694270 CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb02));
4270- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne10));
4271- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb10));
4272- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb11));
4273- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb1));
4274- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb2));
4275-
4276- size_t global_work_size[] = {(size_t )ne10, (size_t )ne11, 1 };
4277- size_t local_work_size[] = {1 , 1 , 1 };
4271+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &nb03));
4272+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne10));
4273+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb10));
4274+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb11));
4275+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb12));
4276+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb1));
4277+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (cl_ulong), &nb2));
4278+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb3));
4279+
4280+ size_t global_work_size[] = {(size_t )ne10*64 , (size_t )ne11, (size_t )ne12};
4281+ size_t local_work_size[] = {64 , 1 , 1 };
42784282
42794283 backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
42804284}
@@ -5874,7 +5878,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
58745878 GGML_ASSERT (dst->extra );
58755879 GGML_ASSERT (src0->type == GGML_TYPE_F32);
58765880 GGML_ASSERT (dst->type == GGML_TYPE_F32);
5877- GGML_ASSERT (src0->ne [3 ] == 1 && dst->ne [3 ] == 1 );
58785881
58795882 ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
58805883
@@ -5892,28 +5895,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
58925895 const int s_ne0 = src0->ne [0 ];
58935896 const int s_ne1 = src0->ne [1 ];
58945897 const int s_ne2 = src0->ne [2 ];
5898+ const int s_ne3 = src0->ne [3 ];
5899+
5900+ const int s_nb0 = src0->nb [0 ];
5901+ const int s_nb1 = src0->nb [1 ];
5902+ const int s_nb2 = src0->nb [2 ];
5903+ const int s_nb3 = src0->nb [3 ];
58955904
58965905 const int d_ne0 = dst->ne [0 ];
58975906 const int d_ne1 = dst->ne [1 ];
58985907 const int d_ne2 = dst->ne [2 ];
5908+ const int d_ne3 = dst->ne [3 ];
5909+
5910+ const int d_nb0 = dst->nb [0 ];
5911+ const int d_nb1 = dst->nb [1 ];
5912+ const int d_nb2 = dst->nb [2 ];
5913+ const int d_nb3 = dst->nb [3 ];
5914+
5915+ const int lp0 = ((const int *)(dst->op_params ))[0 ];
5916+ const int rp0 = ((const int *)(dst->op_params ))[1 ];
5917+ const int lp1 = ((const int *)(dst->op_params ))[2 ];
5918+ const int rp1 = ((const int *)(dst->op_params ))[3 ];
5919+ const int lp2 = ((const int *)(dst->op_params ))[4 ];
5920+ const int rp2 = ((const int *)(dst->op_params ))[5 ];
5921+ const int lp3 = ((const int *)(dst->op_params ))[6 ];
5922+ const int rp3 = ((const int *)(dst->op_params ))[7 ];
58995923
59005924 cl_kernel kernel = backend_ctx->kernel_pad ;
59015925
5902- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra_src0->data_device ));
5903- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &off_src0));
5904- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra_dst->data_device ));
5905- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &off_dst));
5906- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &s_ne0));
5907- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (int ), &s_ne1));
5908- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &s_ne2));
5909- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &d_ne0));
5910- CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &d_ne1));
5911- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &d_ne2));
5926+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra_src0->data_device ));
5927+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &off_src0));
5928+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra_dst->data_device ));
5929+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &off_dst));
5930+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &s_ne0));
5931+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (int ), &s_ne1));
5932+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &s_ne2));
5933+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &s_ne3));
5934+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &s_nb0));
5935+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &s_nb1));
5936+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &s_nb2));
5937+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &s_nb3));
5938+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &d_ne0));
5939+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &d_ne1));
5940+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &d_ne2));
5941+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &d_ne3));
5942+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &d_nb0));
5943+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &d_nb1));
5944+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &d_nb2));
5945+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &d_nb3));
5946+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (int ), &lp0));
5947+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (int ), &rp0));
5948+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &lp1));
5949+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &rp1));
5950+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &lp2));
5951+ CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &rp2));
5952+ CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (int ), &lp3));
5953+ CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (int ), &rp3));
59125954
59135955 size_t lws0 = 64 ;
59145956 size_t gws0 = (( (size_t )d_ne0 + lws0 - 1 ) / lws0) * lws0;
59155957
5916- size_t global_work_size[] = { gws0, (size_t )d_ne1, (size_t )d_ne2 };
5958+ size_t global_work_size[] = { gws0, (size_t )d_ne1, (size_t )d_ne2*d_ne3 };
59175959 size_t local_work_size[] = { lws0, 1 , 1 };
59185960
59195961 size_t * local_work_size_ptr = local_work_size;
0 commit comments