Skip to content

Commit e725a1a

Browse files
authored
opencl: add swiglu_oai and add_id (#15121)
* opencl: add `swiglu-oai` * opencl: add `add_id` * opencl: add missing `add_id.cl`
1 parent 3db4da5 commit e725a1a

File tree

4 files changed

+194
-2
lines changed

4 files changed

+194
-2
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ endfunction()
5555

5656
set(GGML_OPENCL_KERNELS
5757
add
58+
add_id
5859
argsort
5960
clamp
6061
cpy

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ struct ggml_backend_opencl_context {
345345
cl_command_queue queue;
346346

347347
cl_program program_add;
348+
cl_program program_add_id;
348349
cl_program program_clamp;
349350
cl_program program_cpy;
350351
cl_program program_cvt;
@@ -404,6 +405,7 @@ struct ggml_backend_opencl_context {
404405
cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
405406
cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16;
406407
cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
408+
cl_kernel kernel_add_id;
407409
cl_kernel kernel_scale;
408410
cl_kernel kernel_silu, kernel_silu_4;
409411
cl_kernel kernel_gelu, kernel_gelu_4;
@@ -412,7 +414,7 @@ struct ggml_backend_opencl_context {
412414
cl_kernel kernel_relu;
413415
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
414416
cl_kernel kernel_clamp;
415-
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
417+
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
416418
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
417419
cl_kernel kernel_norm;
418420
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
@@ -681,6 +683,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
681683
GGML_LOG_CONT(".");
682684
}
683685

686+
// add_id
687+
{
688+
#ifdef GGML_OPENCL_EMBED_KERNELS
689+
const std::string kernel_src {
690+
#include "add_id.cl.h"
691+
};
692+
#else
693+
const std::string kernel_src = read_file("add_id.cl");
694+
#endif
695+
backend_ctx->program_add_id =
696+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
697+
698+
CL_CHECK((backend_ctx->kernel_add_id = clCreateKernel(backend_ctx->program_add_id, "kernel_add_id", &err), err));
699+
GGML_LOG_CONT(".");
700+
}
701+
684702
// clamp
685703
{
686704
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -787,6 +805,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
787805
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
788806
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
789807
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
808+
CL_CHECK((backend_ctx->kernel_swiglu_oai = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_oai", &err), err));
790809
CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
791810
CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
792811
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
@@ -2467,6 +2486,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
24672486
return (op->src[0]->type == op->src[1]->type) &&
24682487
(op->src[0]->type == op->type) &&
24692488
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
2489+
case GGML_OP_ADD_ID:
2490+
return op->src[0]->type == GGML_TYPE_F32;
24702491
case GGML_OP_UNARY:
24712492
switch (ggml_get_unary_op(op)) {
24722493
case GGML_UNARY_OP_GELU:
@@ -2488,6 +2509,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
24882509
case GGML_GLU_OP_GEGLU:
24892510
case GGML_GLU_OP_REGLU:
24902511
case GGML_GLU_OP_SWIGLU:
2512+
case GGML_GLU_OP_SWIGLU_OAI:
24912513
case GGML_GLU_OP_GEGLU_ERF:
24922514
case GGML_GLU_OP_GEGLU_QUICK:
24932515
return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
@@ -3824,6 +3846,75 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
38243846
}
38253847
}
38263848

3849+
static void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3850+
GGML_ASSERT(src0);
3851+
GGML_ASSERT(src0->extra);
3852+
GGML_ASSERT(src1);
3853+
GGML_ASSERT(src1->extra);
3854+
GGML_ASSERT(dst);
3855+
GGML_ASSERT(dst->extra);
3856+
3857+
const ggml_tensor * src2 = dst->src[2];
3858+
GGML_ASSERT(src2);
3859+
GGML_ASSERT(src2->extra);
3860+
3861+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
3862+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
3863+
GGML_ASSERT(src2->type == GGML_TYPE_I32);
3864+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
3865+
3866+
GGML_ASSERT(ggml_is_contiguous_rows(src0));
3867+
3868+
const int ne00 = src0->ne[0];
3869+
const int ne01 = src0->ne[1];
3870+
const int ne02 = src0->ne[2];
3871+
3872+
const cl_ulong nb01 = src0->nb[1];
3873+
const cl_ulong nb02 = src0->nb[2];
3874+
3875+
const cl_ulong nb11 = src1->nb[1];
3876+
3877+
const cl_ulong nb21 = src2->nb[1];
3878+
3879+
const int ne0 = dst->ne[0];
3880+
const int ne1 = dst->ne[1];
3881+
3882+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
3883+
3884+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
3885+
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
3886+
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
3887+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
3888+
3889+
cl_ulong offset0 = extra0->offset + src0->view_offs;
3890+
cl_ulong offset1 = extra1->offset + src1->view_offs;
3891+
cl_ulong offset2 = extra2->offset + src2->view_offs;
3892+
cl_ulong offsetd = extrad->offset + dst->view_offs;
3893+
3894+
cl_kernel kernel = backend_ctx->kernel_add_id;
3895+
3896+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3897+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3898+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3899+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3900+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
3901+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
3902+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
3903+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
3904+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
3905+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
3906+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
3907+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb21));
3908+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
3909+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
3910+
3911+
int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel));
3912+
size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 };
3913+
size_t local_work_size[] = { (size_t)nth, 1, 1 };
3914+
3915+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
3916+
}
3917+
38273918
static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
38283919
GGML_ASSERT(src0);
38293920
GGML_ASSERT(src0->extra);
@@ -7005,6 +7096,9 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
70057096
kernel = backend_ctx->kernel_swiglu_f16;
70067097
}
70077098
break;
7099+
case GGML_GLU_OP_SWIGLU_OAI:
7100+
kernel = backend_ctx->kernel_swiglu_oai;
7101+
break;
70087102
case GGML_GLU_OP_GEGLU_ERF:
70097103
if (dst->type == GGML_TYPE_F32) {
70107104
kernel = backend_ctx->kernel_geglu_erf;
@@ -7040,7 +7134,10 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
70407134

70417135
const cl_ulong nb1 = dst->nb[1];
70427136

7043-
const int swp = ((const int32_t *) dst->op_params)[1];
7137+
const int swp = ggml_get_op_params_i32(dst, 1);
7138+
const float alpha = ggml_get_op_params_f32(dst, 2);
7139+
const float limit = ggml_get_op_params_f32(dst, 3);
7140+
70447141
const int ne00_off = src1 ? 0 : (swp ? ne0 : 0);
70457142
const int ne10_off = src1 ? 0 : (swp ? 0 : ne0);
70467143

@@ -7057,6 +7154,11 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
70577154
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne00_off));
70587155
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10_off));
70597156

