@@ -34,13 +34,9 @@ Tensor& sub_out(
3434 const Tensor& b,
3535 const Scalar& alpha,
3636 Tensor& out) {
37- // Common Dtype
38- ScalarType common_type =
39- executorch::runtime::promoteTypes (a.scalar_type (), b.scalar_type ());
40- #ifdef OP_ARG_CHECK
4137 ScalarType alpha_type =
4238 torch::executor::native::utils::get_scalar_dtype (alpha);
43-
39+ # ifdef OP_ARG_CHECK
4440 // Check alpha type
4541 ET_KERNEL_CHECK (ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
4642
@@ -67,10 +63,6 @@ Tensor& sub_out(
6763 out);
6864#endif
6965
70- // Compute Dtype
71- ScalarType compute_type =
72- torch::executor::native::utils::get_compute_type (common_type);
73-
7466 // @lint-ignore CLANGTIDY facebook-hte-CArray
7567 static constexpr const char op_name[] = " sub.out" ;
7668
@@ -115,11 +107,14 @@ Tensor& sub_out(
115107 }
116108 }
117109
118- if ((broadcast == 1 ) && (max_dim > kTensorDimensionLimit )) {
110+ if (((broadcast == 1 ) && (max_dim > kTensorDimensionLimit )) ||
111+ (!(((a.scalar_type () == ScalarType::Int) || (a.scalar_type () == ScalarType::Float)) &&
112+ (a.scalar_type () == b.scalar_type ()) && (a.scalar_type () == out.scalar_type ())
113+ && (a.scalar_type () == alpha_type)))) {
119114 optimized = 0 ;
120115 }
121116
122- if ((compute_type == ScalarType::Int) && (optimized)) {
117+ if ((a. scalar_type () == ScalarType::Int) && (optimized)) {
123118 const int * const inp1_data = a.const_data_ptr <int >();
124119 const int * const inp2_data = b.const_data_ptr <int >();
125120 int * const out_data = out.mutable_data_ptr <int >();
@@ -161,7 +156,7 @@ Tensor& sub_out(
161156 alpha_val,
162157 out.numel ());
163158 }
164- } else if ((compute_type == ScalarType::Float) && (optimized)) {
159+ } else if ((a. scalar_type () == ScalarType::Float) && (optimized)) {
165160 const float * const inp1_data = a.const_data_ptr <float >();
166161 const float * const inp2_data = b.const_data_ptr <float >();
167162 float * const out_data = out.mutable_data_ptr <float >();
@@ -204,6 +199,13 @@ Tensor& sub_out(
204199 out.numel ());
205200 }
206201 } else {
202+ // Common Dtype
203+ ScalarType common_type =
204+ executorch::runtime::promoteTypes (a.scalar_type (), b.scalar_type ());
205+ // Compute Dtype
206+ ScalarType compute_type =
207+ torch::executor::native::utils::get_compute_type (common_type);
208+
207209 ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
208210 const CTYPE_COMPUTE val_alpha =
209211 torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha);
@@ -232,14 +234,9 @@ Tensor& sub_scalar_out(
232234 const Scalar& b,
233235 const Scalar& alpha,
234236 Tensor& out) {
235- // Common Dtype
236- ScalarType common_type =
237- torch::executor::native::utils::promote_type_with_scalar (
238- a.scalar_type (), b);
239- #ifdef OP_ARG_CHECK
240237 ScalarType alpha_type =
241238 torch::executor::native::utils::get_scalar_dtype (alpha);
242-
239+ # ifdef OP_ARG_CHECK
243240 // Check alpha type
244241 ET_KERNEL_CHECK (ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
245242
@@ -265,14 +262,19 @@ Tensor& sub_scalar_out(
265262 out);
266263#endif
267264
268- // Compute Dtype
269- ScalarType compute_type =
270- torch::executor::native::utils::get_compute_type (common_type);
271-
272265 // @lint-ignore CLANGTIDY facebook-hte-CArray
273266 static constexpr const char op_name[] = " sub.Scalar_out" ;
274267
275- if (compute_type == ScalarType::Int) {
268+ bool optimized = 1 ;
269+ ScalarType b_type = torch::executor::native::utils::get_scalar_dtype (b);
270+
271+ if (!(((a.scalar_type () == ScalarType::Int) || (a.scalar_type () == ScalarType::Float)) &&
272+ (a.scalar_type () == b_type) && (a.scalar_type () == out.scalar_type ())
273+ && (a.scalar_type () == alpha_type))) {
274+ optimized = 0 ;
275+ }
276+
277+ if ((a.scalar_type () == ScalarType::Int) && (optimized)) {
276278 const int * const inp1_data = a.const_data_ptr <int >();
277279 int inp2_val;
278280 torch::executor::native::utils::extract_scalar (b, &inp2_val);
@@ -291,7 +293,7 @@ Tensor& sub_scalar_out(
291293 inp2_val,
292294 alpha_val,
293295 out.numel ());
294- } else if (compute_type == ScalarType::Float) {
296+ } else if ((a. scalar_type () == ScalarType::Float) && (optimized) ) {
295297 const float * const inp1_data = a.const_data_ptr <float >();
296298 float inp2_val;
297299 torch::executor::native::utils::extract_scalar (b, &inp2_val);
@@ -311,6 +313,13 @@ Tensor& sub_scalar_out(
311313 alpha_val,
312314 out.numel ());
313315 } else {
316+ // Common Dtype
317+ ScalarType common_type =
318+ torch::executor::native::utils::promote_type_with_scalar (
319+ a.scalar_type (), b);
320+ // Compute Dtype
321+ ScalarType compute_type =
322+ torch::executor::native::utils::get_compute_type (common_type);
314323 ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
315324 const CTYPE_COMPUTE val_b =
316325 torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(b);
0 commit comments