Skip to content

Commit 99e1ae1

Browse files
authored
Skip storing unnecessary metadata in ManagedTensor.
Differential Revision: D60854858 Pull Request resolved: #4572
1 parent 18b829c commit 99e1ae1

File tree

5 files changed

+61
-97
lines changed

5 files changed

+61
-97
lines changed

examples/llm_manual/managed_tensor.h

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,21 @@ class ManagedTensor {
3030
using DimOrderType = exec_aten::DimOrderType;
3131
/// The type used for elements of `strides()`.
3232
using StridesType = exec_aten::StridesType;
33+
3334
ManagedTensor() = delete;
3435

3536
explicit ManagedTensor(
3637
void* data,
3738
const std::vector<SizesType>& sizes,
3839
ScalarType dtype)
39-
: dtype_(dtype), sizes_(sizes), data_ptr_(data) {
40-
ssize_t dim = sizes.size();
41-
dim_order_.resize(dim);
42-
strides_.resize(dim);
43-
for (size_t i = 0; i < dim; ++i) {
44-
dim_order_[i] = i;
45-
}
46-
dim_order_to_stride_nocheck(
47-
sizes.data(), dim_order_.data(), dim, strides_.data());
40+
: sizes_(sizes) {
4841
tensor_impl_ = std::make_unique<TensorImpl>(
49-
dtype_,
50-
dim,
42+
dtype,
43+
sizes_.size(),
5144
sizes_.data(),
52-
data_ptr_,
53-
dim_order_.data(),
54-
strides_.data(),
45+
data,
46+
nullptr,
47+
nullptr,
5548
TensorShapeDynamism::DYNAMIC_BOUND);
5649
}
5750

@@ -63,12 +56,9 @@ class ManagedTensor {
6356
}
6457

6558
private:
66-
void* data_ptr_ = nullptr;
6759
std::unique_ptr<TensorImpl> tensor_impl_;
6860
std::vector<SizesType> sizes_;
69-
std::vector<StridesType> strides_;
70-
std::vector<DimOrderType> dim_order_;
71-
ScalarType dtype_;
7261
};
62+
7363
} // namespace executor
7464
} // namespace torch

extension/aten_util/make_aten_functor_from_et_functor.h

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
#endif
2121
#include <ATen/native/Resize.h>
2222
#include <executorch/extension/kernel_util/type_list.h>
23-
#include <executorch/extension/runner_util/managed_tensor.h>
2423
#include <executorch/runtime/core/evalue.h>
24+
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
2525
#include <torch/torch.h>
2626

