@@ -440,6 +440,7 @@ struct vk_device_struct {
440440 vk_pipeline pipeline_geglu[2];
441441 vk_pipeline pipeline_reglu[2];
442442 vk_pipeline pipeline_swiglu[2];
443+ vk_pipeline pipeline_geglu_erf[2];
443444 vk_pipeline pipeline_geglu_quick[2];
444445
445446 vk_pipeline pipeline_leaky_relu_f32;
@@ -2776,6 +2777,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
27762777 CREATE_GLU(geglu)
27772778 CREATE_GLU(reglu)
27782779 CREATE_GLU(swiglu)
2780+ CREATE_GLU(geglu_erf)
27792781 CREATE_GLU(geglu_quick)
27802782#undef CREATE_GLU
27812783
@@ -6509,6 +6511,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
65096511 return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
65106512 case GGML_GLU_OP_SWIGLU:
65116513 return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
6514+ case GGML_GLU_OP_GEGLU_ERF:
6515+ return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
65126516 case GGML_GLU_OP_GEGLU_QUICK:
65136517 return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
65146518 default:
@@ -8845,6 +8849,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
88458849 case GGML_GLU_OP_GEGLU:
88468850 case GGML_GLU_OP_REGLU:
88478851 case GGML_GLU_OP_SWIGLU:
8852+ case GGML_GLU_OP_GEGLU_ERF:
88488853 case GGML_GLU_OP_GEGLU_QUICK:
88498854 break;
88508855 default:
@@ -9092,6 +9097,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
90929097 case GGML_GLU_OP_GEGLU:
90939098 case GGML_GLU_OP_REGLU:
90949099 case GGML_GLU_OP_SWIGLU:
9100+ case GGML_GLU_OP_GEGLU_ERF:
90959101 case GGML_GLU_OP_GEGLU_QUICK:
90969102 ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
90979103 break;
@@ -9310,6 +9316,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
93109316 case GGML_GLU_OP_GEGLU:
93119317 case GGML_GLU_OP_REGLU:
93129318 case GGML_GLU_OP_SWIGLU:
9319+ case GGML_GLU_OP_GEGLU_ERF:
93139320 case GGML_GLU_OP_GEGLU_QUICK:
93149321 buf = tensor->buffer;
93159322 break;
@@ -10120,6 +10127,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1012010127 case GGML_GLU_OP_GEGLU:
1012110128 case GGML_GLU_OP_REGLU:
1012210129 case GGML_GLU_OP_SWIGLU:
10130+ case GGML_GLU_OP_GEGLU_ERF:
1012310131 case GGML_GLU_OP_GEGLU_QUICK:
1012410132 return ggml_is_contiguous(op->src[0]) &&
1012510133 (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
0 commit comments