1010#include < ATen/cpu/vec/vec.h>
1111#include < executorch/kernels/optimized/cpu/binary_ops.h>
1212#include < executorch/kernels/portable/cpu/scalar_utils.h>
13- #include < executorch/kernels/portable/cpu/util/broadcast_util.h>
13+ #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
14+ #include < executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1415#include < executorch/runtime/kernel/kernel_includes.h>
1516#include < executorch/runtime/platform/assert.h>
1617
@@ -31,6 +32,26 @@ Tensor& opt_add_out(
3132 ScalarType a_type = a.scalar_type ();
3233 ScalarType b_type = b.scalar_type ();
3334 ScalarType out_type = out.scalar_type ();
35+ ScalarType common_type = promoteTypes (a_type, b_type);
36+
37+ ET_KERNEL_CHECK (
38+ ctx,
39+ (canCast (common_type, out_type) &&
40+ check_alpha_type (utils::get_scalar_dtype (alpha), common_type)),
41+ InvalidArgument,
42+ out);
43+
44+ ET_KERNEL_CHECK (
45+ ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
46+
47+ ET_KERNEL_CHECK (
48+ ctx,
49+ resize_to_broadcast_target_size (a, b, out) == Error::Ok,
50+ InvalidArgument,
51+ out);
52+
53+ // @lint-ignore CLANGTIDY facebook-hte-CArray
54+ static constexpr const char op_name[] = " add.out" ;
3455
3556 if (b.numel () == 1 ) {
3657 if (executorch::runtime::isComplexType (a_type) ||
@@ -40,13 +61,8 @@ Tensor& opt_add_out(
4061 // output tensors have the same dtype. Support mixed dtypes in the future.
4162 ET_KERNEL_CHECK (
4263 ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
43- ET_KERNEL_CHECK (
44- ctx,
45- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
46- InvalidArgument,
47- out);
4864
49- ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " add.out " , CTYPE, [&]() {
65+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
5066 CTYPE alpha_val = utils::scalar_to<CTYPE>(alpha);
5167 CTYPE b_val = *b.const_data_ptr <CTYPE>();
5268
@@ -61,14 +77,8 @@ Tensor& opt_add_out(
6177 } else if (
6278 a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
6379 a_type != ScalarType::BFloat16) {
64- ET_KERNEL_CHECK (
65- ctx,
66- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
67- InvalidArgument,
68- out);
69-
70- ET_SWITCH_REALB_TYPES (a_type, ctx, " add.out" , CTYPE, [&]() {
71- ET_SWITCH_REALB_TYPES (b_type, ctx, " add.out" , CTYPE_B, [&]() {
80+ ET_SWITCH_REALB_TYPES (a_type, ctx, op_name, CTYPE, [&]() {
81+ ET_SWITCH_REALB_TYPES (b_type, ctx, op_name, CTYPE_B, [&]() {
7282 CTYPE alpha_val;
7383 ET_KERNEL_CHECK (
7484 ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
@@ -91,7 +101,6 @@ Tensor& opt_add_out(
91101 return opt_add_out (ctx, b, a, alpha, out);
92102 }
93103
94- static constexpr const char op_name[] = " add.out" ;
95104 return torch::executor::kernels::impl::opt_add_sub_out_impl<false , op_name>(
96105 ctx, a, b, alpha, out);
97106}
@@ -102,26 +111,29 @@ Tensor& opt_add_scalar_out(
102111 const Scalar& b,
103112 const Scalar& alpha,
104113 Tensor& out) {
105- (void )ctx;
106-
107114 ScalarType a_type = a.scalar_type ();
108- ScalarType common_type =
109- utils::promote_type_with_scalar (a_type, b, /* half_to_float*/ false );
115+ ScalarType common_type = utils::promote_type_with_scalar (a_type, b);
110116 ScalarType out_type = out.scalar_type ();
111117
112- ET_CHECK (common_type == out_type);
118+ ET_KERNEL_CHECK (
119+ ctx,
120+ (common_type == a_type &&
121+ check_alpha_type (utils::get_scalar_dtype (alpha), common_type)),
122+ InvalidArgument,
123+ out);
113124
114- if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
115- common_type = ScalarType::Float;
116- }
125+ ET_KERNEL_CHECK (
126+ ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
127+
128+ ET_KERNEL_CHECK (
129+ ctx, resize_tensor (out, a.sizes ()) == Error::Ok, InvalidArgument, out);
117130
118- // Resize for dynamic shape
119- auto error = resize_tensor (out, a.sizes ());
120- ET_CHECK_MSG (error == Error::Ok, " Failed to resize output tensor." );
131+ // @lint-ignore CLANGTIDY facebook-hte-CArray
132+ static constexpr const char op_name[] = " add.Scalar_out" ;
121133
122134 if (a_type == common_type && a_type == out_type &&
123135 a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
124- ET_SWITCH_REALB_TYPES (a_type, ctx, " add.Scalar_out " , CTYPE, [&]() {
136+ ET_SWITCH_REALB_TYPES (a_type, ctx, op_name , CTYPE, [&]() {
125137 CTYPE b_casted = utils::scalar_to<CTYPE>(b);
126138 CTYPE alpha_val;
127139 ET_KERNEL_CHECK (
@@ -137,28 +149,28 @@ Tensor& opt_add_scalar_out(
137149 out.numel ());
138150 });
139151 } else {
140- ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " add.Scalar_out " , CTYPE_A, [&]() {
141- ET_SWITCH_REALB_TYPES (
142- common_type , ctx, " add.Scalar_out " , CTYPE_IN , [&]() {
143- ET_SWITCH_REALHBBF16_TYPES (
144- out_type, ctx, " add.Scalar_out " , CTYPE_OUT, [&]() {
145- CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
146- CTYPE_IN alpha_val ;
147- ET_KERNEL_CHECK (
148- ctx,
149- utils::extract_scalar (alpha, &alpha_val) ,
150- InvalidArgument, );
151-
152- const size_t n = a. numel ();
153- const CTYPE_A* a_data = a. const_data_ptr <CTYPE_A>();
154- CTYPE_OUT* out_data = out. mutable_data_ptr <CTYPE_OUT>();
155- for ( auto i = 0 ; i < n; ++i) {
156- out_data[i] = static_cast <CTYPE_OUT>(
157- static_cast <CTYPE_IN>(a_data[i]) +
158- alpha_val * b_casted);
159- }
160- });
161- } );
152+ ScalarType compute_type = utils::internal::get_compute_type (common_type);
153+
154+ ET_SWITCH_REALB_TYPES (compute_type , ctx, op_name, CTYPE_COMPUTE , [&]() {
155+ CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
156+ CTYPE_COMPUTE val_alpha;
157+ ET_KERNEL_CHECK (
158+ ctx, utils::extract_scalar (alpha, &val_alpha), InvalidArgument, ) ;
159+ auto val_alpha_times_b = val_alpha * val_b;
160+ utils::apply_unitensor_elementwise_fn<
161+ CTYPE_COMPUTE ,
162+ op_name,
163+ utils::SupportedTensorDtypes::SAME_AS_COMMON>(
164+ [val_alpha_times_b]( const auto val_a) {
165+ // Cast here supports vectorization; either it does nothing
166+ // or it casts from CTYPE_COMPUTE to
167+ // Vectorized<CTYPE_COMPUTE>.
168+ return val_a + decltype (val_a)(val_alpha_times_b);
169+ },
170+ ctx,
171+ a,
172+ utils::SupportedTensorDtypes::REALHBBF16,
173+ out );
162174 });
163175 }
164176
0 commit comments