1414#include < executorch/runtime/kernel/kernel_includes.h>
1515#include < executorch/runtime/platform/assert.h>
1616
17+ #include < executorch/kernels/optimized/cpu/op_add_sub_impl.h>
18+
1719namespace torch {
1820namespace executor {
1921namespace native {
20- namespace {
21-
22- template <
23- bool can_cast,
24- typename CTYPE_A,
25- typename CTYPE_B,
26- typename CTYPE_IN,
27- typename CTYPE_OUT>
28- struct AddInner ;
29-
30- template <
31- typename CTYPE_A,
32- typename CTYPE_B,
33- typename CTYPE_IN,
34- typename CTYPE_OUT>
35- struct AddInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
36- static void
37- run (const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
38- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39- // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40- [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
41- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
42- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
43- CTYPE_IN value = a_casted + alpha_val * b_casted;
44-
45- return static_cast <CTYPE_OUT>(value);
46- },
47- a,
48- b,
49- out);
50- }
51- };
52-
53- template <typename CTYPE_IN>
54- struct ReportCanCastBug {
55- static void run (const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
56- ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
57- }
58- };
59-
60- template <
61- typename CTYPE_A,
62- typename CTYPE_B,
63- typename CTYPE_IN,
64- typename CTYPE_OUT>
65- struct AddInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66- : public ReportCanCastBug<CTYPE_IN> {};
67-
68- } // namespace
69-
7022using Tensor = executorch::aten::Tensor;
7123using ScalarType = executorch::aten::ScalarType;
7224
@@ -76,8 +28,6 @@ Tensor& opt_add_out(
7628 const Tensor& b,
7729 const Scalar& alpha,
7830 Tensor& out) {
79- (void )ctx;
80-
8131 ScalarType a_type = a.scalar_type ();
8232 ScalarType b_type = b.scalar_type ();
8333 ScalarType out_type = out.scalar_type ();
@@ -95,7 +45,9 @@ Tensor& opt_add_out(
9545 ET_SWITCH_REALB_TYPES (b_type, ctx, " add.out" , CTYPE_B, [&]() {
9646 CTYPE alpha_val;
9747 ET_KERNEL_CHECK (
98- ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
48+ ctx,
49+ torch::executor::native::utils::extract_scalar (alpha, &alpha_val),
50+ InvalidArgument, );
9951 CTYPE_B b_val = *b.const_data_ptr <CTYPE_B>();
10052 CTYPE b_casted = static_cast <CTYPE>(b_val);
10153
@@ -115,101 +67,9 @@ Tensor& opt_add_out(
11567 return opt_add_out (ctx, b, a, alpha, out);
11668 }
11769
118- auto selected_optimized_path = select_optimized_path (a, b, out);
119- if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d ) {
120- // Resize for dynamic shape
121- auto error = resize_tensor (out, a.sizes ());
122- ET_KERNEL_CHECK_MSG (
123- ctx,
124- error == Error::Ok,
125- InvalidArgument,
126- out,
127- " Failed to resize output tensor." );
128-
129- ET_SWITCH_REALB_TYPES (a_type, ctx, " add.out" , CTYPE, [&]() {
130- CTYPE alpha_val;
131- ET_KERNEL_CHECK (
132- ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
133-
134- using Vec = executorch::vec::Vectorized<CTYPE>;
135- executorch::vec::map2<CTYPE>(
136- [alpha_val](Vec x, Vec y) { return x + Vec (alpha_val) * y; },
137- out.mutable_data_ptr <CTYPE>(),
138- a.const_data_ptr <CTYPE>(),
139- b.const_data_ptr <CTYPE>(),
140- out.numel ());
141- });
142- } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
143- ScalarType out_type = out.scalar_type ();
144- ET_SWITCH_REALB_TYPES (out_type, ctx, " add.out" , CTYPE, [&]() {
145- CTYPE alpha_val;
146- ET_KERNEL_CHECK_MSG (
147- ctx,
148- utils::extract_scalar (alpha, &alpha_val),
149- InvalidArgument,
150- out,
151- " Failed to extract scalar alpha." );
152- using Vec = executorch::vec::Vectorized<CTYPE>;
153- Vec alpha_val_vec (alpha_val);
154- if (selected_optimized_path ==
155- ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
156- selected_optimized_path ==
157- ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
158- selected_optimized_path ==
159- ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ) {
160- // Reason we swap out args here is because handle_broadcast_elementwise
161- // handles this selected_optimized_path option a bit differently.
162- // This should really be resolved in handle_broadcast_elementwise.
163- // However, the current blocker is that handle_broadcast_elementwise
164- // tries to be agnostic of op. This should be fixed, likely by moving
165- // lambda creation to handle_broadcast_elementwise and it be aware of
166- // which op is being executed.
167- auto add_lambda = [&alpha_val_vec](auto x, auto y) {
168- return y + alpha_val_vec * x;
169- };
170- return torch::executor::handle_broadcast_elementwise<CTYPE>(
171- ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
172- } else {
173- auto add_lambda = [&alpha_val_vec](auto x, auto y) {
174- return x + alpha_val_vec * y;
175- };
176- return torch::executor::handle_broadcast_elementwise<CTYPE>(
177- ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
178- }
179- });
180- } else {
181- ScalarType common_type =
182- promoteTypes (a_type, b_type, /* half_to_float*/ true );
183- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
184-
185- ET_KERNEL_CHECK (
186- ctx,
187- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
188- InvalidArgument,
189- out);
190-
191- ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " add.out" , CTYPE_A, [&]() {
192- ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, " add.out" , CTYPE_B, [&]() {
193- using CTYPE_IN = typename torch::executor::
194- promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
195- ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
196- ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " add.out" , CTYPE_OUT, [&]() {
197- CTYPE_IN alpha_val;
198- ET_KERNEL_CHECK (
199- ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
200-
201- AddInner<
202- can_cast<CTYPE_IN, CTYPE_OUT>::value,
203- CTYPE_A,
204- CTYPE_B,
205- CTYPE_IN,
206- CTYPE_OUT>::run (a, b, alpha_val, out);
207- });
208- });
209- });
210- }
211-
212- return out;
70+ static constexpr const char op_name[] = " add.out" ;
71+ return torch::executor::kernels::impl::opt_add_sub_out_impl<false , op_name>(
72+ ctx, a, b, alpha, out);
21373}
21474
21575Tensor& opt_add_scalar_out (
0 commit comments