Skip to content

Commit 20c2dac

Browse files
ddwkimaeseulgi
andauthored
vulkan: add exp operation (ggml-org#15456)
Co-authored-by: aeseulgi <[email protected]>
1 parent 96452a3 commit 20c2dac

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ struct vk_device_struct {
490490
vk_pipeline pipeline_l2_norm_f32;
491491

492492
// [src/dst 0=fp32,1=fp16]
493+
vk_pipeline pipeline_exp[2];
493494
vk_pipeline pipeline_gelu[2];
494495
vk_pipeline pipeline_gelu_erf[2];
495496
vk_pipeline pipeline_gelu_quick[2];
@@ -3066,6 +3067,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
30663067
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
30673068
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
30683069

3070+
CREATE_UNARY(exp)
30693071
CREATE_UNARY(gelu)
30703072
CREATE_UNARY(gelu_erf)
30713073
CREATE_UNARY(gelu_quick)
@@ -7133,6 +7135,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
71337135
}
71347136

71357137
switch (ggml_get_unary_op(dst)) {
7138+
case GGML_UNARY_OP_EXP:
7139+
return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
71367140
case GGML_UNARY_OP_SILU:
71377141
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
71387142
case GGML_UNARY_OP_GELU:
@@ -9738,6 +9742,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97389742
return false;
97399743
case GGML_OP_UNARY:
97409744
switch (ggml_get_unary_op(node)) {
9745+
case GGML_UNARY_OP_EXP:
97419746
case GGML_UNARY_OP_SILU:
97429747
case GGML_UNARY_OP_GELU:
97439748
case GGML_UNARY_OP_GELU_ERF:
@@ -10015,6 +10020,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1001510020
break;
1001610021
case GGML_OP_UNARY:
1001710022
switch (ggml_get_unary_op(node)) {
10023+
case GGML_UNARY_OP_EXP:
1001810024
case GGML_UNARY_OP_SILU:
1001910025
case GGML_UNARY_OP_GELU:
1002010026
case GGML_UNARY_OP_GELU_ERF:
@@ -10251,6 +10257,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1025110257
break;
1025210258
case GGML_OP_UNARY:
1025310259
switch (ggml_get_unary_op(tensor)) {
10260+
case GGML_UNARY_OP_EXP:
1025410261
case GGML_UNARY_OP_SILU:
1025510262
case GGML_UNARY_OP_GELU:
1025610263
case GGML_UNARY_OP_GELU_ERF:
@@ -11166,6 +11173,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1116611173
switch (op->op) {
1116711174
case GGML_OP_UNARY:
1116811175
switch (ggml_get_unary_op(op)) {
11176+
case GGML_UNARY_OP_EXP:
1116911177
case GGML_UNARY_OP_GELU:
1117011178
case GGML_UNARY_OP_GELU_ERF:
1117111179
case GGML_UNARY_OP_GELU_QUICK:
@@ -11965,6 +11973,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1196511973
}
1196611974
} else if (tensor->op == GGML_OP_UNARY) {
1196711975
switch (ggml_get_unary_op(tensor)) {
11976+
case GGML_UNARY_OP_EXP:
11977+
tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
11978+
break;
1196811979
case GGML_UNARY_OP_SILU:
1196911980
tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
1197011981
break;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
data_d[i] = D_TYPE(exp(float(data_a[i])));
20+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,8 @@ void process_shaders() {
586586

587587
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
588588

589+
string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
590+
string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
589591
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
590592
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
591593
string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});

0 commit comments

Comments
 (0)