1414#include < executorch/runtime/kernel/kernel_includes.h>
1515#include < executorch/runtime/platform/assert.h>
1616
17+ #include < executorch/kernels/optimized/cpu/op_add_sub_impl.h>
18+
1719namespace torch {
1820namespace executor {
1921namespace native {
20- namespace {
21-
22- template <
23- bool can_cast,
24- typename CTYPE_A,
25- typename CTYPE_B,
26- typename CTYPE_IN,
27- typename CTYPE_OUT>
28- struct AddInner ;
29-
30- template <
31- typename CTYPE_A,
32- typename CTYPE_B,
33- typename CTYPE_IN,
34- typename CTYPE_OUT>
35- struct AddInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
36- static void
37- run (const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
38- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39- // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40- [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
41- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
42- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
43- CTYPE_IN value = a_casted + alpha_val * b_casted;
44-
45- return static_cast <CTYPE_OUT>(value);
46- },
47- a,
48- b,
49- out);
50- }
51- };
52-
53- template <typename CTYPE_IN>
54- struct ReportCanCastBug {
55- static void run (const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
56- ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
57- }
58- };
59-
60- template <
61- typename CTYPE_A,
62- typename CTYPE_B,
63- typename CTYPE_IN,
64- typename CTYPE_OUT>
65- struct AddInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66- : public ReportCanCastBug<CTYPE_IN> {};
67-
68- } // namespace
69-
7022using Tensor = executorch::aten::Tensor;
7123using ScalarType = executorch::aten::ScalarType;
7224
@@ -76,8 +28,6 @@ Tensor& opt_add_out(
7628 const Tensor& b,
7729 const Scalar& alpha,
7830 Tensor& out) {
79- (void )ctx;
80-
8131 ScalarType a_type = a.scalar_type ();
8232 ScalarType b_type = b.scalar_type ();
8333 ScalarType out_type = out.scalar_type ();
@@ -95,7 +45,9 @@ Tensor& opt_add_out(
9545 ET_SWITCH_REALB_TYPES (b_type, ctx, " add.out" , CTYPE_B, [&]() {
9646 CTYPE alpha_val;
9747 ET_KERNEL_CHECK (
98- ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
48+ ctx,
49+ torch::executor::native::utils::extract_scalar (alpha, &alpha_val),
50+ InvalidArgument, );
9951 CTYPE_B b_val = *b.const_data_ptr <CTYPE_B>();
10052 CTYPE b_casted = static_cast <CTYPE>(b_val);
10153
@@ -115,100 +67,9 @@ Tensor& opt_add_out(
11567 return opt_add_out (ctx, b, a, alpha, out);
11668 }
11769
118- auto selected_optimized_path = select_optimized_path (a, b, out);
119- if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d ) {
120- // Resize for dynamic shape
121- auto error = resize_tensor (out, a.sizes ());
122- ET_KERNEL_CHECK_MSG (
123- ctx,
124- error == Error::Ok,
125- InvalidArgument,
126- out,
127- " Failed to resize output tensor." );
128-
129- ET_SWITCH_REALB_TYPES (a_type, ctx, " add.out" , CTYPE, [&]() {
130- CTYPE alpha_val;
131- ET_KERNEL_CHECK (
132- ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
133-
134- using Vec = executorch::vec::Vectorized<CTYPE>;
135- executorch::vec::map2<CTYPE>(
136- [alpha_val](Vec x, Vec y) { return x + Vec (alpha_val) * y; },
137- out.mutable_data_ptr <CTYPE>(),
138- a.const_data_ptr <CTYPE>(),
139- b.const_data_ptr <CTYPE>(),
140- out.numel ());
141- });
142- } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
143- ET_SWITCH_REALB_TYPES (out_type, ctx, " add.out" , CTYPE, [&]() {
144- CTYPE alpha_val;
145- ET_KERNEL_CHECK_MSG (
146- ctx,
147- utils::extract_scalar (alpha, &alpha_val),
148- InvalidArgument,
149- out,
150- " Failed to extract scalar alpha." );
151- using Vec = executorch::vec::Vectorized<CTYPE>;
152- Vec alpha_val_vec (alpha_val);
153- if (selected_optimized_path ==
154- ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
155- selected_optimized_path ==
156- ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
157- selected_optimized_path ==
158- ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ) {
159- // Reason we swap out args here is because handle_broadcast_elementwise
160- // handles this selected_optimized_path option a bit differently.
161- // This should really be resolved in handle_broadcast_elementwise.
162- // However, the current blocker is that handle_broadcast_elementwise
163- // tries to be agnostic of op. This should be fixed, likely by moving
164- // lambda creation to handle_broadcast_elementwise and it be aware of
165- // which op is being executed.
166- auto add_lambda = [&alpha_val_vec](auto x, auto y) {
167- return y + alpha_val_vec * x;
168- };
169- return torch::executor::handle_broadcast_elementwise<CTYPE>(
170- ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
171- } else {
172- auto add_lambda = [&alpha_val_vec](auto x, auto y) {
173- return x + alpha_val_vec * y;
174- };
175- return torch::executor::handle_broadcast_elementwise<CTYPE>(
176- ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
177- }
178- });
179- } else {
180- ScalarType common_type =
181- promoteTypes (a_type, b_type, /* half_to_float*/ true );
182- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
183-
184- ET_KERNEL_CHECK (
185- ctx,
186- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
187- InvalidArgument,
188- out);
189-
190- ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " add.out" , CTYPE_A, [&]() {
191- ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, " add.out" , CTYPE_B, [&]() {
192- using CTYPE_IN = typename torch::executor::
193- promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
194- ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
195- ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " add.out" , CTYPE_OUT, [&]() {
196- CTYPE_IN alpha_val;
197- ET_KERNEL_CHECK (
198- ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
199-
200- AddInner<
201- can_cast<CTYPE_IN, CTYPE_OUT>::value,
202- CTYPE_A,
203- CTYPE_B,
204- CTYPE_IN,
205- CTYPE_OUT>::run (a, b, alpha_val, out);
206- });
207- });
208- });
209- }
210-
211- return out;
70+ static constexpr const char op_name[] = " add.out" ;
71+ return torch::executor::kernels::impl::opt_add_sub_out_impl<false , op_name>(
72+ ctx, a, b, alpha, out);
21273}
21374
21475Tensor& opt_add_scalar_out (
0 commit comments