diff --git a/kernels/portable/cpu/op_isinf.cpp b/kernels/portable/cpu/op_isinf.cpp index 92d1e563a2e..ac0c19f0f7a 100644 --- a/kernels/portable/cpu/op_isinf.cpp +++ b/kernels/portable/cpu/op_isinf.cpp @@ -14,12 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& isinf_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - // Lambda is syntactic sugar needed to workaround compilation on some older - // non-compatible distros where isnan is returning int rather than bool - return internal::unary_ufunc_realhb_to_bool( - [](double x) -> bool { return std::isinf(x); }, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_BOOL(isinf_out, std::isinf) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_isnan.cpp b/kernels/portable/cpu/op_isnan.cpp index 51e189992ee..dad38a2619a 100644 --- a/kernels/portable/cpu/op_isnan.cpp +++ b/kernels/portable/cpu/op_isnan.cpp @@ -14,12 +14,7 @@ namespace torch { namespace executor { namespace native { -Tensor& isnan_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - // Lambda is syntactic sugar needed to workaround compilation on some older - // non-compatible distros where isnan is returning int rather than bool - return internal::unary_ufunc_realhb_to_bool( - [](double x) -> bool { return std::isnan(x); }, ctx, in, out); -} +DEFINE_UNARY_UFUNC_REALHBBF16_TO_BOOL(isnan_out, std::isnan) } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/pattern/pattern.h b/kernels/portable/cpu/pattern/pattern.h index 6efb6fc1a53..adebeeea9cd 100644 --- a/kernels/portable/cpu/pattern/pattern.h +++ b/kernels/portable/cpu/pattern/pattern.h @@ -73,16 +73,22 @@ Tensor& unary_ufunc_realhbf16( /** * Implements an op pattern for ops that take a single input tensor of any - * realhb dtye (real, half and boolean), no additional arguments, and outputs a - * boolean tensor of the same size. The function fn specifies the math + * realhbbf16 dtype (real/half/bool/bfloat16), no additional arguments, and + * outputs a boolean tensor of the same size. The function fn specifies the math * operation which is applied to the input tensor element-wise. */ -Tensor& unary_ufunc_realhb_to_bool( - bool (*fn)(double), +Tensor& unary_ufunc_realhbbf16_to_bool( + bool (*fn_float)(float), + bool (*fn_double)(double), KernelRuntimeContext& ctx, const Tensor& in, Tensor& out); +#define DEFINE_UNARY_UFUNC_REALHBBF16_TO_BOOL(op_name, fn) \ + Tensor& op_name(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { \ + return internal::unary_ufunc_realhbbf16_to_bool(fn, fn, ctx, in, out); \ + } + /** * Implements an op pattern for ops that take a single input tensor of any * realhbbf16 dtype (real/half/bool/bfloat16), no additional arguments, and diff --git a/kernels/portable/cpu/pattern/targets.bzl b/kernels/portable/cpu/pattern/targets.bzl index 5fc73ccd911..2217efe5e7f 100644 --- a/kernels/portable/cpu/pattern/targets.bzl +++ b/kernels/portable/cpu/pattern/targets.bzl @@ -50,7 +50,7 @@ def define_common_targets(): runtime.cxx_library( name = "pattern", srcs = [ - "unary_ufunc_realhb_to_bool.cpp", + "unary_ufunc_realhbbf16_to_bool.cpp", "unary_ufunc_realhbbf16_to_floathbf16.cpp", "unary_ufunc_realhbf16.cpp", ], diff --git a/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp b/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_bool.cpp similarity index 73% rename from kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp rename to kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_bool.cpp index 367137ad02c..58c814dc4ca 100644 --- a/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp +++ b/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_bool.cpp @@ -15,13 +15,12 @@ namespace executor { namespace native { namespace internal { -Tensor& unary_ufunc_realhb_to_bool( - bool (*fn)(double), +Tensor& unary_ufunc_realhbbf16_to_bool( + bool (*fn_float)(float), + bool (*fn_double)(double), KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { - (void)ctx; - // Resize for dynamic shape ET_KERNEL_CHECK_MSG( ctx, @@ -45,7 +44,17 @@ Tensor& unary_ufunc_realhb_to_bool( ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] { apply_unary_map_fn( - [fn](const CTYPE_IN val_in) { return fn(val_in); }, + [fn_double, fn_float](const CTYPE_IN val_in) { + if constexpr (std::is_same_v) { + (void)fn_float; + double xi = static_cast(val_in); + return static_cast(fn_double(xi)); + } else { + (void)fn_double; + float xi = static_cast(val_in); + return static_cast(fn_float(xi)); + } + }, in.const_data_ptr(), out.mutable_data_ptr(), in.numel());