Skip to content

Commit 6923ae5

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Use TensorPtr in aten_bridge (#5789)
Summary: Pull Request resolved: #5789 Want it to be very easy to create an ET Tensor from an at::Tensor if Im cool with the added deps Reviewed By: shoumikhin Differential Revision: D63705432 fbshipit-source-id: 4962be71e34392cb0e592512a6af337dc98fece7
1 parent aced6d7 commit 6923ae5

File tree

4 files changed

+19
-1
lines changed

4 files changed

+19
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ if(EXECUTORCH_BUILD_PYBIND)
727727
util PUBLIC ${_common_include_directories} ${TORCH_INCLUDE_DIRS}
728728
)
729729
target_compile_options(util PUBLIC ${_pybind_compile_options})
730-
target_link_libraries(util PRIVATE torch c10 executorch)
730+
target_link_libraries(util PRIVATE torch c10 executorch extension_tensor)
731731

732732
# pybind portable_lib
733733
pybind11_add_module(portable_lib SHARED extension/pybindings/pybindings.cpp)

extension/aten_util/aten_bridge.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,5 +170,12 @@ at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& etensor) {
170170
return t;
171171
}
172172

173+
TensorPtr alias_tensor_ptr_to_attensor(at::Tensor& t) {
174+
return make_tensor_ptr(
175+
{t.sizes().begin(), t.sizes().end()},
176+
t.mutable_data_ptr(),
177+
torch::executor::ScalarType(t.scalar_type()));
178+
}
179+
173180
} // namespace extension
174181
} // namespace executorch

extension/aten_util/aten_bridge.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include <executorch/extension/tensor/tensor.h>
1112
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1213

1314
#include <ATen/Functions.h> // @manual=//caffe2/aten:ATen-cpu
@@ -48,6 +49,8 @@ void alias_etensor_to_attensor(at::Tensor& at, torch::executor::Tensor& et);
4849
*/
4950
at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& et);
5051

52+
TensorPtr alias_tensor_ptr_to_attensor(at::Tensor& t);
53+
5154
} // namespace extension
5255
} // namespace executorch
5356

extension/aten_util/test/aten_bridge_test.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
using namespace ::testing;
1919
using namespace torch::executor;
2020
using namespace torch::executor::util;
21+
using namespace executorch::extension;
2122

2223
namespace {
2324
at::Tensor generate_at_tensor() {
@@ -146,3 +147,10 @@ TEST(ATenBridgeTest, AliasATTensorToETensor) {
146147
auto aliased_at_tensor = alias_attensor_to_etensor(etensor);
147148
EXPECT_EQ(aliased_at_tensor.const_data_ptr(), etensor_data.data());
148149
}
150+
151+
TEST(ATenBridgeTest, AliasTensorPtrToATenTensor) {
152+
auto at_tensor = generate_at_tensor();
153+
const auto& et_tensor_ptr = alias_tensor_ptr_to_attensor(at_tensor);
154+
alias_etensor_to_attensor(at_tensor, *et_tensor_ptr);
155+
EXPECT_EQ(at_tensor.const_data_ptr(), et_tensor_ptr->const_data_ptr());
156+
}

0 commit comments

Comments
 (0)