From 6e97cb3d5faa3f0dd3396f737292033391bf0d29 Mon Sep 17 00:00:00 2001 From: Li He Date: Mon, 8 Sep 2025 19:36:38 -0700 Subject: [PATCH] opencl: support pad_ext --- ggml/src/ggml-opencl/ggml-opencl.cpp | 67 +++++++++++++++++++++------- ggml/src/ggml-opencl/kernels/pad.cl | 49 +++++++++++--------- 2 files changed, 80 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 727163b7fdf95..77bfc4c49f58d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2700,10 +2700,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_REPEAT: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded case GGML_OP_PAD: - return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && - op->src[0]->ne[3] == 1 && op->ne[3] == 1 && - (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && - (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: @@ -5423,7 +5420,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t GGML_ASSERT(dst->extra); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -5441,28 +5437,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t const int s_ne0 = src0->ne[0]; const int s_ne1 = src0->ne[1]; const int s_ne2 = src0->ne[2]; + const int s_ne3 = src0->ne[3]; + + const int s_nb0 = src0->nb[0]; + const int s_nb1 = src0->nb[1]; + const int s_nb2 = src0->nb[2]; + const int s_nb3 = src0->nb[3]; const int d_ne0 = dst->ne[0]; const int d_ne1 = dst->ne[1]; const int d_ne2 = dst->ne[2]; + const int d_ne3 = dst->ne[3]; + + const int d_nb0 = dst->nb[0]; + const int d_nb1 = dst->nb[1]; + const int d_nb2 = dst->nb[2]; + const int d_nb3 = dst->nb[3]; + + const int lp0 = ((const int*)(dst->op_params))[0]; + const int rp0 = ((const int*)(dst->op_params))[1]; + const int lp1 = ((const int*)(dst->op_params))[2]; + const int rp1 = ((const int*)(dst->op_params))[3]; + const int lp2 = ((const int*)(dst->op_params))[4]; + const int rp2 = ((const int*)(dst->op_params))[5]; + const int lp3 = ((const int*)(dst->op_params))[6]; + const int rp3 = ((const int*)(dst->op_params))[7]; cl_kernel kernel = backend_ctx->kernel_pad; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3)); size_t lws0 = 64; size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0; - size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 }; + size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 }; size_t local_work_size[] = { lws0, 1, 1 }; size_t * local_work_size_ptr = local_work_size; diff --git a/ggml/src/ggml-opencl/kernels/pad.cl b/ggml/src/ggml-opencl/kernels/pad.cl index 747fa7febcc74..31fb7ccd3b081 100644 --- a/ggml/src/ggml-opencl/kernels/pad.cl +++ b/ggml/src/ggml-opencl/kernels/pad.cl @@ -1,30 +1,39 @@ kernel void kernel_pad( - global const void * src0_ptr, - ulong src0_offset, - global void * dst_ptr, - ulong dst_offset, - int s_ne0, int s_ne1, int s_ne2, - int d_ne0, int d_ne1, int d_ne2 + global void * src0, + ulong offset0, + global void * dst, + ulong offsetd, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne0, int ne1, int ne2, int ne3, + ulong nb0, ulong nb1, ulong nb2, ulong nb3, + int lp0, int rp0, + int lp1, int rp1, + int lp2, int rp2, + int lp3, int rp3 ) { - global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset); - global float * dst = (global float *)((global char *)dst_ptr + dst_offset); + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); - int nidx = get_global_id(0); - int idx_d1 = get_group_id(1); - int idx_d2 = get_group_id(2); + int i0 = get_global_id(0); + int i1 = get_group_id(1); + int i2 = get_group_id(2) % ne2; + int i3 = get_group_id(2) / ne2; - if (nidx >= d_ne0) { + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { return; } - int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1; + uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00; + uint dst_idx = i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; - bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2); + global float * src0_ptr = (global float *)((global char *)src0 + src0_idx); + global float * dst_ptr = (global float *)((global char *)dst + dst_idx); - if (in_src_bounds) { - int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1; - dst[dst_el_offset] = src0[src_el_offset]; - } else { - dst[dst_el_offset] = 0.0f; - } + bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) && + (i1 >= lp1 && i1 < ne1 - rp1) && + (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3); + + *dst_ptr = in_src_bounds ? *src0_ptr : 0.0f; }