Skip to content

Commit 53936dc

Browse files
swolchokfacebook-github-bot
authored andcommitted
Store the Tensor inline in TensorPtr (#5684)
Summary: Pull Request resolved: #5684 We can preserve the existing interface (except release(), which is problematic anyway!) and avoid an unnecessary heap allocation. ghstack-source-id: 244967796 Reviewed By: shoumikhin Differential Revision: D63468988 fbshipit-source-id: a0905fc8afa0970624f53109b4f92d791bec1d15
1 parent e6237f7 commit 53936dc

File tree

3 files changed

+150
-35
lines changed

3 files changed

+150
-35
lines changed

extension/tensor/tensor_ptr.h

Lines changed: 66 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,45 +10,79 @@
1010

1111
#include <executorch/extension/tensor/tensor_impl_ptr.h>
1212
#include <executorch/runtime/core/error.h>
13+
#include <executorch/runtime/platform/assert.h>
1314

1415
namespace executorch {
1516
namespace extension {
1617

1718
#ifndef USE_ATEN_LIB
18-
namespace internal {
19+
1920
/**
20-
* Custom deleter for TensorPtr that ensures proper management of the associated
21-
* TensorImplPtr.
22-
*
23-
* Since Tensor does not own its TensorImpl, this deleter manages the
24-
* TensorImplPtr lifecycle, ensuring dynamic metadata (sizes, dim_order,
25-
* strides) is released appropriately when the Tensor is destroyed.
21+
* A smart pointer to a Tensor that owns and reference-counts its
22+
* underlying TensorImpl, like torch::Tensor.
2623
*/
27-
struct TensorPtrDeleter final {
28-
TensorImplPtr tensor_impl;
29-
30-
void operator()(exec_aten::Tensor* pointer) {
31-
// Release all resources immediately since the data held by the
32-
// TensorPtrDeleter is tied to the managed object, not the smart pointer
33-
// itself. We need to free this memory when the object is destroyed, not
34-
// when the smart pointer (and deleter) are eventually destroyed or reset.
35-
tensor_impl.reset();
36-
delete pointer;
24+
class TensorPtr {
25+
public:
26+
constexpr TensorPtr() = default;
27+
explicit constexpr TensorPtr(std::nullptr_t) {}
28+
~TensorPtr() = default;
29+
TensorPtr(TensorPtr&& rhs) noexcept = default;
30+
TensorPtr& operator=(TensorPtr&& rhs) noexcept = default;
31+
32+
explicit TensorPtr(TensorImplPtr p)
33+
: tensor_(p.get()), tensor_impl_(std::move(p)) {}
34+
35+
operator bool() const {
36+
return static_cast<bool>(tensor_impl_);
3737
}
38-
};
39-
} // namespace internal
4038

41-
/**
42-
* A smart pointer for managing the lifecycle of a Tensor.
43-
*
44-
* TensorPtr uses a unique pointer to ensure each Tensor object has distinct
45-
* ownership. This abstraction simplifies memory management and serves as a
46-
* safer alternative to the standard Tensor, which does not manage its metadata
47-
* by design. It ensures that the underlying TensorImpl can be safely shared
48-
* among tensors as needed.
49-
*/
50-
using TensorPtr =
51-
std::unique_ptr<exec_aten::Tensor, internal::TensorPtrDeleter>;
39+
exec_aten::Tensor* get() const {
40+
return tensor_impl_ ? &tensor_ : nullptr;
41+
}
42+
43+
exec_aten::Tensor* operator->() const {
44+
return get();
45+
}
46+
47+
exec_aten::Tensor& operator*() const {
48+
ET_DCHECK(*this != nullptr);
49+
return *get();
50+
}
51+
52+
void reset() {
53+
tensor_ = exec_aten::Tensor(nullptr);
54+
tensor_impl_.reset();
55+
}
56+
57+
void swap(TensorPtr& other) noexcept {
58+
std::swap(tensor_, other.tensor_);
59+
std::swap(tensor_impl_, other.tensor_impl_);
60+
}
61+
62+
bool operator==(const TensorPtr& rhs) const {
63+
ET_DCHECK(
64+
(tensor_.unsafeGetTensorImpl() == rhs.tensor_.unsafeGetTensorImpl()) ==
65+
(tensor_impl_ == rhs.tensor_impl_));
66+
return tensor_impl_ == rhs.tensor_impl_;
67+
}
68+
69+
bool operator!=(const TensorPtr& rhs) const {
70+
return !(*this == rhs);
71+
}
72+
73+
bool operator==(std::nullptr_t) const {
74+
return !operator bool();
75+
}
76+
77+
bool operator!=(std::nullptr_t) const {
78+
return !(*this == nullptr);
79+
}
80+
81+
private:
82+
friend TensorPtr make_tensor_ptr(const TensorPtr& tensor);
83+
mutable exec_aten::Tensor tensor_{nullptr};
84+
TensorImplPtr tensor_impl_;
85+
};
5286
#else
5387
/**
5488
* A smart pointer type for managing the lifecycle of a Tensor.
@@ -74,9 +108,7 @@ using TensorPtr = std::unique_ptr<exec_aten::Tensor>;
74108
*/
75109
inline TensorPtr make_tensor_ptr(TensorImplPtr tensor_impl) {
76110
#ifndef USE_ATEN_LIB
77-
auto tensor = std::make_unique<exec_aten::Tensor>(tensor_impl.get());
78-
return TensorPtr(
79-
tensor.release(), internal::TensorPtrDeleter{std::move(tensor_impl)});
111+
return TensorPtr(std::move(tensor_impl));
80112
#else
81113
return std::make_unique<exec_aten::Tensor>(std::move(tensor_impl));
82114
#endif // USE_ATEN_LIB
@@ -96,7 +128,7 @@ inline TensorPtr make_tensor_ptr(TensorImplPtr tensor_impl) {
96128
*/
97129
inline TensorPtr make_tensor_ptr(const TensorPtr& tensor) {
98130
#ifndef USE_ATEN_LIB
99-
return make_tensor_ptr(tensor.get_deleter().tensor_impl);
131+
return make_tensor_ptr(tensor.tensor_impl_);
100132
#else
101133
return make_tensor_ptr(tensor->getIntrusivePtr());
102134
#endif // USE_ATEN_LIB

extension/tensor/test/tensor_ptr_test.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,89 @@ class TensorPtrTest : public ::testing::Test {
2222
}
2323
};
2424

25+
TEST_F(TensorPtrTest, BasicSmartPointerAccess) {
26+
TensorPtr p;
27+
EXPECT_FALSE(p);
28+
EXPECT_EQ(p, nullptr);
29+
EXPECT_EQ(p.get(), nullptr);
30+
EXPECT_EQ(p.operator->(), nullptr);
31+
TensorPtr p2 = make_tensor_ptr({1}, nullptr, {}, {});
32+
EXPECT_TRUE(p2);
33+
EXPECT_NE(p2, nullptr);
34+
ASSERT_NE(p2.get(), nullptr);
35+
ASSERT_NE(p2.operator->(), nullptr);
36+
EXPECT_EQ(p2.get(), p2.operator->());
37+
EXPECT_EQ(p2->dim(), 1);
38+
EXPECT_EQ((*p2).dim(), 1);
39+
EXPECT_NE(p, p2);
40+
p2.reset();
41+
EXPECT_FALSE(p2);
42+
EXPECT_EQ(p2, nullptr);
43+
EXPECT_EQ(p2.get(), nullptr);
44+
EXPECT_EQ(p2.operator->(), nullptr);
45+
EXPECT_EQ(p, p2);
46+
}
47+
48+
TEST_F(TensorPtrTest, Swap) {
49+
TensorPtr p;
50+
TensorPtr p2 = make_tensor_ptr({1}, nullptr, {}, {});
51+
p.swap(p2);
52+
EXPECT_FALSE(p2);
53+
EXPECT_TRUE(p);
54+
EXPECT_EQ(p->dim(), 1);
55+
}
56+
57+
TEST_F(TensorPtrTest, MoveConstruction) {
58+
TensorPtr empty;
59+
TensorPtr emptyMoved(std::move(empty));
60+
EXPECT_FALSE(empty); // NOLINT(bugprone-use-after-move)
61+
EXPECT_FALSE(emptyMoved);
62+
63+
TensorPtr notEmpty = make_tensor_ptr({1}, nullptr, {}, {});
64+
TensorPtr notEmptyMoved(std::move(notEmpty));
65+
EXPECT_FALSE(notEmpty); // NOLINT(bugprone-use-after-move)
66+
EXPECT_TRUE(notEmptyMoved);
67+
EXPECT_EQ(notEmptyMoved->dim(), 1);
68+
}
69+
70+
TEST_F(TensorPtrTest, MoveAssignment) {
71+
{
72+
TensorPtr empty, emptyMoved;
73+
74+
emptyMoved = std::move(empty);
75+
EXPECT_FALSE(empty); // NOLINT(bugprone-use-after-move)
76+
EXPECT_FALSE(emptyMoved);
77+
}
78+
79+
{
80+
TensorPtr empty;
81+
TensorPtr emptyMoved = make_tensor_ptr({1}, nullptr, {}, {});
82+
emptyMoved = std::move(empty);
83+
EXPECT_FALSE(empty); // NOLINT(bugprone-use-after-move)
84+
EXPECT_FALSE(emptyMoved);
85+
}
86+
87+
{
88+
TensorPtr full = make_tensor_ptr({1}, nullptr, {}, {});
89+
TensorPtr fullMoved;
90+
91+
fullMoved = std::move(full);
92+
EXPECT_FALSE(full); // NOLINT(bugprone-use-after-move)
93+
EXPECT_TRUE(fullMoved);
94+
EXPECT_EQ(fullMoved->dim(), 1);
95+
}
96+
97+
{
98+
TensorPtr full = make_tensor_ptr({1}, nullptr, {}, {});
99+
TensorPtr fullMoved = make_tensor_ptr({2, 2}, nullptr, {}, {});
100+
101+
fullMoved = std::move(full);
102+
EXPECT_FALSE(full); // NOLINT(bugprone-use-after-move)
103+
EXPECT_TRUE(fullMoved);
104+
EXPECT_EQ(fullMoved->dim(), 1);
105+
}
106+
}
107+
25108
TEST_F(TensorPtrTest, ScalarTensorCreation) {
26109
float scalar_data = 3.14f;
27110
auto tensor = make_tensor_ptr({}, &scalar_data);

runtime/core/portable_type/tensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class Tensor {
3535
using StridesType = TensorImpl::StridesType;
3636

3737
Tensor() = delete;
38-
explicit Tensor(TensorImpl* impl) : impl_(impl) {}
38+
explicit constexpr Tensor(TensorImpl* impl) : impl_(impl) {}
3939

4040
/**
4141
* Returns a pointer to the underlying TensorImpl.

0 commit comments

Comments
 (0)