diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index 3a2b640b7d7..98fec979afd 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -129,7 +129,7 @@ - (instancetype)initWithNativeInstance:(void *)nativeInstance { - (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor { ET_CHECK(otherTensor); auto tensor = make_tensor_ptr( - **reinterpret_cast(otherTensor.nativeInstance) + *reinterpret_cast(otherTensor.nativeInstance) ); return [self initWithNativeInstance:&tensor]; } diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index 900252109d3..0649ae971f7 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -338,13 +338,16 @@ inline TensorPtr make_tensor_ptr( * @param sizes Optional sizes override. * @param dim_order Optional dimension order override. * @param strides Optional strides override. + * @param deleter A custom deleter function for managing the lifetime of the + * original Tensor. * @return A TensorPtr aliasing the same storage with requested metadata. */ inline TensorPtr make_tensor_ptr( const executorch::aten::Tensor& tensor, std::vector sizes = {}, std::vector dim_order = {}, - std::vector strides = {}) { + std::vector strides = {}, + std::function deleter = nullptr) { if (sizes.empty()) { sizes.assign(tensor.sizes().begin(), tensor.sizes().end()); } @@ -372,16 +375,18 @@ inline TensorPtr make_tensor_ptr( tensor.mutable_data_ptr(), std::move(dim_order), std::move(strides), - tensor.scalar_type() + tensor.scalar_type(), #ifndef USE_ATEN_LIB - , - tensor.shape_dynamism() + tensor.shape_dynamism(), +#else // USE_ATEN_LIB + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND, #endif // USE_ATEN_LIB - ); + std::move(deleter)); } /** * Convenience overload identical to make_tensor_ptr(*tensor_ptr, ...). + * Keeps the original TensorPtr alive until the returned TensorPtr is destroyed. * * @param tensor_ptr The source tensor pointer to alias. * @param sizes Optional sizes override. @@ -395,7 +400,11 @@ inline TensorPtr make_tensor_ptr( std::vector dim_order = {}, std::vector strides = {}) { return make_tensor_ptr( - *tensor_ptr, std::move(sizes), std::move(dim_order), std::move(strides)); + *tensor_ptr, + std::move(sizes), + std::move(dim_order), + std::move(strides), + [tensor_ptr](void*) {}); } /** diff --git a/extension/tensor/test/tensor_ptr_test.cpp b/extension/tensor/test/tensor_ptr_test.cpp index 9156a0c4b10..a693354269f 100644 --- a/extension/tensor/test/tensor_ptr_test.cpp +++ b/extension/tensor/test/tensor_ptr_test.cpp @@ -1038,6 +1038,74 @@ TEST_F(TensorPtrTest, TensorUint8dataTooLargeExpectDeath) { ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 2}, std::move(data)); }, ""); } +TEST_F(TensorPtrTest, MakeViewFromTensorPtrKeepsSourceAlive) { + bool freed = false; + auto* data = new float[6]{1, 2, 3, 4, 5, 6}; + auto tensor = make_tensor_ptr( + {2, 3}, + data, + {}, + {}, + executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND, + [&freed](void* p) { + freed = true; + delete[] static_cast(p); + }); + auto view = make_tensor_ptr(tensor); + tensor.reset(); + EXPECT_FALSE(freed); + EXPECT_EQ(view->const_data_ptr()[0], 1.0f); + view->mutable_data_ptr()[0] = 42.0f; + EXPECT_EQ(view->const_data_ptr()[0], 42.0f); + view.reset(); + EXPECT_TRUE(freed); +} + +TEST_F(TensorPtrTest, MakeViewFromTensorDoesNotKeepAliveByDefault) { + bool freed = false; + auto* data = new float[2]{7.0f, 8.0f}; + auto tensor = make_tensor_ptr( + {2, 1}, + data, + {}, + {}, + executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND, + [&freed](void* p) { + freed = true; + delete[] static_cast(p); + }); + auto view = make_tensor_ptr(*tensor); + auto raw = view->const_data_ptr(); + EXPECT_EQ(raw, data); + tensor.reset(); + EXPECT_TRUE(freed); + view.reset(); +} + +TEST_F(TensorPtrTest, MakeViewFromTensorWithDeleterKeepsAlive) { + bool freed = false; + auto* data = new float[3]{1.0f, 2.0f, 3.0f}; + auto tensor = make_tensor_ptr( + {3}, + data, + {}, + {}, + executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND, + [&freed](void* p) { + freed = true; + delete[] static_cast(p); + }); + auto view = make_tensor_ptr(*tensor, {}, {}, {}, [tensor](void*) {}); + tensor.reset(); + EXPECT_FALSE(freed); + EXPECT_EQ(view->const_data_ptr()[2], 3.0f); + view.reset(); + EXPECT_TRUE(freed); +} + TEST_F(TensorPtrTest, VectorFloatTooSmallExpectDeath) { std::vector data(9, 1.f); ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 5}, std::move(data)); }, "");