Skip to content

Commit e402655

Browse files
[ET][Optimized] Eliminate usage of ET_SWITCH_SCALAR in optimized kernels (#12046)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12033 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/129/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/129/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/116/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/129/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]> Co-authored-by: Manuel Candales <[email protected]>
1 parent 8bf5ffd commit e402655

File tree

3 files changed

+90
-115
lines changed

3 files changed

+90
-115
lines changed

kernels/optimized/cpu/op_add.cpp

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ Tensor& opt_add_out(
4545
ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
4646
CTYPE alpha_val;
4747
ET_KERNEL_CHECK(
48-
ctx,
49-
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
50-
InvalidArgument, );
48+
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
5149
CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
5250
CTYPE b_casted = static_cast<CTYPE>(b_val);
5351

@@ -81,7 +79,6 @@ Tensor& opt_add_scalar_out(
8179
(void)ctx;
8280

8381
ScalarType a_type = a.scalar_type();
84-
ScalarType b_type = utils::get_scalar_dtype(b);
8582
ScalarType common_type =
8683
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
8784
ScalarType out_type = out.scalar_type();
@@ -99,47 +96,43 @@ Tensor& opt_add_scalar_out(
9996
if (a_type == common_type && a_type == out_type &&
10097
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
10198
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() {
102-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
103-
CTYPE_B b_val;
104-
ET_EXTRACT_SCALAR(b, b_val);
105-
CTYPE b_casted = static_cast<CTYPE>(b_val);
106-
CTYPE alpha_val;
107-
ET_EXTRACT_SCALAR(alpha, alpha_val);
108-
109-
using Vec = at::vec::Vectorized<CTYPE>;
110-
at::vec::map<CTYPE>(
111-
[alpha_val, b_casted](Vec x) {
112-
return x + Vec(alpha_val * b_casted);
113-
},
114-
out.mutable_data_ptr<CTYPE>(),
115-
a.const_data_ptr<CTYPE>(),
116-
out.numel());
117-
});
99+
CTYPE b_casted = utils::scalar_to<CTYPE>(b);
100+
CTYPE alpha_val;
101+
ET_KERNEL_CHECK(
102+
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
103+
104+
using Vec = at::vec::Vectorized<CTYPE>;
105+
at::vec::map<CTYPE>(
106+
[alpha_val, b_casted](Vec x) {
107+
return x + Vec(alpha_val * b_casted);
108+
},
109+
out.mutable_data_ptr<CTYPE>(),
110+
a.const_data_ptr<CTYPE>(),
111+
out.numel());
118112
});
119113
} else {
120114
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
121-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
122-
ET_SWITCH_REALB_TYPES(
123-
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
124-
ET_SWITCH_REALHBBF16_TYPES(
125-
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
126-
CTYPE_B b_val;
127-
ET_EXTRACT_SCALAR(b, b_val);
128-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
129-
CTYPE_IN alpha_val;
130-
ET_EXTRACT_SCALAR(alpha, alpha_val);
131-
132-
const size_t n = a.numel();
133-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
134-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
135-
for (auto i = 0; i < n; ++i) {
136-
out_data[i] = static_cast<CTYPE_OUT>(
137-
static_cast<CTYPE_IN>(a_data[i]) +
138-
alpha_val * b_casted);
139-
}
140-
});
141-
});
142-
});
115+
ET_SWITCH_REALB_TYPES(
116+
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
117+
ET_SWITCH_REALHBBF16_TYPES(
118+
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
119+
CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
120+
CTYPE_IN alpha_val;
121+
ET_KERNEL_CHECK(
122+
ctx,
123+
utils::extract_scalar(alpha, &alpha_val),
124+
InvalidArgument, );
125+
126+
const size_t n = a.numel();
127+
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
128+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
129+
for (auto i = 0; i < n; ++i) {
130+
out_data[i] = static_cast<CTYPE_OUT>(
131+
static_cast<CTYPE_IN>(a_data[i]) +
132+
alpha_val * b_casted);
133+
}
134+
});
135+
});
143136
});
144137
}
145138