2727
namespace torch {
@@ -107,25 +107,39 @@ struct type_convert<
107107
typename remove_const_ref<ETensor>::type,
108108
torch::executor::Tensor>>>
109109
final {
110-
public:
111-
ATensor val;
112-
std::unique_ptr<ManagedTensor> managed_tensor;
113-
torch::executor::Tensor converted;
114-
std::vector<exec_aten::SizesType> sizes;
115-
explicit type_convert(ATensor value)
116-
: val(value), converted(torch::executor::Tensor(nullptr)) {
117-
for (auto size : val.sizes()) {
118-
sizes.push_back(size);
119-
}
120-
torch::executor::ScalarType scalar_type =
121-
static_cast<torch::executor::ScalarType>(val.scalar_type());
122-
managed_tensor = std::make_unique<ManagedTensor>(
123-
val.mutable_data_ptr(), sizes, scalar_type);
124-
converted = managed_tensor->get_aliasing_tensor();
110+
explicit type_convert(ATensor value) : value_(value) {
111+
auto sizes = std::make_shared<std::vector<Tensor::SizesType>>(
112+
value_.sizes().begin(), value_.sizes().end());
113+
const ssize_t dim = sizes->size();
114+
auto dim_order = std::make_shared<std::vector<Tensor::DimOrderType>>(dim);
115+
auto strides = std::make_shared<std::vector<Tensor::StridesType>>(dim);
116+
117+
std::iota(dim_order->begin(), dim_order->end(), 0);
118+
dim_order_to_stride_nocheck(
119+
sizes->data(), dim_order->data(), dim, strides->data());
120+
121+
auto tensor_impl = std::make_shared<TensorImpl>(
122+
static_cast<torch::executor::ScalarType>(value_.scalar_type()),
123+
sizes->size(),
124+
sizes->data(),
125+
value_.mutable_data_ptr(),
126+
dim_order->data(),
127+
strides->data());
128+
129+
converted_ = std::unique_ptr<Tensor, std::function<void(Tensor*)>>(
130+
new Tensor(tensor_impl.get()),
131+
[sizes, dim_order, strides, tensor_impl](Tensor* pointer) {
132+
delete pointer;
133+
});
125134
}
135+
126136
ETensor call() {
127-
return converted;
137+
return *converted_;
128138
}
139+
140+
private:
141+
ATensor value_;
142+
std::unique_ptr<Tensor, std::function<void(Tensor*)>> converted_;
129143
};
130144

131145
// Tensors: ETen to ATen.
@@ -139,21 +153,22 @@ struct type_convert<
139153
typename remove_const_ref<ETensor>::type,
140154
torch::executor::Tensor>>>
141155
final {
142-
public:
143-
ETensor val;
144-
at::Tensor converted;
145-
std::vector<int64_t> sizes;
146-
explicit type_convert(ETensor value) : val(value) {
147-
for (auto size : val.sizes()) {
148-
sizes.push_back(size);
149-
}
150-
c10::ScalarType scalar_type =
151-
static_cast<c10::ScalarType>(val.scalar_type());
152-
converted = at::from_blob(val.mutable_data_ptr(), sizes, scalar_type);
156+
explicit type_convert(ETensor value)
157+
: value_(value), sizes_(value_.sizes().begin(), value_.sizes().end()) {
158+
converted_ = at::from_blob(
159+
value_.mutable_data_ptr(),
160+
sizes_,
161+
static_cast<c10::ScalarType>(value_.scalar_type()));
153162
}
163+
154164
ATensor call() {
155-
return converted;
165+
return converted_;
156166
}
167+
168+
private:
169+
ETensor value_;
170+
at::Tensor converted_;
171+
std::vector<int64_t> sizes_;
157172
};
158173

159174
// Optionals: ATen to ETen.

extension/aten_util/targets.bzl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def define_common_targets():
2727
],
2828
exported_deps = [
2929
"//executorch/extension/kernel_util:kernel_util",
30-
"//executorch/extension/runner_util:managed_tensor",
3130
"//executorch/runtime/core:core",
3231
"//executorch/runtime/core:evalue",
3332
"//executorch/runtime/core/exec_aten:lib",

extension/runner_util/managed_tensor.h

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,39 +37,29 @@ class ManagedTensor {
3737
using DimOrderType = exec_aten::DimOrderType;
3838
/// The type used for elements of `strides()`.
3939
using StridesType = exec_aten::StridesType;
40+
4041
ManagedTensor() = delete;
4142

4243
explicit ManagedTensor(
4344
void* data,
4445
const std::vector<SizesType>& sizes,
4546
ScalarType dtype)
46-
: dtype_(dtype), sizes_(sizes), data_ptr_(data) {
47+
: sizes_(sizes) {
4748
#ifdef USE_ATEN_LIB
48-
tensor_ = torch::from_blob(data, sizes, dtype_);
49+
tensor_ = torch::from_blob(data, sizes, dtype);
4950
#else
50-
ssize_t dim = sizes.size();
51-
dim_order_.resize(dim);
52-
strides_.resize(dim);
53-
for (size_t i = 0; i < dim; ++i) {
54-
dim_order_[i] = i;
55-
}
56-
dim_order_to_stride_nocheck(
57-
sizes.data(), dim_order_.data(), dim, strides_.data());
5851
tensor_impl_ = std::make_unique<TensorImpl>(
59-
dtype_,
60-
dim,
52+
dtype,
53+
sizes_.size(),
6154
sizes_.data(),
62-
data_ptr_,
63-
dim_order_.data(),
64-
strides_.data(),
55+
data,
56+
nullptr,
57+
nullptr,
6558
TensorShapeDynamism::DYNAMIC_BOUND);
6659
#endif
6760
}
6861

6962
void resize(const std::vector<SizesType>& new_sizes) {
70-
ET_CHECK_MSG(
71-
new_sizes.size() == sizes_.size(),
72-
"Cannot change rank of a managed tensor");
7363
auto err = resize_tensor(
7464
this->get_aliasing_tensor(),
7565
exec_aten::ArrayRef<SizesType>(new_sizes.data(), new_sizes.size()));
@@ -88,15 +78,12 @@ class ManagedTensor {
8878
}
8979

9080
private:
91-
ScalarType dtype_;
9281
std::unique_ptr<TensorImpl> tensor_impl_;
9382
std::vector<SizesType> sizes_;
94-
std::vector<StridesType> strides_;
95-
std::vector<DimOrderType> dim_order_;
96-
void* data_ptr_ = nullptr;
9783
#ifdef USE_ATEN_LIB
9884
Tensor tensor_;
9985
#endif
10086
};
87+
10188
} // namespace executor
10289
} // namespace torch

extension/runner_util/test/managed_tensor_test.cpp

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,6 @@ TEST_F(ManagedTensorTest, Smoke) {
4242

4343
EXPECT_EQ(tensor.sizes(), ArrayRef<SizesType>(sizes_.data(), sizes_.size()));
4444
EXPECT_EQ(tensor.scalar_type(), ScalarType::Long);
45-
std::vector<DimOrderType> expected_dim_order = {0, 1};
46-
EXPECT_EQ(
47-
tensor.dim_order(),
48-
ArrayRef<DimOrderType>(
49-
expected_dim_order.data(), expected_dim_order.size()));
50-
std::vector<StridesType> expected_strides = {3, 1};
51-
EXPECT_EQ(
52-
tensor.strides(),
53-
ArrayRef<StridesType>(expected_strides.data(), expected_strides.size()));
5445
EXPECT_EQ(tensor.const_data_ptr(), data_.data());
5546
}
5647

@@ -74,15 +65,6 @@ TEST_F(ManagedTensorTest, ResizeShrink) {
7465
tensor.sizes(),
7566
ArrayRef<SizesType>(expected_sizes.data(), expected_sizes.size()));
7667
EXPECT_EQ(tensor.scalar_type(), ScalarType::Long);
77-
std::vector<DimOrderType> expected_dim_order = {0, 1};
78-
EXPECT_EQ(
79-
tensor.dim_order(),
80-
ArrayRef<DimOrderType>(
81-
expected_dim_order.data(), expected_dim_order.size()));
82-
std::vector<StridesType> expected_strides = {2, 1};
83-
EXPECT_EQ(
84-
tensor.strides(),
85-
ArrayRef<StridesType>(expected_strides.data(), expected_strides.size()));
8668
EXPECT_EQ(tensor.const_data_ptr(), data_.data());
8769
}
8870

@@ -95,14 +77,5 @@ TEST_F(ManagedTensorTest, Resize) {
9577
tensor.sizes(),
9678
ArrayRef<SizesType>(expected_sizes.data(), expected_sizes.size()));
9779
EXPECT_EQ(tensor.scalar_type(), ScalarType::Long);
98-
std::vector<DimOrderType> expected_dim_order = {0, 1};
99-
EXPECT_EQ(
100-
tensor.dim_order(),
101-
ArrayRef<DimOrderType>(
102-
expected_dim_order.data(), expected_dim_order.size()));
103-
std::vector<StridesType> expected_strides = {2, 1};
104-
EXPECT_EQ(
105-
tensor.strides(),
106-
ArrayRef<StridesType>(expected_strides.data(), expected_strides.size()));
10780
EXPECT_EQ(tensor.const_data_ptr(), data_.data());
10881
}

0 commit comments

Comments
 (0)