@@ -97,31 +97,57 @@ inline TensorImplPtr make_tensor_impl_ptr(
9797 * specified properties. 
9898 * 
9999 * This template overload is specialized for cases where tensor data is provided 
100-  * as a vector. The scalar type is automatically deduced from the vector's data 
101-  * type. The deleter ensures that the data vector is properly managed, with its 
102-  * lifetime tied to the TensorImpl. 
100+  * as a vector. If the specified `type` differs from the deduced type of the 
101+  * vector's elements, and casting is allowed, the data will be cast to the 
102+  * specified `type`. This allows for flexible creation of tensors with data 
103+  * vectors of one type and a different scalar type. 
103104 * 
104105 * @tparam T The C++ type of the tensor elements, deduced from the vector. 
105106 * @param sizes A vector specifying the size of each dimension. 
106107 * @param data A vector containing the tensor's data. 
107108 * @param dim_order A vector specifying the order of dimensions. 
108109 * @param strides A vector specifying the strides of each dimension. 
109-  * @param type The scalar type of the tensor elements. 
110+  * @param type The scalar type of the tensor elements. If it differs from the 
111+  * deduced type, the data will be cast to this type if allowed. 
110112 * @param dynamism Specifies the mutability of the tensor's shape. 
111113 * @return A TensorImplPtr that manages the newly created TensorImpl. 
112114 */  
113115template  <
114116    typename  T = float ,
115117    exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
116- inline   TensorImplPtr make_tensor_impl_ptr (
118+ TensorImplPtr make_tensor_impl_ptr (
117119    std::vector<exec_aten::SizesType> sizes,
118120    std::vector<T> data,
119121    std::vector<exec_aten::DimOrderType> dim_order = {},
120122    std::vector<exec_aten::StridesType> strides = {},
121123    exec_aten::ScalarType type = deduced_type,
122124    exec_aten::TensorShapeDynamism dynamism =
123125        exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
124-   ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type."  );
126+   if  (type != deduced_type) {
127+     ET_CHECK_MSG (
128+         runtime::canCast (deduced_type, type),
129+         " Cannot cast deduced type to specified type."  );
130+     std::vector<uint8_t > casted_data (data.size () * runtime::elementSize (type));
131+     ET_SWITCH_REALHBBF16_TYPES (
132+         type, nullptr , " make_tensor_impl_ptr"  , CTYPE, [&] {
133+           std::transform (
134+               data.begin (),
135+               data.end (),
136+               reinterpret_cast <CTYPE*>(casted_data.data ()),
137+               [](const  T& val) { return  static_cast <CTYPE>(val); });
138+         });
139+     const  auto  raw_data_ptr = casted_data.data ();
140+     auto  data_ptr =
141+         std::make_shared<std::vector<uint8_t >>(std::move (casted_data));
142+     return  make_tensor_impl_ptr (
143+         std::move (sizes),
144+         raw_data_ptr,
145+         std::move (dim_order),
146+         std::move (strides),
147+         type,
148+         dynamism,
149+         [data_ptr = std::move (data_ptr)](void *) {});
150+   }
125151  const  auto  raw_data_ptr = data.data ();
126152  auto  data_ptr = std::make_shared<std::vector<T>>(std::move (data));
127153  return  make_tensor_impl_ptr (
@@ -138,14 +164,16 @@ inline TensorImplPtr make_tensor_impl_ptr(
138164 * Creates a TensorImplPtr that manages a newly created TensorImpl with the 
139165 * specified properties. 
140166 * 
141-  * This template overload is specialized for cases where the tensor data is 
142-  * provided as a vector. The scalar type is automatically deduced from the 
143-  * vector's data type. The deleter ensures that the data vector is properly 
144-  * managed and its lifetime is tied to the TensorImpl. 
167+  * This template overload is specialized for cases where tensor data is provided 
168+  * as a vector. If the specified `type` differs from the deduced type of the 
169+  * vector's elements, and casting is allowed, the data will be cast to the 
170+  * specified `type`. This allows for flexible creation of tensors with data 
171+  * vectors of one type and a different scalar type. 
145172 * 
146173 * @tparam T The C++ type of the tensor elements, deduced from the vector. 
147174 * @param data A vector containing the tensor's data. 
148-  * @param type The scalar type of the tensor elements. 
175+  * @param type The scalar type of the tensor elements. If it differs from the 
176+  * deduced type, the data will be cast to this type if allowed. 
149177 * @param dynamism Specifies the mutability of the tensor's shape. 
150178 * @return A TensorImplPtr that manages the newly created TensorImpl. 
151179 */  
@@ -157,7 +185,6 @@ inline TensorImplPtr make_tensor_impl_ptr(
157185    exec_aten::ScalarType type = deduced_type,
158186    exec_aten::TensorShapeDynamism dynamism =
159187        exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
160-   ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type."  );
161188  std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (data.size ())};
162189  return  make_tensor_impl_ptr (
163190      std::move (sizes), std::move (data), {0 }, {1 }, type, dynamism);
@@ -168,17 +195,19 @@ inline TensorImplPtr make_tensor_impl_ptr(
168195 * specified properties. 
169196 * 
170197 * This template overload is specialized for cases where tensor data is provided 
171-  * as an initializer list. The scalar type is automatically deduced from the 
172-  * initializer list's data type. The deleter ensures that the data is properly 
173-  * managed, with its lifetime tied to the TensorImpl. 
198+  * as an initializer list. If the specified `type` differs from the deduced type 
199+  * of the initializer list's elements, and casting is allowed, the data will be 
200+  * cast to the specified `type`. This allows for flexible creation of tensors 
201+  * with data initializer list of one type and a different scalar type. 
174202 * 
175203 * @tparam T The C++ type of the tensor elements, deduced from the initializer 
176204 * list. 
177205 * @param sizes A vector specifying the size of each dimension. 
178206 * @param list An initializer list containing the tensor's data. 
179207 * @param dim_order A vector specifying the order of dimensions. 
180208 * @param strides A vector specifying the strides of each dimension. 
181-  * @param type The scalar type of the tensor elements. 
209+  * @param type The scalar type of the tensor elements. If it differs from the 
210+  * deduced type, the data will be cast to this type if allowed. 
182211 * @param dynamism Specifies the mutability of the tensor's shape. 
183212 * @return A TensorImplPtr that manages the newly created TensorImpl. 
184213 */  
@@ -193,34 +222,30 @@ inline TensorImplPtr make_tensor_impl_ptr(
193222    exec_aten::ScalarType type = deduced_type,
194223    exec_aten::TensorShapeDynamism dynamism =
195224        exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
196-   ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type."  );
197-   auto  data = std::vector<T>(std::move (list));
198-   const  auto  raw_data_ptr = data.data ();
199-   auto  data_ptr = std::make_shared<std::vector<T>>(std::move (data));
200225  return  make_tensor_impl_ptr (
201226      std::move (sizes),
202-       raw_data_ptr ,
227+       std::vector<T>( std::move (list)) ,
203228      std::move (dim_order),
204229      std::move (strides),
205230      type,
206-       dynamism,
207-       [data_ptr = std::move (data_ptr)](void *) {});
231+       dynamism);
208232}
209233
210234/* *
211235 * Creates a TensorImplPtr that manages a newly created TensorImpl with the 
212236 * specified properties. 
213237 * 
214-  * This template overload is specialized for cases where the tensor data is 
215-  * provided as an initializer list. The scalar type is automatically deduced 
216-  * from the initializer list's data type. The deleter ensures that the data is 
217-  * properly managed and its lifetime is tied to the TensorImpl. 
238+  * This template overload is specialized for cases where tensor data is provided 
239+  * as an initializer list. If the specified `type` differs from the deduced type 
240+  * of the initializer list's elements, and casting is allowed, the data will be 
241+  * cast to the specified `type`. This allows for flexible creation of tensors 
242+  * with data initializer list of one type and a different scalar type. 
218243 * 
219244 * @tparam T The C++ type of the tensor elements, deduced from the initializer 
220245 * list. 
221-  * @param sizes A vector specifying the size of each dimension. 
222246 * @param list An initializer list containing the tensor's data. 
223-  * @param type The scalar type of the tensor elements. 
247+  * @param type The scalar type of the tensor elements. If it differs from the 
248+  * deduced type, the data will be cast to this type if allowed. 
224249 * @param dynamism Specifies the mutability of the tensor's shape. 
225250 * @return A TensorImplPtr that manages the newly created TensorImpl. 
226251 */  
@@ -232,7 +257,6 @@ inline TensorImplPtr make_tensor_impl_ptr(
232257    exec_aten::ScalarType type = deduced_type,
233258    exec_aten::TensorShapeDynamism dynamism =
234259        exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
235-   ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type."  );
236260  std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (list.size ())};
237261  return  make_tensor_impl_ptr (
238262      std::move (sizes), std::move (list), {0 }, {1 }, type, dynamism);
0 commit comments