Skip to content

Commit 02721cc

Browse files
committed
opencl: add add_id
1 parent 7fbd3b4 commit 02721cc

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ endfunction()
5555

5656
set(GGML_OPENCL_KERNELS
5757
add
58+
add_id
5859
argsort
5960
clamp
6061
cpy

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ struct ggml_backend_opencl_context {
345345
cl_command_queue queue;
346346

347347
cl_program program_add;
348+
cl_program program_add_id;
348349
cl_program program_clamp;
349350
cl_program program_cpy;
350351
cl_program program_cvt;
@@ -404,6 +405,7 @@ struct ggml_backend_opencl_context {
404405
cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
405406
cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16;
406407
cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
408+
cl_kernel kernel_add_id;
407409
cl_kernel kernel_scale;
408410
cl_kernel kernel_silu, kernel_silu_4;
409411
cl_kernel kernel_gelu, kernel_gelu_4;
@@ -681,6 +683,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
681683
GGML_LOG_CONT(".");
682684
}
683685

686+
// add_id
687+
{
688+
#ifdef GGML_OPENCL_EMBED_KERNELS
689+
const std::string kernel_src {
690+
#include "add_id.cl.h"
691+
};
692+
#else
693+
const std::string kernel_src = read_file("add_id.cl");
694+
#endif
695+
backend_ctx->program_add_id =
696+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
697+
698+
CL_CHECK((backend_ctx->kernel_add_id = clCreateKernel(backend_ctx->program_add_id, "kernel_add_id", &err), err));
699+
GGML_LOG_CONT(".");
700+
}
701+
684702
// clamp
685703
{
686704
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2468,6 +2486,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
24682486
return (op->src[0]->type == op->src[1]->type) &&
24692487
(op->src[0]->type == op->type) &&
24702488
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
2489+
case GGML_OP_ADD_ID:
2490+
return op->src[0]->type == GGML_TYPE_F32;
24712491
case GGML_OP_UNARY:
24722492
switch (ggml_get_unary_op(op)) {
24732493
case GGML_UNARY_OP_GELU:
@@ -3826,6 +3846,75 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
38263846
}
38273847
}
38283848

3849+
static void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3850+
GGML_ASSERT(src0);
3851+
GGML_ASSERT(src0->extra);
3852+
GGML_ASSERT(src1);
3853+
GGML_ASSERT(src1->extra);
3854+
GGML_ASSERT(dst);
3855+
GGML_ASSERT(dst->extra);
3856+
3857+
const ggml_tensor * src2 = dst->src[2];
3858+
GGML_ASSERT(src2);
3859+
GGML_ASSERT(src2->extra);
3860+
3861+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
3862+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
3863+
GGML_ASSERT(src2->type == GGML_TYPE_I32);
3864+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
3865+
3866+
GGML_ASSERT(ggml_is_contiguous_rows(src0));
3867+
3868+
const int ne00 = src0->ne[0];
3869+
const int ne01 = src0->ne[1];
3870+
const int ne02 = src0->ne[2];
3871+
3872+
const cl_ulong nb01 = src0->nb[1];
3873+
const cl_ulong nb02 = src0->nb[2];
3874+
3875+
const cl_ulong nb11 = src1->nb[1];
3876+
3877+
const cl_ulong nb21 = src2->nb[1];
3878+
3879+
const int ne0 = dst->ne[0];
3880+
const int ne1 = dst->ne[1];
3881+
3882+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
3883+
3884+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
3885+
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
3886+
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
3887+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
3888+
3889+
cl_ulong offset0 = extra0->offset + src0->view_offs;
3890+
cl_ulong offset1 = extra1->offset + src1->view_offs;
3891+
cl_ulong offset2 = extra2->offset + src2->view_offs;
3892+
cl_ulong offsetd = extrad->offset + dst->view_offs;
3893+
3894+
cl_kernel kernel = backend_ctx->kernel_add_id;
3895+
3896+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3897+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3898+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3899+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3900+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
3901+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
3902+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
3903+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
3904+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
3905+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
3906+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
3907+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb21));
3908+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
3909+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
3910+
3911+
int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel));
3912+
size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 };
3913+
size_t local_work_size[] = { (size_t)nth, 1, 1 };
3914+
3915+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
3916+
}
3917+
38293918
static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
38303919
GGML_ASSERT(src0);
38313920
GGML_ASSERT(src0->extra);
@@ -7126,6 +7215,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
71267215
}
71277216
func = ggml_cl_add;
71287217
break;
7218+
case GGML_OP_ADD_ID:
7219+
if (!any_on_device) {
7220+
return false;
7221+
}
7222+
func = ggml_cl_add_id;
7223+
break;
71297224
case GGML_OP_MUL:
71307225
if (!any_on_device) {
71317226
return false;

0 commit comments

Comments
 (0)