Skip to content

Commit 7fbd3b4

Browse files
committed
opencl: add swiglu-oai
1 parent fd1234c commit 7fbd3b4

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ struct ggml_backend_opencl_context {
412412
cl_kernel kernel_relu;
413413
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
414414
cl_kernel kernel_clamp;
415-
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
415+
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
416416
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
417417
cl_kernel kernel_norm;
418418
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
@@ -787,6 +787,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
787787
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
788788
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
789789
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
790+
CL_CHECK((backend_ctx->kernel_swiglu_oai = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_oai", &err), err));
790791
CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
791792
CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
792793
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
@@ -2488,6 +2489,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
24882489
case GGML_GLU_OP_GEGLU:
24892490
case GGML_GLU_OP_REGLU:
24902491
case GGML_GLU_OP_SWIGLU:
2492+
case GGML_GLU_OP_SWIGLU_OAI:
24912493
case GGML_GLU_OP_GEGLU_ERF:
24922494
case GGML_GLU_OP_GEGLU_QUICK:
24932495
return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
@@ -7005,6 +7007,9 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
70057007
kernel = backend_ctx->kernel_swiglu_f16;
70067008
}
70077009
break;
7010+
case GGML_GLU_OP_SWIGLU_OAI:
7011+
kernel = backend_ctx->kernel_swiglu_oai;
7012+
break;
70087013
case GGML_GLU_OP_GEGLU_ERF:
70097014
if (dst->type == GGML_TYPE_F32) {
70107015
kernel = backend_ctx->kernel_geglu_erf;
@@ -7040,7 +7045,10 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
70407045

70417046
const cl_ulong nb1 = dst->nb[1];
70427047

7043-
const int swp = ((const int32_t *) dst->op_params)[1];
7048+
const int swp = ggml_get_op_params_i32(dst, 1);
7049+
const float alpha = ggml_get_op_params_f32(dst, 2);
7050+
const float limit = ggml_get_op_params_f32(dst, 3);
7051+
70447052
const int ne00_off = src1 ? 0 : (swp ? ne0 : 0);
70457053
const int ne10_off = src1 ? 0 : (swp ? 0 : ne0);
70467054

@@ -7057,6 +7065,11 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
70577065
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne00_off));
70587066
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10_off));
70597067

7068+
if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU_OAI) {
7069+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &limit));
7070+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &alpha));
7071+
}
7072+
70607073
const size_t nrows = ggml_nrows(src0);
70617074
size_t nth = 512;
70627075
size_t global_work_size[] = {nrows*nth, 1, 1};

ggml/src/ggml-opencl/kernels/glu.cl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,47 @@ kernel void kernel_swiglu_f16(
202202
}
203203
}
204204

205+
//------------------------------------------------------------------------------
206+
// swiglu_oai
207+
//------------------------------------------------------------------------------
208+
kernel void kernel_swiglu_oai(
209+
global char * src0,
210+
ulong offset0,
211+
global char * src1,
212+
ulong offset1,
213+
global char * dst,
214+
ulong offsetd,
215+
ulong nb01,
216+
ulong nb11,
217+
int ne0,
218+
ulong nb1,
219+
int ne00_off,
220+
int ne10_off,
221+
float limit,
222+
float alpha
223+
) {
224+
src0 = (global char*)((global char*)src0 + offset0);
225+
src1 = (global char*)((global char*)src1 + offset1);
226+
dst = (global char*)((global char*)dst + offsetd);
227+
228+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
229+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
230+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
231+
232+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
233+
float x0 = src0_row[i0];
234+
float x1 = src1_row[i0];
235+
236+
x0 = min(x0, limit);
237+
x1 = max(min(x1, limit), -limit);
238+
239+
float out_glu = x0 / (1.0f + exp(-x0 * alpha));
240+
out_glu = out_glu * (1.0f + x1);
241+
242+
dst_row[i0] = out_glu;
243+
}
244+
}
245+
205246
//------------------------------------------------------------------------------
206247
// geglu_erf
207248
//------------------------------------------------------------------------------

0 commit comments

Comments
 (0)