From adfaf17e386a0b80bc1fa32ebf4f40926ac94918 Mon Sep 17 00:00:00 2001 From: Abhishek Chandra Date: Tue, 17 Sep 2024 15:13:26 -0700 Subject: [PATCH] aten.hardsigmoid.default in unary_ops (#5396) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5396 Implement aten.hardsigmoid in unary_ops Reviewed By: jorgep31415 Differential Revision: D62584402 --- backends/vulkan/runtime/graph/ops/glsl/activations.h | 12 ++++++++++++ backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml | 2 ++ backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp | 10 +++------- backends/vulkan/test/op_tests/cases.py | 1 + 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/activations.h b/backends/vulkan/runtime/graph/ops/glsl/activations.h index c5ee3b20855..32ec3894687 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/activations.h +++ b/backends/vulkan/runtime/graph/ops/glsl/activations.h @@ -30,3 +30,15 @@ vec4 hardshrink(vec4 tex, float lambda, float neg_lambda) { (vec4(greaterThan(tex, vec4(lambda))) + vec4(lessThan(tex, vec4(neg_lambda)))); } + +float hardsigmoid(float x) { + return mix(float(x >= 0.0), x / 6 + 0.5, float(abs(x) <= 3.0)); +} + +vec4 hardsigmoid(vec4 tex) { + return vec4( + hardsigmoid(tex.x), + hardsigmoid(tex.y), + hardsigmoid(tex.z), + hardsigmoid(tex.w)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index eb05b10b108..2b9f0032f41 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -38,3 +38,5 @@ unary_op: OPERATOR: hardshrink(X, A, B) - NAME: hardswish OPERATOR: hardswish(X) + - NAME: hardsigmoid + OPERATOR: hardsigmoid(X) diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 075c0bc923a..2fede692ac1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -114,12 +114,6 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) { "hardshrink"); \ } -#define DEFINE_HARDSWISH_FN(op_name) \ - void op_name(ComputeGraph& graph, const std::vector& args) { \ - return add_unary_op_node( \ - graph, args[0], kDummyFloat, kDummyFloat, args[1], #op_name); \ - } - void gelu(ComputeGraph& graph, const std::vector& args) { // args[1] is the `approximate` string // https://fburl.com/code/9omngmyo @@ -140,7 +134,8 @@ DEFINE_CLAMP_FN(clamp); DEFINE_CLAMP_FN(hardtanh); DEFINE_RELU_FN(relu); DEFINE_HARDSHRINK_FN(hardshrink); -DEFINE_HARDSWISH_FN(hardswish); +DEFINE_ACTIVATION_FN(hardswish); +DEFINE_ACTIVATION_FN(hardsigmoid); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); @@ -157,6 +152,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.tanh.default, tanh); VK_REGISTER_OP(aten.hardshrink.default, hardshrink); VK_REGISTER_OP(aten.hardswish.default, hardswish); + VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index c839db274c0..f2276b0247c 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -879,6 +879,7 @@ def get_softmax_inputs(): "aten.neg.default", "aten.cos.default", "aten.hardswish.default", + "aten.hardsigmoid.default", ] ) def get_unary_ops_inputs():