@@ -26,9 +26,9 @@ template <
2626 typename CTYPE_IN,
2727 typename CTYPE_OUT,
2828 typename std::enable_if<
29- std::is_same <CTYPE_IN, CTYPE_OUT>::value &&
30- !std::is_same <CTYPE_IN, exec_aten::Half>::value &&
31- !std::is_same <CTYPE_OUT, exec_aten::Half>::value ,
29+ std::is_same_v <CTYPE_IN, CTYPE_OUT> &&
30+ !std::is_same_v <CTYPE_IN, exec_aten::Half> &&
31+ !std::is_same_v <CTYPE_OUT, exec_aten::BFloat16> ,
3232 int >::type = 0 >
3333void exp_data (
3434 const CTYPE_IN* in_data,
@@ -46,9 +46,11 @@ template <
4646 typename CTYPE_IN,
4747 typename CTYPE_OUT,
4848 typename std::enable_if<
49- !std::is_same<CTYPE_IN, CTYPE_OUT>::value ||
50- std::is_same<CTYPE_IN, exec_aten::Half>::value ||
51- std::is_same<CTYPE_OUT, exec_aten::Half>::value,
49+ !std::is_same_v<CTYPE_IN, CTYPE_OUT> ||
50+ std::is_same_v<CTYPE_IN, exec_aten::Half> ||
51+ std::is_same_v<CTYPE_IN, exec_aten::BFloat16> ||
52+ std::is_same_v<CTYPE_OUT, exec_aten::Half> ||
53+ std::is_same_v<CTYPE_OUT, exec_aten::BFloat16>,
5254 int >::type = 0 >
5355void exp_data (
5456 const CTYPE_IN* in_data,
@@ -76,13 +78,14 @@ Tensor& opt_exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
7678
7779 ET_KERNEL_CHECK (ctx, tensor_is_floating_type (out), InvalidArgument, out);
7880
79- ET_SWITCH_REALHB_TYPES (in.scalar_type (), ctx, " exp.out" , CTYPE_IN, [&] {
80- ET_SWITCH_FLOATH_TYPES (out.scalar_type (), ctx, " exp.out" , CTYPE_OUT, [&] {
81- exp_data<CTYPE_IN, CTYPE_OUT>(
82- in.const_data_ptr <CTYPE_IN>(),
83- in.numel (),
84- out.mutable_data_ptr <CTYPE_OUT>());
85- });
81+ ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " exp.out" , CTYPE_IN, [&] {
82+ ET_SWITCH_FLOATHBF16_TYPES (
83+ out.scalar_type (), ctx, " exp.out" , CTYPE_OUT, [&] {
84+ exp_data<CTYPE_IN, CTYPE_OUT>(
85+ in.const_data_ptr <CTYPE_IN>(),
86+ in.numel (),
87+ out.mutable_data_ptr <CTYPE_OUT>());
88+ });
8689 });
8790
8891 return out;
0 commit comments