Skip to content

Commit 3972d7b

Browse files
committed
Revert "[Executorch][Portable] Dont upcast to double for sigmoid"
This reverts commit c242a59. Attempting to debug/fix #7019.
1 parent 0070680 commit 3972d7b

File tree

2 files changed

+15
-24
lines changed

2 files changed

+15
-24
lines changed

kernels/portable/cpu/op_sigmoid.cpp

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#include <cmath>
1010

11-
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1211
#include <executorch/kernels/portable/cpu/util/functional_util.h>
1312
#include <executorch/runtime/kernel/kernel_includes.h>
1413

@@ -36,26 +35,21 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3635
out,
3736
"Failed to resize output tensor.");
3837

39-
ScalarType compute_type =
40-
executorch::runtime::isFloatingType(in.scalar_type()) ? in.scalar_type()
41-
: ScalarType::Float;
42-
compute_type = utils::get_compute_type(compute_type);
43-
44-
// @lint-ignore CLANGTIDY facebook-hte-CArray
45-
static constexpr const char op_name[] = "sigmoid.out";
46-
47-
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
48-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
49-
[](const CTYPE_COMPUTE val_in) {
50-
CTYPE_COMPUTE out_val = static_cast<CTYPE_COMPUTE>(1.0) /
51-
(static_cast<CTYPE_COMPUTE>(1.0) + exp(-val_in));
52-
return out_val;
53-
},
54-
ctx,
55-
in,
56-
utils::SupportedTensorDtypes::REALHBBF16,
57-
out,
58-
utils::SupportedTensorDtypes::FLOATHBF16);
38+
ScalarType in_type = in.scalar_type();
39+
ScalarType out_type = out.scalar_type();
40+
ET_SWITCH_REALHB_TYPES(in_type, ctx, "sigmoid.out", CTYPE_IN, [&]() {
41+
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "sigmoid.out", CTYPE_OUT, [&]() {
42+
apply_unary_map_fn(
43+
[](const CTYPE_IN val_in) {
44+
// perform math in double to preserve precision
45+
double in_casted = static_cast<double>(val_in);
46+
double out_val = 1.0 / (1.0 + exp(-in_casted));
47+
return static_cast<CTYPE_OUT>(out_val);
48+
},
49+
in.const_data_ptr<CTYPE_IN>(),
50+
out.mutable_data_ptr<CTYPE_OUT>(),
51+
in.numel());
52+
});
5953
});
6054

6155
return out;

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,9 +1080,6 @@ ATEN_OPS = (
10801080
name = "op_sigmoid",
10811081
deps = [
10821082
"//executorch/kernels/portable/cpu/util:functional_util",
1083-
"//executorch/kernels/portable/cpu/util:elementwise_util",
1084-
"//executorch/kernels/portable/cpu/util:broadcast_util",
1085-
"//executorch/kernels/portable/cpu/util:dtype_util",
10861083
],
10871084
),
10881085
op_target(

0 commit comments

Comments
 (0)