Skip to content

Commit 6e5ce8a

Browse files
Abhi-hppfacebook-github-bot
authored andcommitted
aten.hardsigmoid.default in unary_ops (#5396)
Summary: Pull Request resolved: #5396 Implement aten.hardsigmoid in unary_ops Reviewed By: jorgep31415 Differential Revision: D62584402
1 parent 06c0fa3 commit 6e5ce8a

File tree

4 files changed

+19
-2
lines changed

4 files changed

+19
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/activations.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,15 @@ vec4 hardshrink(vec4 tex, float lambda, float neg_lambda) {
3030
(vec4(greaterThan(tex, vec4(lambda))) +
3131
vec4(lessThan(tex, vec4(neg_lambda))));
3232
}
33+
34+
float hardsigmoid(float x) {
35+
return mix(float(x >= 0.0), x / 6 + 0.5, float(abs(x) <= 3.0));
36+
}
37+
38+
vec4 hardsigmoid(vec4 tex) {
39+
return vec4(
40+
hardsigmoid(tex.x),
41+
hardsigmoid(tex.y),
42+
hardsigmoid(tex.z),
43+
hardsigmoid(tex.w));
44+
}

backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,5 @@ unary_op:
3838
OPERATOR: hardshrink(X, A, B)
3939
- NAME: hardswish
4040
OPERATOR: hardswish(X)
41+
- NAME: hardsigmoid
42+
OPERATOR: hardsigmoid(X)

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
114114
"hardshrink"); \
115115
}
116116

117-
#define DEFINE_HARDSWISH_FN(op_name) \
117+
#define DEFINE_ACTIVATION_FN(op_name) \
118118
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
119119
return add_unary_op_node( \
120120
graph, args[0], kDummyFloat, kDummyFloat, args[1], #op_name); \
@@ -140,7 +140,8 @@ DEFINE_CLAMP_FN(clamp);
140140
DEFINE_CLAMP_FN(hardtanh);
141141
DEFINE_RELU_FN(relu);
142142
DEFINE_HARDSHRINK_FN(hardshrink);
143-
DEFINE_HARDSWISH_FN(hardswish);
143+
DEFINE_ACTIVATION_FN(hardswish);
144+
DEFINE_ACTIVATION_FN(hardsigmoid);
144145

145146
REGISTER_OPERATORS {
146147
VK_REGISTER_OP(aten.abs.default, abs);
@@ -157,6 +158,7 @@ REGISTER_OPERATORS {
157158
VK_REGISTER_OP(aten.tanh.default, tanh);
158159
VK_REGISTER_OP(aten.hardshrink.default, hardshrink);
159160
VK_REGISTER_OP(aten.hardswish.default, hardswish);
161+
VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid);
160162
}
161163

162164
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,7 @@ def get_softmax_inputs():
879879
"aten.neg.default",
880880
"aten.cos.default",
881881
"aten.hardswish.default",
882+
"aten.hardsigmoid.default",
882883
]
883884
)
884885
def get_unary_ops_inputs():

0 commit comments

Comments
 (0)