Skip to content

Commit 47bd3d2

Browse files
committed
vulkan: add ABS op
Signed-off-by: Giuseppe Scrivano <[email protected]>
1 parent 52842ec commit 47bd3d2

File tree

4 files changed

+32
-1
lines changed

4 files changed

+32
-1
lines changed

docs/ops.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Legend:
1414

1515
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
1616
|-----------|------|------|------|------|------|------|------|------|------|
17-
| ABS |||| 🟡 | 🟡 || 🟡 | ||
17+
| ABS |||| 🟡 | 🟡 || 🟡 | ||
1818
| ACC ||||||||||
1919
| ADD ||||| 🟡 | 🟡 ||||
2020
| ADD1 ||||||||||

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@ struct vk_device_struct {
659659
vk_pipeline pipeline_sigmoid[2];
660660
vk_pipeline pipeline_hardsigmoid[2];
661661
vk_pipeline pipeline_hardswish[2];
662+
vk_pipeline pipeline_abs[2];
662663

663664
vk_pipeline pipeline_geglu[2];
664665
vk_pipeline pipeline_reglu[2];
@@ -3735,6 +3736,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
37353736
CREATE_UNARY(sigmoid)
37363737
CREATE_UNARY(hardsigmoid)
37373738
CREATE_UNARY(hardswish)
3739+
CREATE_UNARY(abs)
37383740
#undef CREATE_UNARY
37393741

37403742
#define CREATE_UNARY_RTE(name) \
@@ -8334,6 +8336,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
83348336
return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
83358337
case GGML_UNARY_OP_HARDSWISH:
83368338
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
8339+
case GGML_UNARY_OP_ABS:
8340+
return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
83378341
default:
83388342
break;
83398343
}
@@ -11286,6 +11290,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1128611290
case GGML_UNARY_OP_SIGMOID:
1128711291
case GGML_UNARY_OP_HARDSIGMOID:
1128811292
case GGML_UNARY_OP_HARDSWISH:
11293+
case GGML_UNARY_OP_ABS:
1128911294
break;
1129011295
default:
1129111296
return false;
@@ -11617,6 +11622,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1161711622
case GGML_UNARY_OP_SIGMOID:
1161811623
case GGML_UNARY_OP_HARDSIGMOID:
1161911624
case GGML_UNARY_OP_HARDSWISH:
11625+
case GGML_UNARY_OP_ABS:
1162011626
ggml_vk_unary(ctx, compute_ctx, src0, node);
1162111627
break;
1162211628
default:
@@ -11888,6 +11894,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1188811894
case GGML_UNARY_OP_SIGMOID:
1188911895
case GGML_UNARY_OP_HARDSIGMOID:
1189011896
case GGML_UNARY_OP_HARDSWISH:
11897+
case GGML_UNARY_OP_ABS:
1189111898
buf = tensor->buffer;
1189211899
break;
1189311900
default:
@@ -13383,6 +13390,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1338313390
case GGML_UNARY_OP_SIGMOID:
1338413391
case GGML_UNARY_OP_HARDSIGMOID:
1338513392
case GGML_UNARY_OP_HARDSWISH:
13393+
case GGML_UNARY_OP_ABS:
1338613394
return ggml_is_contiguous(op->src[0]) &&
1338713395
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1338813396
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#version 450
2+
3+
#include "generic_head.glsl"
4+
#include "types.glsl"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
20+
data_d[i] = D_TYPE(abs(float(data_a[i])));
21+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,8 @@ void process_shaders() {
837837
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
838838
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
839839
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
840+
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
841+
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
840842

841843
for (auto rte : {false, true}) {
842844
std::string suffix = rte ? "_rte" : "";

0 commit comments

Comments
 (0)