20
20
#endif
21
21
#include < ATen/native/Resize.h>
22
22
#include < executorch/extension/kernel_util/type_list.h>
23
+ #include < executorch/extension/tensor/tensor.h>
23
24
#include < executorch/runtime/core/evalue.h>
24
- #include < executorch/runtime/core/exec_aten/util/dim_order_util.h>
25
25
#include < torch/torch.h>
26
26
27
27
namespace executorch {
@@ -105,48 +105,20 @@ struct type_convert<
105
105
typename remove_const_ref<ETensor>::type,
106
106
torch::executor::Tensor>>>
107
107
final {
108
- explicit type_convert (ATensor value) : value_(value) {
109
- auto sizes =
110
- std::make_shared<std::vector<torch::executor::Tensor::SizesType>>(
111
- value_.sizes ().begin (), value_.sizes ().end ());
112
- const ssize_t dim = sizes->size ();
113
- auto dim_order =
114
- std::make_shared<std::vector<torch::executor::Tensor::DimOrderType>>(
115
- dim);
116
- auto strides =
117
- std::make_shared<std::vector<torch::executor::Tensor::StridesType>>(
118
- dim);
119
-
120
- std::iota (dim_order->begin (), dim_order->end (), 0 );
121
- ::executorch::runtime::dim_order_to_stride_nocheck (
122
- sizes->data (), dim_order->data(), dim, strides->data());
123
-
124
- auto tensor_impl = std::make_shared<torch::executor::TensorImpl>(
125
- static_cast <torch::executor::ScalarType>(value_.scalar_type()),
126
- sizes->size(),
127
- sizes->data(),
128
- value_.mutable_data_ptr(),
129
- dim_order->data(),
130
- strides->data());
131
-
132
- converted_ = std::unique_ptr<
133
- torch::executor::Tensor,
134
- std::function<void (torch::executor::Tensor*)>>(
135
- new torch::executor::Tensor (tensor_impl.get ()),
136
- [sizes, dim_order, strides, tensor_impl](
137
- torch::executor::Tensor* pointer) { delete pointer; });
138
- }
108
+ explicit type_convert (ATensor value)
109
+ : value_(value),
110
+ converted_(from_blob(
111
+ value_.mutable_data_ptr(),
112
+ {value_.sizes ().begin (), value_.sizes ().end ()},
113
+ ::torch::executor::ScalarType (value_.scalar_type()))) {}
139
114
140
115
ETensor call () {
141
116
return *converted_;
142
117
}
143
118
144
119
private:
145
120
ATensor value_;
146
- std::unique_ptr<
147
- torch::executor::Tensor,
148
- std::function<void (torch::executor::Tensor*)>>
149
- converted_;
121
+ TensorPtr converted_;
150
122
};
151
123
152
124
// Tensors: ETen to ATen.
@@ -158,15 +130,14 @@ struct type_convert<
158
130
std::is_same_v<typename remove_const_ref<ATensor>::type, at::Tensor> &&
159
131
std::is_same_v<
160
132
typename remove_const_ref<ETensor>::type,
161
- torch::executor::Tensor>>>
133
+ :: torch::executor::Tensor>>>
162
134
final {
163
135
explicit type_convert (ETensor value)
164
- : value_(value), sizes_(value_.sizes().begin(), value_.sizes().end()) {
165
- converted_ = at::from_blob (
166
- value_.mutable_data_ptr (),
167
- sizes_,
168
- static_cast <c10::ScalarType>(value_.scalar_type ()));
169
- }
136
+ : value_(value),
137
+ converted_(at::from_blob(
138
+ value_.mutable_data_ptr(),
139
+ std::vector<int64_t>{value_.sizes ().begin (), value_.sizes ().end ()},
140
+ c10::ScalarType (value_.scalar_type()))) {}
170
141
171
142
ATensor call () {
172
143
return converted_;
@@ -175,7 +146,6 @@ struct type_convert<
175
146
private:
176
147
ETensor value_;
177
148
at::Tensor converted_;
178
- std::vector<int64_t > sizes_;
179
149
};
180
150
181
151
// Optionals: ATen to ETen.
0 commit comments