Skip to content

Commit 79555ea

Browse files
hsharma35facebook-github-bot
authored andcommitted
Use macro XT_KERNEL_CHECK to handle errors returned by nnlib. (#7312)
Summary: Use ET_KERNEL_CHECK to detect error codes returned by xa_nn* library calls. Reviewed By: zonglinpeng Differential Revision: D67128597
1 parent 61b9e1b commit 79555ea

File tree

2 files changed

+120
-32
lines changed

2 files changed

+120
-32
lines changed

backends/cadence/fusion_g3/operators/op_add.cpp

Lines changed: 95 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,27 @@
1313
#include <executorch/runtime/platform/assert.h>
1414
#include <xa_nnlib_kernels_api.h>
1515

16-
using exec_aten::Scalar;
17-
using exec_aten::ScalarType;
18-
using exec_aten::Tensor;
19-
using executorch::runtime::canCast;
20-
using torch::executor::Error;
21-
using torch::executor::KernelRuntimeContext;
16+
using ::executorch::aten::Scalar;
17+
using ::executorch::aten::ScalarType;
18+
using ::executorch::aten::Tensor;
19+
using ::executorch::runtime::canCast;
20+
using ::executorch::runtime::Error;
21+
using ::executorch::runtime::KernelRuntimeContext;
2222

2323
namespace cadence {
2424
namespace impl {
2525
namespace G3 {
2626
namespace native {
2727

28+
#define XT_KERNEL_CHECK(ctx, out, kernel, ...) \
29+
const auto ret = kernel(__VA_ARGS__); \
30+
ET_KERNEL_CHECK_MSG( \
31+
ctx, \
32+
ret == 0, \
33+
InvalidArgument, \
34+
out, \
35+
"Failed to run kernel: " #kernel "(" #__VA_ARGS__ ")");
36+
2837
Tensor& add_out(
2938
KernelRuntimeContext& ctx,
3039
const Tensor& a,
@@ -121,13 +130,30 @@ Tensor& add_out(
121130
torch::executor::native::utils::extract_scalar(alpha, &alpha_val);
122131

123132
if ((a.numel() == 1) && (alpha_val == 1)) {
124-
xa_nn_elm_add_scalar_32x32_32(
125-
out_data, inp2_data, inp1_data[0], alpha_val, out.numel());
133+
XT_KERNEL_CHECK(
134+
ctx,
135+
out,
136+
xa_nn_elm_add_scalar_32x32_32,
137+
out_data,
138+
inp2_data,
139+
inp1_data[0],
140+
alpha_val,
141+
out.numel());
126142
} else if (b.numel() == 1) {
127-
xa_nn_elm_add_scalar_32x32_32(
128-
out_data, inp1_data, inp2_data[0], alpha_val, out.numel());
143+
XT_KERNEL_CHECK(
144+
ctx,
145+
out,
146+
xa_nn_elm_add_scalar_32x32_32,
147+
out_data,
148+
inp1_data,
149+
inp2_data[0],
150+
alpha_val,
151+
out.numel());
129152
} else if (broadcast) {
130-
xa_nn_elm_add_broadcast_5D_32x32_32(
153+
XT_KERNEL_CHECK(
154+
ctx,
155+
out,
156+
xa_nn_elm_add_broadcast_5D_32x32_32,
131157
out_data,
132158
out_shape,
133159
inp1_data,
@@ -137,8 +163,15 @@ Tensor& add_out(
137163
max_dim,
138164
alpha_val);
139165
} else {
140-
xa_nn_elm_add_32x32_32(
141-
out_data, inp1_data, inp2_data, alpha_val, out.numel());
166+
XT_KERNEL_CHECK(
167+
ctx,
168+
out,
169+
xa_nn_elm_add_32x32_32,
170+
out_data,
171+
inp1_data,
172+
inp2_data,
173+
alpha_val,
174+
out.numel());
142175
}
143176
} else if ((compute_type == ScalarType::Float) && (optimized)) {
144177
const float* const inp1_data = a.const_data_ptr<float>();
@@ -149,13 +182,30 @@ Tensor& add_out(
149182
torch::executor::native::utils::extract_scalar(alpha, &alpha_val);
150183

151184
if ((a.numel() == 1) && (alpha_val == 1.0)) {
152-
xa_nn_elm_add_scalar_f32xf32_f32(
153-
out_data, inp2_data, inp1_data[0], alpha_val, out.numel());
185+
XT_KERNEL_CHECK(
186+
ctx,
187+
out,
188+
xa_nn_elm_add_scalar_f32xf32_f32,
189+
out_data,
190+
inp2_data,
191+
inp1_data[0],
192+
alpha_val,
193+
out.numel());
154194
} else if (b.numel() == 1) {
155-
xa_nn_elm_add_scalar_f32xf32_f32(
156-
out_data, inp1_data, inp2_data[0], alpha_val, out.numel());
195+
XT_KERNEL_CHECK(
196+
ctx,
197+
out,
198+
xa_nn_elm_add_scalar_f32xf32_f32,
199+
out_data,
200+
inp1_data,
201+
inp2_data[0],
202+
alpha_val,
203+
out.numel());
157204
} else if (broadcast) {
158-
xa_nn_elm_add_broadcast_5D_f32xf32_f32(
205+
XT_KERNEL_CHECK(
206+
ctx,
207+
out,
208+
xa_nn_elm_add_broadcast_5D_f32xf32_f32,
159209
out_data,
160210
out_shape,
161211
inp1_data,
@@ -165,8 +215,15 @@ Tensor& add_out(
165215
max_dim,
166216
alpha_val);
167217
} else {
168-
xa_nn_elm_add_f32xf32_f32(
169-
out_data, inp1_data, inp2_data, alpha_val, out.numel());
218+
XT_KERNEL_CHECK(
219+
ctx,
220+
out,
221+
xa_nn_elm_add_f32xf32_f32,
222+
out_data,
223+
inp1_data,
224+
inp2_data,
225+
alpha_val,
226+
out.numel());
170227
}
171228
} else {
172229
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
@@ -242,8 +299,15 @@ Tensor& add_scalar_out(
242299

243300
int* const out_data = out.mutable_data_ptr<int>();
244301

245-
xa_nn_elm_add_scalar_32x32_32(
246-
out_data, inp1_data, inp2_val, alpha_val, out.numel());
302+
XT_KERNEL_CHECK(
303+
ctx,
304+
out,
305+
xa_nn_elm_add_scalar_32x32_32,
306+
out_data,
307+
inp1_data,
308+
inp2_val,
309+
alpha_val,
310+
out.numel());
247311

248312
} else if (compute_type == ScalarType::Float) {
249313
const float* const inp1_data = a.const_data_ptr<float>();
@@ -255,8 +319,15 @@ Tensor& add_scalar_out(
255319

256320
float* const out_data = out.mutable_data_ptr<float>();
257321

258-
xa_nn_elm_add_scalar_f32xf32_f32(
259-
out_data, inp1_data, inp2_val, alpha_val, out.numel());
322+
XT_KERNEL_CHECK(
323+
ctx,
324+
out,
325+
xa_nn_elm_add_scalar_f32xf32_f32,
326+
out_data,
327+
inp1_data,
328+
inp2_val,
329+
alpha_val,
330+
out.numel());
260331

261332
} else {
262333
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {

backends/cadence/fusion_g3/operators/tests/test_op_add.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <stdio.h>
1111

1212
#include <executorch/backends/cadence/fusion_g3/operators/operators.h>
13+
#include <executorch/kernels/test/TestUtil.h>
14+
#include <executorch/runtime/core/error.h>
1315
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1416
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1517
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
@@ -24,24 +26,19 @@ namespace {
2426
using ::executorch::aten::Scalar;
2527
using ::executorch::aten::ScalarType;
2628
using ::executorch::aten::Tensor;
29+
using ::executorch::aten::TensorImpl;
30+
using ::executorch::runtime::Error;
2731
using ::executorch::runtime::KernelRuntimeContext;
2832
using ::executorch::runtime::runtime_init;
2933
using ::executorch::runtime::testing::TensorFactory;
30-
using ::testing::Test;
3134

32-
class FusionG3OperatorTest : public Test {
35+
class FusionG3OperatorTest : public OperatorTest {
3336
public:
34-
void SetUp() override {
35-
runtime_init();
36-
}
37-
3837
protected:
3938
Tensor&
4039
add_out(const Tensor& a, const Tensor& b, const Scalar& alpha, Tensor& out) {
4140
return cadence::impl::G3::native::add_out(context_, a, b, alpha, out);
4241
}
43-
44-
KernelRuntimeContext context_;
4542
};
4643

4744
TEST_F(FusionG3OperatorTest, TwoDimFloatTensorAddTest) {
@@ -77,6 +74,26 @@ TEST_F(FusionG3OperatorTest, AddWithBroadcastTest) {
7774
EXPECT_TENSOR_EQ(out, tf.full(size_a, 2));
7875
}
7976

77+
TEST_F(FusionG3OperatorTest, KernelCheckTest) {
78+
TensorFactory<ScalarType::Float> tf;
79+
// Broadcast add.
80+
const std::vector<TensorImpl::SizesType> sizeOfA{1, 3, 2, 4}, sizeOfB{2, 4};
81+
const Tensor b = tf.ones(sizeOfB);
82+
Tensor out = tf.zeros(sizeOfA);
83+
// Create a null tensor to force kernel check failure.
84+
TensorImpl nullTensorImpl(
85+
b.scalar_type(),
86+
b.dim(),
87+
const_cast<TensorImpl::SizesType*>(b.sizes().data()),
88+
// Use nullptr to force kernel check failure.
89+
/*data=*/nullptr,
90+
const_cast<TensorImpl::DimOrderType*>(b.dim_order().data()));
91+
Tensor nullTensor(&nullTensorImpl);
92+
93+
ET_EXPECT_KERNEL_FAILURE(
94+
context_, add_out(tf.ones(sizeOfA), nullTensor, 1, out));
95+
}
96+
8097
} // namespace
8198
} // namespace native
8299
} // namespace G3

0 commit comments

Comments
 (0)