1313#include < executorch/runtime/platform/assert.h>
1414#include < xa_nnlib_kernels_api.h>
1515
16- using exec_aten ::Scalar;
17- using exec_aten ::ScalarType;
18- using exec_aten ::Tensor;
19- using executorch::runtime::canCast;
20- using torch::executor ::Error;
21- using torch::executor ::KernelRuntimeContext;
16+ using ::executorch::aten ::Scalar;
17+ using ::executorch::aten ::ScalarType;
18+ using ::executorch::aten ::Tensor;
19+ using :: executorch::runtime::canCast;
20+ using ::executorch::runtime ::Error;
21+ using ::executorch::runtime ::KernelRuntimeContext;
2222
2323namespace cadence {
2424namespace impl {
2525namespace G3 {
2626namespace native {
2727
28+ #define XT_KERNEL_CHECK (ctx, out, kernel, ...) \
29+ const auto ret = kernel(__VA_ARGS__); \
30+ ET_KERNEL_CHECK_MSG ( \
31+ ctx, \
32+ ret == 0 , \
33+ InvalidArgument, \
34+ out, \
35+ " Failed to run kernel: " #kernel " (" #__VA_ARGS__ " )" );
36+
2837Tensor& add_out (
2938 KernelRuntimeContext& ctx,
3039 const Tensor& a,
@@ -121,13 +130,30 @@ Tensor& add_out(
121130 torch::executor::native::utils::extract_scalar (alpha, &alpha_val);
122131
123132 if ((a.numel () == 1 ) && (alpha_val == 1 )) {
124- xa_nn_elm_add_scalar_32x32_32 (
125- out_data, inp2_data, inp1_data[0 ], alpha_val, out.numel ());
133+ XT_KERNEL_CHECK (
134+ ctx,
135+ out,
136+ xa_nn_elm_add_scalar_32x32_32,
137+ out_data,
138+ inp2_data,
139+ inp1_data[0 ],
140+ alpha_val,
141+ out.numel ());
126142 } else if (b.numel () == 1 ) {
127- xa_nn_elm_add_scalar_32x32_32 (
128- out_data, inp1_data, inp2_data[0 ], alpha_val, out.numel ());
143+ XT_KERNEL_CHECK (
144+ ctx,
145+ out,
146+ xa_nn_elm_add_scalar_32x32_32,
147+ out_data,
148+ inp1_data,
149+ inp2_data[0 ],
150+ alpha_val,
151+ out.numel ());
129152 } else if (broadcast) {
130- xa_nn_elm_add_broadcast_5D_32x32_32 (
153+ XT_KERNEL_CHECK (
154+ ctx,
155+ out,
156+ xa_nn_elm_add_broadcast_5D_32x32_32,
131157 out_data,
132158 out_shape,
133159 inp1_data,
@@ -137,8 +163,15 @@ Tensor& add_out(
137163 max_dim,
138164 alpha_val);
139165 } else {
140- xa_nn_elm_add_32x32_32 (
141- out_data, inp1_data, inp2_data, alpha_val, out.numel ());
166+ XT_KERNEL_CHECK (
167+ ctx,
168+ out,
169+ xa_nn_elm_add_32x32_32,
170+ out_data,
171+ inp1_data,
172+ inp2_data,
173+ alpha_val,
174+ out.numel ());
142175 }
143176 } else if ((compute_type == ScalarType::Float) && (optimized)) {
144177 const float * const inp1_data = a.const_data_ptr <float >();
@@ -149,13 +182,30 @@ Tensor& add_out(
149182 torch::executor::native::utils::extract_scalar (alpha, &alpha_val);
150183
151184 if ((a.numel () == 1 ) && (alpha_val == 1.0 )) {
152- xa_nn_elm_add_scalar_f32xf32_f32 (
153- out_data, inp2_data, inp1_data[0 ], alpha_val, out.numel ());
185+ XT_KERNEL_CHECK (
186+ ctx,
187+ out,
188+ xa_nn_elm_add_scalar_f32xf32_f32,
189+ out_data,
190+ inp2_data,
191+ inp1_data[0 ],
192+ alpha_val,
193+ out.numel ());
154194 } else if (b.numel () == 1 ) {
155- xa_nn_elm_add_scalar_f32xf32_f32 (
156- out_data, inp1_data, inp2_data[0 ], alpha_val, out.numel ());
195+ XT_KERNEL_CHECK (
196+ ctx,
197+ out,
198+ xa_nn_elm_add_scalar_f32xf32_f32,
199+ out_data,
200+ inp1_data,
201+ inp2_data[0 ],
202+ alpha_val,
203+ out.numel ());
157204 } else if (broadcast) {
158- xa_nn_elm_add_broadcast_5D_f32xf32_f32 (
205+ XT_KERNEL_CHECK (
206+ ctx,
207+ out,
208+ xa_nn_elm_add_broadcast_5D_f32xf32_f32,
159209 out_data,
160210 out_shape,
161211 inp1_data,
@@ -165,8 +215,15 @@ Tensor& add_out(
165215 max_dim,
166216 alpha_val);
167217 } else {
168- xa_nn_elm_add_f32xf32_f32 (
169- out_data, inp1_data, inp2_data, alpha_val, out.numel ());
218+ XT_KERNEL_CHECK (
219+ ctx,
220+ out,
221+ xa_nn_elm_add_f32xf32_f32,
222+ out_data,
223+ inp1_data,
224+ inp2_data,
225+ alpha_val,
226+ out.numel ());
170227 }
171228 } else {
172229 ET_SWITCH_REALB_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
@@ -242,8 +299,15 @@ Tensor& add_scalar_out(
242299
243300 int * const out_data = out.mutable_data_ptr <int >();
244301
245- xa_nn_elm_add_scalar_32x32_32 (
246- out_data, inp1_data, inp2_val, alpha_val, out.numel ());
302+ XT_KERNEL_CHECK (
303+ ctx,
304+ out,
305+ xa_nn_elm_add_scalar_32x32_32,
306+ out_data,
307+ inp1_data,
308+ inp2_val,
309+ alpha_val,
310+ out.numel ());
247311
248312 } else if (compute_type == ScalarType::Float) {
249313 const float * const inp1_data = a.const_data_ptr <float >();
@@ -255,8 +319,15 @@ Tensor& add_scalar_out(
255319
256320 float * const out_data = out.mutable_data_ptr <float >();
257321
258- xa_nn_elm_add_scalar_f32xf32_f32 (
259- out_data, inp1_data, inp2_val, alpha_val, out.numel ());
322+ XT_KERNEL_CHECK (
323+ ctx,
324+ out,
325+ xa_nn_elm_add_scalar_f32xf32_f32,
326+ out_data,
327+ inp1_data,
328+ inp2_val,
329+ alpha_val,
330+ out.numel ());
260331
261332 } else {
262333 ET_SWITCH_REALB_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
0 commit comments