@@ -45,9 +45,7 @@ Tensor& opt_add_out(
45
45
ET_SWITCH_REALB_TYPES (b_type, ctx, " add.out" , CTYPE_B, [&]() {
46
46
CTYPE alpha_val;
47
47
ET_KERNEL_CHECK (
48
- ctx,
49
- torch::executor::native::utils::extract_scalar (alpha, &alpha_val),
50
- InvalidArgument, );
48
+ ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
51
49
CTYPE_B b_val = *b.const_data_ptr <CTYPE_B>();
52
50
CTYPE b_casted = static_cast <CTYPE>(b_val);
53
51
@@ -81,7 +79,6 @@ Tensor& opt_add_scalar_out(
81
79
(void )ctx;
82
80
83
81
ScalarType a_type = a.scalar_type ();
84
- ScalarType b_type = utils::get_scalar_dtype (b);
85
82
ScalarType common_type =
86
83
utils::promote_type_with_scalar (a_type, b, /* half_to_float*/ false );
87
84
ScalarType out_type = out.scalar_type ();
@@ -99,47 +96,43 @@ Tensor& opt_add_scalar_out(
99
96
if (a_type == common_type && a_type == out_type &&
100
97
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
101
98
ET_SWITCH_REALB_TYPES (a_type, ctx, " add.Scalar_out" , CTYPE, [&]() {
102
- ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, " add.Scalar_out" , CTYPE_B, [&]() {
103
- CTYPE_B b_val;
104
- ET_EXTRACT_SCALAR (b, b_val);
105
- CTYPE b_casted = static_cast <CTYPE>(b_val);
106
- CTYPE alpha_val;
107
- ET_EXTRACT_SCALAR (alpha, alpha_val);
108
-
109
- using Vec = at::vec::Vectorized<CTYPE>;
110
- at::vec::map<CTYPE>(
111
- [alpha_val, b_casted](Vec x) {
112
- return x + Vec (alpha_val * b_casted);
113
- },
114
- out.mutable_data_ptr <CTYPE>(),
115
- a.const_data_ptr <CTYPE>(),
116
- out.numel ());
117
- });
99
+ CTYPE b_casted = utils::scalar_to<CTYPE>(b);
100
+ CTYPE alpha_val;
101
+ ET_KERNEL_CHECK (
102
+ ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
103
+
104
+ using Vec = at::vec::Vectorized<CTYPE>;
105
+ at::vec::map<CTYPE>(
106
+ [alpha_val, b_casted](Vec x) {
107
+ return x + Vec (alpha_val * b_casted);
108
+ },
109
+ out.mutable_data_ptr <CTYPE>(),
110
+ a.const_data_ptr <CTYPE>(),
111
+ out.numel ());
118
112
});
119
113
} else {
120
114
ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " add.Scalar_out" , CTYPE_A, [&]() {
121
- ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, " add.Scalar_out" , CTYPE_B, [&]() {
122
- ET_SWITCH_REALB_TYPES (
123
- common_type, ctx, " add.Scalar_out" , CTYPE_IN, [&]() {
124
- ET_SWITCH_REALHBBF16_TYPES (
125
- out_type, ctx, " add.Scalar_out" , CTYPE_OUT, [&]() {
126
- CTYPE_B b_val;
127
- ET_EXTRACT_SCALAR (b, b_val);
128
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
129
- CTYPE_IN alpha_val;
130
- ET_EXTRACT_SCALAR (alpha, alpha_val);
131
-
132
- const size_t n = a.numel ();
133
- const CTYPE_A* a_data = a.const_data_ptr <CTYPE_A>();
134
- CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
135
- for (auto i = 0 ; i < n; ++i) {
136
- out_data[i] = static_cast <CTYPE_OUT>(
137
- static_cast <CTYPE_IN>(a_data[i]) +
138
- alpha_val * b_casted);
139
- }
140
- });
141
- });
142
- });
115
+ ET_SWITCH_REALB_TYPES (
116
+ common_type, ctx, " add.Scalar_out" , CTYPE_IN, [&]() {
117
+ ET_SWITCH_REALHBBF16_TYPES (
118
+ out_type, ctx, " add.Scalar_out" , CTYPE_OUT, [&]() {
119
+ CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
120
+ CTYPE_IN alpha_val;
121
+ ET_KERNEL_CHECK (
122
+ ctx,
123
+ utils::extract_scalar (alpha, &alpha_val),
124
+ InvalidArgument, );
125
+
126
+ const size_t n = a.numel ();
127
+ const CTYPE_A* a_data = a.const_data_ptr <CTYPE_A>();
128
+ CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
129
+ for (auto i = 0 ; i < n; ++i) {
130
+ out_data[i] = static_cast <CTYPE_OUT>(
131
+ static_cast <CTYPE_IN>(a_data[i]) +
132
+ alpha_val * b_casted);
133
+ }
134
+ });
135
+ });
143
136
});
144
137
}
145
138
0 commit comments