@@ -34,9 +34,9 @@ Tensor& sub_out(
3434 const Tensor& b,
3535 const Scalar& alpha,
3636 Tensor& out) {
37+ #ifdef OP_ARG_CHECK
3738 ScalarType alpha_type =
3839 torch::executor::native::utils::get_scalar_dtype (alpha);
39- #ifdef OP_ARG_CHECK
4040 // Check alpha type
4141 ET_KERNEL_CHECK (ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
4242
@@ -108,9 +108,10 @@ Tensor& sub_out(
108108 }
109109
110110 if (((broadcast == 1 ) && (max_dim > kTensorDimensionLimit )) ||
111- (!(((a.scalar_type () == ScalarType::Int) || (a.scalar_type () == ScalarType::Float)) &&
112- (a.scalar_type () == b.scalar_type ()) && (a.scalar_type () == out.scalar_type ())
113- && (a.scalar_type () == alpha_type)))) {
111+ (!(((a.scalar_type () == ScalarType::Int) ||
112+ (a.scalar_type () == ScalarType::Float)) &&
113+ (a.scalar_type () == b.scalar_type ()) &&
114+ (a.scalar_type () == out.scalar_type ())))) {
114115 optimized = 0 ;
115116 }
116117
@@ -201,10 +202,10 @@ Tensor& sub_out(
201202 } else {
202203 // Common Dtype
203204 ScalarType common_type =
204- executorch::runtime::promoteTypes (a.scalar_type (), b.scalar_type ());
205+ executorch::runtime::promoteTypes (a.scalar_type (), b.scalar_type ());
205206 // Compute Dtype
206207 ScalarType compute_type =
207- torch::executor::native::utils::get_compute_type (common_type);
208+ torch::executor::native::utils::get_compute_type (common_type);
208209
209210 ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
210211 const CTYPE_COMPUTE val_alpha =
@@ -234,9 +235,9 @@ Tensor& sub_scalar_out(
234235 const Scalar& b,
235236 const Scalar& alpha,
236237 Tensor& out) {
238+ #ifdef OP_ARG_CHECK
237239 ScalarType alpha_type =
238240 torch::executor::native::utils::get_scalar_dtype (alpha);
239- #ifdef OP_ARG_CHECK
240241 // Check alpha type
241242 ET_KERNEL_CHECK (ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
242243
@@ -268,9 +269,10 @@ Tensor& sub_scalar_out(
268269 bool optimized = 1 ;
269270 ScalarType b_type = torch::executor::native::utils::get_scalar_dtype (b);
270271
271- if (!(((a.scalar_type () == ScalarType::Int) || (a.scalar_type () == ScalarType::Float)) &&
272- (a.scalar_type () == b_type) && (a.scalar_type () == out.scalar_type ())
273- && (a.scalar_type () == alpha_type))) {
272+ if (!(((a.scalar_type () == ScalarType::Int) ||
273+ (a.scalar_type () == ScalarType::Float)) &&
274+ (a.scalar_type () == b_type) &&
275+ (a.scalar_type () == out.scalar_type ()))) {
274276 optimized = 0 ;
275277 }
276278
@@ -315,11 +317,11 @@ Tensor& sub_scalar_out(
315317 } else {
316318 // Common Dtype
317319 ScalarType common_type =
318- torch::executor::native::utils::promote_type_with_scalar (
319- a.scalar_type (), b);
320+ torch::executor::native::utils::promote_type_with_scalar (
321+ a.scalar_type (), b);
320322 // Compute Dtype
321323 ScalarType compute_type =
322- torch::executor::native::utils::get_compute_type (common_type);
324+ torch::executor::native::utils::get_compute_type (common_type);
323325 ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
324326 const CTYPE_COMPUTE val_b =
325327 torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(b);
0 commit comments