Skip to content

Commit 87512b8

Browse files
authored
Make TensorPtr constructor check the data dize matches the shape. (#13591)
Summary: . Differential Revision: D80764139
1 parent dc1206d commit 87512b8

File tree

3 files changed

+28
-10
lines changed

3 files changed

+28
-10
lines changed

extension/tensor/tensor_ptr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,10 @@ TensorPtr make_tensor_ptr(
148148
executorch::aten::ScalarType type,
149149
executorch::aten::TensorShapeDynamism dynamism) {
150150
ET_CHECK_MSG(
151-
data.size() >=
151+
data.size() ==
152152
executorch::aten::compute_numel(sizes.data(), sizes.size()) *
153153
executorch::aten::elementSize(type),
154-
"Data size is smaller than required by sizes and scalar type.");
154+
"Data size does not match tensor size.");
155155
auto data_ptr = data.data();
156156
return make_tensor_ptr(
157157
std::move(sizes),

extension/tensor/tensor_ptr.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ inline TensorPtr make_tensor_ptr(
106106
executorch::aten::ScalarType type = deduced_type,
107107
executorch::aten::TensorShapeDynamism dynamism =
108108
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
109+
ET_CHECK_MSG(
110+
data.size() ==
111+
executorch::aten::compute_numel(sizes.data(), sizes.size()),
112+
"Data size does not match tensor size.");
109113
if (type != deduced_type) {
110114
ET_CHECK_MSG(
111115
runtime::canCast(deduced_type, type),

extension/tensor/test/tensor_ptr_test.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -784,16 +784,30 @@ TEST_F(TensorPtrTest, TensorUint8BufferTooSmallExpectDeath) {
784784
{ auto tensor = make_tensor_ptr({2, 2}, std::move(data)); }, "");
785785
}
786786

787-
TEST_F(TensorPtrTest, TensorUint8BufferTooLarge) {
787+
TEST_F(TensorPtrTest, TensorUint8BufferTooLargeExpectDeath) {
788788
std::vector<uint8_t> data(
789-
4 * executorch::aten::elementSize(executorch::aten::ScalarType::Float));
790-
auto tensor = make_tensor_ptr({2, 2}, std::move(data));
789+
5 * executorch::aten::elementSize(executorch::aten::ScalarType::Float));
790+
ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 2}, std::move(data)); }, "");
791+
}
791792

792-
EXPECT_EQ(tensor->dim(), 2);
793-
EXPECT_EQ(tensor->size(0), 2);
794-
EXPECT_EQ(tensor->size(1), 2);
795-
EXPECT_EQ(tensor->strides()[0], 2);
796-
EXPECT_EQ(tensor->strides()[1], 1);
793+
TEST_F(TensorPtrTest, VectorFloatTooSmallExpectDeath) {
794+
std::vector<float> data(9, 1.f);
795+
ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 5}, std::move(data)); }, "");
796+
}
797+
798+
TEST_F(TensorPtrTest, VectorFloatTooLargeExpectDeath) {
799+
std::vector<float> data(11, 1.f);
800+
ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 5}, std::move(data)); }, "");
801+
}
802+
803+
TEST_F(TensorPtrTest, VectorIntToFloatCastTooSmallExpectDeath) {
804+
std::vector<int32_t> data(9, 1);
805+
ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 5}, std::move(data)); }, "");
806+
}
807+
808+
TEST_F(TensorPtrTest, VectorIntToFloatCastTooLargeExpectDeath) {
809+
std::vector<int32_t> data(11, 1);
810+
ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 5}, std::move(data)); }, "");
797811
}
798812

799813
TEST_F(TensorPtrTest, StridesAndDimOrderMustMatchSizes) {

0 commit comments

Comments
 (0)