Skip to content

Commit 7e33791

Browse files
[EE/BE] add float variant to unary_ufunc_realhbbf16_to_bool (#12291)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12279 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/132/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/132/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/131/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/132/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]>
1 parent 31333b1 commit 7e33791

File tree

5 files changed

+27
-22
lines changed

5 files changed

+27
-22
lines changed

kernels/portable/cpu/op_isinf.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@ namespace torch {
1414
namespace executor {
1515
namespace native {
1616

17-
Tensor& isinf_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
18-
// Lambda is syntactic sugar needed to workaround compilation on some older
19-
// non-compatible distros where isnan is returning int rather than bool
20-
return internal::unary_ufunc_realhb_to_bool(
21-
[](double x) -> bool { return std::isinf(x); }, ctx, in, out);
22-
}
17+
DEFINE_UNARY_UFUNC_REALHBBF16_TO_BOOL(isinf_out, std::isinf)
2318

2419
} // namespace native
2520
} // namespace executor

kernels/portable/cpu/op_isnan.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@ namespace torch {
1414
namespace executor {
1515
namespace native {
1616

17-
Tensor& isnan_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
18-
// Lambda is syntactic sugar needed to workaround compilation on some older
19-
// non-compatible distros where isnan is returning int rather than bool
20-
return internal::unary_ufunc_realhb_to_bool(
21-
[](double x) -> bool { return std::isnan(x); }, ctx, in, out);
22-
}
17+
DEFINE_UNARY_UFUNC_REALHBBF16_TO_BOOL(isnan_out, std::isnan)
2318

2419
} // namespace native
2520
} // namespace executor

kernels/portable/cpu/pattern/pattern.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,22 @@ Tensor& unary_ufunc_realhbf16(
7373

7474
/**
7575
* Implements an op pattern for ops that take a single input tensor of any
76-
* realhb dtye (real, half and boolean), no additional arguments, and outputs a
77-
* boolean tensor of the same size. The function fn specifies the math
76+
* realhbbf16 dtype (real/half/bool/bfloat16), no additional arguments, and
77+
* outputs a boolean tensor of the same size. The function fn specifies the math
7878
* operation which is applied to the input tensor element-wise.
7979
*/
80-
Tensor& unary_ufunc_realhb_to_bool(
81-
bool (*fn)(double),
80+
Tensor& unary_ufunc_realhbbf16_to_bool(
81+
bool (*fn_float)(float),
82+
bool (*fn_double)(double),
8283
KernelRuntimeContext& ctx,
8384
const Tensor& in,
8485
Tensor& out);
8586

87+
#define DEFINE_UNARY_UFUNC_REALHBBF16_TO_BOOL(op_name, fn) \
88+
Tensor& op_name(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { \
89+
return internal::unary_ufunc_realhbbf16_to_bool(fn, fn, ctx, in, out); \
90+
}
91+
8692
/**
8793
* Implements an op pattern for ops that take a single input tensor of any
8894
* realhbbf16 dtype (real/half/bool/bfloat16), no additional arguments, and

kernels/portable/cpu/pattern/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def define_common_targets():
5050
runtime.cxx_library(
5151
name = "pattern",
5252
srcs = [
53-
"unary_ufunc_realhb_to_bool.cpp",
53+
"unary_ufunc_realhbbf16_to_bool.cpp",
5454
"unary_ufunc_realhbbf16_to_floathbf16.cpp",
5555
"unary_ufunc_realhbf16.cpp",
5656
],

kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp renamed to kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_bool.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@ namespace executor {
1515
namespace native {
1616
namespace internal {
1717

18-
Tensor& unary_ufunc_realhb_to_bool(
19-
bool (*fn)(double),
18+
Tensor& unary_ufunc_realhbbf16_to_bool(
19+
bool (*fn_float)(float),
20+
bool (*fn_double)(double),
2021
KernelRuntimeContext& ctx,
2122
const Tensor& in,
2223
Tensor& out) {
23-
(void)ctx;
24-
2524
// Resize for dynamic shape
2625
ET_KERNEL_CHECK_MSG(
2726
ctx,
@@ -45,7 +44,17 @@ Tensor& unary_ufunc_realhb_to_bool(
4544

4645
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] {
4746
apply_unary_map_fn(
48-
[fn](const CTYPE_IN val_in) { return fn(val_in); },
47+
[fn_double, fn_float](const CTYPE_IN val_in) {
48+
if constexpr (std::is_same_v<CTYPE_IN, double>) {
49+
(void)fn_float;
50+
double xi = static_cast<double>(val_in);
51+
return static_cast<bool>(fn_double(xi));
52+
} else {
53+
(void)fn_double;
54+
float xi = static_cast<float>(val_in);
55+
return static_cast<bool>(fn_float(xi));
56+
}
57+
},
4958
in.const_data_ptr<CTYPE_IN>(),
5059
out.mutable_data_ptr<bool>(),
5160
in.numel());

0 commit comments

Comments
 (0)