Skip to content

Commit ee4aff1

Browse files
Abhi-hppfacebook-github-bot
authored andcommitted
aten.hardsigmoid.default in unary_ops (#5396)
Summary: Pull Request resolved: #5396 Implement aten.hardsigmoid in unary_ops Differential Revision: D62584402
1 parent f7954f6 commit ee4aff1

File tree

4 files changed

+23
-0
lines changed

4 files changed

+23
-0
lines changed

backends/vulkan/runtime/graph/ops/glsl/activations.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,15 @@ vec4 hardshrink(vec4 tex, float lambda, float neg_lambda) {
3030
(vec4(greaterThan(tex, vec4(lambda))) +
3131
vec4(lessThan(tex, vec4(neg_lambda))));
3232
}
33+
34+
float hardsigmoid(float x) {
35+
return mix(float(x >= 0.0), x / 6 + 0.5, float(abs(x) <= 3.0));
36+
}
37+
38+
vec4 hardsigmoid(vec4 tex) {
39+
return vec4(
40+
hardsigmoid(tex.x),
41+
hardsigmoid(tex.y),
42+
hardsigmoid(tex.z),
43+
hardsigmoid(tex.z));
44+
}

backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,5 @@ unary_op:
3838
OPERATOR: hardshrink(X, A, B)
3939
- NAME: hardswish
4040
OPERATOR: hardswish(X)
41+
- NAME: hardsigmoid
42+
OPERATOR: hardsigmoid(X)

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
120120
graph, args[0], kDummyFloat, kDummyFloat, args[1], #op_name); \
121121
}
122122

123+
#define DEFINE_HARDSIGMOID_FN(op_name) \
124+
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
125+
return add_unary_op_node( \
126+
graph, args[0], kDummyFloat, kDummyFloat, args[1], #op_name); \
127+
}
128+
123129
void gelu(ComputeGraph& graph, const std::vector<ValueRef>& args) {
124130
// args[1] is the `approximate` string
125131
// https://fburl.com/code/9omngmyo
@@ -141,6 +147,7 @@ DEFINE_CLAMP_FN(hardtanh);
141147
DEFINE_RELU_FN(relu);
142148
DEFINE_HARDSHRINK_FN(hardshrink);
143149
DEFINE_HARDSWISH_FN(hardswish);
150+
DEFINE_HARDSIGMOID_FN(hardsigmoid);
144151

145152
REGISTER_OPERATORS {
146153
VK_REGISTER_OP(aten.abs.default, abs);
@@ -157,6 +164,7 @@ REGISTER_OPERATORS {
157164
VK_REGISTER_OP(aten.tanh.default, tanh);
158165
VK_REGISTER_OP(aten.hardshrink.default, hardshrink);
159166
VK_REGISTER_OP(aten.hardswish.default, hardswish);
167+
VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid);
160168
}
161169

162170
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,7 @@ def get_softmax_inputs():
878878
"aten.neg.default",
879879
"aten.cos.default",
880880
"aten.hardswish.default",
881+
"aten.hardsigmoid.default",
881882
]
882883
)
883884
def get_unary_ops_inputs():

0 commit comments

Comments
 (0)