Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion runtime/core/evalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ class BoxedEvalueList {
* unwrapped vals.
*/
BoxedEvalueList(EValue** wrapped_vals, T* unwrapped_vals, int size)
: wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {}
: wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {
ET_CHECK_MSG(wrapped_vals != nullptr, "wrapped_vals cannot be null");
ET_CHECK_MSG(unwrapped_vals != nullptr, "unwrapped_vals cannot be null");
ET_CHECK_MSG(size >= 0, "size cannot be negative");
}
/*
* Constructs and returns the list of T specified by the EValue pointers
*/
Expand Down Expand Up @@ -280,6 +284,7 @@ struct EValue {

/****** String Type ******/
/*implicit*/ EValue(executorch::aten::ArrayRef<char>* s) : tag(Tag::String) {
ET_CHECK_MSG(s != nullptr, "ArrayRef<char> pointer cannot be null");
payload.copyable_union.as_string_ptr = s;
}

Expand All @@ -289,13 +294,18 @@ struct EValue {

std::string_view toString() const {
ET_CHECK_MSG(isString(), "EValue is not a String.");
ET_CHECK_MSG(
payload.copyable_union.as_string_ptr != nullptr,
"EValue string pointer is null.");
return std::string_view(
payload.copyable_union.as_string_ptr->data(),
payload.copyable_union.as_string_ptr->size());
}

/****** Int List Type ******/
/*implicit*/ EValue(BoxedEvalueList<int64_t>* i) : tag(Tag::ListInt) {
ET_CHECK_MSG(
i != nullptr, "BoxedEvalueList<int64_t> pointer cannot be null");
payload.copyable_union.as_int_list_ptr = i;
}

Expand All @@ -305,12 +315,16 @@ struct EValue {

executorch::aten::ArrayRef<int64_t> toIntList() const {
ET_CHECK_MSG(isIntList(), "EValue is not an Int List.");
ET_CHECK_MSG(
payload.copyable_union.as_int_list_ptr != nullptr,
"EValue int list pointer is null.");
return (payload.copyable_union.as_int_list_ptr)->get();
}

/****** Bool List Type ******/
/*implicit*/ EValue(executorch::aten::ArrayRef<bool>* b)
: tag(Tag::ListBool) {
ET_CHECK_MSG(b != nullptr, "ArrayRef<bool> pointer cannot be null");
payload.copyable_union.as_bool_list_ptr = b;
}

Expand All @@ -320,12 +334,16 @@ struct EValue {

executorch::aten::ArrayRef<bool> toBoolList() const {
ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List.");
ET_CHECK_MSG(
payload.copyable_union.as_bool_list_ptr != nullptr,
"EValue bool list pointer is null.");
return *(payload.copyable_union.as_bool_list_ptr);
}

/****** Double List Type ******/
/*implicit*/ EValue(executorch::aten::ArrayRef<double>* d)
: tag(Tag::ListDouble) {
ET_CHECK_MSG(d != nullptr, "ArrayRef<double> pointer cannot be null");
payload.copyable_union.as_double_list_ptr = d;
}

Expand All @@ -335,12 +353,17 @@ struct EValue {

executorch::aten::ArrayRef<double> toDoubleList() const {
ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List.");
ET_CHECK_MSG(
payload.copyable_union.as_double_list_ptr != nullptr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check this in the ctors no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept the original checks and added in boxed ctor and each evalue ctor

"EValue double list pointer is null.");
return *(payload.copyable_union.as_double_list_ptr);
}

/****** Tensor List Type ******/
/*implicit*/ EValue(BoxedEvalueList<executorch::aten::Tensor>* t)
: tag(Tag::ListTensor) {
ET_CHECK_MSG(
t != nullptr, "BoxedEvalueList<Tensor> pointer cannot be null");
payload.copyable_union.as_tensor_list_ptr = t;
}

Expand All @@ -350,13 +373,19 @@ struct EValue {

executorch::aten::ArrayRef<executorch::aten::Tensor> toTensorList() const {
ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List.");
ET_CHECK_MSG(
payload.copyable_union.as_tensor_list_ptr != nullptr,
"EValue tensor list pointer is null.");
return payload.copyable_union.as_tensor_list_ptr->get();
}

/****** List Optional Tensor Type ******/
/*implicit*/ EValue(
BoxedEvalueList<std::optional<executorch::aten::Tensor>>* t)
: tag(Tag::ListOptionalTensor) {
ET_CHECK_MSG(
t != nullptr,
"BoxedEvalueList<optional<Tensor>> pointer cannot be null");
payload.copyable_union.as_list_optional_tensor_ptr = t;
}

Expand All @@ -366,6 +395,11 @@ struct EValue {

executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>
toListOptionalTensor() const {
ET_CHECK_MSG(
isListOptionalTensor(), "EValue is not a List Optional Tensor.");
ET_CHECK_MSG(
payload.copyable_union.as_list_optional_tensor_ptr != nullptr,
"EValue list optional tensor pointer is null.");
return payload.copyable_union.as_list_optional_tensor_ptr->get();
}

Expand Down
144 changes: 144 additions & 0 deletions runtime/core/test/evalue_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,147 @@ TEST_F(EValueTest, ConstructFromNullPtrAborts) {

ET_EXPECT_DEATH({ EValue evalue(null_ptr); }, "");
}

TEST_F(EValueTest, StringConstructorNullCheck) {
executorch::aten::ArrayRef<char>* null_string_ptr = nullptr;
ET_EXPECT_DEATH(
{ EValue evalue(null_string_ptr); },
"ArrayRef<char> pointer cannot be null");
}

TEST_F(EValueTest, BoolListConstructorNullCheck) {
executorch::aten::ArrayRef<bool>* null_bool_list_ptr = nullptr;
ET_EXPECT_DEATH(
{ EValue evalue(null_bool_list_ptr); },
"ArrayRef<bool> pointer cannot be null");
}

TEST_F(EValueTest, DoubleListConstructorNullCheck) {
executorch::aten::ArrayRef<double>* null_double_list_ptr = nullptr;
ET_EXPECT_DEATH(
{ EValue evalue(null_double_list_ptr); },
"ArrayRef<double> pointer cannot be null");
}

TEST_F(EValueTest, IntListConstructorNullCheck) {
BoxedEvalueList<int64_t>* null_int_list_ptr = nullptr;
ET_EXPECT_DEATH(
{ EValue evalue(null_int_list_ptr); },
"BoxedEvalueList<int64_t> pointer cannot be null");
}

TEST_F(EValueTest, TensorListConstructorNullCheck) {
BoxedEvalueList<executorch::aten::Tensor>* null_tensor_list_ptr = nullptr;
ET_EXPECT_DEATH(
{ EValue evalue(null_tensor_list_ptr); },
"BoxedEvalueList<Tensor> pointer cannot be null");
}

TEST_F(EValueTest, OptionalTensorListConstructorNullCheck) {
BoxedEvalueList<std::optional<executorch::aten::Tensor>>*
null_optional_tensor_list_ptr = nullptr;
ET_EXPECT_DEATH(
{ EValue evalue(null_optional_tensor_list_ptr); },
"BoxedEvalueList<optional<Tensor>> pointer cannot be null");
}

TEST_F(EValueTest, BoxedEvalueListConstructorNullChecks) {
int64_t storage[3] = {0, 0, 0};
EValue values[3] = {
EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)};
EValue* values_p[3] = {&values[0], &values[1], &values[2]};

// Test null wrapped_vals
ET_EXPECT_DEATH(
{ BoxedEvalueList<int64_t> list(nullptr, storage, 3); },
"wrapped_vals cannot be null");

// Test null unwrapped_vals
ET_EXPECT_DEATH(
{ BoxedEvalueList<int64_t> list(values_p, nullptr, 3); },
"unwrapped_vals cannot be null");

// Test negative size
ET_EXPECT_DEATH(
{ BoxedEvalueList<int64_t> list(values_p, storage, -1); },
"size cannot be negative");
}

TEST_F(EValueTest, toListOptionalTensorTypeCheck) {
// Create an EValue that's not a ListOptionalTensor
EValue e((int64_t)42);
EXPECT_TRUE(e.isInt());
EXPECT_FALSE(e.isListOptionalTensor());

// Should fail type check
ET_EXPECT_DEATH(
{ e.toListOptionalTensor(); }, "EValue is not a List Optional Tensor");
}

TEST_F(EValueTest, toStringNullPointerCheck) {
// Create an EValue with String tag but null pointer
EValue e;
e.tag = Tag::String;
e.payload.copyable_union.as_string_ptr = nullptr;

// Should pass isString() check but fail null pointer check
EXPECT_TRUE(e.isString());
ET_EXPECT_DEATH({ e.toString(); }, "EValue string pointer is null");
}

TEST_F(EValueTest, toIntListNullPointerCheck) {
// Create an EValue with ListInt tag but null pointer
EValue e;
e.tag = Tag::ListInt;
e.payload.copyable_union.as_int_list_ptr = nullptr;

// Should pass isIntList() check but fail null pointer check
EXPECT_TRUE(e.isIntList());
ET_EXPECT_DEATH({ e.toIntList(); }, "EValue int list pointer is null");
}

TEST_F(EValueTest, toBoolListNullPointerCheck) {
// Create an EValue with ListBool tag but null pointer
EValue e;
e.tag = Tag::ListBool;
e.payload.copyable_union.as_bool_list_ptr = nullptr;

// Should pass isBoolList() check but fail null pointer check
EXPECT_TRUE(e.isBoolList());
ET_EXPECT_DEATH({ e.toBoolList(); }, "EValue bool list pointer is null");
}

TEST_F(EValueTest, toDoubleListNullPointerCheck) {
// Create an EValue with ListDouble tag but null pointer
EValue e;
e.tag = Tag::ListDouble;
e.payload.copyable_union.as_double_list_ptr = nullptr;

// Should pass isDoubleList() check but fail null pointer check
EXPECT_TRUE(e.isDoubleList());
ET_EXPECT_DEATH({ e.toDoubleList(); }, "EValue double list pointer is null");
}

TEST_F(EValueTest, toTensorListNullPointerCheck) {
// Create an EValue with ListTensor tag but null pointer
EValue e;
e.tag = Tag::ListTensor;
e.payload.copyable_union.as_tensor_list_ptr = nullptr;

// Should pass isTensorList() check but fail null pointer check
EXPECT_TRUE(e.isTensorList());
ET_EXPECT_DEATH({ e.toTensorList(); }, "EValue tensor list pointer is null");
}

TEST_F(EValueTest, toListOptionalTensorNullPointerCheck) {
// Create an EValue with ListOptionalTensor tag but null pointer
EValue e;
e.tag = Tag::ListOptionalTensor;
e.payload.copyable_union.as_list_optional_tensor_ptr = nullptr;

// Should pass isListOptionalTensor() check but fail null pointer check
EXPECT_TRUE(e.isListOptionalTensor());
ET_EXPECT_DEATH(
{ e.toListOptionalTensor(); },
"EValue list optional tensor pointer is null");
}
Loading