diff --git a/kernels/portable/cpu/op_acos.cpp b/kernels/portable/cpu/op_acos.cpp index dac3b1546f3..81daf10c9a6 100644 --- a/kernels/portable/cpu/op_acos.cpp +++ b/kernels/portable/cpu/op_acos.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& acos_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::acos, ctx, in, out); + static constexpr const char op_name[] = "acos.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::acos(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_acosh.cpp b/kernels/portable/cpu/op_acosh.cpp index 77f7edf4c5d..b402698d761 100644 --- a/kernels/portable/cpu/op_acosh.cpp +++ b/kernels/portable/cpu/op_acosh.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& acosh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::acosh, ctx, in, out); + static constexpr const char op_name[] = "acosh.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::acosh(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_asin.cpp b/kernels/portable/cpu/op_asin.cpp index 6affa6e4122..ddb52c70e84 100644 --- a/kernels/portable/cpu/op_asin.cpp +++ b/kernels/portable/cpu/op_asin.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& asin_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::asin, ctx, in, out); + static constexpr const char op_name[] = "asin.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::asin(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_asinh.cpp b/kernels/portable/cpu/op_asinh.cpp index bce8dcf6d5a..9441db09589 100644 --- a/kernels/portable/cpu/op_asinh.cpp +++ b/kernels/portable/cpu/op_asinh.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& asinh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::asinh, ctx, in, out); + static constexpr const char op_name[] = "asinh.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::asinh(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_atan.cpp b/kernels/portable/cpu/op_atan.cpp index 23549627a3b..6a73341bf0d 100644 --- a/kernels/portable/cpu/op_atan.cpp +++ b/kernels/portable/cpu/op_atan.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& atan_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::atan, ctx, in, out); + static constexpr const char op_name[] = "atan.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::atan(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_atanh.cpp b/kernels/portable/cpu/op_atanh.cpp index 13e6e8ca141..9e036a5fb3b 100644 --- a/kernels/portable/cpu/op_atanh.cpp +++ b/kernels/portable/cpu/op_atanh.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& atanh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::atanh, ctx, in, out); + static constexpr const char op_name[] = "atanh.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::atanh(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_ceil.cpp b/kernels/portable/cpu/op_ceil.cpp index 5aa09ba0084..e2c8e6f07b6 100644 --- a/kernels/portable/cpu/op_ceil.cpp +++ b/kernels/portable/cpu/op_ceil.cpp @@ -17,7 +17,9 @@ namespace native { using executorch::aten::Tensor; Tensor& ceil_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbf16(std::ceil, ctx, in, out); + static constexpr const char op_name[] = "ceil.out"; + return internal::unary_ufunc_realhbf16( + [](auto x) { return executorch::math::ceil(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_cos.cpp b/kernels/portable/cpu/op_cos.cpp index e536060d162..e7876116f94 100644 --- a/kernels/portable/cpu/op_cos.cpp +++ b/kernels/portable/cpu/op_cos.cpp @@ -15,7 +15,9 @@ namespace executor { namespace native { Tensor& cos_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(std::cos, ctx, in, out); + static constexpr const char op_name[] = "cos.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::cos(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_cosh.cpp b/kernels/portable/cpu/op_cosh.cpp index e622bbe6fcd..9703ff0336c 100644 --- a/kernels/portable/cpu/op_cosh.cpp +++ b/kernels/portable/cpu/op_cosh.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& cosh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::cosh, ctx, in, out); + static constexpr const char op_name[] = "cosh.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::cosh(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_erf.cpp b/kernels/portable/cpu/op_erf.cpp index 6897bcda95b..aee0101fdb4 100644 --- a/kernels/portable/cpu/op_erf.cpp +++ b/kernels/portable/cpu/op_erf.cpp @@ -15,7 +15,9 @@ namespace executor { namespace native { Tensor& erf_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(std::erf, ctx, in, out); + static constexpr const char op_name[] = "erf.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::erf(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_exp.cpp b/kernels/portable/cpu/op_exp.cpp index cbfc8924cb0..f2241613609 100644 --- a/kernels/portable/cpu/op_exp.cpp +++ b/kernels/portable/cpu/op_exp.cpp @@ -15,7 +15,9 @@ namespace executor { namespace native { Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(std::exp, ctx, in, out); + static constexpr const char op_name[] = "exp.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::exp(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_expm1.cpp b/kernels/portable/cpu/op_expm1.cpp index f2d49f615b1..67af9b343bb 100644 --- a/kernels/portable/cpu/op_expm1.cpp +++ b/kernels/portable/cpu/op_expm1.cpp @@ -7,16 +7,19 @@ */ #include +#include #include #include +#include namespace torch { namespace executor { namespace native { Tensor& expm1_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::expm1, ctx, in, out); + static constexpr const char op_name[] = "expm1.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::expm1(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_floor.cpp b/kernels/portable/cpu/op_floor.cpp index 4061722bd27..14b49cafbc1 100644 --- a/kernels/portable/cpu/op_floor.cpp +++ b/kernels/portable/cpu/op_floor.cpp @@ -17,7 +17,9 @@ namespace native { using executorch::aten::Tensor; Tensor& floor_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbf16(std::floor, ctx, in, out); + static constexpr const char op_name[] = "floor.out"; + return internal::unary_ufunc_realhbf16( + [](auto x) { return executorch::math::floor(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_isinf.cpp b/kernels/portable/cpu/op_isinf.cpp index 92d1e563a2e..42798231a84 100644 --- a/kernels/portable/cpu/op_isinf.cpp +++ b/kernels/portable/cpu/op_isinf.cpp @@ -17,8 +17,9 @@ namespace native { Tensor& isinf_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { // Lambda is syntactic sugar needed to workaround compilation on some older // non-compatible distros where isnan is returning int rather than bool - return internal::unary_ufunc_realhb_to_bool( - [](double x) -> bool { return std::isinf(x); }, ctx, in, out); + static constexpr const char op_name[] = "isinf.out"; + return internal::unary_ufunc_realhb_to_bool( + [](auto x) -> bool { return std::isinf(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_isnan.cpp b/kernels/portable/cpu/op_isnan.cpp index 51e189992ee..817d314fd2b 100644 --- a/kernels/portable/cpu/op_isnan.cpp +++ b/kernels/portable/cpu/op_isnan.cpp @@ -17,8 +17,9 @@ namespace native { Tensor& isnan_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { // Lambda is syntactic sugar needed to workaround compilation on some older // non-compatible distros where isnan is returning int rather than bool - return internal::unary_ufunc_realhb_to_bool( - [](double x) -> bool { return std::isnan(x); }, ctx, in, out); + static constexpr const char op_name[] = "isnan.out"; + return internal::unary_ufunc_realhb_to_bool( + [](auto x) -> bool { return std::isnan(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_log.cpp b/kernels/portable/cpu/op_log.cpp index 8a36bce8c49..5b0c32549aa 100644 --- a/kernels/portable/cpu/op_log.cpp +++ b/kernels/portable/cpu/op_log.cpp @@ -15,7 +15,9 @@ namespace executor { namespace native { Tensor& log_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(std::log, ctx, in, out); + static constexpr const char op_name[] = "log.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::log(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_log10.cpp b/kernels/portable/cpu/op_log10.cpp index 89f9b672476..5251aea201d 100644 --- a/kernels/portable/cpu/op_log10.cpp +++ b/kernels/portable/cpu/op_log10.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& log10_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::log10, ctx, in, out); + static constexpr const char op_name[] = "log10.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::log10(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_log1p.cpp b/kernels/portable/cpu/op_log1p.cpp index 2daa31e37ff..f352750a944 100644 --- a/kernels/portable/cpu/op_log1p.cpp +++ b/kernels/portable/cpu/op_log1p.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& log1p_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::log1p, ctx, in, out); + static constexpr const char op_name[] = "log1p.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::log1p(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_log2.cpp b/kernels/portable/cpu/op_log2.cpp index 4d7406832e4..42d17ea83b9 100644 --- a/kernels/portable/cpu/op_log2.cpp +++ b/kernels/portable/cpu/op_log2.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& log2_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::log2, ctx, in, out); + static constexpr const char op_name[] = "log2.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::log2(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_reciprocal.cpp b/kernels/portable/cpu/op_reciprocal.cpp index f22f9883858..a1bd116a962 100644 --- a/kernels/portable/cpu/op_reciprocal.cpp +++ b/kernels/portable/cpu/op_reciprocal.cpp @@ -12,18 +12,11 @@ namespace torch { namespace executor { namespace native { -namespace { - -double reciprocal(double x) { - return 1.0 / x; -} - -} // namespace - Tensor& reciprocal_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - reciprocal, ctx, in, out); + static constexpr const char op_name[] = "reciprocal.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::reciprocal(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_rsqrt.cpp b/kernels/portable/cpu/op_rsqrt.cpp index 19c4c6c1a57..a14eb15d7ec 100644 --- a/kernels/portable/cpu/op_rsqrt.cpp +++ b/kernels/portable/cpu/op_rsqrt.cpp @@ -12,16 +12,11 @@ namespace torch { namespace executor { namespace native { -namespace { - -double rsqrt(double x) { - return 1.0 / std::sqrt(x); -} - -} // namespace Tensor& rsqrt_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(rsqrt, ctx, in, out); + static constexpr const char op_name[] = "rsqrt.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::rsqrt(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_sin.cpp b/kernels/portable/cpu/op_sin.cpp index ad65c4be18b..aeb73009729 100644 --- a/kernels/portable/cpu/op_sin.cpp +++ b/kernels/portable/cpu/op_sin.cpp @@ -15,7 +15,9 @@ namespace executor { namespace native { Tensor& sin_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(std::sin, ctx, in, out); + static constexpr const char op_name[] = "sin.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::sin(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_sinh.cpp b/kernels/portable/cpu/op_sinh.cpp index 21666392392..f4cc67ad35f 100644 --- a/kernels/portable/cpu/op_sinh.cpp +++ b/kernels/portable/cpu/op_sinh.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& sinh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::sinh, ctx, in, out); + static constexpr const char op_name[] = "sinh.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::sinh(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_sqrt.cpp b/kernels/portable/cpu/op_sqrt.cpp index bd2075f5b04..1b3d2ff6de5 100644 --- a/kernels/portable/cpu/op_sqrt.cpp +++ b/kernels/portable/cpu/op_sqrt.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& sqrt_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::sqrt, ctx, in, out); + static constexpr const char op_name[] = "sqrt.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::sqrt(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_tan.cpp b/kernels/portable/cpu/op_tan.cpp index a2b921d5146..19ccb84935b 100644 --- a/kernels/portable/cpu/op_tan.cpp +++ b/kernels/portable/cpu/op_tan.cpp @@ -15,7 +15,9 @@ namespace executor { namespace native { Tensor& tan_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16(std::tan, ctx, in, out); + static constexpr const char op_name[] = "tan.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::tan(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_tanh.cpp b/kernels/portable/cpu/op_tanh.cpp index ae9f93dc62c..623968ac721 100644 --- a/kernels/portable/cpu/op_tanh.cpp +++ b/kernels/portable/cpu/op_tanh.cpp @@ -15,8 +15,9 @@ namespace executor { namespace native { Tensor& tanh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbbf16_to_floathbf16( - std::tanh, ctx, in, out); + static constexpr const char op_name[] = "tanh.out"; + return internal::unary_ufunc_realhbbf16_to_floathbf16( + [](auto x) { return executorch::math::tanh(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/op_trunc.cpp b/kernels/portable/cpu/op_trunc.cpp index 2d70a3b1724..9c96865db0e 100644 --- a/kernels/portable/cpu/op_trunc.cpp +++ b/kernels/portable/cpu/op_trunc.cpp @@ -15,7 +15,9 @@ namespace executor { namespace native { Tensor& trunc_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbf16(std::trunc, ctx, in, out); + static constexpr const char op_name[] = "trunc.out"; + return internal::unary_ufunc_realhbf16( + [](auto x) { return executorch::math::trunc(x); }, ctx, in, out); } } // namespace native diff --git a/kernels/portable/cpu/pattern/pattern.cpp b/kernels/portable/cpu/pattern/pattern.cpp new file mode 100644 index 00000000000..61571f25ddc --- /dev/null +++ b/kernels/portable/cpu/pattern/pattern.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace torch::executor::native::internal { + +bool check_and_resize_inputs( + KernelRuntimeContext& ctx, + const Tensor& in, + Tensor& out) { + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, false); + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor(out, in.sizes()) == Error::Ok, + InvalidArgument, + false, + "Failed to resize output tensor."); + return true; +} + +} // namespace torch::executor::native::internal diff --git a/kernels/portable/cpu/pattern/pattern.h b/kernels/portable/cpu/pattern/pattern.h index 2d4b2ac509c..02690739a01 100644 --- a/kernels/portable/cpu/pattern/pattern.h +++ b/kernels/portable/cpu/pattern/pattern.h @@ -46,6 +46,7 @@ question is a bit more specific, then add a descriptive sufix. */ #pragma once +#include #include namespace torch { @@ -53,29 +54,78 @@ namespace executor { namespace native { namespace internal { +// Implementation detail for the other helpers in this header. Returns +// true on success, false on failure. +bool check_and_resize_inputs( + KernelRuntimeContext& ctx, + const Tensor& in, + Tensor& out); + /** * Implements an op pattern for ops that take a single input tensor of any - * realh dtye, no additional arguments, and outputs a tensor of the same size - * and dtype. The function fn specifies the math operation which is applied to - * the input tensor element-wise. + * realhbf16 dtype, no additional arguments, and outputs a tensor of the same + * size and dtype. The function fn specifies the math operation which is applied + * to the input tensor element-wise. */ +template Tensor& unary_ufunc_realhbf16( - double (*fn)(double), + const Op& fn, KernelRuntimeContext& ctx, const Tensor& in, - Tensor& out); + Tensor& out) { + if (!check_and_resize_inputs(ctx, in, out)) { + return out; + } + ET_KERNEL_CHECK( + ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out); + + ET_SWITCH_REALHBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE, [&] { + utils::apply_unitensor_elementwise_fn< + CTYPE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + fn, ctx, in, utils::SupportedTensorDtypes::REALHBF16, out); + }); + return out; +} /** * Implements an op pattern for ops that take a single input tensor of any - * realhb dtye (real, half and boolean), no additional arguments, and outputs a + * realhb dtype (real, half and boolean), no additional arguments, and outputs a * boolean tensor of the same size. The function fn specifies the math * operation which is applied to the input tensor element-wise. */ +template Tensor& unary_ufunc_realhb_to_bool( - bool (*fn)(double), + const Op& fn, KernelRuntimeContext& ctx, const Tensor& in, - Tensor& out); + Tensor& out) { + if (!check_and_resize_inputs(ctx, in, out)) { + return out; + } + ET_KERNEL_CHECK_MSG( + ctx, + out.scalar_type() == executorch::aten::ScalarType::Bool, + InvalidArgument, + out, + "Expected out tensor to have dtype Bool, but got %" PRId8 " instead.", + static_cast(out.scalar_type())); + + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] { + utils::apply_unitensor_elementwise_fn< + CTYPE_IN, + op_name, + utils::SupportedTensorDtypes::BOOL>( + [fn](const CTYPE_IN val_in) { return fn(val_in); }, + ctx, + in, + utils::SupportedTensorDtypes::REALHBBF16, + out); + }); + + return out; +} /** * Implements an op pattern for ops that take a single input tensor of any @@ -83,11 +133,35 @@ Tensor& unary_ufunc_realhb_to_bool( * outputs a floating point tensor of the same size. The function fn specifies * the math operation which is applied to the input tensor element-wise. */ +template Tensor& unary_ufunc_realhbbf16_to_floathbf16( - double (*fn)(double), + const Op& fn, KernelRuntimeContext& ctx, const Tensor& in, - Tensor& out); + Tensor& out) { + ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out); + + if (!check_and_resize_inputs(ctx, in, out)) { + return out; + } + + ScalarType compute_type = in.scalar_type() == ScalarType::Double + ? ScalarType::Double + : ScalarType::Float; + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&] { + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::FLOATHBF16>( + [fn](const auto val_in) { return fn(val_in); }, + ctx, + in, + utils::SupportedTensorDtypes::REALHBBF16, + out); + }); + + return out; +} } // namespace internal } // namespace native diff --git a/kernels/portable/cpu/pattern/targets.bzl b/kernels/portable/cpu/pattern/targets.bzl index 5fc73ccd911..4140e4e0f14 100644 --- a/kernels/portable/cpu/pattern/targets.bzl +++ b/kernels/portable/cpu/pattern/targets.bzl @@ -49,18 +49,14 @@ def define_common_targets(): runtime.cxx_library( name = "pattern", - srcs = [ - "unary_ufunc_realhb_to_bool.cpp", - "unary_ufunc_realhbbf16_to_floathbf16.cpp", - "unary_ufunc_realhbf16.cpp", - ], + srcs = ["pattern.cpp"], exported_headers = [ "pattern.h", ], compiler_flags = ["-Wno-missing-prototypes"], exported_deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/runtime/kernel:kernel_includes", ], visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."], diff --git a/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp b/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp deleted file mode 100644 index 367137ad02c..00000000000 --- a/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -namespace torch { -namespace executor { -namespace native { -namespace internal { - -Tensor& unary_ufunc_realhb_to_bool( - bool (*fn)(double), - KernelRuntimeContext& ctx, - const Tensor& in, - Tensor& out) { - (void)ctx; - - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( - ctx, - resize_tensor(out, in.sizes()) == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_KERNEL_CHECK_MSG( - ctx, - out.scalar_type() == executorch::aten::ScalarType::Bool, - InvalidArgument, - out, - "Expected out tensor to have dtype Bool, but got %" PRId8 " instead.", - static_cast(out.scalar_type())); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - - const auto in_type = in.scalar_type(); - - ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] { - apply_unary_map_fn( - [fn](const CTYPE_IN val_in) { return fn(val_in); }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); - }); - - return out; -} - -} // namespace internal -} // namespace native -} // namespace executor -} // namespace torch diff --git a/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp b/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp deleted file mode 100644 index 602b5b1bfd2..00000000000 --- a/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -namespace torch { -namespace executor { -namespace native { -namespace internal { - -Tensor& unary_ufunc_realhbbf16_to_floathbf16( - double (*fn)(double), - KernelRuntimeContext& ctx, - const Tensor& in, - Tensor& out) { - (void)ctx; - - ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out); - - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( - ctx, - resize_tensor(out, in.sizes()) == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - - const auto in_type = in.scalar_type(); - const auto out_type = out.scalar_type(); - - ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] { - ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, __func__, CTYPE_OUT, [&] { - apply_unary_map_fn( - [fn](const CTYPE_IN val_in) { - CTYPE_OUT xi = static_cast(val_in); - return static_cast(fn(xi)); - }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); - }); - }); - - return out; -} - -} // namespace internal -} // namespace native -} // namespace executor -} // namespace torch diff --git a/kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp b/kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp deleted file mode 100644 index 3672e223b7e..00000000000 --- a/kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -namespace torch { -namespace executor { -namespace native { -namespace internal { - -Tensor& unary_ufunc_realhbf16( - double (*fn)(double), - KernelRuntimeContext& ctx, - const Tensor& in, - Tensor& out) { - (void)ctx; - - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( - ctx, - resize_tensor(out, in.sizes()) == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - - ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] { - apply_unary_map_fn( - [fn](const CTYPE val_in) { return static_cast(fn(val_in)); }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); - }); - - return out; -} - -} // namespace internal -} // namespace native -} // namespace executor -} // namespace torch diff --git a/kernels/portable/cpu/util/dtype_util.cpp b/kernels/portable/cpu/util/dtype_util.cpp index d240b9f83bc..525199a6f78 100644 --- a/kernels/portable/cpu/util/dtype_util.cpp +++ b/kernels/portable/cpu/util/dtype_util.cpp @@ -27,6 +27,8 @@ bool check_tensor_dtype( return executorch::runtime::tensor_is_floating_type(t); case SupportedTensorDtypes::INTB: return executorch::runtime::tensor_is_integral_type(t, true); + case SupportedTensorDtypes::BOOL: + return executorch::runtime::tensor_is_type(t, ScalarType::Bool); case SupportedTensorDtypes::BOOL_OR_BYTE: return (executorch::runtime::tensor_is_type( t, ScalarType::Bool, ScalarType::Byte)); diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index 1e7901c80b2..15732219c8f 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -72,6 +72,16 @@ load_to_compute_fn get_load_to_compute_fn_intb(const Tensor& t) { return result; } +template +load_to_compute_fn get_load_to_compute_fn_bool(const Tensor& t) { + ET_CHECK_MSG( + t.scalar_type() == ScalarType::Bool, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + return internal::load_and_convert; +} + template load_to_compute_fn get_load_to_compute_fn_bool_or_byte( const Tensor& t) { @@ -165,6 +175,17 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn_intb( return result; } +template +store_compute_to_tensor_fn get_store_compute_to_tensor_fn_bool( + const Tensor& t) { + ET_CHECK_MSG( + t.scalar_type() == ScalarType::Bool, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + return internal::convert_and_store; +} + template store_compute_to_tensor_fn get_store_compute_to_tensor_fn_bool_or_byte(const Tensor& t) { @@ -219,6 +240,7 @@ enum class SupportedTensorDtypes { REALHBF16, FLOATHBF16, INTB, + BOOL, BOOL_OR_BYTE, // DEPRECATED: not likely to be correct; use SAME_AS_COMMON. SAME_AS_COMPUTE, @@ -240,6 +262,8 @@ load_to_compute_fn get_load_to_compute_fn_impl( return get_load_to_compute_fn_realhbf16(t); case SupportedTensorDtypes::INTB: return get_load_to_compute_fn_intb(t); + case SupportedTensorDtypes::BOOL: + return get_load_to_compute_fn_bool(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_load_to_compute_fn_bool_or_byte(t); case SupportedTensorDtypes::SAME_AS_COMPUTE: @@ -271,6 +295,8 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn( t); case SupportedTensorDtypes::INTB: return get_store_compute_to_tensor_fn_intb(t); + case SupportedTensorDtypes::BOOL: + return get_store_compute_to_tensor_fn_bool(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_store_compute_to_tensor_fn_bool_or_byte< CTYPE_COMPUTE, @@ -318,12 +344,14 @@ bool check_tensor_dtype( const ScalarType compute_type); /// Return the one output type we are willing to emit specialized code -/// to handle, given a compute type of CTYPE_COMMON and supported +/// to handle, given a compute type of CTYPE_COMPUTE and supported /// output types of out_dtypes. template inline constexpr ScalarType specialized_output_scalar_type( SupportedTensorDtypes out_dtypes) { switch (out_dtypes) { + case SupportedTensorDtypes::BOOL: + return ScalarType::Bool; case SupportedTensorDtypes::BOOL_OR_BYTE: return ScalarType::Bool; case SupportedTensorDtypes::REALHBBF16: diff --git a/kernels/portable/cpu/util/vectorized_math.h b/kernels/portable/cpu/util/vectorized_math.h index 9e706ace56d..823d0ccc39a 100644 --- a/kernels/portable/cpu/util/vectorized_math.h +++ b/kernels/portable/cpu/util/vectorized_math.h @@ -104,11 +104,14 @@ auto convert_to_vectorized_n_of_float(at::vec::Vectorized vec) { #endif // ET_USE_PYTORCH_HEADERS // To simplify client code, we provide coverage for a bunch of float ops (the -// same ones listed in ATen vml.h) here. +// same ones listed in ATen vml.h, plus acosh, asinh, atanh) here. ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(abs) ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(acos) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(acosh) ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(asin) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(asinh) ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(atan) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(atanh) ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(ceil) ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(cos) ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(cosh) @@ -131,12 +134,30 @@ ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(trunc) ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(lgamma) #ifdef ET_USE_PYTORCH_HEADERS -ET_INTERNAL_VECTORIZED_FLOAT_BINARY_FUNC(rsqrt) +ET_INTERNAL_VECTORIZED_FLOAT_UNARY_FUNC(reciprocal) +ET_INTERNAL_VECTORIZED_FLOAT_UNARY_FUNC(rsqrt) #endif // ET_USE_PYTORCH_HEADERS namespace executorch { inline namespace math { -template >> +inline float reciprocal(float x) { + return 1.0f / x; +} + +inline double reciprocal(double x) { + return 1.0 / x; +} + +template < + typename Integer, + std::enable_if_t, bool> = true> +double reciprocal(Integer x) { + return reciprocal((double)x); +} + +template < + typename T, + std::enable_if_t, bool> = true> T rsqrt(T x) { return T(1) / std::sqrt(x); } diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index c21cceeaae3..34433fbe95c 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -746,6 +746,21 @@ TEST_F(OpMulOutTest, DynamicShapeUnbound) { EXPECT_TENSOR_CLOSE(out, expected_result); } +// >>> torch.ops.aten.mul(torch.tensor([100], dtype=torch.int8), +// torch.tensor([100], dtype=torch.int8), out=torch.zeros([1], +// dtype=torch.long)) tensor([16]) +TEST_F(OpMulOutTest, MixedIntegerDtypeMatchesATen) { + TensorFactory tf_in; + TensorFactory tf_out; + + Tensor in = tf_in.make({1}, {100}); + Tensor out = tf_out.zeros({1}); + Tensor ret = op_mul_out(in, in, out); + + Tensor expected = tf_out.make({1}, {16}); + EXPECT_TENSOR_CLOSE(out, expected); +} + TEST_F(OpMulScalarOutTest, SanityCheck) { TensorFactory tf_a; TensorFactory tf_out; diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index a731ce5c674..84c6567b495 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -533,6 +533,7 @@ ATEN_OPS = ( name = "op_expm1", deps = [ "//executorch/kernels/portable/cpu/pattern:pattern", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target(