Skip to content

Commit a798eff

Browse files
authored
add GEGLU_ERF and GEGLU_QUICK for opencl
1 parent 95e15ae commit a798eff

File tree

2 files changed

+162
-8
lines changed

2 files changed

+162
-8
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,8 @@ struct ggml_backend_opencl_context {
402402
cl_kernel kernel_relu;
403403
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
404404
cl_kernel kernel_clamp;
405-
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu,
406-
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16;
405+
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
406+
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
407407
cl_kernel kernel_norm;
408408
cl_kernel kernel_rms_norm;
409409
cl_kernel kernel_group_norm;
@@ -753,12 +753,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
753753
backend_ctx->program_glu =
754754
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
755755

756-
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
757-
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
758-
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
759-
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
760-
CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
761-
CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
756+
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
757+
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
758+
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
759+
CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
760+
CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
761+
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
762+
CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
763+
CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
764+
CL_CHECK((backend_ctx->kernel_geglu_erf_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf_f16", &err), err));
765+
CL_CHECK((backend_ctx->kernel_geglu_quick_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick_f16", &err), err));
762766
GGML_LOG_CONT(".");
763767
}
764768

@@ -6217,6 +6221,20 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
62176221
kernel = backend_ctx->kernel_swiglu_f16;
62186222
}
62196223
break;
6224+
case GGML_GLU_OP_GEGLU_ERF:
6225+
if (dst->type == GGML_TYPE_F32) {
6226+
kernel = backend_ctx->kernel_geglu_erf;
6227+
} else {
6228+
kernel = backend_ctx->kernel_geglu_erf_f16;
6229+
}
6230+
break;
6231+
case GGML_GLU_OP_GEGLU_QUICK:
6232+
if (dst->type == GGML_TYPE_F32) {
6233+
kernel = backend_ctx->kernel_geglu_quick;
6234+
} else {
6235+
kernel = backend_ctx->kernel_geglu_quick_f16;
6236+
}
6237+
break;
62206238
default:
62216239
GGML_ABORT("Unsupported glu op");
62226240
}

ggml/src/ggml-opencl/kernels/glu.cl

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
22

33
#define GELU_COEF_A 0.044715f
4+
#define GELU_QUICK_COEF -1.702f
45
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
6+
#define SQRT_2_INV 0.70710678118654752440084436210484f
57

68
//------------------------------------------------------------------------------
79
// geglu
@@ -199,3 +201,137 @@ kernel void kernel_swiglu_f16(
199201
dst_row[i0] = silu*x1;
200202
}
201203
}
204+
205+
//------------------------------------------------------------------------------
206+
// geglu_erf
207+
//------------------------------------------------------------------------------
208+
kernel void kernel_geglu_erf(
209+
global char * src0,
210+
ulong offset0,
211+
global char * src1,
212+
ulong offset1,
213+
global char * dst,
214+
ulong offsetd,
215+
ulong nb01,
216+
ulong nb11,
217+
int ne0,
218+
ulong nb1,
219+
int ne00_off,
220+
int ne10_off
221+
) {
222+
src0 = (global char*)((global char*)src0 + offset0);
223+
src1 = (global char*)((global char*)src1 + offset1);
224+
dst = (global char*)((global char*)dst + offsetd);
225+
226+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
227+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
228+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
229+
230+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
231+
const float x0 = src0_row[i0];
232+
const float x1 = src1_row[i0];
233+
234+
const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
235+
236+
dst_row[i0] = gelu_erf*x1;
237+
}
238+
}
239+
240+
kernel void kernel_geglu_erf_f16(
241+
global char * src0,
242+
ulong offset0,
243+
global char * src1,
244+
ulong offset1,
245+
global char * dst,
246+
ulong offsetd,
247+
ulong nb01,
248+
ulong nb11,
249+
int ne0,
250+
ulong nb1,
251+
int ne00_off,
252+
int ne10_off
253+
) {
254+
src0 = (global char*)((global char*)src0 + offset0);
255+
src1 = (global char*)((global char*)src1 + offset1);
256+
dst = (global char*)((global char*)dst + offsetd);
257+
258+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
259+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
260+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
261+
262+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
263+
const half x0 = src0_row[i0];
264+
const half x1 = src1_row[i0];
265+
266+
const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
267+
268+
dst_row[i0] = gelu_erf*x1;
269+
}
270+
}
271+
272+
//------------------------------------------------------------------------------
273+
// geglu_quick
274+
//------------------------------------------------------------------------------
275+
kernel void kernel_geglu_quick(
276+
global char * src0,
277+
ulong offset0,
278+
global char * src1,
279+
ulong offset1,
280+
global char * dst,
281+
ulong offsetd,
282+
ulong nb01,
283+
ulong nb11,
284+
int ne0,
285+
ulong nb1,
286+
int ne00_off,
287+
int ne10_off
288+
) {
289+
src0 = (global char*)((global char*)src0 + offset0);
290+
src1 = (global char*)((global char*)src1 + offset1);
291+
dst = (global char*)((global char*)dst + offsetd);
292+
293+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
294+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
295+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
296+
297+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
298+
const float x0 = src0_row[i0];
299+
const float x1 = src1_row[i0];
300+
301+
const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
302+
303+
dst_row[i0] = gelu_quick*x1;
304+
}
305+
}
306+
307+
kernel void kernel_geglu_quick_f16(
308+
global char * src0,
309+
ulong offset0,
310+
global char * src1,
311+
ulong offset1,
312+
global char * dst,
313+
ulong offsetd,
314+
ulong nb01,
315+
ulong nb11,
316+
int ne0,
317+
ulong nb1,
318+
int ne00_off,
319+
int ne10_off
320+
) {
321+
src0 = (global char*)((global char*)src0 + offset0);
322+
src1 = (global char*)((global char*)src1 + offset1);
323+
dst = (global char*)((global char*)dst + offsetd);
324+
325+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
326+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
327+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
328+
329+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
330+
const half x0 = src0_row[i0];
331+
const half x1 = src1_row[i0];
332+
333+
const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
334+
335+
dst_row[i0] = gelu_quick*x1;
336+
}
337+
}

0 commit comments

Comments
 (0)