@@ -432,6 +432,7 @@ struct vk_device_struct {
432
432
433
433
// [src/dst 0=fp32,1=fp16]
434
434
vk_pipeline pipeline_gelu[2];
435
+ vk_pipeline pipeline_gelu_erf[2];
435
436
vk_pipeline pipeline_gelu_quick[2];
436
437
vk_pipeline pipeline_silu[2];
437
438
vk_pipeline pipeline_relu[2];
@@ -2798,6 +2799,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2798
2799
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2799
2800
2800
2801
CREATE_UNARY(gelu)
2802
+ CREATE_UNARY(gelu_erf)
2801
2803
CREATE_UNARY(gelu_quick)
2802
2804
CREATE_UNARY(silu)
2803
2805
CREATE_UNARY(relu)
@@ -6531,6 +6533,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6531
6533
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
6532
6534
case GGML_UNARY_OP_GELU:
6533
6535
return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
6536
+ case GGML_UNARY_OP_GELU_ERF:
6537
+ return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
6534
6538
case GGML_UNARY_OP_GELU_QUICK:
6535
6539
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
6536
6540
case GGML_UNARY_OP_RELU:
@@ -8822,6 +8826,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
8822
8826
switch (ggml_get_unary_op(node)) {
8823
8827
case GGML_UNARY_OP_SILU:
8824
8828
case GGML_UNARY_OP_GELU:
8829
+ case GGML_UNARY_OP_GELU_ERF:
8825
8830
case GGML_UNARY_OP_GELU_QUICK:
8826
8831
case GGML_UNARY_OP_RELU:
8827
8832
case GGML_UNARY_OP_TANH:
@@ -9072,6 +9077,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9072
9077
switch (ggml_get_unary_op(node)) {
9073
9078
case GGML_UNARY_OP_SILU:
9074
9079
case GGML_UNARY_OP_GELU:
9080
+ case GGML_UNARY_OP_GELU_ERF:
9075
9081
case GGML_UNARY_OP_GELU_QUICK:
9076
9082
case GGML_UNARY_OP_RELU:
9077
9083
case GGML_UNARY_OP_TANH:
@@ -9290,6 +9296,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9290
9296
switch (ggml_get_unary_op(tensor)) {
9291
9297
case GGML_UNARY_OP_SILU:
9292
9298
case GGML_UNARY_OP_GELU:
9299
+ case GGML_UNARY_OP_GELU_ERF:
9293
9300
case GGML_UNARY_OP_GELU_QUICK:
9294
9301
case GGML_UNARY_OP_RELU:
9295
9302
case GGML_UNARY_OP_TANH:
@@ -10096,6 +10103,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10096
10103
case GGML_OP_UNARY:
10097
10104
switch (ggml_get_unary_op(op)) {
10098
10105
case GGML_UNARY_OP_GELU:
10106
+ case GGML_UNARY_OP_GELU_ERF:
10099
10107
case GGML_UNARY_OP_GELU_QUICK:
10100
10108
case GGML_UNARY_OP_SILU:
10101
10109
case GGML_UNARY_OP_RELU:
@@ -10836,6 +10844,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10836
10844
case GGML_UNARY_OP_GELU:
10837
10845
tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
10838
10846
break;
10847
+ case GGML_UNARY_OP_GELU_ERF:
10848
+ tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
10849
+ break;
10839
10850
case GGML_UNARY_OP_GELU_QUICK:
10840
10851
tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
10841
10852
break;
0 commit comments