Skip to content

Commit dd43946

Browse files
authored
update the validity rule for regenerated stride from dim order
Differential Revision: D77759383 Pull Request resolved: #12280
1 parent f2f2a9d commit dd43946

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

extension/tensor/tensor_ptr.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,27 @@ TensorPtr make_tensor_ptr(
8080
}
8181
}
8282
std::vector<executorch::aten::StridesType> computed_strides(dim);
83+
8384
auto error = runtime::dim_order_to_stride(
8485
sizes.data(), dim_order.data(), dim, computed_strides.data());
8586
ET_CHECK_MSG(error == runtime::Error::Ok, "Failed to compute strides.");
8687

8788
if (!strides.empty()) {
88-
ET_CHECK_MSG(computed_strides == strides, "Invalid strides provided.");
89-
} else {
90-
strides = std::move(computed_strides);
89+
for (size_t i = 0; i < dim; i++) {
90+
ET_CHECK_MSG(
91+
strides[i] == computed_strides[i] || sizes[i] == 1,
92+
"invalid strides for dim %zu: %" ET_PRI_SIZES_AND_STRIDES
93+
"!= %" ET_PRI_SIZES_AND_STRIDES
94+
" while its size is %" ET_PRI_SIZES_AND_STRIDES " != 1",
95+
i,
96+
strides[i],
97+
computed_strides[i],
98+
sizes[i]);
99+
}
91100
}
101+
102+
strides = std::move(computed_strides);
103+
92104
#ifndef USE_ATEN_LIB
93105
executorch::aten::TensorImpl tensor_impl(
94106
type,

extension/tensor/test/tensor_ptr_maker_test.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <gtest/gtest.h>
1212

1313
#include <executorch/runtime/platform/runtime.h>
14+
#include <executorch/test/utils/DeathTest.h>
1415

1516
using namespace ::executorch::extension;
1617
using namespace ::executorch::runtime;
@@ -113,6 +114,31 @@ TEST_F(TensorPtrMakerTest, CreateTensorUsingFromBlobWithStrides) {
113114
EXPECT_EQ(tensor->const_data_ptr<float>()[0], 3);
114115
}
115116

117+
TEST_F(TensorPtrMakerTest, CreateTensorUsingFromBlobWithLegalStrides) {
118+
float data[20] = {3};
119+
auto tensor = from_blob(data, {1, 2, 2}, {10, 2, 1});
120+
121+
EXPECT_EQ(tensor->dim(), 3);
122+
EXPECT_EQ(tensor->size(0), 1);
123+
EXPECT_EQ(tensor->size(1), 2);
124+
EXPECT_EQ(tensor->size(2), 2);
125+
126+
// recalculated stride[0]t o 2 to meet ET's requirement while maintain the
127+
// same behavior as original tensor since size[0] == 1
128+
EXPECT_EQ(tensor->strides()[0], 4);
129+
EXPECT_EQ(tensor->strides()[1], 2);
130+
EXPECT_EQ(tensor->strides()[2], 1);
131+
EXPECT_EQ(tensor->const_data_ptr<float>(), data);
132+
EXPECT_EQ(tensor->const_data_ptr<float>()[0], 3);
133+
}
134+
135+
TEST_F(TensorPtrMakerTest, FailedCreateTensorUsingFromBlobWithIllegalStrides) {
136+
float data[20] = {3};
137+
ET_EXPECT_DEATH(
138+
from_blob(data, {2, 2, 2}, {10, 2, 1}),
139+
"invalid strides for dim 0: 10!= 4 while its size is 2 != 1");
140+
}
141+
116142
TEST_F(TensorPtrMakerTest, TensorMakerConversionOperator) {
117143
float data[20] = {2};
118144
TensorPtr tensor =

0 commit comments

Comments
 (0)