Skip to content

Commit d5e4a58

Browse files
authored
add GEGLU_ERF for vulkan
1 parent 38593bc commit d5e4a58

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ struct vk_device_struct {
440440
vk_pipeline pipeline_geglu[2];
441441
vk_pipeline pipeline_reglu[2];
442442
vk_pipeline pipeline_swiglu[2];
443+
vk_pipeline pipeline_geglu_erf[2];
443444
vk_pipeline pipeline_geglu_quick[2];
444445

445446
vk_pipeline pipeline_leaky_relu_f32;
@@ -2776,6 +2777,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
27762777
CREATE_GLU(geglu)
27772778
CREATE_GLU(reglu)
27782779
CREATE_GLU(swiglu)
2780+
CREATE_GLU(geglu_erf)
27792781
CREATE_GLU(geglu_quick)
27802782
#undef CREATE_GLU
27812783

@@ -6509,6 +6511,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
65096511
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
65106512
case GGML_GLU_OP_SWIGLU:
65116513
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
6514+
case GGML_GLU_OP_GEGLU_ERF:
6515+
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
65126516
case GGML_GLU_OP_GEGLU_QUICK:
65136517
return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
65146518
default:
@@ -8845,6 +8849,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
88458849
case GGML_GLU_OP_GEGLU:
88468850
case GGML_GLU_OP_REGLU:
88478851
case GGML_GLU_OP_SWIGLU:
8852+
case GGML_GLU_OP_GEGLU_ERF:
88488853
case GGML_GLU_OP_GEGLU_QUICK:
88498854
break;
88508855
default:
@@ -9092,6 +9097,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
90929097
case GGML_GLU_OP_GEGLU:
90939098
case GGML_GLU_OP_REGLU:
90949099
case GGML_GLU_OP_SWIGLU:
9100+
case GGML_GLU_OP_GEGLU_ERF:
90959101
case GGML_GLU_OP_GEGLU_QUICK:
90969102
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
90979103
break;
@@ -9310,6 +9316,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
93109316
case GGML_GLU_OP_GEGLU:
93119317
case GGML_GLU_OP_REGLU:
93129318
case GGML_GLU_OP_SWIGLU:
9319+
case GGML_GLU_OP_GEGLU_ERF:
93139320
case GGML_GLU_OP_GEGLU_QUICK:
93149321
buf = tensor->buffer;
93159322
break;
@@ -10120,6 +10127,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1012010127
case GGML_GLU_OP_GEGLU:
1012110128
case GGML_GLU_OP_REGLU:
1012210129
case GGML_GLU_OP_SWIGLU:
10130+
case GGML_GLU_OP_GEGLU_ERF:
1012310131
case GGML_GLU_OP_GEGLU_QUICK:
1012410132
return ggml_is_contiguous(op->src[0]) &&
1012510133
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#version 450
2+
3+
#include "glu_head.comp"
4+
5+
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
6+
// ref: https://www.johndcook.com/blog/python_erf/
7+
const float p_erf = 0.3275911f;
8+
const float a1_erf = 0.254829592f;
9+
const float a2_erf = -0.284496736f;
10+
const float a3_erf = 1.421413741f;
11+
const float a4_erf = -1.453152027f;
12+
const float a5_erf = 1.061405429f;
13+
14+
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
15+
16+
float op(float a, float b) {
17+
const float a_div_sqr2 = a * SQRT_2_INV;
18+
const float sign_x = sign(a_div_sqr2);
19+
const float x = abs(a_div_sqr2);
20+
const float t = 1.0f / (1.0f + p_erf * x);
21+
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
22+
const float erf_approx = sign_x * y;
23+
24+
return 0.5f * a * (1.0f + erf_approx) * b;
25+
}
26+
27+
#include "glu_main.comp"

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,8 @@ void process_shaders() {
591591
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
592592
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
593593
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
594+
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
595+
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
594596
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
595597
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
596598

0 commit comments

Comments
 (0)