Skip to content
29 changes: 16 additions & 13 deletions kernels/optimized/cpu/op_exp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ template <
typename CTYPE_IN,
typename CTYPE_OUT,
typename std::enable_if<
std::is_same<CTYPE_IN, CTYPE_OUT>::value &&
!std::is_same<CTYPE_IN, exec_aten::Half>::value &&
!std::is_same<CTYPE_OUT, exec_aten::Half>::value,
std::is_same_v<CTYPE_IN, CTYPE_OUT> &&
!std::is_same_v<CTYPE_IN, exec_aten::Half> &&
!std::is_same_v<CTYPE_OUT, exec_aten::BFloat16>,
int>::type = 0>
void exp_data(
const CTYPE_IN* in_data,
Expand All @@ -46,9 +46,11 @@ template <
typename CTYPE_IN,
typename CTYPE_OUT,
typename std::enable_if<
!std::is_same<CTYPE_IN, CTYPE_OUT>::value ||
std::is_same<CTYPE_IN, exec_aten::Half>::value ||
std::is_same<CTYPE_OUT, exec_aten::Half>::value,
!std::is_same_v<CTYPE_IN, CTYPE_OUT> ||
std::is_same_v<CTYPE_IN, exec_aten::Half> ||
std::is_same_v<CTYPE_IN, exec_aten::BFloat16> ||
std::is_same_v<CTYPE_OUT, exec_aten::Half> ||
std::is_same_v<CTYPE_OUT, exec_aten::BFloat16>,
int>::type = 0>
void exp_data(
const CTYPE_IN* in_data,
Expand Down Expand Up @@ -76,13 +78,14 @@ Tensor& opt_exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {

ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out);

ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "exp.out", CTYPE_IN, [&] {
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, "exp.out", CTYPE_OUT, [&] {
exp_data<CTYPE_IN, CTYPE_OUT>(
in.const_data_ptr<CTYPE_IN>(),
in.numel(),
out.mutable_data_ptr<CTYPE_OUT>());
});
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "exp.out", CTYPE_IN, [&] {
ET_SWITCH_FLOATHBF16_TYPES(
out.scalar_type(), ctx, "exp.out", CTYPE_OUT, [&] {
exp_data<CTYPE_IN, CTYPE_OUT>(
in.const_data_ptr<CTYPE_IN>(),
in.numel(),
out.mutable_data_ptr<CTYPE_OUT>());
});
});

return out;
Expand Down
Loading