1010
1111#include < executorch/extension/tensor/tensor_impl_ptr.h>
1212#include < executorch/runtime/core/error.h>
13+ #include < executorch/runtime/platform/assert.h>
1314
1415namespace executorch {
1516namespace extension {
1617
1718#ifndef USE_ATEN_LIB
18- namespace internal {
19+
1920/* *
20- * Custom deleter for TensorPtr that ensures proper management of the associated
21- * TensorImplPtr.
22- *
23- * Since Tensor does not own its TensorImpl, this deleter manages the
24- * TensorImplPtr lifecycle, ensuring dynamic metadata (sizes, dim_order,
25- * strides) is released appropriately when the Tensor is destroyed.
21+ * A smart pointer to a Tensor that owns and reference-counts its
22+ * underlying TensorImpl, like torch::Tensor.
2623 */
27- struct TensorPtrDeleter final {
28- TensorImplPtr tensor_impl;
29-
30- void operator ()(exec_aten::Tensor* pointer) {
31- // Release all resources immediately since the data held by the
32- // TensorPtrDeleter is tied to the managed object, not the smart pointer
33- // itself. We need to free this memory when the object is destroyed, not
34- // when the smart pointer (and deleter) are eventually destroyed or reset.
35- tensor_impl.reset ();
36- delete pointer;
24+ class TensorPtr {
25+ public:
26+ constexpr TensorPtr () = default;
27+ explicit constexpr TensorPtr (std::nullptr_t ) {}
28+ ~TensorPtr () = default ;
29+ TensorPtr (TensorPtr&& rhs) noexcept = default ;
30+ TensorPtr& operator =(TensorPtr&& rhs) noexcept = default ;
31+
32+ explicit TensorPtr (TensorImplPtr p)
33+ : tensor_(p.get()), tensor_impl_(std::move(p)) {}
34+
35+ operator bool () const {
36+ return static_cast <bool >(tensor_impl_);
3737 }
38- };
39- } // namespace internal
4038
41- /* *
42- * A smart pointer for managing the lifecycle of a Tensor.
43- *
44- * TensorPtr uses a unique pointer to ensure each Tensor object has distinct
45- * ownership. This abstraction simplifies memory management and serves as a
46- * safer alternative to the standard Tensor, which does not manage its metadata
47- * by design. It ensures that the underlying TensorImpl can be safely shared
48- * among tensors as needed.
49- */
50- using TensorPtr =
51- std::unique_ptr<exec_aten::Tensor, internal::TensorPtrDeleter>;
39+ exec_aten::Tensor* get () const {
40+ return tensor_impl_ ? &tensor_ : nullptr ;
41+ }
42+
43+ exec_aten::Tensor* operator ->() const {
44+ return get ();
45+ }
46+
47+ exec_aten::Tensor& operator *() const {
48+ ET_DCHECK (*this != nullptr );
49+ return *get ();
50+ }
51+
52+ void reset () {
53+ tensor_ = exec_aten::Tensor (nullptr );
54+ tensor_impl_.reset ();
55+ }
56+
57+ void swap (TensorPtr& other) noexcept {
58+ std::swap (tensor_, other.tensor_ );
59+ std::swap (tensor_impl_, other.tensor_impl_ );
60+ }
61+
62+ bool operator ==(const TensorPtr& rhs) const {
63+ ET_DCHECK (
64+ (tensor_.unsafeGetTensorImpl () == rhs.tensor_ .unsafeGetTensorImpl ()) ==
65+ (tensor_impl_ == rhs.tensor_impl_ ));
66+ return tensor_impl_ == rhs.tensor_impl_ ;
67+ }
68+
69+ bool operator !=(const TensorPtr& rhs) const {
70+ return !(*this == rhs);
71+ }
72+
73+ bool operator ==(std::nullptr_t ) const {
74+ return !operator bool ();
75+ }
76+
77+ bool operator !=(std::nullptr_t ) const {
78+ return !(*this == nullptr );
79+ }
80+
81+ private:
82+ friend TensorPtr make_tensor_ptr (const TensorPtr& tensor);
83+ mutable exec_aten::Tensor tensor_{nullptr };
84+ TensorImplPtr tensor_impl_;
85+ };
5286#else
5387/* *
5488 * A smart pointer type for managing the lifecycle of a Tensor.
@@ -74,9 +108,7 @@ using TensorPtr = std::unique_ptr<exec_aten::Tensor>;
74108 */
75109inline TensorPtr make_tensor_ptr (TensorImplPtr tensor_impl) {
76110#ifndef USE_ATEN_LIB
77- auto tensor = std::make_unique<exec_aten::Tensor>(tensor_impl.get ());
78- return TensorPtr (
79- tensor.release (), internal::TensorPtrDeleter{std::move (tensor_impl)});
111+ return TensorPtr (std::move (tensor_impl));
80112#else
81113 return std::make_unique<exec_aten::Tensor>(std::move (tensor_impl));
82114#endif // USE_ATEN_LIB
@@ -96,7 +128,7 @@ inline TensorPtr make_tensor_ptr(TensorImplPtr tensor_impl) {
96128 */
97129inline TensorPtr make_tensor_ptr (const TensorPtr& tensor) {
98130#ifndef USE_ATEN_LIB
99- return make_tensor_ptr (tensor.get_deleter (). tensor_impl );
131+ return make_tensor_ptr (tensor.tensor_impl_ );
100132#else
101133 return make_tensor_ptr (tensor->getIntrusivePtr ());
102134#endif // USE_ATEN_LIB
0 commit comments