Skip to content

Commit a0e054b

Browse files
Tensor view keeps original tensor alive. (#15063)
Summary: TensorPtr view created with TensorPtr should keep it alive to match ATen behavior. Differential Revision: D84512176 Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent 491654b commit a0e054b

File tree

3 files changed

+139
-16
lines changed

3 files changed

+139
-16
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ - (instancetype)initWithNativeInstance:(void *)nativeInstance {
129129
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor {
130130
ET_CHECK(otherTensor);
131131
auto tensor = make_tensor_ptr(
132-
**reinterpret_cast<TensorPtr *>(otherTensor.nativeInstance)
132+
*reinterpret_cast<TensorPtr *>(otherTensor.nativeInstance)
133133
);
134134
return [self initWithNativeInstance:&tensor];
135135
}

extension/tensor/tensor_ptr.h

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -327,29 +327,84 @@ inline TensorPtr make_tensor_ptr(
327327
* Creates a TensorPtr to manage a new Tensor with the same properties
328328
* as the given Tensor, sharing the same data without owning it.
329329
*
330-
* @param tensor The Tensor whose properties are used to create a new TensorPtr.
331-
* @return A new TensorPtr managing a Tensor with the same properties as the
332-
* original.
330+
* If an override is provided (non-empty), it is passed as-is. If an override is
331+
* empty, the corresponding metadata is reused from the source tensor when it
332+
* fits; otherwise it is left empty for the core factory to derive a valid
333+
* configuration. If `dim_order` is empty but `strides` is provided, `dim_order`
334+
* is left empty so the core may infer it from the provided strides.
335+
*
336+
* @param tensor The source tensor to alias.
337+
* @param sizes Optional sizes override.
338+
* @param dim_order Optional dimension order override.
339+
* @param strides Optional strides override.
340+
* @param deleter A custom deleter function for managing the lifetime of the
341+
* original Tensor.
342+
* @return A TensorPtr aliasing the same storage with requested metadata.
333343
*/
334-
inline TensorPtr make_tensor_ptr(const executorch::aten::Tensor& tensor) {
344+
inline TensorPtr make_tensor_ptr(
345+
const executorch::aten::Tensor& tensor,
346+
std::vector<executorch::aten::SizesType> sizes = {},
347+
std::vector<executorch::aten::DimOrderType> dim_order = {},
348+
std::vector<executorch::aten::StridesType> strides = {},
349+
std::function<void(void*)> deleter = nullptr) {
350+
if (sizes.empty()) {
351+
sizes.assign(tensor.sizes().begin(), tensor.sizes().end());
352+
}
353+
const auto same_rank = sizes.size() == static_cast<size_t>(tensor.dim());
354+
const auto same_shape = same_rank &&
355+
std::equal(sizes.begin(), sizes.end(), tensor.sizes().begin());
356+
const auto element_count =
357+
executorch::aten::compute_numel(sizes.data(), sizes.size());
358+
const auto parent_element_count = tensor.numel();
359+
ET_CHECK_MSG(
360+
element_count <= parent_element_count,
361+
"Requested view has %zd elements, but source tensor only has %zd.",
362+
static_cast<ssize_t>(element_count),
363+
static_cast<ssize_t>(parent_element_count));
364+
#ifndef USE_ATEN_LIB
365+
if (dim_order.empty() && strides.empty() && same_rank) {
366+
dim_order.assign(tensor.dim_order().begin(), tensor.dim_order().end());
367+
}
368+
#endif // USE_ATEN_LIB
369+
if (strides.empty() && dim_order.empty() && same_shape) {
370+
strides.assign(tensor.strides().begin(), tensor.strides().end());
371+
}
335372
return make_tensor_ptr(
336373
std::vector<executorch::aten::SizesType>(
337374
tensor.sizes().begin(), tensor.sizes().end()),
338375
tensor.mutable_data_ptr(),
339-
#ifndef USE_ATEN_LIB
340-
std::vector<executorch::aten::DimOrderType>(
341-
tensor.dim_order().begin(), tensor.dim_order().end()),
342-
std::vector<executorch::aten::StridesType>(
343-
tensor.strides().begin(), tensor.strides().end()),
376+
std::move(dim_order),
377+
std::move(strides),
344378
tensor.scalar_type(),
345-
tensor.shape_dynamism()
379+
#ifndef USE_ATEN_LIB
380+
tensor.shape_dynamism(),
346381
#else // USE_ATEN_LIB
347-
{},
348-
std::vector<executorch::aten::StridesType>(
349-
tensor.strides().begin(), tensor.strides().end()),
350-
tensor.scalar_type()
382+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
351383
#endif // USE_ATEN_LIB
352-
);
384+
std::move(deleter));
385+
}
386+
387+
/**
388+
* Convenience overload identical to make_tensor_ptr(*tensor_ptr, ...).
389+
* Keeps the original TensorPtr alive until the returned TensorPtr is destroyed.
390+
*
391+
* @param tensor_ptr The source tensor pointer to alias.
392+
* @param sizes Optional sizes override.
393+
* @param dim_order Optional dimension order override.
394+
* @param strides Optional strides override.
395+
* @return A TensorPtr aliasing the same storage with requested metadata.
396+
*/
397+
inline TensorPtr make_tensor_ptr(
398+
const TensorPtr& tensor_ptr,
399+
std::vector<executorch::aten::SizesType> sizes = {},
400+
std::vector<executorch::aten::DimOrderType> dim_order = {},
401+
std::vector<executorch::aten::StridesType> strides = {}) {
402+
return make_tensor_ptr(
403+
*tensor_ptr,
404+
std::move(sizes),
405+
std::move(dim_order),
406+
std::move(strides),
407+
[tensor_ptr](void*) {});
353408
}
354409

355410
/**

extension/tensor/test/tensor_ptr_test.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,74 @@ TEST_F(TensorPtrTest, TensorUint8BufferTooLargeExpectDeath) {
790790
ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 2}, std::move(data)); }, "");
791791
}
792792

793+
TEST_F(TensorPtrTest, MakeViewFromTensorPtrKeepsSourceAlive) {
794+
bool freed = false;
795+
auto* data = new float[6]{1, 2, 3, 4, 5, 6};
796+
auto tensor = make_tensor_ptr(
797+
{2, 3},
798+
data,
799+
{},
800+
{},
801+
executorch::aten::ScalarType::Float,
802+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
803+
[&freed](void* p) {
804+
freed = true;
805+
delete[] static_cast<float*>(p);
806+
});
807+
auto view = make_tensor_ptr(tensor);
808+
tensor.reset();
809+
EXPECT_FALSE(freed);
810+
EXPECT_EQ(view->const_data_ptr<float>()[0], 1.0f);
811+
view->mutable_data_ptr<float>()[0] = 42.0f;
812+
EXPECT_EQ(view->const_data_ptr<float>()[0], 42.0f);
813+
view.reset();
814+
EXPECT_TRUE(freed);
815+
}
816+
817+
TEST_F(TensorPtrTest, MakeViewFromTensorDoesNotKeepAliveByDefault) {
818+
bool freed = false;
819+
auto* data = new float[2]{7.0f, 8.0f};
820+
auto tensor = make_tensor_ptr(
821+
{2, 1},
822+
data,
823+
{},
824+
{},
825+
executorch::aten::ScalarType::Float,
826+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
827+
[&freed](void* p) {
828+
freed = true;
829+
delete[] static_cast<float*>(p);
830+
});
831+
auto view = make_tensor_ptr(*tensor);
832+
auto raw = view->const_data_ptr<float>();
833+
EXPECT_EQ(raw, data);
834+
tensor.reset();
835+
EXPECT_TRUE(freed);
836+
view.reset();
837+
}
838+
839+
TEST_F(TensorPtrTest, MakeViewFromTensorWithDeleterKeepsAlive) {
840+
bool freed = false;
841+
auto* data = new float[3]{1.0f, 2.0f, 3.0f};
842+
auto tensor = make_tensor_ptr(
843+
{3},
844+
data,
845+
{},
846+
{},
847+
executorch::aten::ScalarType::Float,
848+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
849+
[&freed](void* p) {
850+
freed = true;
851+
delete[] static_cast<float*>(p);
852+
});
853+
auto view = make_tensor_ptr(*tensor, {}, {}, {}, [tensor](void*) {});
854+
tensor.reset();
855+
EXPECT_FALSE(freed);
856+
EXPECT_EQ(view->const_data_ptr<float>()[2], 3.0f);
857+
view.reset();
858+
EXPECT_TRUE(freed);
859+
}
860+
793861
TEST_F(TensorPtrTest, VectorFloatTooSmallExpectDeath) {
794862
std::vector<float> data(9, 1.f);
795863
ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 5}, std::move(data)); }, "");

0 commit comments

Comments
 (0)