Skip to content

Commit f65c28e

Browse files
authored
Handle null data edge case in data_is_close testing util.
Differential Revision: D61783890 Pull Request resolved: #4901
1 parent 3fb03dc commit f65c28e

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

runtime/core/exec_aten/testing_util/tensor_util.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ bool data_is_close(
4141
size_t numel,
4242
double rtol,
4343
double atol) {
44+
ET_CHECK_MSG(
45+
numel == 0 || (a != nullptr && b != nullptr),
46+
"Pointers must not be null when numel > 0: numel %zu, a 0x%p, b 0x%p",
47+
numel,
48+
a,
49+
b);
50+
if (a == b) {
51+
return true;
52+
}
4453
for (size_t i = 0; i < numel; i++) {
4554
const auto ai = a[i];
4655
const auto bi = b[i];

runtime/core/exec_aten/testing_util/test/tensor_util_test.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
using namespace ::testing;
2424
using exec_aten::ScalarType;
2525
using exec_aten::Tensor;
26+
using exec_aten::TensorImpl;
2627
using exec_aten::TensorList;
2728
using executorch::runtime::testing::IsCloseTo;
2829
using executorch::runtime::testing::IsDataCloseTo;
@@ -826,4 +827,22 @@ TEST(TensorUtilTest, TensorStreamBool) {
826827
"ETensor(sizes={2, 2}, dtype=Bool, data={1, 0, 1, 0})");
827828
}
828829

830+
TEST(TensorTest, TestZeroShapeTensorEquality) {
831+
TensorImpl::SizesType sizes[2] = {2, 2};
832+
TensorImpl::StridesType strides[2] = {2, 1};
833+
TensorImpl::DimOrderType dim_order[2] = {0, 1};
834+
835+
TensorImpl t1(ScalarType::Float, 2, sizes, nullptr, dim_order, strides);
836+
TensorImpl t2(ScalarType::Float, 2, sizes, nullptr, dim_order, strides);
837+
838+
ET_EXPECT_DEATH({ EXPECT_TENSOR_EQ(Tensor(&t1), Tensor(&t2)); }, "");
839+
840+
float data[] = {1.0, 2.0, 3.0, 4.0};
841+
842+
t1.set_data(data);
843+
t2.set_data(data);
844+
845+
EXPECT_TENSOR_EQ(Tensor(&t1), Tensor(&t2));
846+
}
847+
829848
#endif // !USE_ATEN_LIB

0 commit comments

Comments
 (0)