@@ -35,21 +35,7 @@ Tensor& add_out(
3535 const Tensor& b,
3636 const Scalar& alpha,
3737 Tensor& out) {
38- // Common Dtype
39- ScalarType common_type =
40- executorch::runtime::promoteTypes (a.scalar_type (), b.scalar_type ());
41-
4238#ifdef OP_ARG_CHECK
43- // Check Common Dtype
44- ET_KERNEL_CHECK (
45- ctx,
46- (canCast (common_type, out.scalar_type ()) &&
47- torch::executor::check_alpha_type (
48- torch::executor::native::utils::get_scalar_dtype (alpha),
49- common_type)),
50- InvalidArgument,
51- out);
52-
5339 // Check Dim Order
5440 ET_KERNEL_CHECK (
5541 ctx,
@@ -65,10 +51,6 @@ Tensor& add_out(
6551 out);
6652#endif
6753
68- // Compute Dtype
69- ScalarType compute_type =
70- torch::executor::native::utils::get_compute_type (common_type);
71-
7254 static constexpr const char op_name[] = " add.out" ;
7355
7456 int kTensorDimensionLimit = 5 ;
@@ -77,12 +59,12 @@ Tensor& add_out(
7759 int inp2_shape[kTensorDimensionLimit ];
7860 int out_shape[kTensorDimensionLimit ];
7961
80- bool broadcast = 0 ;
62+ bool broadcast = false ;
8163
8264 int max_dim = a.dim () > b.dim () ? a.dim () : b.dim ();
8365 max_dim = out.dim () > max_dim ? out.dim () : max_dim;
8466
85- bool optimized = 1 ;
67+ bool optimized = true ;
8668
8769 /* Added change to work with input dimensions more than 5 */
8870 for (int i = 0 ; i < max_dim; i++) {
@@ -109,15 +91,19 @@ Tensor& add_out(
10991 for (int i = 0 ; i < out.dim (); i++) {
11092 if (((inp1_shape[i]) != (out_shape[i])) ||
11193 ((inp2_shape[i]) != (out_shape[i]))) {
112- broadcast = 1 ;
94+ broadcast = true ;
11395 }
11496 }
11597
116- if ((broadcast == 1 ) && (max_dim > kTensorDimensionLimit )) {
117- optimized = 0 ;
98+ if (((broadcast) && (max_dim > kTensorDimensionLimit )) ||
99+ (!(((a.scalar_type () == ScalarType::Int) ||
100+ (a.scalar_type () == ScalarType::Float)) &&
101+ (a.scalar_type () == b.scalar_type ()) &&
102+ (a.scalar_type () == out.scalar_type ())))) {
103+ optimized = false ;
118104 }
119105
120- if ((compute_type == ScalarType::Int) && (optimized)) {
106+ if ((a. scalar_type () == ScalarType::Int) && (optimized)) {
121107 const int * const inp1_data = a.const_data_ptr <int >();
122108 const int * const inp2_data = b.const_data_ptr <int >();
123109 int * const out_data = out.mutable_data_ptr <int >();
@@ -169,7 +155,7 @@ Tensor& add_out(
169155 alpha_val,
170156 out.numel ());
171157 }
172- } else if ((compute_type == ScalarType::Float) && (optimized)) {
158+ } else if ((a. scalar_type () == ScalarType::Float) && (optimized)) {
173159 const float * const inp1_data = a.const_data_ptr <float >();
174160 const float * const inp2_data = b.const_data_ptr <float >();
175161 float * const out_data = out.mutable_data_ptr <float >();
@@ -222,6 +208,23 @@ Tensor& add_out(
222208 out.numel ());
223209 }
224210 } else {
211+ // Common Dtype
212+ ScalarType common_type =
213+ executorch::runtime::promoteTypes (a.scalar_type (), b.scalar_type ());
214+ // Compute Dtype
215+ ScalarType compute_type =
216+ torch::executor::native::utils::get_compute_type (common_type);
217+
218+ // Check Common Dtype
219+ ET_KERNEL_CHECK (
220+ ctx,
221+ (canCast (common_type, out.scalar_type ()) &&
222+ torch::executor::check_alpha_type (
223+ torch::executor::native::utils::get_scalar_dtype (alpha),
224+ common_type)),
225+ InvalidArgument,
226+ out);
227+
225228 ET_SWITCH_REALB_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
226229 const CTYPE_COMPUTE val_alpha =
227230 torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha);
@@ -249,22 +252,7 @@ Tensor& add_scalar_out(
249252 const Scalar& b,
250253 const Scalar& alpha,
251254 Tensor& out) {
252- // Common Dtype
253- ScalarType common_type =
254- torch::executor::native::utils::promote_type_with_scalar (
255- a.scalar_type (), b);
256-
257255#ifdef OP_ARG_CHECK
258- // Check Common Dtype
259- ET_KERNEL_CHECK (
260- ctx,
261- (common_type == out.scalar_type () &&
262- torch::executor::check_alpha_type (
263- torch::executor::native::utils::get_scalar_dtype (alpha),
264- common_type)),
265- InvalidArgument,
266- out);
267-
268256 // Check Dim Order
269257 ET_KERNEL_CHECK (
270258 ctx,
@@ -279,14 +267,23 @@ Tensor& add_scalar_out(
279267 InvalidArgument,
280268 out);
281269#endif
282- // Compute Dtype
283- ScalarType compute_type =
284- torch::executor::native::utils::get_compute_type (common_type);
285270
286271 // @lint-ignore CLANGTIDY facebook-hte-CArray
287272 static constexpr const char op_name[] = " add.Scalar_out" ;
288273
289- if (compute_type == ScalarType::Int) {
274+ bool optimized = true ;
275+
276+ if (!(((a.scalar_type () == ScalarType::Int) ||
277+ (a.scalar_type () == ScalarType::Float)) &&
278+ (a.scalar_type () == out.scalar_type ()))) {
279+ optimized = false ;
280+ }
281+
282+ if ((b.isFloatingPoint ()) && (a.scalar_type () == ScalarType::Int)) {
283+ optimized = false ;
284+ }
285+
286+ if ((a.scalar_type () == ScalarType::Int) && (optimized)) {
290287 const int * const inp1_data = a.const_data_ptr <int >();
291288 int inp2_val;
292289 torch::executor::native::utils::extract_scalar (b, &inp2_val);
@@ -306,7 +303,7 @@ Tensor& add_scalar_out(
306303 alpha_val,
307304 out.numel ());
308305
309- } else if (compute_type == ScalarType::Float) {
306+ } else if ((a. scalar_type () == ScalarType::Float) && (optimized) ) {
310307 const float * const inp1_data = a.const_data_ptr <float >();
311308 float inp2_val;
312309 torch::executor::native::utils::extract_scalar (b, &inp2_val);
@@ -327,6 +324,24 @@ Tensor& add_scalar_out(
327324 out.numel ());
328325
329326 } else {
327+ // Common Dtype
328+ ScalarType common_type =
329+ torch::executor::native::utils::promote_type_with_scalar (
330+ a.scalar_type (), b);
331+ // Compute Dtype
332+ ScalarType compute_type =
333+ torch::executor::native::utils::get_compute_type (common_type);
334+
335+ // Check Common Dtype
336+ ET_KERNEL_CHECK (
337+ ctx,
338+ (common_type == out.scalar_type () &&
339+ torch::executor::check_alpha_type (
340+ torch::executor::native::utils::get_scalar_dtype (alpha),
341+ common_type)),
342+ InvalidArgument,
343+ out);
344+
330345 ET_SWITCH_REALB_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
331346 torch::executor::native::utils::
332347 apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
0 commit comments