Skip to content

Commit 13a1a30

Browse files
Removed unnecessary conditions in op_add and op_mul (#7135)
* Removed unnecessary conditions in op_add and op_mul Added scalar function call in op_add and op_mul Updated undefining of macros in op_quantize Updated elseif() instead of if() in op_cat Signed-off-by: [email protected] <[email protected]> * Removed linter errors Signed-off-by: [email protected] <[email protected]> --------- Signed-off-by: [email protected] <[email protected]> Co-authored-by: [email protected] <[email protected]>
1 parent 7eab0cc commit 13a1a30

File tree

4 files changed

+68
-52
lines changed

4 files changed

+68
-52
lines changed

backends/cadence/fusion_g3/operators/op_add.cpp

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,34 +66,20 @@ Tensor& add_out(
6666
// @lint-ignore CLANGTIDY facebook-hte-CArray
6767
static constexpr const char op_name[] = "add.out";
6868

69-
const exec_aten::ArrayRef<Tensor::SizesType> a_size = a.sizes();
70-
const exec_aten::ArrayRef<Tensor::SizesType> b_size = b.sizes();
71-
const exec_aten::ArrayRef<Tensor::SizesType> out_size = out.sizes();
72-
7369
int kTensorDimensionLimit = 5;
7470

7571
int inp1_shape[kTensorDimensionLimit];
7672
int inp2_shape[kTensorDimensionLimit];
7773
int out_shape[kTensorDimensionLimit];
7874

79-
/*find broadcast*/
80-
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
81-
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
82-
const bool broadcast = (a_is_broadcasted || b_is_broadcasted);
75+
bool broadcast = 0;
8376

8477
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
8578
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
8679

8780
bool optimized = 1;
8881

89-
if ((a.dim() == 0) || (b.dim() == 0)) {
90-
optimized = 0;
91-
}
92-
93-
if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
94-
optimized = 0;
95-
}
96-
82+
/* Added change to work with input dimensions more than 5 */
9783
for (int i = 0; i < max_dim; i++) {
9884
out_shape[i] = 1;
9985
inp1_shape[i] = 1;
@@ -114,14 +100,33 @@ Tensor& add_out(
114100
inp2_shape[i + offset_inp2] = b.size(i);
115101
}
116102

103+
/*find broadcast*/
104+
for (int i = 0; i < out.dim(); i++) {
105+
if (((inp1_shape[i]) != (out_shape[i])) ||
106+
((inp2_shape[i]) != (out_shape[i]))) {
107+
broadcast = 1;
108+
}
109+
}
110+
111+
if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
112+
optimized = 0;
113+
}
114+
117115
if ((compute_type == ScalarType::Int) && (optimized)) {
118116
const int* const inp1_data = a.const_data_ptr<int>();
119117
const int* const inp2_data = b.const_data_ptr<int>();
120118
int* const out_data = out.mutable_data_ptr<int>();
121119

122120
int alpha_val;
123121
torch::executor::native::utils::extract_scalar(alpha, &alpha_val);
124-
if (broadcast) {
122+
123+
if ((a.numel() == 1) && (alpha_val == 1)) {
124+
xa_nn_elm_add_scalar_32x32_32(
125+
out_data, inp2_data, inp1_data[0], alpha_val, out.numel());
126+
} else if (b.numel() == 1) {
127+
xa_nn_elm_add_scalar_32x32_32(
128+
out_data, inp1_data, inp2_data[0], alpha_val, out.numel());
129+
} else if (broadcast) {
125130
xa_nn_elm_add_broadcast_5D_32x32_32(
126131
out_data,
127132
out_shape,
@@ -143,7 +148,13 @@ Tensor& add_out(
143148
float alpha_val;
144149
torch::executor::native::utils::extract_scalar(alpha, &alpha_val);
145150

146-
if (broadcast) {
151+
if ((a.numel() == 1) && (alpha_val == 1.0)) {
152+
xa_nn_elm_add_scalar_f32xf32_f32(
153+
out_data, inp2_data, inp1_data[0], alpha_val, out.numel());
154+
} else if (b.numel() == 1) {
155+
xa_nn_elm_add_scalar_f32xf32_f32(
156+
out_data, inp1_data, inp2_data[0], alpha_val, out.numel());
157+
} else if (broadcast) {
147158
xa_nn_elm_add_broadcast_5D_f32xf32_f32(
148159
out_data,
149160
out_shape,
@@ -176,7 +187,6 @@ Tensor& add_out(
176187
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16);
177188
});
178189
}
179-
180190
return out;
181191
}
182192

@@ -234,6 +244,7 @@ Tensor& add_scalar_out(
234244

235245
xa_nn_elm_add_scalar_32x32_32(
236246
out_data, inp1_data, inp2_val, alpha_val, out.numel());
247+
237248
} else if (compute_type == ScalarType::Float) {
238249
const float* const inp1_data = a.const_data_ptr<float>();
239250
float inp2_val;
@@ -246,6 +257,7 @@ Tensor& add_scalar_out(
246257

247258
xa_nn_elm_add_scalar_f32xf32_f32(
248259
out_data, inp1_data, inp2_val, alpha_val, out.numel());
260+
249261
} else {
250262
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
251263
torch::executor::native::utils::
@@ -266,6 +278,7 @@ Tensor& add_scalar_out(
266278
SAME_AS_COMMON);
267279
});
268280
}
281+
269282
return out;
270283
}
271284

backends/cadence/fusion_g3/operators/op_cat.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@ using torch::executor::KernelRuntimeContext;
2222
* updated to have support for below data types, these can be removed and
2323
* operator need to be updated accordingly
2424
*/
25-
enum datatype {
26-
Ushort = 20,
27-
Uint = 23,
28-
};
25+
enum datatype { Ushort = 20, Uint = 23 };
2926

3027
namespace cadence {
3128
namespace impl {
@@ -118,8 +115,7 @@ Tensor& cat_out(
118115
tensors.size(),
119116
(int)dim,
120117
sizeof(char));
121-
}
122-
if (out.scalar_type() == (ScalarType)Uint) {
118+
} else if (out.scalar_type() == (ScalarType)Uint) {
123119
xa_nn_cat(
124120
out_data,
125121
out_shapes,
@@ -164,7 +160,6 @@ Tensor& cat_out(
164160
if (all_1d_empty) {
165161
return out;
166162
}
167-
168163
const size_t outer = executorch::runtime::getLeadingDims(out, dim);
169164
const size_t dim_stride = executorch::runtime::getTrailingDims(out, dim);
170165
const size_t ninputs = tensors.size();

backends/cadence/fusion_g3/operators/op_mul.cpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,34 +58,20 @@ Tensor& mul_out(
5858
// @lint-ignore CLANGTIDY facebook-hte-CArray
5959
static constexpr const char op_name[] = "mul.out";
6060

61-
const exec_aten::ArrayRef<Tensor::SizesType> a_size = a.sizes();
62-
const exec_aten::ArrayRef<Tensor::SizesType> b_size = b.sizes();
63-
const exec_aten::ArrayRef<Tensor::SizesType> out_size = out.sizes();
64-
6561
int kTensorDimensionLimit = 5;
6662

6763
int inp1_shape[kTensorDimensionLimit];
6864
int inp2_shape[kTensorDimensionLimit];
6965
int out_shape[kTensorDimensionLimit];
7066

71-
/*find broadcast*/
72-
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
73-
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
74-
const bool broadcast = (a_is_broadcasted || b_is_broadcasted);
67+
bool broadcast = 0;
7568

7669
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
7770
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
7871

7972
bool optimized = 1;
8073

81-
if ((a.dim() == 0) || (b.dim() == 0)) {
82-
optimized = 0;
83-
}
84-
85-
if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
86-
optimized = 0;
87-
}
88-
74+
/* Added change to work with input dimensions more than 5 */
8975
for (int i = 0; i < max_dim; i++) {
9076
out_shape[i] = 1;
9177
inp1_shape[i] = 1;
@@ -106,12 +92,30 @@ Tensor& mul_out(
10692
inp2_shape[i + offset_inp2] = b.size(i);
10793
}
10894

95+
/*find broadcast*/
96+
for (int i = 0; i < out.dim(); i++) {
97+
if (((inp1_shape[i]) != (out_shape[i])) ||
98+
((inp2_shape[i]) != (out_shape[i]))) {
99+
broadcast = 1;
100+
}
101+
}
102+
103+
if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
104+
optimized = 0;
105+
}
106+
109107
if ((compute_type == ScalarType::Int) && (optimized)) {
110108
const int* const inp1_data = a.const_data_ptr<int>();
111109
const int* const inp2_data = b.const_data_ptr<int>();
112110
int* const out_data = out.mutable_data_ptr<int>();
113111

114-
if (broadcast) {
112+
if (a.numel() == 1) {
113+
xa_nn_elm_mul_scalar_32x32_32(
114+
out_data, inp2_data, inp1_data[0], out.numel());
115+
} else if (b.numel() == 1) {
116+
xa_nn_elm_mul_scalar_32x32_32(
117+
out_data, inp1_data, inp2_data[0], out.numel());
118+
} else if (broadcast) {
115119
xa_nn_elm_mul_broadcast_5D_32x32_32(
116120
out_data,
117121
out_shape,
@@ -128,7 +132,13 @@ Tensor& mul_out(
128132
const float* const inp2_data = b.const_data_ptr<float>();
129133
float* const out_data = out.mutable_data_ptr<float>();
130134

131-
if (broadcast) {
135+
if (a.numel() == 1) {
136+
xa_nn_elm_mul_scalar_f32xf32_f32(
137+
out_data, inp2_data, inp1_data[0], out.numel());
138+
} else if (b.numel() == 1) {
139+
xa_nn_elm_mul_scalar_f32xf32_f32(
140+
out_data, inp1_data, inp2_data[0], out.numel());
141+
} else if (broadcast) {
132142
xa_nn_elm_mul_broadcast_5D_f32xf32_f32(
133143
out_data,
134144
out_shape,
@@ -157,7 +167,6 @@ Tensor& mul_out(
157167
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16);
158168
});
159169
}
160-
161170
return out;
162171
}
163172

backends/cadence/fusion_g3/operators/op_quantize.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ enum datatype { Ushort = 20, Bits4u = 21, Bits4 = 22 };
3131
*/
3232
namespace cadence {
3333
namespace impl {
34-
namespace FusionG3 {
34+
namespace G3 {
3535
namespace native {
3636

3737
namespace {
@@ -364,8 +364,8 @@ void quantize_impl(
364364

365365
#undef ASYM_CALCULATE_FLOAT_TYPE_TENSOR
366366
#undef ASYM_CALCULATE_FLOAT_TYPE_CHANNEL
367-
#undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR
368-
#undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL
367+
#undef ASYM_QUANTIZE_IMPL_TENSOR
368+
#undef ASYM_QUANTIZE_IMPL_CHANNEL
369369
}
370370
} else {
371371
if (out.scalar_type() == ScalarType::Byte) {
@@ -549,8 +549,8 @@ void quantize_impl(
549549
}
550550
#undef SYM_CALCULATE_FLOAT_TYPE_TENSOR
551551
#undef SYM_CALCULATE_FLOAT_TYPE_CHANNEL
552-
#undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR
553-
#undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL
552+
#undef SYM_QUANTIZE_IMPL_TENSOR
553+
#undef SYM_QUANTIZE_IMPL_CHANNEL
554554
}
555555
}
556556
}
@@ -719,7 +719,6 @@ Tensor& quantize_per_channel_out(
719719
axis_ptr,
720720
(int)quant_min,
721721
(int)quant_max);
722-
723722
return out;
724723
}
725724

@@ -802,6 +801,6 @@ Tensor& quantize_per_token_out(
802801
}
803802

804803
} // namespace native
805-
} // namespace FusionG3
804+
} // namespace G3
806805
} // namespace impl
807806
} // namespace cadence

0 commit comments

Comments
 (0)