Skip to content

Commit 0014fb4

Browse files
authored
ggml vulkan: add hardsigmoid and hardswish operations (#15762)
1 parent 661ae31 commit 0014fb4

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ struct vk_device_struct {
529529
vk_pipeline pipeline_relu[2];
530530
vk_pipeline pipeline_tanh[2];
531531
vk_pipeline pipeline_sigmoid[2];
532+
vk_pipeline pipeline_hardsigmoid[2];
533+
vk_pipeline pipeline_hardswish[2];
532534

533535
vk_pipeline pipeline_geglu[2];
534536
vk_pipeline pipeline_reglu[2];
@@ -3261,6 +3263,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
32613263
CREATE_UNARY(relu)
32623264
CREATE_UNARY(tanh)
32633265
CREATE_UNARY(sigmoid)
3266+
CREATE_UNARY(hardsigmoid)
3267+
CREATE_UNARY(hardswish)
32643268
#undef CREATE_UNARY
32653269

32663270
#define CREATE_GLU(name) \
@@ -7533,6 +7537,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
75337537
return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
75347538
case GGML_UNARY_OP_SIGMOID:
75357539
return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
7540+
case GGML_UNARY_OP_HARDSIGMOID:
7541+
return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
7542+
case GGML_UNARY_OP_HARDSWISH:
7543+
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
75367544
default:
75377545
break;
75387546
}
@@ -10201,6 +10209,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1020110209
case GGML_UNARY_OP_RELU:
1020210210
case GGML_UNARY_OP_TANH:
1020310211
case GGML_UNARY_OP_SIGMOID:
10212+
case GGML_UNARY_OP_HARDSIGMOID:
10213+
case GGML_UNARY_OP_HARDSWISH:
1020410214
break;
1020510215
default:
1020610216
return false;
@@ -10571,6 +10581,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1057110581
case GGML_UNARY_OP_RELU:
1057210582
case GGML_UNARY_OP_TANH:
1057310583
case GGML_UNARY_OP_SIGMOID:
10584+
case GGML_UNARY_OP_HARDSIGMOID:
10585+
case GGML_UNARY_OP_HARDSWISH:
1057410586
ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
1057510587
break;
1057610588
default:
@@ -10813,6 +10825,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1081310825
case GGML_UNARY_OP_RELU:
1081410826
case GGML_UNARY_OP_TANH:
1081510827
case GGML_UNARY_OP_SIGMOID:
10828+
case GGML_UNARY_OP_HARDSIGMOID:
10829+
case GGML_UNARY_OP_HARDSWISH:
1081610830
buf = tensor->buffer;
1081710831
break;
1081810832
default:
@@ -11764,6 +11778,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1176411778
case GGML_UNARY_OP_RELU:
1176511779
case GGML_UNARY_OP_TANH:
1176611780
case GGML_UNARY_OP_SIGMOID:
11781+
case GGML_UNARY_OP_HARDSIGMOID:
11782+
case GGML_UNARY_OP_HARDSWISH:
1176711783
return ggml_is_contiguous(op->src[0]) &&
1176811784
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1176911785
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -12580,6 +12596,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1258012596
case GGML_UNARY_OP_SIGMOID:
1258112597
tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
1258212598
break;
12599+
case GGML_UNARY_OP_HARDSIGMOID:
12600+
tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]);
12601+
break;
12602+
case GGML_UNARY_OP_HARDSWISH:
12603+
tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
12604+
break;
1258312605
default:
1258412606
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1258512607
GGML_ABORT("fatal error");
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
20+
const float x = float(data_a[i]);
21+
data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
22+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
20+
const float x = float(data_a[i]);
21+
data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
22+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,10 @@ void process_shaders() {
657657
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
658658
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
659659
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
660+
string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
661+
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
662+
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
663+
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
660664

661665
for (auto rte : {false, true}) {
662666
std::string suffix = rte ? "_rte" : "";

0 commit comments

Comments
 (0)