@@ -17,13 +17,17 @@ namespace executor {
1717namespace native {
1818
1919namespace {
20+
21+ template <typename T>
22+ constexpr bool is_half_or_bf16_v = std::is_same_v<T, exec_aten::Half> ||
23+ std::is_same_v<T, exec_aten::BFloat16>;
24+
2025template <
2126 typename CTYPE_IN,
2227 typename CTYPE_OUT,
2328 typename std::enable_if<
24- std::is_same_v<CTYPE_IN, CTYPE_OUT> &&
25- !std::is_same_v<CTYPE_IN, exec_aten::Half> &&
26- !std::is_same_v<CTYPE_OUT, exec_aten::BFloat16>,
29+ std::is_same_v<CTYPE_IN, CTYPE_OUT> && !is_half_or_bf16_v<CTYPE_IN> &&
30+ !is_half_or_bf16_v<CTYPE_OUT>,
2731 int >::type = 0 >
2832void sigmoid_data (
2933 const CTYPE_IN* in_data,
@@ -32,7 +36,7 @@ void sigmoid_data(
3236 using Vec = executorch::vec::Vectorized<CTYPE_IN>;
3337 executorch::vec::map<CTYPE_IN>(
3438 [](Vec x) {
35- auto one_plus_exp = x.neg ().exp () + Vec (1.0 );
39+ auto one_plus_exp = x.neg ().exp () + Vec (static_cast <CTYPE_IN>( 1.0 ) );
3640 return one_plus_exp.reciprocal ();
3741 },
3842 out_data,
@@ -44,19 +48,16 @@ template <
4448 typename CTYPE_IN,
4549 typename CTYPE_OUT,
4650 typename std::enable_if<
47- !std::is_same_v<CTYPE_IN, CTYPE_OUT> ||
48- std::is_same_v<CTYPE_IN, exec_aten::Half> ||
49- std::is_same_v<CTYPE_IN, exec_aten::BFloat16> ||
50- std::is_same_v<CTYPE_OUT, exec_aten::Half> ||
51- std::is_same_v<CTYPE_OUT, exec_aten::BFloat16>,
51+ !std::is_same_v<CTYPE_IN, CTYPE_OUT> || is_half_or_bf16_v<CTYPE_IN> ||
52+ is_half_or_bf16_v<CTYPE_OUT>,
5253 int >::type = 0 >
5354void sigmoid_data (
5455 const CTYPE_IN* in_data,
5556 const size_t numel,
5657 CTYPE_OUT* out_data) {
5758 for (size_t i = 0 ; i < numel; i++) {
5859 CTYPE_OUT xi = static_cast <CTYPE_OUT>(in_data[i]);
59- out_data[i] = (1.0 / (1.0 + std::exp (-xi)));
60+ out_data[i] = (1 .0f / (1 .0f + std::exp (-xi)));
6061 }
6162}
6263
0 commit comments