Skip to content

Commit 48e4d29

Browse files
authored
Add strides to ManagedTensor
Differential Revision: D61509079 Pull Request resolved: #4786
1 parent 75e6413 commit 48e4d29

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

extension/runner_util/managed_tensor.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,23 @@ class ManagedTensor {
4848
#ifdef USE_ATEN_LIB
4949
tensor_ = torch::from_blob(data, sizes, dtype);
5050
#else
51+
// Calculate strides.
52+
strides_ = std::vector<StridesType>(sizes_.size());
53+
if (sizes_.size() > 0) {
54+
strides_.back() = 1;
55+
for (size_t i = strides_.size() - 1; i > 0; --i) {
56+
strides_[i - 1] = strides_[i] * sizes_[i];
57+
}
58+
}
59+
60+
// Allocate TensorImpl.
5161
tensor_impl_ = std::make_unique<TensorImpl>(
5262
dtype,
5363
sizes_.size(),
5464
sizes_.data(),
5565
data,
56-
nullptr,
57-
nullptr,
66+
/*dim_order=*/nullptr,
67+
strides_.data(),
5868
TensorShapeDynamism::DYNAMIC_BOUND);
5969
#endif
6070
}
@@ -80,6 +90,7 @@ class ManagedTensor {
8090
private:
8191
std::unique_ptr<TensorImpl> tensor_impl_;
8292
std::vector<SizesType> sizes_;
93+
std::vector<StridesType> strides_;
8394
#ifdef USE_ATEN_LIB
8495
Tensor tensor_;
8596
#endif

extension/runner_util/test/managed_tensor_test.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,17 @@ class ManagedTensorTest : public ::testing::Test {
2525
void SetUp() override {
2626
torch::executor::runtime_init();
2727

28-
data_ = {1, 2, 3, 4, 5, 6};
29-
sizes_ = {2, 3};
28+
data_ = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
29+
sizes_ = {2, 3, 4};
30+
expected_strides_ = {12, 4, 1};
3031
managed_tensor_ =
3132
std::make_unique<ManagedTensor>(data_.data(), sizes_, ScalarType::Long);
3233
}
3334

3435
protected:
3536
std::vector<int64_t> data_;
3637
std::vector<SizesType> sizes_;
38+
std::vector<int> expected_strides_;
3739
std::unique_ptr<ManagedTensor> managed_tensor_;
3840
};
3941

@@ -43,24 +45,27 @@ TEST_F(ManagedTensorTest, Smoke) {
4345
EXPECT_EQ(tensor.sizes(), ArrayRef<SizesType>(sizes_.data(), sizes_.size()));
4446
EXPECT_EQ(tensor.scalar_type(), ScalarType::Long);
4547
EXPECT_EQ(tensor.const_data_ptr(), data_.data());
48+
for (size_t i = 0; i < expected_strides_.size(); ++i) {
49+
EXPECT_EQ(tensor.strides()[i], expected_strides_[i]);
50+
}
4651
}
4752

4853
TEST_F(ManagedTensorTest, ResizeWithUpdatedRank) {
4954
// gtest death test doesn't work on iOS:
5055
// https://github.com/google/googletest/issues/2834
5156
#if !GTEST_OS_IOS
5257
EXPECT_EXIT(
53-
managed_tensor_->resize(std::vector<SizesType>{2, 3, 4}),
58+
managed_tensor_->resize(std::vector<SizesType>{2, 3, 4, 5}),
5459
::testing::KilledBySignal(SIGABRT),
5560
"");
5661
#endif
5762
}
5863

5964
TEST_F(ManagedTensorTest, ResizeShrink) {
60-
managed_tensor_->resize(std::vector<SizesType>{2, 2});
65+
managed_tensor_->resize(std::vector<SizesType>{2, 2, 2});
6166
const auto tensor = managed_tensor_->get_aliasing_tensor();
6267

63-
std::vector<SizesType> expected_sizes = {2, 2};
68+
std::vector<SizesType> expected_sizes = {2, 2, 2};
6469
EXPECT_EQ(
6570
tensor.sizes(),
6671
ArrayRef<SizesType>(expected_sizes.data(), expected_sizes.size()));
@@ -69,10 +74,10 @@ TEST_F(ManagedTensorTest, ResizeShrink) {
6974
}
7075

7176
TEST_F(ManagedTensorTest, Resize) {
72-
managed_tensor_->resize(std::vector<SizesType>{3, 2});
77+
managed_tensor_->resize(std::vector<SizesType>{4, 3, 2});
7378
const auto tensor = managed_tensor_->get_aliasing_tensor();
7479

75-
std::vector<SizesType> expected_sizes = {3, 2};
80+
std::vector<SizesType> expected_sizes = {4, 3, 2};
7681
EXPECT_EQ(
7782
tensor.sizes(),
7883
ArrayRef<SizesType>(expected_sizes.data(), expected_sizes.size()));

0 commit comments

Comments
 (0)