7157+
if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU_OAI) {
7158+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &limit));
7159+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &alpha));
7160+
}
7161+
70607162
const size_t nrows = ggml_nrows(src0);
70617163
size_t nth = 512;
70627164
size_t global_work_size[] = {nrows*nth, 1, 1};
@@ -7113,6 +7215,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
71137215
}
71147216
func = ggml_cl_add;
71157217
break;
7218+
case GGML_OP_ADD_ID:
7219+
if (!any_on_device) {
7220+
return false;
7221+
}
7222+
func = ggml_cl_add_id;
7223+
break;
71167224
case GGML_OP_MUL:
71177225
if (!any_on_device) {
71187226
return false;
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
//------------------------------------------------------------------------------
4+
// add_id
5+
//------------------------------------------------------------------------------
6+
kernel void kernel_add_id(
7+
global char * src0,
8+
ulong offset0,
9+
global char * src1,
10+
ulong offset1,
11+
global char * src2,
12+
ulong offset2,
13+
global char * dst,
14+
ulong offsetd,
15+
ulong nb01,
16+
ulong nb02,
17+
ulong nb11,
18+
ulong nb21,
19+
int ne0,
20+
int ne1
21+
) {
22+
src0 = (global char*)((global char*)src0 + offset0);
23+
src1 = (global char*)((global char*)src1 + offset1);
24+
src2 = (global char*)((global char*)src2 + offset2);
25+
dst = (global char*)((global char*)dst + offsetd);
26+
27+
int i1 = get_group_id(0);
28+
int i2 = get_group_id(1);
29+
30+
const int i11 = *((global const int *) (src2 + i1*sizeof(int) + i2*nb21));
31+
32+
const size_t nb1 = ne0 * sizeof(float);
33+
const size_t nb2 = ne1 * nb1;
34+
35+
global float * dst_row = (global float *)((global char *)dst + i1*nb1 + i2*nb2);
36+
global float * src0_row = (global float *)((global char *)src0 + i1*nb01 + i2*nb02);
37+
global float * src1_row = (global float *)((global char *)src1 + i11*nb11);
38+
39+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
40+
dst_row[i0] = src0_row[i0] + src1_row[i0];
41+
}
42+
}

ggml/src/ggml-opencl/kernels/glu.cl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,47 @@ kernel void kernel_swiglu_f16(
202202
}
203203
}
204204

205+
//------------------------------------------------------------------------------
206+
// swiglu_oai
207+
//------------------------------------------------------------------------------
208+
kernel void kernel_swiglu_oai(
209+
global char * src0,
210+
ulong offset0,
211+
global char * src1,
212+
ulong offset1,
213+
global char * dst,
214+
ulong offsetd,
215+
ulong nb01,
216+
ulong nb11,
217+
int ne0,
218+
ulong nb1,
219+
int ne00_off,
220+
int ne10_off,
221+
float limit,
222+
float alpha
223+
) {
224+
src0 = (global char*)((global char*)src0 + offset0);
225+
src1 = (global char*)((global char*)src1 + offset1);
226+
dst = (global char*)((global char*)dst + offsetd);
227+
228+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
229+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
230+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
231+
232+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
233+
float x0 = src0_row[i0];
234+
float x1 = src1_row[i0];
235+
236+
x0 = min(x0, limit);
237+
x1 = max(min(x1, limit), -limit);
238+
239+
float out_glu = x0 / (1.0f + exp(-x0 * alpha));
240+
out_glu = out_glu * (1.0f + x1);
241+
242+
dst_row[i0] = out_glu;
243+
}
244+
}
245+
205246
//------------------------------------------------------------------------------
206247
// geglu_erf
207248
//------------------------------------------------------------------------------

0 commit comments

Comments
 (0)