@@ -34,13 +34,9 @@ Tensor& sub_out(
3434 const Tensor& b,
3535 const Scalar& alpha,
3636 Tensor& out) {
37- // Common Dtype
38- ScalarType common_type =
39- executorch::runtime::promoteTypes (a.scalar_type (), b.scalar_type ());
4037#ifdef OP_ARG_CHECK
4138 ScalarType alpha_type =
4239 torch::executor::native::utils::get_scalar_dtype (alpha);
43-
4440 // Check alpha type
4541 ET_KERNEL_CHECK (ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
4642
@@ -67,10 +63,6 @@ Tensor& sub_out(
6763 out);
6864#endif
6965
70- // Compute Dtype
71- ScalarType compute_type =
72- torch::executor::native::utils::get_compute_type (common_type);
73-
7466 // @lint-ignore CLANGTIDY facebook-hte-CArray
7567 static constexpr const char op_name[] = " sub.out" ;
7668
@@ -115,11 +107,15 @@ Tensor& sub_out(
115107 }
116108 }
117109
118- if ((broadcast == 1 ) && (max_dim > kTensorDimensionLimit )) {
110+ if (((broadcast == 1 ) && (max_dim > kTensorDimensionLimit )) ||
111+ (!(((a.scalar_type () == ScalarType::Int) ||
112+ (a.scalar_type () == ScalarType::Float)) &&
113+ (a.scalar_type () == b.scalar_type ()) &&
114+ (a.scalar_type () == out.scalar_type ())))) {
119115 optimized = 0 ;
120116 }
121117
122- if ((compute_type == ScalarType::Int) && (optimized)) {
118+ if ((a. scalar_type () == ScalarType::Int) && (optimized)) {
123119 const int * const inp1_data = a.const_data_ptr <int >();
124120 const int * const inp2_data = b.const_data_ptr <int >();
125121 int * const out_data = out.mutable_data_ptr <int >();
@@ -161,7 +157,7 @@ Tensor& sub_out(
161157 alpha_val,
162158 out.numel ());
163159 }
164- } else if ((compute_type == ScalarType::Float) && (optimized)) {
160+ } else if ((a. scalar_type () == ScalarType::Float) && (optimized)) {
165161 const float * const inp1_data = a.const_data_ptr <float >();
166162 const float * const inp2_data = b.const_data_ptr <float >();
167163 float * const out_data = out.mutable_data_ptr <float >();
@@ -204,6 +200,13 @@ Tensor& sub_out(
204200 out.numel ());
205201 }
206202 } else {
203+ // Common Dtype
204+ ScalarType common_type =
205+ executorch::runtime::promoteTypes (a.scalar_type (), b.scalar_type ());
206+ // Compute Dtype
207+ ScalarType compute_type =
208+ torch::executor::native::utils::get_compute_type (common_type);
209+
207210 ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
208211 const CTYPE_COMPUTE val_alpha =
209212 torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha);
@@ -232,14 +235,9 @@ Tensor& sub_scalar_out(
232235 const Scalar& b,
233236 const Scalar& alpha,
234237 Tensor& out) {
235- // Common Dtype
236- ScalarType common_type =
237- torch::executor::native::utils::promote_type_with_scalar (
238- a.scalar_type (), b);
239238#ifdef OP_ARG_CHECK
240239 ScalarType alpha_type =
241240 torch::executor::native::utils::get_scalar_dtype (alpha);
242-
243241 // Check alpha type
244242 ET_KERNEL_CHECK (ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
245243
@@ -265,14 +263,20 @@ Tensor& sub_scalar_out(
265263 out);
266264#endif
267265
268- // Compute Dtype
269- ScalarType compute_type =
270- torch::executor::native::utils::get_compute_type (common_type);
271-
272266 // @lint-ignore CLANGTIDY facebook-hte-CArray
273267 static constexpr const char op_name[] = " sub.Scalar_out" ;
274268
275- if (compute_type == ScalarType::Int) {
269+ bool optimized = 1 ;
270+ ScalarType b_type = torch::executor::native::utils::get_scalar_dtype (b);
271+
272+ if (!(((a.scalar_type () == ScalarType::Int) ||
273+ (a.scalar_type () == ScalarType::Float)) &&
274+ (a.scalar_type () == b_type) &&
275+ (a.scalar_type () == out.scalar_type ()))) {
276+ optimized = 0 ;
277+ }
278+
279+ if ((a.scalar_type () == ScalarType::Int) && (optimized)) {
276280 const int * const inp1_data = a.const_data_ptr <int >();
277281 int inp2_val;
278282 torch::executor::native::utils::extract_scalar (b, &inp2_val);
@@ -291,7 +295,7 @@ Tensor& sub_scalar_out(
291295 inp2_val,
292296 alpha_val,
293297 out.numel ());
294- } else if (compute_type == ScalarType::Float) {
298+ } else if ((a. scalar_type () == ScalarType::Float) && (optimized) ) {
295299 const float * const inp1_data = a.const_data_ptr <float >();
296300 float inp2_val;
297301 torch::executor::native::utils::extract_scalar (b, &inp2_val);
@@ -311,6 +315,13 @@ Tensor& sub_scalar_out(
311315 alpha_val,
312316 out.numel ());
313317 } else {
318+ // Common Dtype
319+ ScalarType common_type =
320+ torch::executor::native::utils::promote_type_with_scalar (
321+ a.scalar_type (), b);
322+ // Compute Dtype
323+ ScalarType compute_type =
324+ torch::executor::native::utils::get_compute_type (common_type);
314325 ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
315326 const CTYPE_COMPUTE val_b =
316327 torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(b);
0 commit comments