Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 51 additions & 16 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2700,10 +2700,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_OP_REPEAT:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
case GGML_OP_PAD:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
op->src[0]->ne[3] == 1 && op->ne[3] == 1 &&
(ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
Expand Down Expand Up @@ -5423,7 +5420,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
GGML_ASSERT(dst->extra);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1);

ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

Expand All @@ -5441,28 +5437,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
const int s_ne0 = src0->ne[0];
const int s_ne1 = src0->ne[1];
const int s_ne2 = src0->ne[2];
const int s_ne3 = src0->ne[3];

const int s_nb0 = src0->nb[0];
const int s_nb1 = src0->nb[1];
const int s_nb2 = src0->nb[2];
const int s_nb3 = src0->nb[3];

const int d_ne0 = dst->ne[0];
const int d_ne1 = dst->ne[1];
const int d_ne2 = dst->ne[2];
const int d_ne3 = dst->ne[3];

const int d_nb0 = dst->nb[0];
const int d_nb1 = dst->nb[1];
const int d_nb2 = dst->nb[2];
const int d_nb3 = dst->nb[3];

const int lp0 = ((const int*)(dst->op_params))[0];
const int rp0 = ((const int*)(dst->op_params))[1];
const int lp1 = ((const int*)(dst->op_params))[2];
const int rp1 = ((const int*)(dst->op_params))[3];
const int lp2 = ((const int*)(dst->op_params))[4];
const int rp2 = ((const int*)(dst->op_params))[5];
const int lp3 = ((const int*)(dst->op_params))[6];
const int rp3 = ((const int*)(dst->op_params))[7];

cl_kernel kernel = backend_ctx->kernel_pad;

CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2));
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3));

size_t lws0 = 64;
size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0;

size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 };
size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 };
size_t local_work_size[] = { lws0, 1, 1 };

size_t * local_work_size_ptr = local_work_size;
Expand Down
49 changes: 29 additions & 20 deletions ggml/src/ggml-opencl/kernels/pad.cl
Original file line number Diff line number Diff line change
@@ -1,30 +1,39 @@
kernel void kernel_pad(
global const void * src0_ptr,
ulong src0_offset,
global void * dst_ptr,
ulong dst_offset,
int s_ne0, int s_ne1, int s_ne2,
int d_ne0, int d_ne1, int d_ne2
global void * src0,
ulong offset0,
global void * dst,
ulong offsetd,
int ne00, int ne01, int ne02, int ne03,
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
int ne0, int ne1, int ne2, int ne3,
ulong nb0, ulong nb1, ulong nb2, ulong nb3,
int lp0, int rp0,
int lp1, int rp1,
int lp2, int rp2,
int lp3, int rp3
) {
global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset);
global float * dst = (global float *)((global char *)dst_ptr + dst_offset);
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);

int nidx = get_global_id(0);
int idx_d1 = get_group_id(1);
int idx_d2 = get_group_id(2);
int i0 = get_global_id(0);
int i1 = get_group_id(1);
int i2 = get_group_id(2) % ne2;
int i3 = get_group_id(2) / ne2;

if (nidx >= d_ne0) {
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}

int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1;
uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
uint dst_idx = i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;

bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2);
global float * src0_ptr = (global float *)((global char *)src0 + src0_idx);
global float * dst_ptr = (global float *)((global char *)dst + dst_idx);

if (in_src_bounds) {
int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1;
dst[dst_el_offset] = src0[src_el_offset];
} else {
dst[dst_el_offset] = 0.0f;
}
bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) &&
(i1 >= lp1 && i1 < ne1 - rp1) &&
(i2 >= lp2 && i2 < ne2 - rp2) &&
(i3 >= lp3 && i3 < ne3 - rp3);

*dst_ptr = in_src_bounds ? *src0_ptr : 0.0f;
}
Loading