Skip to content

Commit 9edc3ed

Browse files
added conditions to call the kernels only when the input datatypes are of the same type. Else, the generic implementation will be called
1 parent 91f09aa commit 9edc3ed

File tree

1 file changed

+33
-24
lines changed

1 file changed

+33
-24
lines changed

backends/cadence/fusion_g3/operators/op_sub.cpp

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,9 @@ Tensor& sub_out(
3434
const Tensor& b,
3535
const Scalar& alpha,
3636
Tensor& out) {
37-
// Common Dtype
38-
ScalarType common_type =
39-
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
40-
#ifdef OP_ARG_CHECK
4137
ScalarType alpha_type =
4238
torch::executor::native::utils::get_scalar_dtype(alpha);
43-
39+
#ifdef OP_ARG_CHECK
4440
// Check alpha type
4541
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
4642

@@ -67,10 +63,6 @@ Tensor& sub_out(
6763
out);
6864
#endif
6965

70-
// Compute Dtype
71-
ScalarType compute_type =
72-
torch::executor::native::utils::get_compute_type(common_type);
73-
7466
// @lint-ignore CLANGTIDY facebook-hte-CArray
7567
static constexpr const char op_name[] = "sub.out";
7668

@@ -115,11 +107,14 @@ Tensor& sub_out(
115107
}
116108
}
117109

118-
if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
110+
if (((broadcast == 1) && (max_dim > kTensorDimensionLimit)) ||
111+
(!(((a.scalar_type() == ScalarType::Int) || (a.scalar_type() == ScalarType::Float)) &&
112+
(a.scalar_type() == b.scalar_type()) && (a.scalar_type() == out.scalar_type())
113+
&& (a.scalar_type() == alpha_type)))) {
119114
optimized = 0;
120115
}
121116

122-
if ((compute_type == ScalarType::Int) && (optimized)) {
117+
if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
123118
const int* const inp1_data = a.const_data_ptr<int>();
124119
const int* const inp2_data = b.const_data_ptr<int>();
125120
int* const out_data = out.mutable_data_ptr<int>();
@@ -161,7 +156,7 @@ Tensor& sub_out(
161156
alpha_val,
162157
out.numel());
163158
}
164-
} else if ((compute_type == ScalarType::Float) && (optimized)) {
159+
} else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
165160
const float* const inp1_data = a.const_data_ptr<float>();
166161
const float* const inp2_data = b.const_data_ptr<float>();
167162
float* const out_data = out.mutable_data_ptr<float>();
@@ -204,6 +199,13 @@ Tensor& sub_out(
204199
out.numel());
205200
}
206201
} else {
202+
// Common Dtype
203+
ScalarType common_type =
204+
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
205+
// Compute Dtype
206+
ScalarType compute_type =
207+
torch::executor::native::utils::get_compute_type(common_type);
208+
207209
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
208210
const CTYPE_COMPUTE val_alpha =
209211
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha);
@@ -232,14 +234,9 @@ Tensor& sub_scalar_out(
232234
const Scalar& b,
233235
const Scalar& alpha,
234236
Tensor& out) {
235-
// Common Dtype
236-
ScalarType common_type =
237-
torch::executor::native::utils::promote_type_with_scalar(
238-
a.scalar_type(), b);
239-
#ifdef OP_ARG_CHECK
240237
ScalarType alpha_type =
241238
torch::executor::native::utils::get_scalar_dtype(alpha);
242-
239+
#ifdef OP_ARG_CHECK
243240
// Check alpha type
244241
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
245242

@@ -265,14 +262,19 @@ Tensor& sub_scalar_out(
265262
out);
266263
#endif
267264

268-
// Compute Dtype
269-
ScalarType compute_type =
270-
torch::executor::native::utils::get_compute_type(common_type);
271-
272265
// @lint-ignore CLANGTIDY facebook-hte-CArray
273266
static constexpr const char op_name[] = "sub.Scalar_out";
274267

275-
if (compute_type == ScalarType::Int) {
268+
bool optimized = 1;
269+
ScalarType b_type = torch::executor::native::utils::get_scalar_dtype(b);
270+
271+
if (!(((a.scalar_type() == ScalarType::Int) || (a.scalar_type() == ScalarType::Float)) &&
272+
(a.scalar_type() == b_type) && (a.scalar_type() == out.scalar_type())
273+
&& (a.scalar_type() == alpha_type))) {
274+
optimized = 0;
275+
}
276+
277+
if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
276278
const int* const inp1_data = a.const_data_ptr<int>();
277279
int inp2_val;
278280
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -291,7 +293,7 @@ Tensor& sub_scalar_out(
291293
inp2_val,
292294
alpha_val,
293295
out.numel());
294-
} else if (compute_type == ScalarType::Float) {
296+
} else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
295297
const float* const inp1_data = a.const_data_ptr<float>();
296298
float inp2_val;
297299
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -311,6 +313,13 @@ Tensor& sub_scalar_out(
311313
alpha_val,
312314
out.numel());
313315
} else {
316+
// Common Dtype
317+
ScalarType common_type =
318+
torch::executor::native::utils::promote_type_with_scalar(
319+
a.scalar_type(), b);
320+
// Compute Dtype
321+
ScalarType compute_type =
322+
torch::executor::native::utils::get_compute_type(common_type);
314323
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
315324
const CTYPE_COMPUTE val_b =
316325
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(b);

0 commit comments

Comments
 (0)