Skip to content

Commit 31333b1

Browse files
[EE/BE] add float variant to unary_ufunc_realhbf16 (#12290)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12278 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/131/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/131/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/130/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/131/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]>
1 parent 6cbee41 commit 31333b1

File tree

5 files changed

+23
-14
lines changed

5 files changed

+23
-14
lines changed

kernels/portable/cpu/op_ceil.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ namespace native {
1616

1717
using executorch::aten::Tensor;
1818

19-
Tensor& ceil_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
20-
return internal::unary_ufunc_realhbf16(std::ceil, ctx, in, out);
21-
}
19+
DEFINE_UNARY_UFUNC_REALHBF16(ceil_out, std::ceil)
2220

2321
} // namespace native
2422
} // namespace executor

kernels/portable/cpu/op_floor.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ namespace native {
1616

1717
using executorch::aten::Tensor;
1818

19-
Tensor& floor_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
20-
return internal::unary_ufunc_realhbf16(std::floor, ctx, in, out);
21-
}
19+
DEFINE_UNARY_UFUNC_REALHBF16(floor_out, std::floor)
2220

2321
} // namespace native
2422
} // namespace executor

kernels/portable/cpu/op_trunc.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ namespace torch {
1414
namespace executor {
1515
namespace native {
1616

17-
Tensor& trunc_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
18-
return internal::unary_ufunc_realhbf16(std::trunc, ctx, in, out);
19-
}
17+
DEFINE_UNARY_UFUNC_REALHBF16(trunc_out, std::trunc)
2018

2119
} // namespace native
2220
} // namespace executor

kernels/portable/cpu/pattern/pattern.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,17 @@ namespace internal {
6060
* the input tensor element-wise.
6161
*/
6262
Tensor& unary_ufunc_realhbf16(
63-
double (*fn)(double),
63+
float (*fn_float)(float),
64+
double (*fn_double)(double),
6465
KernelRuntimeContext& ctx,
6566
const Tensor& in,
6667
Tensor& out);
6768

69+
#define DEFINE_UNARY_UFUNC_REALHBF16(op_name, fn) \
70+
Tensor& op_name(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { \
71+
return internal::unary_ufunc_realhbf16(fn, fn, ctx, in, out); \
72+
}
73+
6874
/**
6975
* Implements an op pattern for ops that take a single input tensor of any
7076
* realhb dtye (real, half and boolean), no additional arguments, and outputs a

kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@ namespace native {
1616
namespace internal {
1717

1818
Tensor& unary_ufunc_realhbf16(
19-
double (*fn)(double),
19+
float (*fn_float)(float),
20+
double (*fn_double)(double),
2021
KernelRuntimeContext& ctx,
2122
const Tensor& in,
2223
Tensor& out) {
23-
(void)ctx;
24-
2524
// Resize for dynamic shape
2625
ET_KERNEL_CHECK_MSG(
2726
ctx,
@@ -38,7 +37,17 @@ Tensor& unary_ufunc_realhbf16(
3837

3938
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] {
4039
apply_unary_map_fn(
41-
[fn](const CTYPE val_in) { return static_cast<CTYPE>(fn(val_in)); },
40+
[fn_double, fn_float](const CTYPE val_in) {
41+
if constexpr (std::is_same_v<CTYPE, double>) {
42+
(void)fn_float;
43+
double xi = static_cast<double>(val_in);
44+
return fn_double(xi);
45+
} else {
46+
(void)fn_double;
47+
float xi = static_cast<float>(val_in);
48+
return static_cast<CTYPE>(fn_float(xi));
49+
}
50+
},
4251
in.const_data_ptr<CTYPE>(),
4352
out.mutable_data_ptr<CTYPE>(),
4453
in.numel());

0 commit comments

Comments
 (0)