|  | 
| 8 | 8 | 
 | 
| 9 | 9 | #include <cmath> | 
| 10 | 10 | 
 | 
|  | 11 | +#include <executorch/kernels/portable/cpu/util/elementwise_util.h> | 
| 11 | 12 | #include <executorch/kernels/portable/cpu/util/functional_util.h> | 
| 12 | 13 | #include <executorch/runtime/kernel/kernel_includes.h> | 
| 13 | 14 | 
 | 
| @@ -35,21 +36,26 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { | 
| 35 | 36 |       out, | 
| 36 | 37 |       "Failed to resize output tensor."); | 
| 37 | 38 | 
 | 
| 38 |  | -  ScalarType in_type = in.scalar_type(); | 
| 39 |  | -  ScalarType out_type = out.scalar_type(); | 
| 40 |  | -  ET_SWITCH_REALHB_TYPES(in_type, ctx, "sigmoid.out", CTYPE_IN, [&]() { | 
| 41 |  | -    ET_SWITCH_FLOATH_TYPES(out_type, ctx, "sigmoid.out", CTYPE_OUT, [&]() { | 
| 42 |  | -      apply_unary_map_fn( | 
| 43 |  | -          [](const CTYPE_IN val_in) { | 
| 44 |  | -            // perform math in double to preserve precision | 
| 45 |  | -            double in_casted = static_cast<double>(val_in); | 
| 46 |  | -            double out_val = 1.0 / (1.0 + exp(-in_casted)); | 
| 47 |  | -            return static_cast<CTYPE_OUT>(out_val); | 
| 48 |  | -          }, | 
| 49 |  | -          in.const_data_ptr<CTYPE_IN>(), | 
| 50 |  | -          out.mutable_data_ptr<CTYPE_OUT>(), | 
| 51 |  | -          in.numel()); | 
| 52 |  | -    }); | 
|  | 39 | +  ScalarType compute_type = | 
|  | 40 | +      executorch::runtime::isFloatingType(in.scalar_type()) ? in.scalar_type() | 
|  | 41 | +                                                            : ScalarType::Float; | 
|  | 42 | +  compute_type = utils::get_compute_type(compute_type); | 
|  | 43 | + | 
|  | 44 | +  // @lint-ignore CLANGTIDY facebook-hte-CArray | 
|  | 45 | +  static constexpr const char op_name[] = "sigmoid.out"; | 
|  | 46 | + | 
|  | 47 | +  ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { | 
|  | 48 | +    utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>( | 
|  | 49 | +        [](const CTYPE_COMPUTE val_in) { | 
|  | 50 | +          CTYPE_COMPUTE out_val = static_cast<CTYPE_COMPUTE>(1.0) / | 
|  | 51 | +              (static_cast<CTYPE_COMPUTE>(1.0) + exp(-val_in)); | 
|  | 52 | +          return out_val; | 
|  | 53 | +        }, | 
|  | 54 | +        ctx, | 
|  | 55 | +        in, | 
|  | 56 | +        utils::SupportedTensorDtypes::REALHBBF16, | 
|  | 57 | +        out, | 
|  | 58 | +        utils::SupportedTensorDtypes::FLOATHBF16); | 
| 53 | 59 |   }); | 
| 54 | 60 | 
 | 
| 55 | 61 |   return out; | 
|  | 
0 commit comments