Skip to content

Commit 23a1422

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Use TensorPtr in aten_bridge (#5789)
Summary: Pull Request resolved: #5789 Want it to be very easy to create an ET Tensor from an at::Tensor if Im cool with the added deps Differential Revision: D63705432
1 parent 7183f19 commit 23a1422

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

extension/aten_util/aten_bridge.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,5 +170,12 @@ at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& etensor) {
170170
return t;
171171
}
172172

173+
TensorPtr alias_tensor_ptr_to_attensor(at::Tensor& t) {
174+
return make_tensor_ptr(
175+
{t.sizes().begin(), t.sizes().end()},
176+
t.mutable_data_ptr(),
177+
torch::executor::ScalarType(t.scalar_type()));
178+
}
179+
173180
} // namespace extension
174181
} // namespace executorch

extension/aten_util/aten_bridge.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include <executorch/extension/tensor/tensor.h>
1112
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1213

1314
#include <ATen/Functions.h> // @manual=//caffe2/aten:ATen-cpu
@@ -48,6 +49,8 @@ void alias_etensor_to_attensor(at::Tensor& at, torch::executor::Tensor& et);
4849
*/
4950
at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& et);
5051

52+
TensorPtr alias_tensor_ptr_to_attensor(at::Tensor& t);
53+
5154
} // namespace extension
5255
} // namespace executorch
5356

extension/aten_util/test/aten_bridge_test.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
using namespace ::testing;
1919
using namespace torch::executor;
2020
using namespace torch::executor::util;
21+
using namespace executorch::extension;
2122

2223
namespace {
2324
at::Tensor generate_at_tensor() {
@@ -146,3 +147,10 @@ TEST(ATenBridgeTest, AliasATTensorToETensor) {
146147
auto aliased_at_tensor = alias_attensor_to_etensor(etensor);
147148
EXPECT_EQ(aliased_at_tensor.const_data_ptr(), etensor_data.data());
148149
}
150+
151+
TEST(ATenBridgeTest, AliasTensorPtrToATenTensor) {
152+
auto at_tensor = generate_at_tensor();
153+
const auto& et_tensor_ptr = alias_tensor_ptr_to_attensor(at_tensor);
154+
alias_etensor_to_attensor(at_tensor, *et_tensor_ptr);
155+
EXPECT_EQ(at_tensor.const_data_ptr(), et_tensor_ptr->const_data_ptr());
156+
}

0 commit comments

Comments
 (0)