66 * LICENSE file in the root directory of this source tree.
77 */
88
9+ #include < executorch/backends/cadence/hifi/kernels/kernels.h>
910#include < executorch/kernels/portable/cpu/scalar_utils.h>
1011#include < executorch/kernels/portable/cpu/util/broadcast_util.h>
1112#include < executorch/kernels/portable/cpu/util/functional_util.h>
1213#include < executorch/kernels/portable/cpu/util/math_util.h>
1314#include < executorch/runtime/kernel/kernel_includes.h>
1415#include < executorch/runtime/platform/assert.h>
1516#include < cmath>
16- #include < executorch/backends/cadence/hifi/kernels/kernels.h>
1717
1818using exec_aten::Scalar;
1919using exec_aten::ScalarType;
@@ -22,7 +22,7 @@ using executorch::aten::RuntimeContext;
2222using torch::executor::Error;
2323
2424namespace impl {
25- namespace HiFi {
25+ namespace HiFi {
2626namespace native {
2727
2828namespace {
@@ -74,29 +74,27 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
7474 int max_dim = a.dim () > b.dim () ? a.dim () : b.dim ();
7575 max_dim = out.dim () > max_dim ? out.dim () : max_dim;
7676
77- if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
77+ if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
7878 optimized = 0 ;
7979
80- if ((a_dim == 0 ) || (b_dim == 0 ) )
80+ if ((a_dim == 0 ) || (b_dim == 0 ) )
8181 optimized = 0 ;
8282
83- if ((broadcast == 1 ) && (max_dim > kNnlibMaxDim ))
83+ if ((broadcast == 1 ) && (max_dim > kNnlibMaxDim ))
8484 optimized = 0 ;
8585
86- if (optimized)
87- {
86+ if (optimized) {
8887 float * a_data = a.mutable_data_ptr <float >();
8988 float * b_data = b.mutable_data_ptr <float >();
9089 float * out_data = out.mutable_data_ptr <float >();
9190
92- if (broadcast == 1 )
93- {
91+ if (broadcast == 1 ) {
9492
9593 int out_shape[kNnlibMaxDim ];
9694 int inp1_shape[kNnlibMaxDim ];
9795 int inp2_shape[kNnlibMaxDim ];
9896
99- for (int i = 0 ; i < kNnlibMaxDim ; i++)
97+ for (int i = 0 ; i < kNnlibMaxDim ; i++)
10098 {
10199 out_shape[i] = 1 ;
102100 inp1_shape[i] = 1 ;
@@ -106,34 +104,35 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
106104 int off_o = kNnlibMaxDim - out.dim ();
107105 int off_a = kNnlibMaxDim - a.dim ();
108106 int off_b = kNnlibMaxDim - b.dim ();
109- for (int i = 0 ; i < out.dim (); i++)
107+ for (int i = 0 ; i < out.dim (); i++)
110108 out_shape[i+off_o] = out.size (i);
111- for (int i = 0 ; i < a.dim (); i++)
109+ for (int i = 0 ; i < a.dim (); i++)
112110 inp1_shape[i+off_a] = a.size (i);
113- for (int i = 0 ; i < b.dim (); i++)
111+ for (int i = 0 ; i < b.dim (); i++)
114112 inp2_shape[i+off_b] = b.size (i);
115113
116- xa_nn_elm_div_broadcast_4D_f32xf32_f32 (out_data, out_shape, a_data, inp1_shape, b_data, inp2_shape);
114+ xa_nn_elm_div_broadcast_4D_f32xf32_f32 (
115+ out_data, out_shape, a_data, inp1_shape, b_data, inp2_shape);
117116 }
118117 else
119118 {
120-
121119 xa_nn_elm_div_f32xf32_f32 (out_data, a_data, b_data, out.numel ());
122120 }
123-
121+
124122 return out;
125123 }
126-
124+
127125 ScalarType common_type = get_compute_type (a_type, b_type);
128126 ScalarType out_type = out.scalar_type ();
129-
127+
130128 ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
131-
129+
132130 ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " div.out" , CTYPE_A, [&]() {
133131 ET_SWITCH_REAL_TYPES_AND (Bool, b_type, ctx, " div.out" , CTYPE_B, [&]() {
134132 ET_SWITCH_FLOAT_TYPES (common_type, ctx, " div.out" , CTYPE_IN, [&]() {
135133 ET_SWITCH_FLOAT_TYPES (out_type, ctx, " div.out" , CTYPE_OUT, [&]() {
136- torch::executor::apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
134+ torch::executor::
135+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
137136 [](const CTYPE_A val_a, const CTYPE_B val_b) {
138137 CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
139138 CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
@@ -188,13 +187,13 @@ Tensor& div_out_mode(
188187 int max_dim = a.dim () > b.dim () ? a.dim () : b.dim ();
189188 max_dim = out.dim () > max_dim ? out.dim () : max_dim;
190189
191- if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
190+ if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
192191 optimized = 0 ;
193192
194- if ((a_dim == 0 ) || (b_dim == 0 ))
193+ if ((a_dim == 0 ) || (b_dim == 0 ))
195194 optimized = 0 ;
196195
197- if ((broadcast == 1 ) && (max_dim > kNnlibMaxDim ))
196+ if ((broadcast == 1 ) && (max_dim > kNnlibMaxDim ))
198197 optimized = 0 ;
199198 int mode_val = -1 ;
200199 if (mode.has_value () && mode.value () == " trunc" )
@@ -204,20 +203,17 @@ Tensor& div_out_mode(
204203 else
205204 optimized = 0 ;
206205
207- if (optimized)
208- {
206+ if (optimized) {
209207 float * a_data = a.mutable_data_ptr <float >();
210208 float * b_data = b.mutable_data_ptr <float >();
211209 float * out_data = out.mutable_data_ptr <float >();
212210
213- if (broadcast)
214- {
211+ if (broadcast) {
215212 int out_shape[kNnlibMaxDim ];
216213 int inp1_shape[kNnlibMaxDim ];
217214 int inp2_shape[kNnlibMaxDim ];
218215
219- for (int i = 0 ; i < kNnlibMaxDim ; i++)
220- {
216+ for (int i = 0 ; i < kNnlibMaxDim ; i++) {
221217 inp1_shape[i] = 1 ;
222218 inp2_shape[i] = 1 ;
223219 out_shape[i] = 1 ;
@@ -227,18 +223,20 @@ Tensor& div_out_mode(
227223 int off_a = kNnlibMaxDim - a.dim ();
228224 int off_b = kNnlibMaxDim - b.dim ();
229225
230- for (int i = 0 ; i < out.dim (); i++)
226+ for (int i = 0 ; i < out.dim (); i++)
231227 out_shape[i+off_o] = out.size (i);
232- for (int i = 0 ; i < a.dim (); i++)
228+ for (int i = 0 ; i < a.dim (); i++)
233229 inp1_shape[i+off_a] = a.size (i);
234- for (int i = 0 ; i < b.dim (); i++)
230+ for (int i = 0 ; i < b.dim (); i++)
235231 inp2_shape[i+off_b] = b.size (i);
236232
237- xa_nn_elm_div_mode_broadcast_4D_f32xf32_f32 (out_data, out_shape, a_data, inp1_shape, b_data, inp2_shape, mode_val);
233+ xa_nn_elm_div_mode_broadcast_4D_f32xf32_f32 (
234+ out_data, out_shape, a_data, inp1_shape, b_data, inp2_shape, mode_val);
238235 }
239236 else
240237 {
241- xa_nn_elm_div_mode_f32xf32_f32 (out_data, a_data, b_data, out.numel (), mode_val);
238+ xa_nn_elm_div_mode_f32xf32_f32 (
239+ out_data, a_data, b_data, out.numel (), mode_val);
242240 }
243241
244242 return out;
@@ -248,7 +246,8 @@ Tensor& div_out_mode(
248246 ET_SWITCH_REAL_TYPES_AND (Bool, b_type, ctx, " div.out_mode" , CTYPE_B, [&]() {
249247 ET_SWITCH_FLOAT_TYPES (common_type, ctx, " div.out_mode" , CTYPE_IN, [&]() {
250248 ET_SWITCH_REAL_TYPES (out_type, ctx, " div.out_mode" , CTYPE_OUT, [&]() {
251- torch::executor::apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
249+ torch::executor::
250+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
252251 [mode](const CTYPE_A val_a, const CTYPE_B val_b) {
253252 CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
254253 CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
@@ -272,6 +271,6 @@ Tensor& div_out_mode(
272271}
273272
274273
275- } // namespace impl
276- } // namespace HiFi
277274} // namespace native
275+ } // namespace HiFi
276+ } // namespace impl
0 commit comments