@@ -412,7 +412,7 @@ struct ggml_backend_opencl_context {
412
412
cl_kernel kernel_relu;
413
413
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
414
414
cl_kernel kernel_clamp;
415
- cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
415
+ cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
416
416
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
417
417
cl_kernel kernel_norm;
418
418
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
@@ -787,6 +787,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
787
787
CL_CHECK ((backend_ctx->kernel_geglu = clCreateKernel (backend_ctx->program_glu , " kernel_geglu" , &err), err));
788
788
CL_CHECK ((backend_ctx->kernel_reglu = clCreateKernel (backend_ctx->program_glu , " kernel_reglu" , &err), err));
789
789
CL_CHECK ((backend_ctx->kernel_swiglu = clCreateKernel (backend_ctx->program_glu , " kernel_swiglu" , &err), err));
790
+ CL_CHECK ((backend_ctx->kernel_swiglu_oai = clCreateKernel (backend_ctx->program_glu , " kernel_swiglu_oai" , &err), err));
790
791
CL_CHECK ((backend_ctx->kernel_geglu_erf = clCreateKernel (backend_ctx->program_glu , " kernel_geglu_erf" , &err), err));
791
792
CL_CHECK ((backend_ctx->kernel_geglu_quick = clCreateKernel (backend_ctx->program_glu , " kernel_geglu_quick" , &err), err));
792
793
CL_CHECK ((backend_ctx->kernel_geglu_f16 = clCreateKernel (backend_ctx->program_glu , " kernel_geglu_f16" , &err), err));
@@ -2488,6 +2489,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2488
2489
case GGML_GLU_OP_GEGLU:
2489
2490
case GGML_GLU_OP_REGLU:
2490
2491
case GGML_GLU_OP_SWIGLU:
2492
+ case GGML_GLU_OP_SWIGLU_OAI:
2491
2493
case GGML_GLU_OP_GEGLU_ERF:
2492
2494
case GGML_GLU_OP_GEGLU_QUICK:
2493
2495
return ggml_is_contiguous_1 (op->src [0 ]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
@@ -7005,6 +7007,9 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
7005
7007
kernel = backend_ctx->kernel_swiglu_f16 ;
7006
7008
}
7007
7009
break ;
7010
+ case GGML_GLU_OP_SWIGLU_OAI:
7011
+ kernel = backend_ctx->kernel_swiglu_oai ;
7012
+ break ;
7008
7013
case GGML_GLU_OP_GEGLU_ERF:
7009
7014
if (dst->type == GGML_TYPE_F32) {
7010
7015
kernel = backend_ctx->kernel_geglu_erf ;
@@ -7040,7 +7045,10 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
7040
7045
7041
7046
const cl_ulong nb1 = dst->nb [1 ];
7042
7047
7043
- const int swp = ((const int32_t *) dst->op_params )[1 ];
7048
+ const int swp = ggml_get_op_params_i32 (dst, 1 );
7049
+ const float alpha = ggml_get_op_params_f32 (dst, 2 );
7050
+ const float limit = ggml_get_op_params_f32 (dst, 3 );
7051
+
7044
7052
const int ne00_off = src1 ? 0 : (swp ? ne0 : 0 );
7045
7053
const int ne10_off = src1 ? 0 : (swp ? 0 : ne0);
7046
7054
@@ -7057,6 +7065,11 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
7057
7065
CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne00_off));
7058
7066
CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne10_off));
7059
7067
7068
+ if (ggml_get_glu_op (dst) == GGML_GLU_OP_SWIGLU_OAI) {
7069
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (float ), &limit));
7070
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (float ), &alpha));
7071
+ }
7072
+
7060
7073
const size_t nrows = ggml_nrows (src0);
7061
7074
size_t nth = 512 ;
7062
7075
size_t global_work_size[] = {nrows*nth, 1 , 1 };
0 commit comments