99#include < executorch/kernels/portable/cpu/scalar_utils.h>
1010#include < executorch/kernels/portable/cpu/util/broadcast_util.h>
1111#include < executorch/kernels/portable/cpu/util/functional_util.h>
12+ #include < executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1213#include < executorch/runtime/kernel/kernel_includes.h>
1314#include < executorch/runtime/platform/assert.h>
14- #include " kernels.h "
15+ #include < executorch/backends/cadence/hifi/ kernels/kernels.h >
1516
1617namespace torch {
1718namespace executor {
1819namespace native {
20+ namespace {
1921
20- #define NNLIB_MAX_DIM 4 /* Add fallback if broadcast and dim > 4 */
22+ template <
23+ bool can_cast,
24+ typename CTYPE_A,
25+ typename CTYPE_B,
26+ typename CTYPE_IN,
27+ typename CTYPE_OUT>
28+ struct AddInner ;
29+
30+ template <
31+ typename CTYPE_A,
32+ typename CTYPE_B,
33+ typename CTYPE_IN,
34+ typename CTYPE_OUT>
35+ struct AddInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
36+ static void
37+ run (const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
38+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39+ // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40+ [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
41+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
42+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
43+ CTYPE_IN value = a_casted + alpha_val * b_casted;
44+
45+ return static_cast <CTYPE_OUT>(value);
46+ },
47+ a,
48+ b,
49+ out);
50+ }
51+ };
52+
53+ template <typename CTYPE_IN>
54+ struct ReportCanCastBug {
55+ static void run (const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
56+ ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
57+ }
58+ };
59+
60+ template <
61+ typename CTYPE_A,
62+ typename CTYPE_B,
63+ typename CTYPE_IN,
64+ typename CTYPE_OUT>
65+ struct AddInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66+ : public ReportCanCastBug<CTYPE_IN> {};
67+
68+ } // namespace
2169
2270Tensor& add_out (
23- RuntimeContext & ctx,
71+ KernelRuntimeContext & ctx,
2472 const Tensor& a,
2573 const Tensor& b,
2674 const Scalar& alpha,
2775 Tensor& out) {
28- (void )ctx;
76+ ET_KERNEL_CHECK (
77+ ctx,
78+ resize_to_broadcast_target_size (a, b, out) == Error::Ok,
79+ InvalidArgument,
80+ out);
81+
82+ ET_KERNEL_CHECK (
83+ ctx,
84+ executorch::runtime::tensor_is_realhbbf16_type (out),
85+ InvalidArgument,
86+ out);
87+ ET_KERNEL_CHECK (
88+ ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
2989
3090 ScalarType a_type = a.scalar_type ();
3191 ScalarType b_type = b.scalar_type ();
32- ScalarType common_type = promoteTypes (a_type, b_type);
92+ ScalarType alpha_type = utils::get_scalar_dtype (alpha);
93+ ScalarType common_type = promoteTypes (a_type, b_type, /* half_to_float*/ true );
3394 ScalarType out_type = out.scalar_type ();
3495
35- ET_CHECK_MSG (a_type == ScalarType::Float, " Input tensor not a float.\n " );
36- ET_CHECK_MSG (b_type == ScalarType::Float, " Input tensor not a float.\n " );
37- ET_CHECK_MSG (out_type == ScalarType::Float, " Output tensor not a float.\n " );
38-
39- ET_CHECK (canCast (common_type, out_type));
40-
41- using CTYPE_A = float ;
42- using CTYPE_B = float ;
43- using CTYPE_IN = float ;
44- using CTYPE_OUT = float ;
45- CTYPE_IN alpha_val;
46- ET_EXTRACT_SCALAR (alpha, alpha_val);
96+ ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
97+ ET_KERNEL_CHECK (
98+ ctx, check_alpha_type (alpha_type, common_type), InvalidArgument, out);
99+
100+ float alpha_val;
101+ utils::extract_scalar (alpha, &alpha_val);
47102
103+ constexpr auto name = " add.out" ;
104+ constexpr int kNnlibMaxDim = 4 ; /* fallback if broadcast and dim > 4 */
105+
48106 int a_dim = a.dim (), b_dim = b.dim (), out_dim = out.dim ();
49- int fall_back = 0 ;
107+ bool optimized = 1 ;
50108 /* find broadcast*/
51- const int a_is_broadcasted = !out.sizes ().equals (a.sizes ());
52- const int b_is_broadcasted = !out.sizes ().equals (b.sizes ());
53- const int broadcast = (a_is_broadcasted || b_is_broadcasted);
109+ const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
110+ const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
111+ const bool broadcast = (a_is_broadcasted || b_is_broadcasted);
54112 int max_dim = a.dim () > b.dim () ? a.dim () : b.dim ();
55113 max_dim = out.dim () > max_dim ? out.dim () : max_dim;
56114
57- if ( (out_type != ScalarType::Float) || (alpha_val != 1.0 ))
58- fall_back = 1 ;
115+ if ((out_type != ScalarType::Float) || (alpha_val != 1.0 ))
116+ optimized = 0 ;
59117
60- if ( (a_dim == 0 ) || (b_dim == 0 ) )
61- fall_back = 1 ;
118+ if ((a_dim == 0 ) || (b_dim == 0 ) )
119+ optimized = 0 ;
62120
63- if ((broadcast == 1 ) && (max_dim > NNLIB_MAX_DIM ))
64- fall_back = 1 ;
121+ if ((broadcast == 1 ) && (max_dim > kNnlibMaxDim ))
122+ optimized = 0 ;
65123
66124
67- if (!fall_back )
125+ if (optimized )
68126 {
69127 const float * const a_data = a.const_data_ptr <float >();
70128 const float * const b_data = b.const_data_ptr <float >();
71129 float * const out_data = out.mutable_data_ptr <float >();
72130 if (broadcast == 1 )
73131 {
74- int out_shape[NNLIB_MAX_DIM ];
75- int inp1_shape[NNLIB_MAX_DIM ];
76- int inp2_shape[NNLIB_MAX_DIM ];
132+ int out_shape[kNnlibMaxDim ];
133+ int inp1_shape[kNnlibMaxDim ];
134+ int inp2_shape[kNnlibMaxDim ];
77135
78- for (int i = 0 ; i < NNLIB_MAX_DIM ; i++)
136+ for (int i = 0 ; i < kNnlibMaxDim ; i++)
79137 {
80138 out_shape[i] = 1 ;
81139 inp1_shape[i] = 1 ;
82140 inp2_shape[i] = 1 ;
83141 }
84142
85- int off_o = NNLIB_MAX_DIM - out.dim ();
86- int off_a = NNLIB_MAX_DIM - a.dim ();
87- int off_b = NNLIB_MAX_DIM - b.dim ();
143+ int off_o = kNnlibMaxDim - out.dim ();
144+ int off_a = kNnlibMaxDim - a.dim ();
145+ int off_b = kNnlibMaxDim - b.dim ();
88146
89147 for (int i = 0 ; i < out.dim (); i++)
90148 out_shape[i+off_o] = out.size (i);
@@ -97,24 +155,109 @@ Tensor& add_out(
97155 b_data, inp2_shape);
98156 }
99157 else
158+ {
100159 xa_nn_elm_add_f32xf32_f32 (out_data, a_data, b_data, out.numel ());
101-
160+ }
161+
162+ return out;
102163 }
103- else
104- {
105- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
106- [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
107- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
108- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
109- CTYPE_IN value = a_casted + alpha_val * b_casted;
110-
111- return static_cast <CTYPE_OUT>(value);
112- },
113- a,
114- b,
164+
165+ ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, name, CTYPE_A, [&]() {
166+ ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
167+ using CTYPE_IN = typename torch::executor::
168+ promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
169+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
170+ CTYPE_IN alpha_val;
171+ utils::extract_scalar (alpha, &alpha_val);
172+
173+ ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, name, CTYPE_OUT, [&]() {
174+ AddInner<
175+ can_cast<CTYPE_IN, CTYPE_OUT>::value,
176+ CTYPE_A,
177+ CTYPE_B,
178+ CTYPE_IN,
179+ CTYPE_OUT>::run (a, b, alpha_val, out);
180+ });
181+ });
182+ });
183+
184+ return out;
185+ }
186+
187+ Tensor& add_scalar_out (
188+ KernelRuntimeContext& ctx,
189+ const Tensor& a,
190+ const Scalar& b,
191+ const Scalar& alpha,
192+ Tensor& out) {
193+
194+ // Resize for dynamic shape
195+ ET_KERNEL_CHECK_MSG (
196+ ctx,
197+ resize_tensor (out, a.sizes ()) == Error::Ok,
198+ InvalidArgument,
199+ out,
200+ " Failed to resize output tensor." );
201+
202+ ET_KERNEL_CHECK (
203+ ctx,
204+ executorch::runtime::tensor_is_realhbbf16_type (out),
205+ InvalidArgument,
115206 out);
207+ ET_KERNEL_CHECK (
208+ ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
209+
210+ ScalarType a_type = a.scalar_type ();
211+ ScalarType b_type = utils::get_scalar_dtype (b);
212+ ScalarType alpha_type = utils::get_scalar_dtype (alpha);
213+ ScalarType common_type =
214+ utils::promote_type_with_scalar (a_type, b, /* half_to_float*/ false );
215+ ScalarType out_type = out.scalar_type ();
216+
217+ ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
218+ ET_KERNEL_CHECK (
219+ ctx, check_alpha_type (alpha_type, common_type), InvalidArgument, out);
220+
221+ /* When Half first compute the result in float precision
222+ and then downcast to half*/
223+ if (common_type == ScalarType::Half) {
224+ common_type = ScalarType::Float;
116225 }
117226
227+ constexpr auto name = " add.Scalar_out" ;
228+
229+ ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, name, CTYPE_A, [&]() {
230+ ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
231+ using CTYPE_IN = typename utils::promote_type_with_scalar_type<
232+ CTYPE_A,
233+ CTYPE_B,
234+ /* half_to_float*/ true >::type;
235+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
236+
237+ CTYPE_B b_val;
238+ utils::extract_scalar (b, &b_val);
239+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
240+
241+ CTYPE_IN alpha_val;
242+ utils::extract_scalar (alpha, &alpha_val);
243+
244+ using CTYPE_OUT = typename std::conditional<
245+ std::is_same<CTYPE_A, internal::F2>::value,
246+ internal::F2,
247+ CTYPE_IN>::type;
248+
249+ apply_unary_map_fn (
250+ [b_casted, alpha_val](const CTYPE_A val_a) {
251+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
252+ CTYPE_IN value = a_casted + alpha_val * b_casted;
253+ return static_cast <CTYPE_OUT>(value);
254+ },
255+ a.const_data_ptr <CTYPE_A>(),
256+ out.mutable_data_ptr <CTYPE_OUT>(),
257+ out.numel ());
258+ });
259+ });
260+
118261 return out;
119262}
120263
0 commit comments