Skip to content

Commit adfaf17

Browse files
Abhi-hppfacebook-github-bot
authored andcommitted
aten.hardsigmoid.default in unary_ops (#5396)
Summary: Pull Request resolved: #5396 Implement aten.hardsigmoid in unary_ops Reviewed By: jorgep31415 Differential Revision: D62584402
1 parent 06c0fa3 commit adfaf17

File tree

4 files changed

+18
-7
lines changed

4 files changed

+18
-7
lines changed

backends/vulkan/runtime/graph/ops/glsl/activations.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,15 @@ vec4 hardshrink(vec4 tex, float lambda, float neg_lambda) {
3030
(vec4(greaterThan(tex, vec4(lambda))) +
3131
vec4(lessThan(tex, vec4(neg_lambda))));
3232
}
33+
34+
float hardsigmoid(float x) {
35+
return mix(float(x >= 0.0), x / 6 + 0.5, float(abs(x) <= 3.0));
36+
}
37+
38+
vec4 hardsigmoid(vec4 tex) {
39+
return vec4(
40+
hardsigmoid(tex.x),
41+
hardsigmoid(tex.y),
42+
hardsigmoid(tex.z),
43+
hardsigmoid(tex.w));
44+
}

backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,5 @@ unary_op:
3838
OPERATOR: hardshrink(X, A, B)
3939
- NAME: hardswish
4040
OPERATOR: hardswish(X)
41+
- NAME: hardsigmoid
42+
OPERATOR: hardsigmoid(X)

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,6 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
114114
"hardshrink"); \
115115
}
116116

117-
#define DEFINE_HARDSWISH_FN(op_name) \
118-
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
119-
return add_unary_op_node( \
120-
graph, args[0], kDummyFloat, kDummyFloat, args[1], #op_name); \
121-
}
122-
123117
void gelu(ComputeGraph& graph, const std::vector<ValueRef>& args) {
124118
// args[1] is the `approximate` string
125119
// https://fburl.com/code/9omngmyo
@@ -140,7 +134,8 @@ DEFINE_CLAMP_FN(clamp);
140134
DEFINE_CLAMP_FN(hardtanh);
141135
DEFINE_RELU_FN(relu);
142136
DEFINE_HARDSHRINK_FN(hardshrink);
143-
DEFINE_HARDSWISH_FN(hardswish);
137+
DEFINE_ACTIVATION_FN(hardswish);
138+
DEFINE_ACTIVATION_FN(hardsigmoid);
144139

145140
REGISTER_OPERATORS {
146141
VK_REGISTER_OP(aten.abs.default, abs);
@@ -157,6 +152,7 @@ REGISTER_OPERATORS {
157152
VK_REGISTER_OP(aten.tanh.default, tanh);
158153
VK_REGISTER_OP(aten.hardshrink.default, hardshrink);
159154
VK_REGISTER_OP(aten.hardswish.default, hardswish);
155+
VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid);
160156
}
161157

162158
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,7 @@ def get_softmax_inputs():
879879
"aten.neg.default",
880880
"aten.cos.default",
881881
"aten.hardswish.default",
882+
"aten.hardsigmoid.default",
882883
]
883884
)
884885
def get_unary_ops_inputs():

0 commit comments

Comments
 (0)