Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions kernels/prim_ops/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ namespace function {

namespace {

#define __ET_PRIM_OP_ERROR_IMPL(a, b, context) \
else { \
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag); \
#define __ET_PRIM_OP_ERROR_IMPL(a, b, context) \
else { \
ET_KERNEL_CHECK(context, false, InvalidType, /* void */); \
Copy link
Contributor

@lucylq lucylq Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use ET_KERNEL_CHECK_MSG to keep the format string/error message?

e.g.

ET_KERNEL_CHECK_MSG(context, false, InvalidType, "%zu, %zu", (size_t)a.tag, (size_t)b.tag); 

}

// TODO Fail using runtime context
#define __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
(void)context; \
EValue& a = *stack[0]; \
Expand Down Expand Up @@ -168,8 +167,7 @@ static Kernel prim_ops[] = {
} else if (a.isDouble() && b.isInt()) {
floor_div_double(a.toDouble(), static_cast<double>(b.toInt()), out);
} else {
// TODO Fail using runtime context
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
}
}),

Expand All @@ -193,8 +191,7 @@ static Kernel prim_ops[] = {
} else if (a.isDouble() && b.isInt()) {
out = EValue(a.toDouble() / b.toInt());
} else {
// TODO Fail using runtime context
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
}
}),

Expand All @@ -214,8 +211,7 @@ static Kernel prim_ops[] = {
// TODO: This should be impossible
out = EValue(a.toDouble());
} else {
// TODO Fail using runtime context
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
}
}),

Expand Down Expand Up @@ -265,8 +261,7 @@ static Kernel prim_ops[] = {
} else if (a.isDouble()) {
out = EValue(-a.toDouble());
} else {
// TODO Fail using runtime context
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
}
}),

Expand Down Expand Up @@ -303,7 +298,7 @@ static Kernel prim_ops[] = {
if (a.isInt() && b.isInt()) {
out = EValue(a.toInt() % b.toInt());
} else {
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
}
}),

Expand All @@ -317,7 +312,7 @@ static Kernel prim_ops[] = {
if (a.isDouble()) {
out = EValue(static_cast<int64_t>(ceil(a.toDouble())));
} else {
ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
}
}),

Expand Down Expand Up @@ -348,7 +343,7 @@ static Kernel prim_ops[] = {

out = EValue(static_cast<int64_t>(res));
} else {
ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
}
}),

Expand All @@ -362,7 +357,7 @@ static Kernel prim_ops[] = {
if (a.isDouble()) {
out = EValue(static_cast<int64_t>(trunc(a.toDouble())));
} else {
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
ET_KERNEL_CHECK(context, false, InvalidType, /* void */);
}
}),

Expand Down
6 changes: 4 additions & 2 deletions kernels/prim_ops/test/prim_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ TEST_F(RegisterPrimOpsTest, NegScalarReturnsCorrectValue) {
EXPECT_EQ(stack[1]->toInt(), -5l);
}

TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorDies) {
TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorFails) {
testing::TensorFactory<ScalarType::Int> tf;

EValue values[2];
Expand All @@ -325,7 +325,9 @@ TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorDies) {
}

// Try to negate a tensor, which should cause a runtime error.
ET_EXPECT_DEATH(getOpsFn("executorch_prim::neg.Scalar")(context_, stack), "");
ET_EXPECT_KERNEL_FAILURE(
context_,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Linter failure here - can you try run lintrunner -a?

See: https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#lintrunner

getOpsFn("executorch_prim::neg.Scalar")(context_, stack));
}

TEST_F(RegisterPrimOpsTest, TestETView) {
Expand Down
Loading