Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ struct vk_device_struct {
vk_pipeline pipeline_l2_norm_f32;

// [src/dst 0=fp32,1=fp16]
vk_pipeline pipeline_exp[2];
vk_pipeline pipeline_gelu[2];
vk_pipeline pipeline_gelu_erf[2];
vk_pipeline pipeline_gelu_quick[2];
Expand Down Expand Up @@ -3062,6 +3063,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);

CREATE_UNARY(exp)
CREATE_UNARY(gelu)
CREATE_UNARY(gelu_erf)
CREATE_UNARY(gelu_quick)
Expand Down Expand Up @@ -7097,6 +7099,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}

switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_EXP:
return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_SILU:
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_GELU:
Expand Down Expand Up @@ -9702,6 +9706,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
return false;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(node)) {
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
Expand Down Expand Up @@ -9979,6 +9984,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(node)) {
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
Expand Down Expand Up @@ -10215,6 +10221,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
Expand Down Expand Up @@ -11125,6 +11132,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
Expand Down Expand Up @@ -11924,6 +11932,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
}
} else if (tensor->op == GGML_OP_UNARY) {
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_EXP:
tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_SILU:
tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
break;
Expand Down
20 changes: 20 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/exp.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#version 450

#include "generic_head.comp"
#include "types.comp"

#extension GL_EXT_control_flow_attributes : enable

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;

if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(exp(float(data_a[i])));
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,8 @@ void process_shaders() {

string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});

string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
Expand Down
Loading