Skip to content

Commit b6996a7

Browse files
committed
optimized sigmoid
basically use exp approximation Differential Revision: [D64156864](https://our.internmc.facebook.com/intern/diff/D64156864/) ghstack-source-id: 248160684 Pull Request resolved: #6241
1 parent 7036fd9 commit b6996a7

File tree

4 files changed

+109
-0
lines changed

4 files changed

+109
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cmath>
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/kernels/optimized/vec/vec.h>
13+
#include <executorch/kernels/optimized/vec/functional.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
namespace {
20+
template <
21+
typename CTYPE_IN,
22+
typename CTYPE_OUT,
23+
typename std::enable_if<
24+
std::is_same_v<CTYPE_IN, CTYPE_OUT> &&
25+
!std::is_same_v<CTYPE_IN, exec_aten::Half> &&
26+
!std::is_same_v<CTYPE_OUT, exec_aten::BFloat16>,
27+
int>::type = 0>
28+
void sigmoid_data(
29+
const CTYPE_IN* in_data,
30+
const size_t numel,
31+
CTYPE_OUT* out_data) {
32+
using Vec = executorch::vec::Vectorized<CTYPE_IN>;
33+
executorch::vec::map<CTYPE_IN>(
34+
[](Vec x) {
35+
auto one_plus_exp = x.neg().exp() + Vec(1.0);
36+
return one_plus_exp.reciprocal();
37+
}, out_data, in_data, numel);
38+
}
39+
40+
template <
41+
typename CTYPE_IN,
42+
typename CTYPE_OUT,
43+
typename std::enable_if<
44+
!std::is_same_v<CTYPE_IN, CTYPE_OUT> ||
45+
std::is_same_v<CTYPE_IN, exec_aten::Half> ||
46+
std::is_same_v<CTYPE_IN, exec_aten::BFloat16> ||
47+
std::is_same_v<CTYPE_OUT, exec_aten::Half> ||
48+
std::is_same_v<CTYPE_OUT, exec_aten::BFloat16>,
49+
int>::type = 0>
50+
void sigmoid_data(
51+
const CTYPE_IN* in_data,
52+
const size_t numel,
53+
CTYPE_OUT* out_data) {
54+
for (size_t i = 0; i < numel; i++) {
55+
CTYPE_OUT xi = static_cast<CTYPE_OUT>(in_data[i]);
56+
out_data[i] = (1.0 / (1.0 + std::exp(-xi)));
57+
}
58+
}
59+
60+
}
61+
62+
using Tensor = exec_aten::Tensor;
63+
64+
Tensor& opt_sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
65+
(void)ctx;
66+
67+
ET_KERNEL_CHECK(
68+
ctx, in.scalar_type() != ScalarType::Bool, InvalidArgument, out);
69+
ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out);
70+
71+
ET_KERNEL_CHECK(
72+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
73+
74+
// Resize for dynamic shape
75+
ET_KERNEL_CHECK_MSG(
76+
ctx,
77+
resize_tensor(out, in.sizes()) == Error::Ok,
78+
InvalidArgument,
79+
out,
80+
"Failed to resize output tensor.");
81+
82+
ScalarType in_type = in.scalar_type();
83+
ScalarType out_type = out.scalar_type();
84+
ET_SWITCH_REALHB_TYPES(in_type, ctx, "sigmoid.out", CTYPE_IN, [&]() {
85+
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "sigmoid.out", CTYPE_OUT, [&]() {
86+
sigmoid_data<CTYPE_IN, CTYPE_OUT>(
87+
in.const_data_ptr<CTYPE_IN>(),
88+
in.numel(),
89+
out.mutable_data_ptr<CTYPE_OUT>());
90+
});
91+
});
92+
93+
return out;
94+
}
95+
96+
} // namespace native
97+
} // namespace executor
98+
} // namespace torch

kernels/optimized/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ _OPTIMIZED_ATEN_OPS = (
2525
],
2626
),
2727
op_target(name = "op_exp"),
28+
op_target(name = "op_sigmoid"),
2829
op_target(
2930
name = "op_gelu",
3031
deps = select({

kernels/optimized/optimized-oss.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
- arg_meta: null
3636
kernel_name: torch::executor::opt_exp_out
3737

38+
- op: sigmoid.out
39+
kernels:
40+
- arg_meta: null
41+
kernel_name: torch::executor::opt_exp_out
42+
3843
- op: le.Scalar_out
3944
kernels:
4045
- arg_meta: null

kernels/optimized/optimized.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
- arg_meta: null
3838
kernel_name: torch::executor::opt_exp_out
3939

40+
- op: sigmoid.out
41+
kernels:
42+
- arg_meta: null
43+
kernel_name: torch::executor::opt_sigmoid_out
44+
4045
- op: gelu.out
4146
kernels:
4247
- arg_meta: null

0 commit comments

Comments
 (0)