@@ -21,15 +21,14 @@ namespace extension {
2121
2222#ifndef USE_ATEN_LIB
2323/* *
24- * A smart pointer type for managing the lifecycle of a TensorImpl.
24+ * A smart pointer for managing the lifecycle of a TensorImpl.
2525 *
26- * TensorImplPtr uses a shared pointer because multiple Tensor objects might
27- * share the same underlying data and metadata. This shared ownership model
28- * ensures that the TensorImpl is only destroyed when all references to it are
29- * gone, providing a safe and efficient way to manage shared tensor
30- * implementations. This abstraction is designed to be a safer and more
31- * convenient alternative to the original TensorImpl, which does not
32- * manage metadata by design.
26+ * TensorImplPtr uses a shared pointer since multiple Tensor objects may
27+ * share the same underlying data and metadata. This shared ownership ensures
28+ * that the TensorImpl is destroyed only when all references to it are gone,
29+ * providing a safe and efficient way to manage shared tensor implementations.
30+ * It serves as a safer, more convenient alternative to the original TensorImpl,
31+ * which does not manage its metadata by design.
3332 */
3433using TensorImplPtr = std::shared_ptr<exec_aten::TensorImpl>;
3534#else
@@ -48,23 +47,23 @@ using TensorImplPtr =
4847 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
4948 * specified properties.
5049 *
51- * @param type The scalar type of the tensor elements.
5250 * @param sizes A vector specifying the size of each dimension.
5351 * @param data A pointer to the data buffer.
5452 * @param dim_order A vector specifying the order of dimensions.
5553 * @param strides A vector specifying the strides of each dimension.
54+ * @param type The scalar type of the tensor elements.
5655 * @param dynamism Specifies the mutability of the tensor's shape.
5756 * @param deleter A custom deleter function for managing the lifetime of the
58- * data buffer. If provided, this deleter will be called when the managed
59- * TensorImpl object is destroyed.
57+ * data buffer. If provided, this deleter is called when the managed TensorImpl
58+ * is destroyed.
6059 * @return A TensorImplPtr managing the newly created TensorImpl.
6160 */
6261TensorImplPtr make_tensor_impl_ptr (
63- exec_aten::ScalarType type,
6462 std::vector<exec_aten::SizesType> sizes,
6563 void * data,
66- std::vector<exec_aten::DimOrderType> dim_order = {},
67- std::vector<exec_aten::StridesType> strides = {},
64+ std::vector<exec_aten::DimOrderType> dim_order,
65+ std::vector<exec_aten::StridesType> strides,
66+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
6867 exec_aten::TensorShapeDynamism dynamism =
6968 exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
7069 std::function<void (void *)> deleter = nullptr);
@@ -73,37 +72,64 @@ TensorImplPtr make_tensor_impl_ptr(
7372 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
7473 * specified properties.
7574 *
76- * This template overload is specialized for cases where the tensor data is
77- * provided as a vector. The scalar type is automatically deduced from the
78- * vector's data type. The deleter ensures that the data vector is properly
79- * managed and its lifetime is tied to the TensorImpl.
75+ * @param sizes A vector specifying the size of each dimension.
76+ * @param data A pointer to the data buffer.
77+ * @param type The scalar type of the tensor elements.
78+ * @param dynamism Specifies the mutability of the tensor's shape.
79+ * @param deleter A custom deleter function for managing the lifetime of the
80+ * data buffer. If provided, this deleter is called when the managed TensorImpl
81+ * is destroyed.
82+ * @return A TensorImplPtr managing the newly created TensorImpl.
83+ */
84+ inline TensorImplPtr make_tensor_impl_ptr (
85+ std::vector<exec_aten::SizesType> sizes,
86+ void * data,
87+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
88+ exec_aten::TensorShapeDynamism dynamism =
89+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
90+ std::function<void (void *)> deleter = nullptr) {
91+ return make_tensor_impl_ptr (
92+ std::move (sizes), data, {}, {}, type, dynamism, std::move (deleter));
93+ }
94+
95+ /* *
96+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
97+ * specified properties.
98+ *
99+ * This template overload is specialized for cases where tensor data is provided
100+ * as a vector. The scalar type is automatically deduced from the vector's data
101+ * type. The deleter ensures that the data vector is properly managed, with its
102+ * lifetime tied to the TensorImpl.
80103 *
81104 * @tparam T The C++ type of the tensor elements, deduced from the vector.
82105 * @param sizes A vector specifying the size of each dimension.
83106 * @param data A vector containing the tensor's data.
84107 * @param dim_order A vector specifying the order of dimensions.
85108 * @param strides A vector specifying the strides of each dimension.
109+ * @param type The scalar type of the tensor elements.
86110 * @param dynamism Specifies the mutability of the tensor's shape.
87111 * @return A TensorImplPtr that manages the newly created TensorImpl.
88112 */
89- template <typename T = float >
113+ template <
114+ typename T = float ,
115+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
90116inline TensorImplPtr make_tensor_impl_ptr (
91117 std::vector<exec_aten::SizesType> sizes,
92118 std::vector<T> data,
93119 std::vector<exec_aten::DimOrderType> dim_order = {},
94120 std::vector<exec_aten::StridesType> strides = {},
121+ exec_aten::ScalarType type = deduced_type,
95122 exec_aten::TensorShapeDynamism dynamism =
96123 exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
97- constexpr exec_aten::ScalarType scalar_type =
98- runtime::CppTypeToScalarType<T>::value;
124+ ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
99125 const auto raw_data_ptr = data.data ();
100126 auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
101127 return make_tensor_impl_ptr (
102- scalar_type,
103128 std::move (sizes),
104129 raw_data_ptr,
105130 std::move (dim_order),
106131 std::move (strides),
132+ type,
107133 dynamism,
108134 [data_ptr = std::move (data_ptr)](void *) {});
109135}
@@ -119,43 +145,159 @@ inline TensorImplPtr make_tensor_impl_ptr(
119145 *
120146 * @tparam T The C++ type of the tensor elements, deduced from the vector.
121147 * @param data A vector containing the tensor's data.
148+ * @param type The scalar type of the tensor elements.
122149 * @param dynamism Specifies the mutability of the tensor's shape.
123150 * @return A TensorImplPtr that manages the newly created TensorImpl.
124151 */
125- template <typename T = float >
152+ template <
153+ typename T = float ,
154+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
126155inline TensorImplPtr make_tensor_impl_ptr (
127156 std::vector<T> data,
157+ exec_aten::ScalarType type = deduced_type,
128158 exec_aten::TensorShapeDynamism dynamism =
129159 exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
160+ ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
130161 std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (data.size ())};
131162 return make_tensor_impl_ptr (
132- std::move (sizes), std::move (data), {0 }, {1 }, dynamism);
163+ std::move (sizes), std::move (data), {0 }, {1 }, type, dynamism);
164+ }
165+
166+ /* *
167+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
168+ * specified properties.
169+ *
170+ * This template overload is specialized for cases where tensor data is provided
171+ * as an initializer list. The scalar type is automatically deduced from the
172+ * initializer list's data type. The deleter ensures that the data is properly
173+ * managed, with its lifetime tied to the TensorImpl.
174+ *
175+ * @tparam T The C++ type of the tensor elements, deduced from the initializer
176+ * list.
177+ * @param sizes A vector specifying the size of each dimension.
178+ * @param list An initializer list containing the tensor's data.
179+ * @param dim_order A vector specifying the order of dimensions.
180+ * @param strides A vector specifying the strides of each dimension.
181+ * @param type The scalar type of the tensor elements.
182+ * @param dynamism Specifies the mutability of the tensor's shape.
183+ * @return A TensorImplPtr that manages the newly created TensorImpl.
184+ */
185+ template <
186+ typename T = float ,
187+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
188+ inline TensorImplPtr make_tensor_impl_ptr (
189+ std::vector<exec_aten::SizesType> sizes,
190+ std::initializer_list<T> list,
191+ std::vector<exec_aten::DimOrderType> dim_order = {},
192+ std::vector<exec_aten::StridesType> strides = {},
193+ exec_aten::ScalarType type = deduced_type,
194+ exec_aten::TensorShapeDynamism dynamism =
195+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
196+ ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
197+ auto data = std::vector<T>(std::move (list));
198+ const auto raw_data_ptr = data.data ();
199+ auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
200+ return make_tensor_impl_ptr (
201+ std::move (sizes),
202+ raw_data_ptr,
203+ std::move (dim_order),
204+ std::move (strides),
205+ type,
206+ dynamism,
207+ [data_ptr = std::move (data_ptr)](void *) {});
208+ }
209+
210+ /* *
211+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
212+ * specified properties.
213+ *
214+ * This template overload is specialized for cases where the tensor data is
215+ * provided as an initializer list. The scalar type is automatically deduced
216+ * from the initializer list's data type. The deleter ensures that the data is
217+ * properly managed and its lifetime is tied to the TensorImpl.
218+ *
219+ * @tparam T The C++ type of the tensor elements, deduced from the initializer
220+ * list.
221+ * @param sizes A vector specifying the size of each dimension.
222+ * @param list An initializer list containing the tensor's data.
223+ * @param type The scalar type of the tensor elements.
224+ * @param dynamism Specifies the mutability of the tensor's shape.
225+ * @return A TensorImplPtr that manages the newly created TensorImpl.
226+ */
227+ template <
228+ typename T = float ,
229+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
230+ inline TensorImplPtr make_tensor_impl_ptr (
231+ std::initializer_list<T> list,
232+ exec_aten::ScalarType type = deduced_type,
233+ exec_aten::TensorShapeDynamism dynamism =
234+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
235+ ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
236+ std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (list.size ())};
237+ return make_tensor_impl_ptr (
238+ std::move (sizes), std::move (list), {0 }, {1 }, type, dynamism);
239+ }
240+
241+ /* *
242+ * Creates a TensorImplPtr to manage a Tensor with a single scalar value.
243+ *
244+ * @tparam T The C++ type of the scalar value.
245+ * @param value The scalar value used for the Tensor.
246+ * @return A TensorImplPtr managing the newly created TensorImpl.
247+ */
248+ template <typename T>
249+ inline TensorImplPtr make_tensor_impl_ptr (T value) {
250+ return make_tensor_impl_ptr ({}, std::vector<T>{value});
133251}
134252
135253/* *
136254 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
137255 * specified properties.
138256 *
139257 * This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
140- * and a scalar type to interpret the data. The vector is managed, and the
141- * memory's lifetime is tied to the TensorImpl.
258+ * and a scalar type to interpret the data. The vector is managed, and its
259+ * lifetime is tied to the TensorImpl.
142260 *
143- * @param scalar_type The scalar type of the tensor elements.
144261 * @param sizes A vector specifying the size of each dimension.
145- * @param data A vector containing the raw memory for the tensor's data.
262+ * @param data A vector containing the raw memory buffer for the tensor's data.
146263 * @param dim_order A vector specifying the order of dimensions.
147264 * @param strides A vector specifying the strides of each dimension.
265+ * @param type The scalar type of the tensor elements.
148266 * @param dynamism Specifies the mutability of the tensor's shape.
149267 * @return A TensorImplPtr managing the newly created TensorImpl.
150268 */
151269TensorImplPtr make_tensor_impl_ptr (
152- exec_aten::ScalarType scalar_type,
153270 std::vector<exec_aten::SizesType> sizes,
154271 std::vector<uint8_t > data,
155- std::vector<exec_aten::DimOrderType> dim_order = {},
156- std::vector<exec_aten::StridesType> strides = {},
272+ std::vector<exec_aten::DimOrderType> dim_order,
273+ std::vector<exec_aten::StridesType> strides,
274+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
157275 exec_aten::TensorShapeDynamism dynamism =
158276 exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
159277
278+ /* *
279+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
280+ * specified properties.
281+ *
282+ * This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
283+ * and a scalar type to interpret the data. The vector is managed, and the
284+ * memory's lifetime is tied to the TensorImpl.
285+ *
286+ * @param sizes A vector specifying the size of each dimension.
287+ * @param data A vector containing the raw memory for the tensor's data.
288+ * @param type The scalar type of the tensor elements.
289+ * @param dynamism Specifies the mutability of the tensor's shape.
290+ * @return A TensorImplPtr managing the newly created TensorImpl.
291+ */
292+ inline TensorImplPtr make_tensor_impl_ptr (
293+ std::vector<exec_aten::SizesType> sizes,
294+ std::vector<uint8_t > data,
295+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
296+ exec_aten::TensorShapeDynamism dynamism =
297+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
298+ return make_tensor_impl_ptr (
299+ std::move (sizes), std::move (data), {}, {}, type, dynamism);
300+ }
301+
160302} // namespace extension
161303} // namespace executorch
0 commit comments