|
9 | 9 | #include <executorch/kernels/optimized/cpu/binary_ops.h> |
10 | 10 | #include <executorch/kernels/optimized/vec/functional.h> |
11 | 11 | #include <executorch/kernels/optimized/vec/vec.h> |
| 12 | +#include <executorch/kernels/portable/cpu/op_add_impl.h> |
12 | 13 | #include <executorch/kernels/portable/cpu/scalar_utils.h> |
13 | 14 | #include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
14 | 15 | #include <executorch/runtime/kernel/kernel_includes.h> |
@@ -176,35 +177,7 @@ Tensor& opt_add_out( |
176 | 177 | lhs->sizes()[lhs->dim() - 1]); |
177 | 178 | }); |
178 | 179 | } else { |
179 | | - ScalarType common_type = |
180 | | - promoteTypes(a_type, b_type, /*half_to_float*/ true); |
181 | | - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); |
182 | | - |
183 | | - ET_KERNEL_CHECK( |
184 | | - ctx, |
185 | | - resize_to_broadcast_target_size(a, b, out) == Error::Ok, |
186 | | - InvalidArgument, |
187 | | - out); |
188 | | - |
189 | | - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() { |
190 | | - ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { |
191 | | - using CTYPE_IN = typename torch::executor:: |
192 | | - promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type; |
193 | | - ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type); |
194 | | - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() { |
195 | | - CTYPE_IN alpha_val; |
196 | | - ET_KERNEL_CHECK( |
197 | | - ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); |
198 | | - |
199 | | - AddInner< |
200 | | - can_cast<CTYPE_IN, CTYPE_OUT>::value, |
201 | | - CTYPE_A, |
202 | | - CTYPE_B, |
203 | | - CTYPE_IN, |
204 | | - CTYPE_OUT>::run(a, b, alpha_val, out); |
205 | | - }); |
206 | | - }); |
207 | | - }); |
| 180 | + add_out_impl(ctx, a, b, alpha, out); |
208 | 181 | } |
209 | 182 |
|
210 | 183 | return out; |
@@ -255,30 +228,7 @@ Tensor& opt_add_scalar_out( |
255 | 228 | }); |
256 | 229 | }); |
257 | 230 | } else { |
258 | | - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() { |
259 | | - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() { |
260 | | - ET_SWITCH_REALB_TYPES( |
261 | | - common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() { |
262 | | - ET_SWITCH_REALHBBF16_TYPES( |
263 | | - out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() { |
264 | | - CTYPE_B b_val; |
265 | | - ET_EXTRACT_SCALAR(b, b_val); |
266 | | - CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val); |
267 | | - CTYPE_IN alpha_val; |
268 | | - ET_EXTRACT_SCALAR(alpha, alpha_val); |
269 | | - |
270 | | - const size_t n = a.numel(); |
271 | | - const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>(); |
272 | | - CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>(); |
273 | | - for (auto i = 0; i < n; ++i) { |
274 | | - out_data[i] = static_cast<CTYPE_OUT>( |
275 | | - static_cast<CTYPE_IN>(a_data[i]) + |
276 | | - alpha_val * b_casted); |
277 | | - } |
278 | | - }); |
279 | | - }); |
280 | | - }); |
281 | | - }); |
| 231 | + add_scalar_out_impl(ctx, a, b, alpha, out); |
282 | 232 | } |
283 | 233 |
|
284 | 234 | return out; |
|
0 commit comments