|
8 | 8 |
|
9 | 9 | #include <cmath> |
10 | 10 |
|
11 | | -#include <executorch/kernels/portable/cpu/util/elementwise_util.h> |
12 | 11 | #include <executorch/kernels/portable/cpu/util/functional_util.h> |
13 | 12 | #include <executorch/runtime/kernel/kernel_includes.h> |
14 | 13 |
|
@@ -36,26 +35,21 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { |
36 | 35 | out, |
37 | 36 | "Failed to resize output tensor."); |
38 | 37 |
|
39 | | - ScalarType compute_type = |
40 | | - executorch::runtime::isFloatingType(in.scalar_type()) ? in.scalar_type() |
41 | | - : ScalarType::Float; |
42 | | - compute_type = utils::get_compute_type(compute_type); |
43 | | - |
44 | | - // @lint-ignore CLANGTIDY facebook-hte-CArray |
45 | | - static constexpr const char op_name[] = "sigmoid.out"; |
46 | | - |
47 | | - ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { |
48 | | - utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>( |
49 | | - [](const CTYPE_COMPUTE val_in) { |
50 | | - CTYPE_COMPUTE out_val = static_cast<CTYPE_COMPUTE>(1.0) / |
51 | | - (static_cast<CTYPE_COMPUTE>(1.0) + exp(-val_in)); |
52 | | - return out_val; |
53 | | - }, |
54 | | - ctx, |
55 | | - in, |
56 | | - utils::SupportedTensorDtypes::REALHBBF16, |
57 | | - out, |
58 | | - utils::SupportedTensorDtypes::FLOATHBF16); |
| 38 | + ScalarType in_type = in.scalar_type(); |
| 39 | + ScalarType out_type = out.scalar_type(); |
| 40 | + ET_SWITCH_REALHB_TYPES(in_type, ctx, "sigmoid.out", CTYPE_IN, [&]() { |
| 41 | + ET_SWITCH_FLOATH_TYPES(out_type, ctx, "sigmoid.out", CTYPE_OUT, [&]() { |
| 42 | + apply_unary_map_fn( |
| 43 | + [](const CTYPE_IN val_in) { |
| 44 | + // perform math in double to preserve precision |
| 45 | + double in_casted = static_cast<double>(val_in); |
| 46 | + double out_val = 1.0 / (1.0 + exp(-in_casted)); |
| 47 | + return static_cast<CTYPE_OUT>(out_val); |
| 48 | + }, |
| 49 | + in.const_data_ptr<CTYPE_IN>(), |
| 50 | + out.mutable_data_ptr<CTYPE_OUT>(), |
| 51 | + in.numel()); |
| 52 | + }); |
59 | 53 | }); |
60 | 54 |
|
61 | 55 | return out; |
|
0 commit comments