10
10
#include < ATen/cpu/vec/vec.h>
11
11
#include < executorch/kernels/optimized/cpu/binary_ops.h>
12
12
#include < executorch/kernels/portable/cpu/scalar_utils.h>
13
- #include < executorch/kernels/portable/cpu/util/broadcast_util.h>
13
+ #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
14
+ #include < executorch/kernels/portable/cpu/util/kernel_ops_util.h>
14
15
#include < executorch/runtime/kernel/kernel_includes.h>
15
16
#include < executorch/runtime/platform/assert.h>
16
17
@@ -31,6 +32,26 @@ Tensor& opt_add_out(
31
32
ScalarType a_type = a.scalar_type ();
32
33
ScalarType b_type = b.scalar_type ();
33
34
ScalarType out_type = out.scalar_type ();
35
+ ScalarType common_type = promoteTypes (a_type, b_type);
36
+
37
+ ET_KERNEL_CHECK (
38
+ ctx,
39
+ (canCast (common_type, out_type) &&
40
+ check_alpha_type (utils::get_scalar_dtype (alpha), common_type)),
41
+ InvalidArgument,
42
+ out);
43
+
44
+ ET_KERNEL_CHECK (
45
+ ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
46
+
47
+ ET_KERNEL_CHECK (
48
+ ctx,
49
+ resize_to_broadcast_target_size (a, b, out) == Error::Ok,
50
+ InvalidArgument,
51
+ out);
52
+
53
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
54
+ static constexpr const char op_name[] = " add.out" ;
34
55
35
56
if (b.numel () == 1 ) {
36
57
if (executorch::runtime::isComplexType (a_type) ||
@@ -40,13 +61,8 @@ Tensor& opt_add_out(
40
61
// output tensors have the same dtype. Support mixed dtypes in the future.
41
62
ET_KERNEL_CHECK (
42
63
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
43
- ET_KERNEL_CHECK (
44
- ctx,
45
- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
46
- InvalidArgument,
47
- out);
48
64
49
- ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " add.out " , CTYPE, [&]() {
65
+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
50
66
CTYPE alpha_val = utils::scalar_to<CTYPE>(alpha);
51
67
CTYPE b_val = *b.const_data_ptr <CTYPE>();
52
68
@@ -61,14 +77,8 @@ Tensor& opt_add_out(
61
77
} else if (
62
78
a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
63
79
a_type != ScalarType::BFloat16) {
64
- ET_KERNEL_CHECK (
65
- ctx,
66
- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
67
- InvalidArgument,
68
- out);
69
-
70
- ET_SWITCH_REALB_TYPES (a_type, ctx, " add.out" , CTYPE, [&]() {
71
- ET_SWITCH_REALB_TYPES (b_type, ctx, " add.out" , CTYPE_B, [&]() {
80
+ ET_SWITCH_REALB_TYPES (a_type, ctx, op_name, CTYPE, [&]() {
81
+ ET_SWITCH_REALB_TYPES (b_type, ctx, op_name, CTYPE_B, [&]() {
72
82
CTYPE alpha_val;
73
83
ET_KERNEL_CHECK (
74
84
ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
@@ -91,7 +101,6 @@ Tensor& opt_add_out(
91
101
return opt_add_out (ctx, b, a, alpha, out);
92
102
}
93
103
94
- static constexpr const char op_name[] = " add.out" ;
95
104
return torch::executor::kernels::impl::opt_add_sub_out_impl<false , op_name>(
96
105
ctx, a, b, alpha, out);
97
106
}
@@ -102,26 +111,29 @@ Tensor& opt_add_scalar_out(
102
111
const Scalar& b,
103
112
const Scalar& alpha,
104
113
Tensor& out) {
105
- (void )ctx;
106
-
107
114
ScalarType a_type = a.scalar_type ();
108
- ScalarType common_type =
109
- utils::promote_type_with_scalar (a_type, b, /* half_to_float*/ false );
115
+ ScalarType common_type = utils::promote_type_with_scalar (a_type, b);
110
116
ScalarType out_type = out.scalar_type ();
111
117
112
- ET_CHECK (common_type == out_type);
118
+ ET_KERNEL_CHECK (
119
+ ctx,
120
+ (common_type == a_type &&
121
+ check_alpha_type (utils::get_scalar_dtype (alpha), common_type)),
122
+ InvalidArgument,
123
+ out);
113
124
114
- if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
115
- common_type = ScalarType::Float;
116
- }
125
+ ET_KERNEL_CHECK (
126
+ ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
127
+
128
+ ET_KERNEL_CHECK (
129
+ ctx, resize_tensor (out, a.sizes ()) == Error::Ok, InvalidArgument, out);
117
130
118
- // Resize for dynamic shape
119
- auto error = resize_tensor (out, a.sizes ());
120
- ET_CHECK_MSG (error == Error::Ok, " Failed to resize output tensor." );
131
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
132
+ static constexpr const char op_name[] = " add.Scalar_out" ;
121
133
122
134
if (a_type == common_type && a_type == out_type &&
123
135
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
124
- ET_SWITCH_REALB_TYPES (a_type, ctx, " add.Scalar_out " , CTYPE, [&]() {
136
+ ET_SWITCH_REALB_TYPES (a_type, ctx, op_name , CTYPE, [&]() {
125
137
CTYPE b_casted = utils::scalar_to<CTYPE>(b);
126
138
CTYPE alpha_val;
127
139
ET_KERNEL_CHECK (
@@ -137,28 +149,28 @@ Tensor& opt_add_scalar_out(
137
149
out.numel ());
138
150
});
139
151
} else {
140
- ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " add.Scalar_out " , CTYPE_A, [&]() {
141
- ET_SWITCH_REALB_TYPES (
142
- common_type , ctx, " add.Scalar_out " , CTYPE_IN , [&]() {
143
- ET_SWITCH_REALHBBF16_TYPES (
144
- out_type, ctx, " add.Scalar_out " , CTYPE_OUT, [&]() {
145
- CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
146
- CTYPE_IN alpha_val ;
147
- ET_KERNEL_CHECK (
148
- ctx,
149
- utils::extract_scalar (alpha, &alpha_val) ,
150
- InvalidArgument, );
151
-
152
- const size_t n = a. numel ();
153
- const CTYPE_A* a_data = a. const_data_ptr <CTYPE_A>();
154
- CTYPE_OUT* out_data = out. mutable_data_ptr <CTYPE_OUT>();
155
- for ( auto i = 0 ; i < n; ++i) {
156
- out_data[i] = static_cast <CTYPE_OUT>(
157
- static_cast <CTYPE_IN>(a_data[i]) +
158
- alpha_val * b_casted);
159
- }
160
- });
161
- } );
152
+ ScalarType compute_type = utils::internal::get_compute_type (common_type);
153
+
154
+ ET_SWITCH_REALB_TYPES (compute_type , ctx, op_name, CTYPE_COMPUTE , [&]() {
155
+ CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
156
+ CTYPE_COMPUTE val_alpha;
157
+ ET_KERNEL_CHECK (
158
+ ctx, utils::extract_scalar (alpha, &val_alpha), InvalidArgument, ) ;
159
+ auto val_alpha_times_b = val_alpha * val_b;
160
+ utils::apply_unitensor_elementwise_fn<
161
+ CTYPE_COMPUTE ,
162
+ op_name,
163
+ utils::SupportedTensorDtypes::SAME_AS_COMMON>(
164
+ [val_alpha_times_b]( const auto val_a) {
165
+ // Cast here supports vectorization; either it does nothing
166
+ // or it casts from CTYPE_COMPUTE to
167
+ // Vectorized<CTYPE_COMPUTE>.
168
+ return val_a + decltype (val_a)(val_alpha_times_b);
169
+ },
170
+ ctx,
171
+ a,
172
+ utils::SupportedTensorDtypes::REALHBBF16,
173
+ out );
162
174
});
163
175
}
164
176
0 commit comments