kernels/optimized/cpu/op_mul.cpp

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ Tensor& opt_mul_scalar_out(
218218
(void)ctx;
219219

220220
ScalarType a_type = a.scalar_type();
221-
ScalarType b_type = utils::get_scalar_dtype(b);
222221
ScalarType common_type =
223222
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
224223
ScalarType out_type = out.scalar_type();
@@ -236,40 +235,32 @@ Tensor& opt_mul_scalar_out(
236235
if (a_type == common_type && a_type == out_type &&
237236
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
238237
ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE, [&]() {
239-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
240-
CTYPE_B b_val;
241-
ET_EXTRACT_SCALAR(b, b_val);
242-
CTYPE b_casted = static_cast<CTYPE>(b_val);
243-
244-
using Vec = at::vec::Vectorized<CTYPE>;
245-
at::vec::map<CTYPE>(
246-
[b_casted](Vec x) { return x * Vec(b_casted); },
247-
out.mutable_data_ptr<CTYPE>(),
248-
a.const_data_ptr<CTYPE>(),
249-
out.numel());
250-
});
238+
CTYPE b_casted = utils::scalar_to<CTYPE>(b);
239+
240+
using Vec = at::vec::Vectorized<CTYPE>;
241+
at::vec::map<CTYPE>(
242+
[b_casted](Vec x) { return x * Vec(b_casted); },
243+
out.mutable_data_ptr<CTYPE>(),
244+
a.const_data_ptr<CTYPE>(),
245+
out.numel());
251246
});
252247
} else {
253248
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
254-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
255-
ET_SWITCH_REALB_TYPES(
256-
common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() {
257-
ET_SWITCH_REALHBBF16_TYPES(
258-
out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() {
259-
CTYPE_B b_val;
260-
ET_EXTRACT_SCALAR(b, b_val);
261-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
262-
263-
const size_t n = a.numel();
264-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
265-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
266-
for (auto i = 0; i < n; ++i) {
267-
out_data[i] = static_cast<CTYPE_OUT>(
268-
static_cast<CTYPE_IN>(a_data[i]) * b_casted);
269-
}
270-
});
271-
});
272-
});
249+
ET_SWITCH_REALB_TYPES(
250+
common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() {
251+
ET_SWITCH_REALHBBF16_TYPES(
252+
out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() {
253+
CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
254+
255+
const size_t n = a.numel();
256+
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
257+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
258+
for (auto i = 0; i < n; ++i) {
259+
out_data[i] = static_cast<CTYPE_OUT>(
260+
static_cast<CTYPE_IN>(a_data[i]) * b_casted);
261+
}
262+
});
263+
});
273264
});
274265
}
275266

kernels/optimized/cpu/op_sub.cpp

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ Tensor& opt_sub_scalar_out(
154154
(void)ctx;
155155

156156
ScalarType a_type = a.scalar_type();
157-
ScalarType b_type = utils::get_scalar_dtype(b);
158157
ScalarType common_type =
159158
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
160159
ScalarType out_type = out.scalar_type();
@@ -172,49 +171,41 @@ Tensor& opt_sub_scalar_out(
172171
if (a_type == common_type && a_type == out_type &&
173172
a_type != ScalarType::Half) {
174173
ET_SWITCH_REAL_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE, [&]() {
175-
ET_SWITCH_SCALAR_OBJ_REAL_TYPES(
176-
b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() {
177-
CTYPE_B b_val;
178-
ET_EXTRACT_SCALAR(b, b_val);
179-
CTYPE b_casted = static_cast<CTYPE>(b_val);
180-
CTYPE alpha_val;
181-
ET_EXTRACT_SCALAR(alpha, alpha_val);
182-
183-
using Vec = at::vec::Vectorized<CTYPE>;
184-
at::vec::map<CTYPE>(
185-
[alpha_val, b_casted](Vec x) {
186-
return x - Vec(alpha_val * b_casted);
187-
},
188-
out.mutable_data_ptr<CTYPE>(),
189-
a.const_data_ptr<CTYPE>(),
190-
out.numel());
191-
});
174+
CTYPE b_casted = utils::scalar_to<CTYPE>(b);
175+
CTYPE alpha_val;
176+
ET_KERNEL_CHECK(
177+
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
178+
179+
using Vec = at::vec::Vectorized<CTYPE>;
180+
at::vec::map<CTYPE>(
181+
[alpha_val, b_casted](Vec x) {
182+
return x - Vec(alpha_val * b_casted);
183+
},
184+
out.mutable_data_ptr<CTYPE>(),
185+
a.const_data_ptr<CTYPE>(),
186+
out.numel());
192187
});
193188
} else {
194189
ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() {
195-
ET_SWITCH_SCALAR_OBJ_REAL_TYPES(
196-
b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() {
197-
ET_SWITCH_REAL_TYPES(
198-
common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() {
199-
ET_SWITCH_REALH_TYPES(
200-
out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() {
201-
CTYPE_B b_val;
202-
ET_EXTRACT_SCALAR(b, b_val);
203-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
204-
CTYPE_IN alpha_val;
205-
ET_EXTRACT_SCALAR(alpha, alpha_val);
206-
207-
const size_t n = a.numel();
208-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
209-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
210-
for (auto i = 0; i < n; ++i) {
211-
out_data[i] = static_cast<CTYPE_OUT>(
212-
static_cast<CTYPE_IN>(a_data[i]) -
213-
alpha_val * b_casted);
214-
}
215-
});
216-
});
217-
});
190+
ET_SWITCH_REAL_TYPES(common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() {
191+
ET_SWITCH_REALH_TYPES(
192+
out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() {
193+
CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
194+
CTYPE_IN alpha_val;
195+
ET_KERNEL_CHECK(
196+
ctx,
197+
utils::extract_scalar(alpha, &alpha_val),
198+
InvalidArgument, );
199+
200+
const size_t n = a.numel();
201+
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
202+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
203+
for (auto i = 0; i < n; ++i) {
204+
out_data[i] = static_cast<CTYPE_OUT>(
205+
static_cast<CTYPE_IN>(a_data[i]) - alpha_val * b_casted);
206+
}
207+
});
208+
});
218209
});
219210
}
220211

0 commit comments

Comments
 (0)