Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 120 additions & 71 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2480,6 +2480,13 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_OP_SCALE:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD:
if (op->type == GGML_TYPE_F16) {
const bool src0_ok = op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32;
const bool src1_ok = op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32;
if (src0_ok && src1_ok) {
return true;
}
}
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SUB:
Expand Down Expand Up @@ -3718,34 +3725,30 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);

GGML_ASSERT(src0->type == src1->type);
GGML_ASSERT(src0->type == dst->type);
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);

const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];

const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];

const int ne10 = src1->ne[0];
const int ne11 = src1->ne[1];
const int ne12 = src1->ne[2];
const int ne13 = src1->ne[3]; UNUSED(ne13);
const int ne10 = src1->ne[0];
const int ne11 = src1->ne[1];
const int ne12 = src1->ne[2];
const int ne13 = src1->ne[3];

const cl_ulong nb10 = src1->nb[0];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
const cl_ulong nb13 = src1->nb[3];

const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3];

const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
Expand All @@ -3762,68 +3765,114 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;

bool bcast_row = false;
cl_kernel kernel;

if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0));
const bool bcast_row = ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0;

// src1 is a row
if (bcast_row) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ne11 == 1);
}

bcast_row = true;
int ne = ne00 / 4;

if (src0->type == GGML_TYPE_F32) {
if (dst->type == GGML_TYPE_F32) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);
if (bcast_row) {
kernel = backend_ctx->kernel_add_row;
const int ne = ne00 / 4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
} else {
kernel = backend_ctx->kernel_add_row_f16;
}

CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_add;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
}
} else if (dst->type == GGML_TYPE_F16) {
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
const int type_src0 = (src0->type == GGML_TYPE_F32);
const int type_src1 = (src1->type == GGML_TYPE_F32);
if (bcast_row) {
kernel = backend_ctx->kernel_add_row_f16;
const int ne = ne00 / 4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &type_src0));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &type_src1));
} else {
kernel = backend_ctx->kernel_add_f16;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &type_src0));
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(int), &type_src1));
}

CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
} else {
GGML_ASSERT(false && "unsupported data types for add");
}

if (bcast_row) {
Expand All @@ -3833,13 +3882,13 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const

size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
local_work_size_ptr = nullptr;
}

backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size_ptr, dst);
} else {
unsigned int nth = MIN(64, ne0);
size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {nth, 1, 1};

backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
Expand Down
50 changes: 42 additions & 8 deletions ggml/src/ggml-opencl/kernels/add.cl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ kernel void kernel_add_f16(
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
ulong nb3,
int type_src0,
int type_src1
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
Expand All @@ -132,25 +134,57 @@ kernel void kernel_add_f16(

for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const int i10 = i0 % ne10;
*((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) + *((global half *)(src1_ptr + i10*nb10));

half v0, v1;
if (type_src0 == 1) {
v0 = convert_half(*((global float *)(src0_ptr + i0*nb00)));
} else {
v0 = *((global half *)(src0_ptr + i0*nb00));
}

if (type_src1 == 1) {
v1 = convert_half(*((global float *)(src1_ptr + i10*nb10)));
} else {
v1 = *((global half *)(src1_ptr + i10*nb10));
}

*((global half *)(dst_ptr + i0*nb0)) = v0 + v1;
}
}

kernel void kernel_add_row_f16(
global half4 * src0,
global char * src0,
ulong offset0,
global half4 * src1,
global char * src1,
ulong offset1,
global half4 * dst,
ulong offsetd,
int ne
int ne,
int type_src0,
int type_src1
) {
src0 = (global half4*)((global char*)src0 + offset0);
src1 = (global half4*)((global char*)src1 + offset1);
dst = (global half4*)((global char*)dst + offsetd);

// This performs better than using %.
uint gid = get_global_id(0);
uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
dst[gid] = src0[gid] + src1[idx1];

half4 v0, v1;
if (type_src0 == 1) {
global float4* src0_f32 = (global float4*)((global char*)src0 + offset0);
v0 = convert_half4(src0_f32[gid]);
} else {
global half4* src0_f16 = (global half4*)((global char*)src0 + offset0);
v0 = src0_f16[gid];
}

if (type_src1 == 1) {
global float4* src1_f32 = (global float4*)((global char*)src1 + offset1);
v1 = convert_half4(src1_f32[idx1]);
} else {
global half4* src1_f16 = (global half4*)((global char*)src1 + offset1);
v1 = src1_f16[idx1];
}

dst[gid] = v0 + v1;
}
Loading