diff --git a/backends/cadence/fusion_g3/operators/op_exp.cpp b/backends/cadence/fusion_g3/operators/op_exp.cpp index 41b5d70b222..84d2ac0b94e 100644 --- a/backends/cadence/fusion_g3/operators/op_exp.cpp +++ b/backends/cadence/fusion_g3/operators/op_exp.cpp @@ -60,7 +60,7 @@ Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { return out; } else { return torch::executor::native::internal:: - unary_ufunc_realhbbf16_to_floathbf16(std::exp, ctx, in, out); + unary_ufunc_realhbbf16_to_floathbf16(std::exp, std::exp, ctx, in, out); } } diff --git a/backends/cadence/fusion_g3/operators/op_rsqrt.cpp b/backends/cadence/fusion_g3/operators/op_rsqrt.cpp index 5a869fadd09..59f9094aa29 100644 --- a/backends/cadence/fusion_g3/operators/op_rsqrt.cpp +++ b/backends/cadence/fusion_g3/operators/op_rsqrt.cpp @@ -27,7 +27,8 @@ namespace native { namespace { -double rsqrt(double x) { +template +T rsqrt(T x) { return 1.0 / std::sqrt(x); } @@ -61,11 +62,11 @@ Tensor& rsqrt_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { return out; } else { return torch::executor::native::internal:: - unary_ufunc_realhbbf16_to_floathbf16(rsqrt, ctx, in, out); + unary_ufunc_realhbbf16_to_floathbf16(rsqrt, rsqrt, ctx, in, out); } } } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence \ No newline at end of file +} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_sqrt.cpp b/backends/cadence/fusion_g3/operators/op_sqrt.cpp index c6a5a29fab8..4b0de889a39 100644 --- a/backends/cadence/fusion_g3/operators/op_sqrt.cpp +++ b/backends/cadence/fusion_g3/operators/op_sqrt.cpp @@ -55,7 +55,8 @@ Tensor& sqrt_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { return out; } else { return torch::executor::native::internal:: - unary_ufunc_realhbbf16_to_floathbf16(std::sqrt, ctx, in, out); + unary_ufunc_realhbbf16_to_floathbf16( + std::sqrt, std::sqrt, ctx, in, out); } } diff --git a/backends/cadence/fusion_g3/operators/op_tanh.cpp b/backends/cadence/fusion_g3/operators/op_tanh.cpp index 05f39f1361e..14a21066632 100644 --- a/backends/cadence/fusion_g3/operators/op_tanh.cpp +++ b/backends/cadence/fusion_g3/operators/op_tanh.cpp @@ -55,7 +55,8 @@ Tensor& tanh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { return out; } else { return torch::executor::native::internal:: - unary_ufunc_realhbbf16_to_floathbf16(std::tanh, ctx, in, out); + unary_ufunc_realhbbf16_to_floathbf16( + std::tanh, std::tanh, ctx, in, out); } } diff --git a/backends/cadence/hifi/operators/op_rsqrt.cpp b/backends/cadence/hifi/operators/op_rsqrt.cpp index 885c26723ae..81a20398087 100644 --- a/backends/cadence/hifi/operators/op_rsqrt.cpp +++ b/backends/cadence/hifi/operators/op_rsqrt.cpp @@ -21,7 +21,8 @@ namespace HiFi { namespace native { namespace { -double rsqrt(double x) { +template +T rsqrt(T x) { return 1.0 / std::sqrt(x); } @@ -46,7 +47,7 @@ Tensor& rsqrt_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { } return torch::executor::native::internal:: - unary_ufunc_realhbbf16_to_floathbf16(rsqrt, ctx, in, out); + unary_ufunc_realhbbf16_to_floathbf16(rsqrt, rsqrt, ctx, in, out); } } // namespace native diff --git a/backends/cadence/hifi/operators/op_tanh.cpp b/backends/cadence/hifi/operators/op_tanh.cpp index 3fdd3111ef8..1132efee3d8 100644 --- a/backends/cadence/hifi/operators/op_tanh.cpp +++ b/backends/cadence/hifi/operators/op_tanh.cpp @@ -35,10 +35,10 @@ Tensor& tanh_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { } return torch::executor::native::internal:: - unary_ufunc_realhbbf16_to_floathbf16(std::tanh, ctx, in, out); + unary_ufunc_realhbbf16_to_floathbf16(std::tanh, std::tanh, ctx, in, out); } } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence \ No newline at end of file +} // namespace cadence diff --git a/kernels/portable/cpu/op_acos.cpp b/kernels/portable/cpu/op_acos.cpp index dac3b1546f3..3fc30473fe5 100644 --- a/kernels/portable/cpu/op_acos.cpp +++ b/kernels/portable/cpu/op_acos.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& acos_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::acos, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(acos_out, std::acos) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_acosh.cpp b/kernels/portable/cpu/op_acosh.cpp index 77f7edf4c5d..1d38655b543 100644 --- a/kernels/portable/cpu/op_acosh.cpp +++ b/kernels/portable/cpu/op_acosh.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& acosh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::acosh, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(acosh_out, std::acosh) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_asin.cpp b/kernels/portable/cpu/op_asin.cpp index 6affa6e4122..cdadf8c8bec 100644 --- a/kernels/portable/cpu/op_asin.cpp +++ b/kernels/portable/cpu/op_asin.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& asin_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::asin, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(asin_out, std::asin) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_asinh.cpp b/kernels/portable/cpu/op_asinh.cpp index bce8dcf6d5a..6c96510ac8f 100644 --- a/kernels/portable/cpu/op_asinh.cpp +++ b/kernels/portable/cpu/op_asinh.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& asinh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::asinh, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(asinh_out, std::asinh) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_atan.cpp b/kernels/portable/cpu/op_atan.cpp index 23549627a3b..6c6c6df38c4 100644 --- a/kernels/portable/cpu/op_atan.cpp +++ b/kernels/portable/cpu/op_atan.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& atan_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::atan, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(atan_out, std::atan) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_atanh.cpp b/kernels/portable/cpu/op_atanh.cpp index 13e6e8ca141..df52330015d 100644 --- a/kernels/portable/cpu/op_atanh.cpp +++ b/kernels/portable/cpu/op_atanh.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& atanh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::atanh, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(atanh_out, std::atanh) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_ceil.cpp b/kernels/portable/cpu/op_ceil.cpp index 5aa09ba0084..a39d0aa4f3b 100644 --- a/kernels/portable/cpu/op_ceil.cpp +++ b/kernels/portable/cpu/op_ceil.cpp @@ -16,9 +16,7 @@ namespace native { using executorch::aten::Tensor; -Tensor& ceil_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbf16(std::ceil, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBF16(ceil_out, std::ceil) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_cos.cpp b/kernels/portable/cpu/op_cos.cpp index e536060d162..9a2bb2d610d 100644 --- a/kernels/portable/cpu/op_cos.cpp +++ b/kernels/portable/cpu/op_cos.cpp @@ -14,9 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& cos_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(std::cos, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(cos_out, std::cos) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_cosh.cpp b/kernels/portable/cpu/op_cosh.cpp index e622bbe6fcd..01de2d81fe9 100644 --- a/kernels/portable/cpu/op_cosh.cpp +++ b/kernels/portable/cpu/op_cosh.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& cosh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::cosh, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(cosh_out, std::cosh) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_erf.cpp b/kernels/portable/cpu/op_erf.cpp index 6897bcda95b..30c78e130dc 100644 --- a/kernels/portable/cpu/op_erf.cpp +++ b/kernels/portable/cpu/op_erf.cpp @@ -14,9 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& erf_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(std::erf, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(erf_out, std::erf) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_exp.cpp b/kernels/portable/cpu/op_exp.cpp index cbfc8924cb0..c4a120b328f 100644 --- a/kernels/portable/cpu/op_exp.cpp +++ b/kernels/portable/cpu/op_exp.cpp @@ -14,9 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(std::exp, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(exp_out, std::exp) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_expm1.cpp b/kernels/portable/cpu/op_expm1.cpp index f2d49f615b1..0a6cc86ffe7 100644 --- a/kernels/portable/cpu/op_expm1.cpp +++ b/kernels/portable/cpu/op_expm1.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& expm1_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::expm1, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(expm1_out, std::expm1) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_floor.cpp b/kernels/portable/cpu/op_floor.cpp index 4061722bd27..a5bb9c740e0 100644 --- a/kernels/portable/cpu/op_floor.cpp +++ b/kernels/portable/cpu/op_floor.cpp @@ -16,9 +16,7 @@ namespace native { using executorch::aten::Tensor; -Tensor& floor_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbf16(std::floor, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBF16(floor_out, std::floor) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_isinf.cpp b/kernels/portable/cpu/op_isinf.cpp index 92d1e563a2e..ac0c19f0f7a 100644 --- a/kernels/portable/cpu/op_isinf.cpp +++ b/kernels/portable/cpu/op_isinf.cpp @@ -14,12 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& isinf_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - // Lambda is syntactic sugar needed to workaround compilation on some older - // non-compatible distros where isnan is returning int rather than bool - return internal::unary_ufunc_realhb_to_bool( - [](double x) -> bool { return std::isinf(x); }, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_BOOL(isinf_out, std::isinf) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_isnan.cpp b/kernels/portable/cpu/op_isnan.cpp index 51e189992ee..dad38a2619a 100644 --- a/kernels/portable/cpu/op_isnan.cpp +++ b/kernels/portable/cpu/op_isnan.cpp @@ -14,12 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& isnan_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - // Lambda is syntactic sugar needed to workaround compilation on some older - // non-compatible distros where isnan is returning int rather than bool - return internal::unary_ufunc_realhb_to_bool( - [](double x) -> bool { return std::isnan(x); }, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_BOOL(isnan_out, std::isnan) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_log.cpp b/kernels/portable/cpu/op_log.cpp index 8a36bce8c49..51300ee7441 100644 --- a/kernels/portable/cpu/op_log.cpp +++ b/kernels/portable/cpu/op_log.cpp @@ -14,9 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& log_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(std::log, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(log_out, std::log) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_log10.cpp b/kernels/portable/cpu/op_log10.cpp index 89f9b672476..f159c10eeaa 100644 --- a/kernels/portable/cpu/op_log10.cpp +++ b/kernels/portable/cpu/op_log10.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& log10_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::log10, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(log10_out, std::log10) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_log1p.cpp b/kernels/portable/cpu/op_log1p.cpp index 2daa31e37ff..1d8ed064152 100644 --- a/kernels/portable/cpu/op_log1p.cpp +++ b/kernels/portable/cpu/op_log1p.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& log1p_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::log1p, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(log1p_out, std::log1p) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_log2.cpp b/kernels/portable/cpu/op_log2.cpp index 4d7406832e4..88c4776e001 100644 --- a/kernels/portable/cpu/op_log2.cpp +++ b/kernels/portable/cpu/op_log2.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& log2_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::log2, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(log2_out, std::log2) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_reciprocal.cpp b/kernels/portable/cpu/op_reciprocal.cpp index f22f9883858..4713ce756bd 100644 --- a/kernels/portable/cpu/op_reciprocal.cpp +++ b/kernels/portable/cpu/op_reciprocal.cpp @@ -14,17 +14,14 @@ namespace executor { namespace native { namespace { -double reciprocal(double x) { +template +T reciprocal(T x) { return 1.0 / x; } } // namespace -Tensor& -reciprocal_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - reciprocal, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(reciprocal_out, reciprocal) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_rsqrt.cpp b/kernels/portable/cpu/op_rsqrt.cpp index 19c4c6c1a57..c2a47ce4c26 100644 --- a/kernels/portable/cpu/op_rsqrt.cpp +++ b/kernels/portable/cpu/op_rsqrt.cpp @@ -14,15 +14,14 @@ namespace executor { namespace native { namespace { -double rsqrt(double x) { +template +T rsqrt(T x) { return 1.0 / std::sqrt(x); } } // namespace -Tensor& rsqrt_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(rsqrt, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(rsqrt_out, rsqrt) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_sin.cpp b/kernels/portable/cpu/op_sin.cpp index ad65c4be18b..a763c216353 100644 --- a/kernels/portable/cpu/op_sin.cpp +++ b/kernels/portable/cpu/op_sin.cpp @@ -14,9 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& sin_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(std::sin, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(sin_out, std::sin) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_sinh.cpp b/kernels/portable/cpu/op_sinh.cpp index 21666392392..363936e586e 100644 --- a/kernels/portable/cpu/op_sinh.cpp +++ b/kernels/portable/cpu/op_sinh.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& sinh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::sinh, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(sinh_out, std::sinh) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_sqrt.cpp b/kernels/portable/cpu/op_sqrt.cpp index bd2075f5b04..ad31580f5d4 100644 --- a/kernels/portable/cpu/op_sqrt.cpp +++ b/kernels/portable/cpu/op_sqrt.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& sqrt_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::sqrt, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(sqrt_out, std::sqrt) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_tan.cpp b/kernels/portable/cpu/op_tan.cpp index a2b921d5146..453cfba5638 100644 --- a/kernels/portable/cpu/op_tan.cpp +++ b/kernels/portable/cpu/op_tan.cpp @@ -14,9 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& tan_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(std::tan, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(tan_out, std::tan) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_tanh.cpp b/kernels/portable/cpu/op_tanh.cpp index ae9f93dc62c..7de7c3adc75 100644 --- a/kernels/portable/cpu/op_tanh.cpp +++ b/kernels/portable/cpu/op_tanh.cpp @@ -14,10 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& tanh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::tanh, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(tanh_out, std::tanh) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_trunc.cpp b/kernels/portable/cpu/op_trunc.cpp index 2d70a3b1724..edc717b2ade 100644 --- a/kernels/portable/cpu/op_trunc.cpp +++ b/kernels/portable/cpu/op_trunc.cpp @@ -14,9 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& trunc_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbf16(std::trunc, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBF16(trunc_out, std::trunc) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/pattern/pattern.h b/kernels/portable/cpu/pattern/pattern.h index 2d4b2ac509c..adebeeea9cd 100644 --- a/kernels/portable/cpu/pattern/pattern.h +++ b/kernels/portable/cpu/pattern/pattern.h @@ -60,23 +60,35 @@ namespace internal { * the input tensor element-wise. */ Tensor& unary_ufunc_realhbf16( - double (*fn)(double), + float (*fn_float)(float), + double (*fn_double)(double), KernelRuntimeContext& ctx, const Tensor& in, Tensor& out); +#define DEFINE_UNARY_UFUNC_REALHBF16(op_name, fn) \ + Tensor& op_name(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { \ + return internal::unary_ufunc_realhbf16(fn, fn, ctx, in, out); \ + } + /** * Implements an op pattern for ops that take a single input tensor of any - * realhb dtye (real, half and boolean), no additional arguments, and outputs a - * boolean tensor of the same size. The function fn specifies the math + * realhbbf16 dtype (real/half/bool/bfloat16), no additional arguments, and + * outputs a boolean tensor of the same size. The function fn specifies the math * operation which is applied to the input tensor element-wise. */ -Tensor& unary_ufunc_realhb_to_bool( - bool (*fn)(double), +Tensor& unary_ufunc_realhbbf16_to_bool( + bool (*fn_float)(float), + bool (*fn_double)(double), KernelRuntimeContext& ctx, const Tensor& in, Tensor& out); +#define DEFINE_UNARY_UFUNC_REALHBBF16_TO_BOOL(op_name, fn) \ + Tensor& op_name(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { \ + return internal::unary_ufunc_realhbbf16_to_bool(fn, fn, ctx, in, out); \ + } + /** * Implements an op pattern for ops that take a single input tensor of any * realhbbf16 dtype (real/half/bool/bfloat16), no additional arguments, and @@ -84,11 +96,18 @@ Tensor& unary_ufunc_realhb_to_bool( * the math operation which is applied to the input tensor element-wise. */ Tensor& unary_ufunc_realhbbf16_to_floathbf16( - double (*fn)(double), + float (*fn_float)(float), + double (*fn_double)(double), KernelRuntimeContext& ctx, const Tensor& in, Tensor& out); +#define DEFINE_UNARY_UFUNC_REALHBBF16_TO_FLOATHBF16(op_name, fn) \ + Tensor& op_name(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { \ + return internal::unary_ufunc_realhbbf16_to_floathbf16( \ + fn, fn, ctx, in, out); \ + } + } // namespace internal } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/pattern/targets.bzl b/kernels/portable/cpu/pattern/targets.bzl index 5fc73ccd911..2217efe5e7f 100644 --- a/kernels/portable/cpu/pattern/targets.bzl +++ b/kernels/portable/cpu/pattern/targets.bzl @@ -50,7 +50,7 @@ def define_common_targets(): runtime.cxx_library( name = "pattern", srcs = [ - "unary_ufunc_realhb_to_bool.cpp", + "unary_ufunc_realhbbf16_to_bool.cpp", "unary_ufunc_realhbbf16_to_floathbf16.cpp", "unary_ufunc_realhbf16.cpp", ], diff --git a/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp b/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_bool.cpp similarity index 73% rename from kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp rename to kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_bool.cpp index 367137ad02c..58c814dc4ca 100644 --- a/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp +++ b/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_bool.cpp @@ -15,13 +15,12 @@ namespace executor { namespace native { namespace internal { -Tensor& unary_ufunc_realhb_to_bool( - bool (*fn)(double), +Tensor& unary_ufunc_realhbbf16_to_bool( + bool (*fn_float)(float), + bool (*fn_double)(double), KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - (void)ctx; - // Resize for dynamic shape ET_KERNEL_CHECK_MSG( ctx, @@ -45,7 +44,17 @@ Tensor& unary_ufunc_realhb_to_bool( ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] { apply_unary_map_fn( - [fn](const CTYPE_IN val_in) { return fn(val_in); }, + [fn_double, fn_float](const CTYPE_IN val_in) { + if constexpr (std::is_same_v) { + (void)fn_float; + double xi = static_cast(val_in); + return static_cast(fn_double(xi)); + } else { + (void)fn_double; + float xi = static_cast(val_in); + return static_cast(fn_float(xi)); + } + }, in.const_data_ptr(), out.mutable_data_ptr(), in.numel()); diff --git a/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp b/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp index 602b5b1bfd2..9c513c15890 100644 --- a/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp +++ b/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp @@ -16,12 +16,11 @@ namespace native { namespace internal { Tensor& unary_ufunc_realhbbf16_to_floathbf16( - double (*fn)(double), + float (*fn_float)(float), + double (*fn_double)(double), KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - (void)ctx; - ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out); // Resize for dynamic shape @@ -41,9 +40,16 @@ Tensor& unary_ufunc_realhbbf16_to_floathbf16( ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] { ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, __func__, CTYPE_OUT, [&] { apply_unary_map_fn( - [fn](const CTYPE_IN val_in) { - CTYPE_OUT xi = static_cast(val_in); - return static_cast(fn(xi)); + [fn_double, fn_float](const CTYPE_IN val_in) { + if constexpr (std::is_same_v) { + (void)fn_float; + double xi = static_cast(val_in); + return static_cast(fn_double(xi)); + } else { + (void)fn_double; + float xi = static_cast(val_in); + return static_cast(fn_float(xi)); + } }, in.const_data_ptr(), out.mutable_data_ptr(), diff --git a/kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp b/kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp index 3672e223b7e..584dfb153ab 100644 --- a/kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp +++ b/kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp @@ -16,12 +16,11 @@ namespace native { namespace internal { Tensor& unary_ufunc_realhbf16( - double (*fn)(double), + float (*fn_float)(float), + double (*fn_double)(double), KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - (void)ctx; - // Resize for dynamic shape ET_KERNEL_CHECK_MSG( ctx, @@ -38,7 +37,17 @@ Tensor& unary_ufunc_realhbf16( ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] { apply_unary_map_fn( - [fn](const CTYPE val_in) { return static_cast(fn(val_in)); }, + [fn_double, fn_float](const CTYPE val_in) { + if constexpr (std::is_same_v) { + (void)fn_float; + double xi = static_cast(val_in); + return fn_double(xi); + } else { + (void)fn_double; + float xi = static_cast(val_in); + return static_cast(fn_float(xi)); + } + }, in.const_data_ptr(), out.mutable_data_ptr(), in.numel());