@@ -431,6 +431,7 @@ struct vk_device_struct {
431431
432432    // [src/dst 0=fp32,1=fp16]
433433    vk_pipeline pipeline_gelu[2];
434+     vk_pipeline pipeline_gelu_erf[2];
434435    vk_pipeline pipeline_gelu_quick[2];
435436    vk_pipeline pipeline_silu[2];
436437    vk_pipeline pipeline_relu[2];
@@ -2761,6 +2762,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
27612762    ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
27622763
27632764    CREATE_UNARY(gelu)
2765+     CREATE_UNARY(gelu_erf)
27642766    CREATE_UNARY(gelu_quick)
27652767    CREATE_UNARY(silu)
27662768    CREATE_UNARY(relu)
@@ -6481,6 +6483,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
64816483                return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
64826484            case GGML_UNARY_OP_GELU:
64836485                return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
6486+             case GGML_UNARY_OP_GELU_ERF:
6487+                 return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
64846488            case GGML_UNARY_OP_GELU_QUICK:
64856489                return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
64866490            case GGML_UNARY_OP_RELU:
@@ -8827,6 +8831,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
88278831        switch (ggml_get_unary_op(node)) {
88288832        case GGML_UNARY_OP_SILU:
88298833        case GGML_UNARY_OP_GELU:
8834+         case GGML_UNARY_OP_GELU_ERF:
88308835        case GGML_UNARY_OP_GELU_QUICK:
88318836        case GGML_UNARY_OP_RELU:
88328837        case GGML_UNARY_OP_TANH:
@@ -9072,6 +9077,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
90729077        switch (ggml_get_unary_op(node)) {
90739078        case GGML_UNARY_OP_SILU:
90749079        case GGML_UNARY_OP_GELU:
9080+         case GGML_UNARY_OP_GELU_ERF:
90759081        case GGML_UNARY_OP_GELU_QUICK:
90769082        case GGML_UNARY_OP_RELU:
90779083        case GGML_UNARY_OP_TANH:
@@ -9289,6 +9295,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
92899295        switch (ggml_get_unary_op(tensor)) {
92909296        case GGML_UNARY_OP_SILU:
92919297        case GGML_UNARY_OP_GELU:
9298+         case GGML_UNARY_OP_GELU_ERF:
92929299        case GGML_UNARY_OP_GELU_QUICK:
92939300        case GGML_UNARY_OP_RELU:
92949301        case GGML_UNARY_OP_TANH:
@@ -10095,6 +10102,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1009510102        case GGML_OP_UNARY:
1009610103            switch (ggml_get_unary_op(op)) {
1009710104                case GGML_UNARY_OP_GELU:
10105+                 case GGML_UNARY_OP_GELU_ERF:
1009810106                case GGML_UNARY_OP_GELU_QUICK:
1009910107                case GGML_UNARY_OP_SILU:
1010010108                case GGML_UNARY_OP_RELU:
@@ -10835,6 +10843,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1083510843        case GGML_UNARY_OP_GELU:
1083610844            tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
1083710845            break;
10846+         case GGML_UNARY_OP_GELU_ERF:
10847+             tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
10848+             break;
1083810849        case GGML_UNARY_OP_GELU_QUICK:
1083910850            tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
1084010851            break;
0 commit comments