@@ -21,15 +21,14 @@ namespace extension {
2121
2222#ifndef  USE_ATEN_LIB
2323/* *
24-  * A smart pointer type  for managing the lifecycle of a TensorImpl. 
24+  * A smart pointer for managing the lifecycle of a TensorImpl. 
2525 * 
26-  * TensorImplPtr uses a shared pointer because multiple Tensor objects might 
27-  * share the same underlying data and metadata. This shared ownership model 
28-  * ensures that the TensorImpl is only destroyed when all references to it are 
29-  * gone, providing a safe and efficient way to manage shared tensor 
30-  * implementations. This abstraction is designed to be a safer and more 
31-  * convenient alternative to the original TensorImpl, which does not 
32-  * manage metadata by design. 
26+  * TensorImplPtr uses a shared pointer since multiple Tensor objects may 
27+  * share the same underlying data and metadata. This shared ownership ensures 
28+  * that the TensorImpl is destroyed only when all references to it are gone, 
29+  * providing a safe and efficient way to manage shared tensor implementations. 
30+  * It serves as a safer, more convenient alternative to the original TensorImpl, 
31+  * which does not manage its metadata by design. 
3332 */  
3433using  TensorImplPtr = std::shared_ptr<exec_aten::TensorImpl>;
3534#else 
@@ -48,23 +47,23 @@ using TensorImplPtr =
4847 * Creates a TensorImplPtr that manages a newly created TensorImpl with the 
4948 * specified properties. 
5049 * 
51-  * @param type The scalar type of the tensor elements. 
5250 * @param sizes A vector specifying the size of each dimension. 
5351 * @param data A pointer to the data buffer. 
5452 * @param dim_order A vector specifying the order of dimensions. 
5553 * @param strides A vector specifying the strides of each dimension. 
54+  * @param type The scalar type of the tensor elements. 
5655 * @param dynamism Specifies the mutability of the tensor's shape. 
5756 * @param deleter A custom deleter function for managing the lifetime of the 
58-  * data buffer. If provided, this deleter will be  called when the managed 
59-  * TensorImpl object  is destroyed. 
57+  * data buffer. If provided, this deleter is  called when the managed TensorImpl  
58+  * is destroyed. 
6059 * @return A TensorImplPtr managing the newly created TensorImpl. 
6160 */  
6261TensorImplPtr make_tensor_impl_ptr (
63-     exec_aten::ScalarType type,
6462    std::vector<exec_aten::SizesType> sizes,
6563    void * data,
66-     std::vector<exec_aten::DimOrderType> dim_order = {},
67-     std::vector<exec_aten::StridesType> strides = {},
64+     std::vector<exec_aten::DimOrderType> dim_order,
65+     std::vector<exec_aten::StridesType> strides,
66+     exec_aten::ScalarType type = exec_aten::ScalarType::Float,
6867    exec_aten::TensorShapeDynamism dynamism =
6968        exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
7069    std::function<void (void *)> deleter = nullptr);
@@ -73,37 +72,64 @@ TensorImplPtr make_tensor_impl_ptr(
7372 * Creates a TensorImplPtr that manages a newly created TensorImpl with the 
7473 * specified properties. 
7574 * 
76-  * This template overload is specialized for cases where the tensor data is 
77-  * provided as a vector. The scalar type is automatically deduced from the 
78-  * vector's data type. The deleter ensures that the data vector is properly 
79-  * managed and its lifetime is tied to the TensorImpl. 
75+  * @param sizes A vector specifying the size of each dimension. 
76+  * @param data A pointer to the data buffer. 
77+  * @param type The scalar type of the tensor elements. 
78+  * @param dynamism Specifies the mutability of the tensor's shape. 
79+  * @param deleter A custom deleter function for managing the lifetime of the 
80+  * data buffer. If provided, this deleter is called when the managed TensorImpl 
81+  * is destroyed. 
82+  * @return A TensorImplPtr managing the newly created TensorImpl. 
83+  */  
84+ inline  TensorImplPtr make_tensor_impl_ptr (
85+     std::vector<exec_aten::SizesType> sizes,
86+     void * data,
87+     exec_aten::ScalarType type = exec_aten::ScalarType::Float,
88+     exec_aten::TensorShapeDynamism dynamism =
89+         exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
90+     std::function<void (void *)> deleter = nullptr) {
91+   return  make_tensor_impl_ptr (
92+       std::move (sizes), data, {}, {}, type, dynamism, std::move (deleter));
93+ }
94+ 
95+ /* *
96+  * Creates a TensorImplPtr that manages a newly created TensorImpl with the 
97+  * specified properties. 
98+  * 
99+  * This template overload is specialized for cases where tensor data is provided 
100+  * as a vector. The scalar type is automatically deduced from the vector's data 
101+  * type. The deleter ensures that the data vector is properly managed, with its 
102+  * lifetime tied to the TensorImpl. 
80103 * 
81104 * @tparam T The C++ type of the tensor elements, deduced from the vector. 
82105 * @param sizes A vector specifying the size of each dimension. 
83106 * @param data A vector containing the tensor's data. 
84107 * @param dim_order A vector specifying the order of dimensions. 
85108 * @param strides A vector specifying the strides of each dimension. 
109+  * @param type The scalar type of the tensor elements. 
86110 * @param dynamism Specifies the mutability of the tensor's shape. 
87111 * @return A TensorImplPtr that manages the newly created TensorImpl. 
88112 */  
89- template  <typename  T = float >
113+ template  <
114+     typename  T = float ,
115+     exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
90116inline  TensorImplPtr make_tensor_impl_ptr (
91117    std::vector<exec_aten::SizesType> sizes,
92118    std::vector<T> data,
93119    std::vector<exec_aten::DimOrderType> dim_order = {},
94120    std::vector<exec_aten::StridesType> strides = {},
121+     exec_aten::ScalarType type = deduced_type,
95122    exec_aten::TensorShapeDynamism dynamism =
96123        exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
97-   constexpr  exec_aten::ScalarType scalar_type =
98-       runtime::CppTypeToScalarType<T>::value;
124+   ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type."  );
99125  const  auto  raw_data_ptr = data.data ();
100126  auto  data_ptr = std::make_shared<std::vector<T>>(std::move (data));
101127  return  make_tensor_impl_ptr (
102-       scalar_type,
103128      std::move (sizes),
104129      raw_data_ptr,
105130      std::move (dim_order),
106131      std::move (strides),
132+       type,
107133      dynamism,
108134      [data_ptr = std::move (data_ptr)](void *) {});
109135}
@@ -119,43 +145,159 @@ inline TensorImplPtr make_tensor_impl_ptr(
119145 * 
120146 * @tparam T The C++ type of the tensor elements, deduced from the vector. 
121147 * @param data A vector containing the tensor's data. 
148+  * @param type The scalar type of the tensor elements. 
122149 * @param dynamism Specifies the mutability of the tensor's shape. 
123150 * @return A TensorImplPtr that manages the newly created TensorImpl. 
124151 */  
125- template  <typename  T = float >
152+ template  <
153+     typename  T = float ,
154+     exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
126155inline  TensorImplPtr make_tensor_impl_ptr (
127156    std::vector<T> data,
157+     exec_aten::ScalarType type = deduced_type,
128158    exec_aten::TensorShapeDynamism dynamism =
129159        exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
160+   ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type."  );
130161  std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (data.size ())};
131162  return  make_tensor_impl_ptr (
132-       std::move (sizes), std::move (data), {0 }, {1 }, dynamism);
163+       std::move (sizes), std::move (data), {0 }, {1 }, type, dynamism);
164+ }
165+ 
166+ /* *
167+  * Creates a TensorImplPtr that manages a newly created TensorImpl with the 
168+  * specified properties. 
169+  * 
170+  * This template overload is specialized for cases where tensor data is provided 
171+  * as an initializer list. The scalar type is automatically deduced from the 
172+  * initializer list's data type. The deleter ensures that the data is properly 
173+  * managed, with its lifetime tied to the TensorImpl. 
174+  * 
175+  * @tparam T The C++ type of the tensor elements, deduced from the initializer 
176+  * list. 
177+  * @param sizes A vector specifying the size of each dimension. 
178+  * @param list An initializer list containing the tensor's data. 
179+  * @param dim_order A vector specifying the order of dimensions. 
180+  * @param strides A vector specifying the strides of each dimension. 
181+  * @param type The scalar type of the tensor elements. 
182+  * @param dynamism Specifies the mutability of the tensor's shape. 
183+  * @return A TensorImplPtr that manages the newly created TensorImpl. 
184+  */  
185+ template  <
186+     typename  T = float ,
187+     exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
188+ inline  TensorImplPtr make_tensor_impl_ptr (
189+     std::vector<exec_aten::SizesType> sizes,
190+     std::initializer_list<T> list,
191+     std::vector<exec_aten::DimOrderType> dim_order = {},
192+     std::vector<exec_aten::StridesType> strides = {},
193+     exec_aten::ScalarType type = deduced_type,
194+     exec_aten::TensorShapeDynamism dynamism =
195+         exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
196+   ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type."  );
197+   auto  data = std::vector<T>(std::move (list));
198+   const  auto  raw_data_ptr = data.data ();
199+   auto  data_ptr = std::make_shared<std::vector<T>>(std::move (data));
200+   return  make_tensor_impl_ptr (
201+       std::move (sizes),
202+       raw_data_ptr,
203+       std::move (dim_order),
204+       std::move (strides),
205+       type,
206+       dynamism,
207+       [data_ptr = std::move (data_ptr)](void *) {});
208+ }
209+ 
210+ /* *
211+  * Creates a TensorImplPtr that manages a newly created TensorImpl with the 
212+  * specified properties. 
213+  * 
214+  * This template overload is specialized for cases where the tensor data is 
215+  * provided as an initializer list. The scalar type is automatically deduced 
216+  * from the initializer list's data type. The deleter ensures that the data is 
217+  * properly managed and its lifetime is tied to the TensorImpl. 
218+  * 
219+  * @tparam T The C++ type of the tensor elements, deduced from the initializer 
220+  * list. 
221+  * @param sizes A vector specifying the size of each dimension. 
222+  * @param list An initializer list containing the tensor's data. 
223+  * @param type The scalar type of the tensor elements. 
224+  * @param dynamism Specifies the mutability of the tensor's shape. 
225+  * @return A TensorImplPtr that manages the newly created TensorImpl. 
226+  */  
227+ template  <
228+     typename  T = float ,
229+     exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
230+ inline  TensorImplPtr make_tensor_impl_ptr (
231+     std::initializer_list<T> list,
232+     exec_aten::ScalarType type = deduced_type,
233+     exec_aten::TensorShapeDynamism dynamism =
234+         exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
235+   ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type."  );
236+   std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (list.size ())};
237+   return  make_tensor_impl_ptr (
238+       std::move (sizes), std::move (list), {0 }, {1 }, type, dynamism);
239+ }
240+ 
241+ /* *
242+  * Creates a TensorImplPtr to manage a Tensor with a single scalar value. 
243+  * 
244+  * @tparam T The C++ type of the scalar value. 
245+  * @param value The scalar value used for the Tensor. 
246+  * @return A TensorImplPtr managing the newly created TensorImpl. 
247+  */  
248+ template  <typename  T>
249+ inline  TensorImplPtr make_tensor_impl_ptr (T value) {
250+   return  make_tensor_impl_ptr ({}, std::vector<T>{value});
133251}
134252
135253/* *
136254 * Creates a TensorImplPtr that manages a newly created TensorImpl with the 
137255 * specified properties. 
138256 * 
139257 * This overload accepts a raw memory buffer stored in a std::vector<uint8_t> 
140-  * and a scalar type to interpret the data. The vector is managed, and the  
141-  * memory's  lifetime is tied to the TensorImpl. 
258+  * and a scalar type to interpret the data. The vector is managed, and its  
259+  * lifetime is tied to the TensorImpl. 
142260 * 
143-  * @param scalar_type The scalar type of the tensor elements. 
144261 * @param sizes A vector specifying the size of each dimension. 
145-  * @param data A vector containing the raw memory for the tensor's data. 
262+  * @param data A vector containing the raw memory buffer  for the tensor's data. 
146263 * @param dim_order A vector specifying the order of dimensions. 
147264 * @param strides A vector specifying the strides of each dimension. 
265+  * @param type The scalar type of the tensor elements. 
148266 * @param dynamism Specifies the mutability of the tensor's shape. 
149267 * @return A TensorImplPtr managing the newly created TensorImpl. 
150268 */  
151269TensorImplPtr make_tensor_impl_ptr (
152-     exec_aten::ScalarType scalar_type,
153270    std::vector<exec_aten::SizesType> sizes,
154271    std::vector<uint8_t > data,
155-     std::vector<exec_aten::DimOrderType> dim_order = {},
156-     std::vector<exec_aten::StridesType> strides = {},
272+     std::vector<exec_aten::DimOrderType> dim_order,
273+     std::vector<exec_aten::StridesType> strides,
274+     exec_aten::ScalarType type = exec_aten::ScalarType::Float,
157275    exec_aten::TensorShapeDynamism dynamism =
158276        exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
159277
278+ /* *
279+  * Creates a TensorImplPtr that manages a newly created TensorImpl with the 
280+  * specified properties. 
281+  * 
282+  * This overload accepts a raw memory buffer stored in a std::vector<uint8_t> 
283+  * and a scalar type to interpret the data. The vector is managed, and the 
284+  * memory's lifetime is tied to the TensorImpl. 
285+  * 
286+  * @param sizes A vector specifying the size of each dimension. 
287+  * @param data A vector containing the raw memory for the tensor's data. 
288+  * @param type The scalar type of the tensor elements. 
289+  * @param dynamism Specifies the mutability of the tensor's shape. 
290+  * @return A TensorImplPtr managing the newly created TensorImpl. 
291+  */  
292+ inline  TensorImplPtr make_tensor_impl_ptr (
293+     std::vector<exec_aten::SizesType> sizes,
294+     std::vector<uint8_t > data,
295+     exec_aten::ScalarType type = exec_aten::ScalarType::Float,
296+     exec_aten::TensorShapeDynamism dynamism =
297+         exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
298+   return  make_tensor_impl_ptr (
299+       std::move (sizes), std::move (data), {}, {}, type, dynamism);
300+ }
301+ 
160302} //  namespace extension
161303} //  namespace executorch
0 commit comments