Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ if(EXECUTORCH_BUILD_PYBIND)
util PUBLIC ${_common_include_directories} ${TORCH_INCLUDE_DIRS}
)
target_compile_options(util PUBLIC ${_pybind_compile_options})
target_link_libraries(util PRIVATE torch c10 executorch)
target_link_libraries(util PRIVATE torch c10 executorch extension_tensor)

# pybind portable_lib
pybind11_add_module(portable_lib SHARED extension/pybindings/pybindings.cpp)
Expand Down
7 changes: 7 additions & 0 deletions extension/aten_util/aten_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,12 @@ at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& etensor) {
return t;
}

TensorPtr alias_tensor_ptr_to_attensor(at::Tensor& t) {
return make_tensor_ptr(
{t.sizes().begin(), t.sizes().end()},
t.mutable_data_ptr(),
torch::executor::ScalarType(t.scalar_type()));
}

} // namespace extension
} // namespace executorch
3 changes: 3 additions & 0 deletions extension/aten_util/aten_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>

#include <ATen/Functions.h> // @manual=//caffe2/aten:ATen-cpu
Expand Down Expand Up @@ -48,6 +49,8 @@ void alias_etensor_to_attensor(at::Tensor& at, torch::executor::Tensor& et);
*/
at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& et);

TensorPtr alias_tensor_ptr_to_attensor(at::Tensor& t);

} // namespace extension
} // namespace executorch

Expand Down
8 changes: 8 additions & 0 deletions extension/aten_util/test/aten_bridge_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using namespace ::testing;
using namespace torch::executor;
using namespace torch::executor::util;
using namespace executorch::extension;

namespace {
at::Tensor generate_at_tensor() {
Expand Down Expand Up @@ -146,3 +147,10 @@ TEST(ATenBridgeTest, AliasATTensorToETensor) {
auto aliased_at_tensor = alias_attensor_to_etensor(etensor);
EXPECT_EQ(aliased_at_tensor.const_data_ptr(), etensor_data.data());
}

TEST(ATenBridgeTest, AliasTensorPtrToATenTensor) {
auto at_tensor = generate_at_tensor();
const auto& et_tensor_ptr = alias_tensor_ptr_to_attensor(at_tensor);
alias_etensor_to_attensor(at_tensor, *et_tensor_ptr);
EXPECT_EQ(at_tensor.const_data_ptr(), et_tensor_ptr->const_data_ptr());
}
Loading