Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions kernels/portable/cpu/op_isinf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@ namespace torch {
namespace executor {
namespace native {

Tensor& isinf_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
// Lambda is syntactic sugar needed to workaround compilation on some older
// non-compatible distros where isnan is returning int rather than bool
return internal::unary_ufunc_realhb_to_bool(
[](double x) -> bool { return std::isinf(x); }, ctx, in, out);
}
DEFINE_UNARY_UFUNC_REALHBBF16_TO_BOOL(isinf_out, std::isinf)

} // namespace native
} // namespace executor
Expand Down
7 changes: 1 addition & 6 deletions kernels/portable/cpu/op_isnan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@ namespace torch {
namespace executor {
namespace native {

Tensor& isnan_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
// Lambda is syntactic sugar needed to workaround compilation on some older
// non-compatible distros where isnan is returning int rather than bool
return internal::unary_ufunc_realhb_to_bool(
[](double x) -> bool { return std::isnan(x); }, ctx, in, out);
}
DEFINE_UNARY_UFUNC_REALHBBF16_TO_BOOL(isnan_out, std::isnan)

} // namespace native
} // namespace executor
Expand Down
14 changes: 10 additions & 4 deletions kernels/portable/cpu/pattern/pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,22 @@ Tensor& unary_ufunc_realhbf16(

/**
* Implements an op pattern for ops that take a single input tensor of any
* realhb dtye (real, half and boolean), no additional arguments, and outputs a
* boolean tensor of the same size. The function fn specifies the math
* realhbbf16 dtype (real/half/bool/bfloat16), no additional arguments, and
* outputs a boolean tensor of the same size. The function fn specifies the math
* operation which is applied to the input tensor element-wise.
*/
Tensor& unary_ufunc_realhb_to_bool(
bool (*fn)(double),
Tensor& unary_ufunc_realhbbf16_to_bool(
bool (*fn_float)(float),
bool (*fn_double)(double),
KernelRuntimeContext& ctx,
const Tensor& in,
Tensor& out);

#define DEFINE_UNARY_UFUNC_REALHBBF16_TO_BOOL(op_name, fn) \
Tensor& op_name(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { \
return internal::unary_ufunc_realhbbf16_to_bool(fn, fn, ctx, in, out); \
}

/**
* Implements an op pattern for ops that take a single input tensor of any
* realhbbf16 dtype (real/half/bool/bfloat16), no additional arguments, and
Expand Down
2 changes: 1 addition & 1 deletion kernels/portable/cpu/pattern/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def define_common_targets():
runtime.cxx_library(
name = "pattern",
srcs = [
"unary_ufunc_realhb_to_bool.cpp",
"unary_ufunc_realhbbf16_to_bool.cpp",
"unary_ufunc_realhbbf16_to_floathbf16.cpp",
"unary_ufunc_realhbf16.cpp",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ namespace executor {
namespace native {
namespace internal {

Tensor& unary_ufunc_realhb_to_bool(
bool (*fn)(double),
Tensor& unary_ufunc_realhbbf16_to_bool(
bool (*fn_float)(float),
bool (*fn_double)(double),
KernelRuntimeContext& ctx,
const Tensor& in,
Tensor& out) {
(void)ctx;

// Resize for dynamic shape
ET_KERNEL_CHECK_MSG(
ctx,
Expand All @@ -45,7 +44,17 @@ Tensor& unary_ufunc_realhb_to_bool(

ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] {
apply_unary_map_fn(
[fn](const CTYPE_IN val_in) { return fn(val_in); },
[fn_double, fn_float](const CTYPE_IN val_in) {
if constexpr (std::is_same_v<CTYPE_IN, double>) {
(void)fn_float;
double xi = static_cast<double>(val_in);
return static_cast<bool>(fn_double(xi));
} else {
(void)fn_double;
float xi = static_cast<float>(val_in);
return static_cast<bool>(fn_float(xi));
}
},
in.const_data_ptr<CTYPE_IN>(),
out.mutable_data_ptr<bool>(),
in.numel());
Expand Down
Loading