Skip to content

Commit af412bd

Browse files
committed
Modified prim ops to not abort on error
1 parent 9c9f665 commit af412bd

File tree

2 files changed

+14
-17
lines changed

2 files changed

+14
-17
lines changed

kernels/prim_ops/register_prim_ops.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@ namespace function {
2222

2323
namespace {
2424

25-
#define __ET_PRIM_OP_ERROR_IMPL(a, b, context) \
26-
else { \
27-
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag); \
25+
#define __ET_PRIM_OP_ERROR_IMPL(a, b, context) \
26+
else { \
27+
ET_KERNEL_CHECK(context, false, InvalidType, /* void */); \
2828
}
2929

30-
// TODO Fail using runtime context
3130
#define __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
3231
(void)context; \
3332
EValue& a = *stack[0]; \
@@ -168,8 +167,7 @@ static Kernel prim_ops[] = {
168167
} else if (a.isDouble() && b.isInt()) {
169168
floor_div_double(a.toDouble(), static_cast<double>(b.toInt()), out);
170169
} else {
171-
// TODO Fail using runtime context
172-
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
170+
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
173171
}
174172
}),
175173

@@ -193,8 +191,7 @@ static Kernel prim_ops[] = {
193191
} else if (a.isDouble() && b.isInt()) {
194192
out = EValue(a.toDouble() / b.toInt());
195193
} else {
196-
// TODO Fail using runtime context
197-
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
194+
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
198195
}
199196
}),
200197

@@ -214,8 +211,7 @@ static Kernel prim_ops[] = {
214211
// TODO: This should be impossible
215212
out = EValue(a.toDouble());
216213
} else {
217-
// TODO Fail using runtime context
218-
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
214+
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
219215
}
220216
}),
221217

@@ -265,8 +261,7 @@ static Kernel prim_ops[] = {
265261
} else if (a.isDouble()) {
266262
out = EValue(-a.toDouble());
267263
} else {
268-
// TODO Fail using runtime context
269-
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
264+
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
270265
}
271266
}),
272267

@@ -303,7 +298,7 @@ static Kernel prim_ops[] = {
303298
if (a.isInt() && b.isInt()) {
304299
out = EValue(a.toInt() % b.toInt());
305300
} else {
306-
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
301+
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
307302
}
308303
}),
309304

@@ -317,7 +312,7 @@ static Kernel prim_ops[] = {
317312
if (a.isDouble()) {
318313
out = EValue(static_cast<int64_t>(ceil(a.toDouble())));
319314
} else {
320-
ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
315+
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
321316
}
322317
}),
323318

@@ -348,7 +343,7 @@ static Kernel prim_ops[] = {
348343

349344
out = EValue(static_cast<int64_t>(res));
350345
} else {
351-
ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
346+
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
352347
}
353348
}),
354349

@@ -362,7 +357,7 @@ static Kernel prim_ops[] = {
362357
if (a.isDouble()) {
363358
out = EValue(static_cast<int64_t>(trunc(a.toDouble())));
364359
} else {
365-
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
360+
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
366361
}
367362
}),
368363

kernels/prim_ops/test/prim_ops_test.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,9 @@ TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorDies) {
325325
}
326326

327327
// Try to negate a tensor, which should cause a runtime error.
328-
ET_EXPECT_DEATH(getOpsFn("executorch_prim::neg.Scalar")(context_, stack), "");
328+
ET_EXPECT_KERNEL_FAILURE(
329+
context_,
330+
getOpsFn("executorch_prim::neg.Scalar")(context_, stack));
329331
}
330332

331333
TEST_F(RegisterPrimOpsTest, TestETView) {

0 commit comments

Comments
 (0)