Skip to content

Commit a2b620a

Browse files
authored
Prim ops replace abort with error code (#11985)
### Summary We want to minimize the scenarios where ET fails fatally. If we need to check a precondition inside a kernel rather then calling ET_CHECK which internally dispatches to ABORT we should call ET_KERNEL_CHECK which sets an error state and returns. ### Test plan Replaced TestNegScalarWithTensorDies with TestNegScalarWithTensorFails. Before, it would expect an abort when negating a tensor. The new test expects a failure state instead. The test can be run with `test/run_oss_cpp_tests.sh`
1 parent 999b1b0 commit a2b620a

File tree

2 files changed

+57
-18
lines changed

2 files changed

+57
-18
lines changed

kernels/prim_ops/register_prim_ops.cpp

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,18 @@ namespace function {
2222

2323
namespace {
2424

25-
#define __ET_PRIM_OP_ERROR_IMPL(a, b, context) \
26-
else { \
27-
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag); \
25+
#define __ET_PRIM_OP_ERROR_IMPL(a, b, context) \
26+
else { \
27+
ET_KERNEL_CHECK_MSG( \
28+
context, \
29+
false, \
30+
InvalidType, \
31+
/* void */, \
32+
"%zu, %zu", \
33+
(size_t)a.tag, \
34+
(size_t)b.tag); \
2835
}
2936

30-
// TODO Fail using runtime context
3137
#define __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
3238
(void)context; \
3339
EValue& a = *stack[0]; \
@@ -168,8 +174,14 @@ static Kernel prim_ops[] = {
168174
} else if (a.isDouble() && b.isInt()) {
169175
floor_div_double(a.toDouble(), static_cast<double>(b.toInt()), out);
170176
} else {
171-
// TODO Fail using runtime context
172-
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
177+
ET_KERNEL_CHECK_MSG(
178+
context,
179+
false,
180+
InvalidType,
181+
/* void */,
182+
"%zu, %zu",
183+
(size_t)a.tag,
184+
(size_t)b.tag);
173185
}
174186
}),
175187

@@ -193,8 +205,14 @@ static Kernel prim_ops[] = {
193205
} else if (a.isDouble() && b.isInt()) {
194206
out = EValue(a.toDouble() / b.toInt());
195207
} else {
196-
// TODO Fail using runtime context
197-
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
208+
ET_KERNEL_CHECK_MSG(
209+
context,
210+
false,
211+
InvalidType,
212+
/* void */,
213+
"%zu, %zu",
214+
(size_t)a.tag,
215+
(size_t)b.tag);
198216
}
199217
}),
200218

@@ -214,8 +232,8 @@ static Kernel prim_ops[] = {
214232
// TODO: This should be impossible
215233
out = EValue(a.toDouble());
216234
} else {
217-
// TODO Fail using runtime context
218-
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
235+
ET_KERNEL_CHECK_MSG(
236+
context, false, InvalidType, /* void */, "%zu", (size_t)a.tag);
219237
}
220238
}),
221239

@@ -265,8 +283,8 @@ static Kernel prim_ops[] = {
265283
} else if (a.isDouble()) {
266284
out = EValue(-a.toDouble());
267285
} else {
268-
// TODO Fail using runtime context
269-
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
286+
ET_KERNEL_CHECK_MSG(
287+
context, false, InvalidType, /* void */, "%zu", (size_t)a.tag);
270288
}
271289
}),
272290

@@ -303,7 +321,14 @@ static Kernel prim_ops[] = {
303321
if (a.isInt() && b.isInt()) {
304322
out = EValue(a.toInt() % b.toInt());
305323
} else {
306-
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
324+
ET_KERNEL_CHECK_MSG(
325+
context,
326+
false,
327+
InvalidType,
328+
/* void */,
329+
"%zu, %zu",
330+
(size_t)a.tag,
331+
(size_t)b.tag);
307332
}
308333
}),
309334

@@ -317,7 +342,13 @@ static Kernel prim_ops[] = {
317342
if (a.isDouble()) {
318343
out = EValue(static_cast<int64_t>(ceil(a.toDouble())));
319344
} else {
320-
ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
345+
ET_KERNEL_CHECK_MSG(
346+
context,
347+
false,
348+
InvalidType,
349+
/* void */,
350+
"Unsupported DType %zu",
351+
(size_t)a.tag);
321352
}
322353
}),
323354

@@ -348,7 +379,13 @@ static Kernel prim_ops[] = {
348379

349380
out = EValue(static_cast<int64_t>(res));
350381
} else {
351-
ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
382+
ET_KERNEL_CHECK_MSG(
383+
context,
384+
false,
385+
InvalidType,
386+
/* void */,
387+
"Unsupported DType %zu",
388+
(size_t)a.tag);
352389
}
353390
}),
354391

@@ -362,7 +399,8 @@ static Kernel prim_ops[] = {
362399
if (a.isDouble()) {
363400
out = EValue(static_cast<int64_t>(trunc(a.toDouble())));
364401
} else {
365-
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
402+
ET_KERNEL_CHECK_MSG(
403+
context, false, InvalidType, /* void */, "%zu", (size_t)a.tag);
366404
}
367405
}),
368406

kernels/prim_ops/test/prim_ops_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ TEST_F(RegisterPrimOpsTest, NegScalarReturnsCorrectValue) {
308308
EXPECT_EQ(stack[1]->toInt(), -5l);
309309
}
310310

311-
TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorDies) {
311+
TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorFails) {
312312
testing::TensorFactory<ScalarType::Int> tf;
313313

314314
EValue values[2];
@@ -325,7 +325,8 @@ TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorDies) {
325325
}
326326

327327
// Try to negate a tensor, which should cause a runtime error.
328-
ET_EXPECT_DEATH(getOpsFn("executorch_prim::neg.Scalar")(context_, stack), "");
328+
ET_EXPECT_KERNEL_FAILURE(
329+
context_, getOpsFn("executorch_prim::neg.Scalar")(context_, stack));
329330
}
330331

331332
TEST_F(RegisterPrimOpsTest, TestETView) {

0 commit comments

Comments
 (0)