Skip to content

Commit 085b817

Browse files
swolchokfacebook-github-bot
authored andcommitted
Make optimized op_exp support bf16 (#5677)
Summary: Pull Request resolved: pytorch/executorch#5677 needed to unblock UnaryUfuncRealHBToFloatHTest change that should be stacked on this rev. ghstack-source-id: 245578280 exported-using-ghexport Reviewed By: manuelcandales Differential Revision: D63438655 fbshipit-source-id: ba71c4f8aca9ed9e676fe2987f9d689fc4c42611
1 parent 07bcd7f commit 085b817

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

kernels/optimized/cpu/op_exp.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ template <
2626
typename CTYPE_IN,
2727
typename CTYPE_OUT,
2828
typename std::enable_if<
29-
std::is_same<CTYPE_IN, CTYPE_OUT>::value &&
30-
!std::is_same<CTYPE_IN, exec_aten::Half>::value &&
31-
!std::is_same<CTYPE_OUT, exec_aten::Half>::value,
29+
std::is_same_v<CTYPE_IN, CTYPE_OUT> &&
30+
!std::is_same_v<CTYPE_IN, exec_aten::Half> &&
31+
!std::is_same_v<CTYPE_OUT, exec_aten::BFloat16>,
3232
int>::type = 0>
3333
void exp_data(
3434
const CTYPE_IN* in_data,
@@ -46,9 +46,11 @@ template <
4646
typename CTYPE_IN,
4747
typename CTYPE_OUT,
4848
typename std::enable_if<
49-
!std::is_same<CTYPE_IN, CTYPE_OUT>::value ||
50-
std::is_same<CTYPE_IN, exec_aten::Half>::value ||
51-
std::is_same<CTYPE_OUT, exec_aten::Half>::value,
49+
!std::is_same_v<CTYPE_IN, CTYPE_OUT> ||
50+
std::is_same_v<CTYPE_IN, exec_aten::Half> ||
51+
std::is_same_v<CTYPE_IN, exec_aten::BFloat16> ||
52+
std::is_same_v<CTYPE_OUT, exec_aten::Half> ||
53+
std::is_same_v<CTYPE_OUT, exec_aten::BFloat16>,
5254
int>::type = 0>
5355
void exp_data(
5456
const CTYPE_IN* in_data,
@@ -76,13 +78,14 @@ Tensor& opt_exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
7678

7779
ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out);
7880

79-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "exp.out", CTYPE_IN, [&] {
80-
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, "exp.out", CTYPE_OUT, [&] {
81-
exp_data<CTYPE_IN, CTYPE_OUT>(
82-
in.const_data_ptr<CTYPE_IN>(),
83-
in.numel(),
84-
out.mutable_data_ptr<CTYPE_OUT>());
85-
});
81+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "exp.out", CTYPE_IN, [&] {
82+
ET_SWITCH_FLOATHBF16_TYPES(
83+
out.scalar_type(), ctx, "exp.out", CTYPE_OUT, [&] {
84+
exp_data<CTYPE_IN, CTYPE_OUT>(
85+
in.const_data_ptr<CTYPE_IN>(),
86+
in.numel(),
87+
out.mutable_data_ptr<CTYPE_OUT>());
88+
});
8689
});
8790

8891
return out;

0 commit comments

Comments
 (0)