diff --git a/kernels/portable/cpu/op_ceil.cpp b/kernels/portable/cpu/op_ceil.cpp index 5aa09ba0084..a39d0aa4f3b 100644 --- a/kernels/portable/cpu/op_ceil.cpp +++ b/kernels/portable/cpu/op_ceil.cpp @@ -16,9 +16,7 @@ namespace native { using executorch::aten::Tensor; -Tensor& ceil_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbf16(std::ceil, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBF16(ceil_out, std::ceil) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_floor.cpp b/kernels/portable/cpu/op_floor.cpp index 4061722bd27..a5bb9c740e0 100644 --- a/kernels/portable/cpu/op_floor.cpp +++ b/kernels/portable/cpu/op_floor.cpp @@ -16,9 +16,7 @@ namespace native { using executorch::aten::Tensor; -Tensor& floor_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbf16(std::floor, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBF16(floor_out, std::floor) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_trunc.cpp b/kernels/portable/cpu/op_trunc.cpp index 2d70a3b1724..edc717b2ade 100644 --- a/kernels/portable/cpu/op_trunc.cpp +++ b/kernels/portable/cpu/op_trunc.cpp @@ -14,9 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& trunc_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - return internal::unary_ufunc_realhbf16(std::trunc, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBF16(trunc_out, std::trunc) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/pattern/pattern.h b/kernels/portable/cpu/pattern/pattern.h index f4be3bddc30..6efb6fc1a53 100644 --- a/kernels/portable/cpu/pattern/pattern.h +++ b/kernels/portable/cpu/pattern/pattern.h @@ -60,11 +60,17 @@ namespace internal { * the input tensor element-wise. */ Tensor& unary_ufunc_realhbf16( - double (*fn)(double), + float (*fn_float)(float), + double (*fn_double)(double), KernelRuntimeContext& ctx, const Tensor& in, Tensor& out); +#define DEFINE_UNARY_UFUNC_REALHBF16(op_name, fn) \ + Tensor& op_name(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { \ + return internal::unary_ufunc_realhbf16(fn, fn, ctx, in, out); \ + } + /** * Implements an op pattern for ops that take a single input tensor of any * realhb dtye (real, half and boolean), no additional arguments, and outputs a diff --git a/kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp b/kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp index 3672e223b7e..584dfb153ab 100644 --- a/kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp +++ b/kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp @@ -16,12 +16,11 @@ namespace native { namespace internal { Tensor& unary_ufunc_realhbf16( - double (*fn)(double), + float (*fn_float)(float), + double (*fn_double)(double), KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - (void)ctx; - // Resize for dynamic shape ET_KERNEL_CHECK_MSG( ctx, @@ -38,7 +37,17 @@ Tensor& unary_ufunc_realhbf16( ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] { apply_unary_map_fn( - [fn](const CTYPE val_in) { return static_cast(fn(val_in)); }, + [fn_double, fn_float](const CTYPE val_in) { + if constexpr (std::is_same_v) { + (void)fn_float; + double xi = static_cast(val_in); + return fn_double(xi); + } else { + (void)fn_double; + float xi = static_cast(val_in); + return static_cast(fn_float(xi)); + } + }, in.const_data_ptr(), out.mutable_data_ptr(), in.numel());