diff --git a/CMakeLists.txt b/CMakeLists.txt index 34fed923529..020cd2cb2f0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -727,7 +727,7 @@ if(EXECUTORCH_BUILD_PYBIND) util PUBLIC ${_common_include_directories} ${TORCH_INCLUDE_DIRS} ) target_compile_options(util PUBLIC ${_pybind_compile_options}) - target_link_libraries(util PRIVATE torch c10 executorch) + target_link_libraries(util PRIVATE torch c10 executorch extension_tensor) # pybind portable_lib pybind11_add_module(portable_lib SHARED extension/pybindings/pybindings.cpp) diff --git a/extension/aten_util/aten_bridge.cpp b/extension/aten_util/aten_bridge.cpp index fc167dd71e8..1305ae75ce0 100644 --- a/extension/aten_util/aten_bridge.cpp +++ b/extension/aten_util/aten_bridge.cpp @@ -170,5 +170,12 @@ at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& etensor) { return t; } +TensorPtr alias_tensor_ptr_to_attensor(at::Tensor& t) { + return make_tensor_ptr( + {t.sizes().begin(), t.sizes().end()}, + t.mutable_data_ptr(), + torch::executor::ScalarType(t.scalar_type())); +} + } // namespace extension } // namespace executorch diff --git a/extension/aten_util/aten_bridge.h b/extension/aten_util/aten_bridge.h index 0d6b697463c..62b07eee51d 100644 --- a/extension/aten_util/aten_bridge.h +++ b/extension/aten_util/aten_bridge.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include // @manual=//caffe2/aten:ATen-cpu @@ -48,6 +49,8 @@ void alias_etensor_to_attensor(at::Tensor& at, torch::executor::Tensor& et); */ at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& et); +TensorPtr alias_tensor_ptr_to_attensor(at::Tensor& t); + } // namespace extension } // namespace executorch diff --git a/extension/aten_util/test/aten_bridge_test.cpp b/extension/aten_util/test/aten_bridge_test.cpp index cf6d2b85978..ba331162fca 100644 --- a/extension/aten_util/test/aten_bridge_test.cpp +++ b/extension/aten_util/test/aten_bridge_test.cpp @@ -18,6 +18,7 @@ using namespace ::testing; using namespace torch::executor; using namespace torch::executor::util; +using namespace executorch::extension; namespace { at::Tensor generate_at_tensor() { @@ -146,3 +147,10 @@ TEST(ATenBridgeTest, AliasATTensorToETensor) { auto aliased_at_tensor = alias_attensor_to_etensor(etensor); EXPECT_EQ(aliased_at_tensor.const_data_ptr(), etensor_data.data()); } + +TEST(ATenBridgeTest, AliasTensorPtrToATenTensor) { + auto at_tensor = generate_at_tensor(); + const auto& et_tensor_ptr = alias_tensor_ptr_to_attensor(at_tensor); + alias_etensor_to_attensor(at_tensor, *et_tensor_ptr); + EXPECT_EQ(at_tensor.const_data_ptr(), et_tensor_ptr->const_data_ptr()); +}