@@ -97,31 +97,57 @@ inline TensorImplPtr make_tensor_impl_ptr(
9797 * specified properties.
9898 *
9999 * This template overload is specialized for cases where tensor data is provided
100- * as a vector. The scalar type is automatically deduced from the vector's data
101- * type. The deleter ensures that the data vector is properly managed, with its
102- * lifetime tied to the TensorImpl.
100+ * as a vector. If the specified `type` differs from the deduced type of the
101+ * vector's elements, and casting is allowed, the data will be cast to the
102+ * specified `type`. This allows for flexible creation of tensors with data
103+ * vectors of one type and a different scalar type.
103104 *
104105 * @tparam T The C++ type of the tensor elements, deduced from the vector.
105106 * @param sizes A vector specifying the size of each dimension.
106107 * @param data A vector containing the tensor's data.
107108 * @param dim_order A vector specifying the order of dimensions.
108109 * @param strides A vector specifying the strides of each dimension.
109- * @param type The scalar type of the tensor elements.
110+ * @param type The scalar type of the tensor elements. If it differs from the
111+ * deduced type, the data will be cast to this type if allowed.
110112 * @param dynamism Specifies the mutability of the tensor's shape.
111113 * @return A TensorImplPtr that manages the newly created TensorImpl.
112114 */
113115template <
114116 typename T = float ,
115117 exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
116- inline TensorImplPtr make_tensor_impl_ptr (
118+ TensorImplPtr make_tensor_impl_ptr (
117119 std::vector<exec_aten::SizesType> sizes,
118120 std::vector<T> data,
119121 std::vector<exec_aten::DimOrderType> dim_order = {},
120122 std::vector<exec_aten::StridesType> strides = {},
121123 exec_aten::ScalarType type = deduced_type,
122124 exec_aten::TensorShapeDynamism dynamism =
123125 exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
124- ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
126+ if (type != deduced_type) {
127+ ET_CHECK_MSG (
128+ runtime::canCast (deduced_type, type),
129+ " Cannot cast deduced type to specified type." );
130+ std::vector<uint8_t > casted_data (data.size () * runtime::elementSize (type));
131+ ET_SWITCH_REALHBBF16_TYPES (
132+ type, nullptr , " make_tensor_impl_ptr" , CTYPE, [&] {
133+ std::transform (
134+ data.begin (),
135+ data.end (),
136+ reinterpret_cast <CTYPE*>(casted_data.data ()),
137+ [](const T& val) { return static_cast <CTYPE>(val); });
138+ });
139+ const auto raw_data_ptr = casted_data.data ();
140+ auto data_ptr =
141+ std::make_shared<std::vector<uint8_t >>(std::move (casted_data));
142+ return make_tensor_impl_ptr (
143+ std::move (sizes),
144+ raw_data_ptr,
145+ std::move (dim_order),
146+ std::move (strides),
147+ type,
148+ dynamism,
149+ [data_ptr = std::move (data_ptr)](void *) {});
150+ }
125151 const auto raw_data_ptr = data.data ();
126152 auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
127153 return make_tensor_impl_ptr (
@@ -138,14 +164,16 @@ inline TensorImplPtr make_tensor_impl_ptr(
138164 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
139165 * specified properties.
140166 *
141- * This template overload is specialized for cases where the tensor data is
142- * provided as a vector. The scalar type is automatically deduced from the
143- * vector's data type. The deleter ensures that the data vector is properly
144- * managed and its lifetime is tied to the TensorImpl.
167+ * This template overload is specialized for cases where tensor data is provided
168+ * as a vector. If the specified `type` differs from the deduced type of the
169+ * vector's elements, and casting is allowed, the data will be cast to the
170+ * specified `type`. This allows for flexible creation of tensors with data
171+ * vectors of one type and a different scalar type.
145172 *
146173 * @tparam T The C++ type of the tensor elements, deduced from the vector.
147174 * @param data A vector containing the tensor's data.
148- * @param type The scalar type of the tensor elements.
175+ * @param type The scalar type of the tensor elements. If it differs from the
176+ * deduced type, the data will be cast to this type if allowed.
149177 * @param dynamism Specifies the mutability of the tensor's shape.
150178 * @return A TensorImplPtr that manages the newly created TensorImpl.
151179 */
@@ -157,7 +185,6 @@ inline TensorImplPtr make_tensor_impl_ptr(
157185 exec_aten::ScalarType type = deduced_type,
158186 exec_aten::TensorShapeDynamism dynamism =
159187 exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
160- ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
161188 std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (data.size ())};
162189 return make_tensor_impl_ptr (
163190 std::move (sizes), std::move (data), {0 }, {1 }, type, dynamism);
@@ -168,17 +195,19 @@ inline TensorImplPtr make_tensor_impl_ptr(
168195 * specified properties.
169196 *
170197 * This template overload is specialized for cases where tensor data is provided
171- * as an initializer list. The scalar type is automatically deduced from the
172- * initializer list's data type. The deleter ensures that the data is properly
173- * managed, with its lifetime tied to the TensorImpl.
198+ * as an initializer list. If the specified `type` differs from the deduced type
199+ * of the initializer list's elements, and casting is allowed, the data will be
200+ * cast to the specified `type`. This allows for flexible creation of tensors
201+ * with data initializer list of one type and a different scalar type.
174202 *
175203 * @tparam T The C++ type of the tensor elements, deduced from the initializer
176204 * list.
177205 * @param sizes A vector specifying the size of each dimension.
178206 * @param list An initializer list containing the tensor's data.
179207 * @param dim_order A vector specifying the order of dimensions.
180208 * @param strides A vector specifying the strides of each dimension.
181- * @param type The scalar type of the tensor elements.
209+ * @param type The scalar type of the tensor elements. If it differs from the
210+ * deduced type, the data will be cast to this type if allowed.
182211 * @param dynamism Specifies the mutability of the tensor's shape.
183212 * @return A TensorImplPtr that manages the newly created TensorImpl.
184213 */
@@ -193,34 +222,30 @@ inline TensorImplPtr make_tensor_impl_ptr(
193222 exec_aten::ScalarType type = deduced_type,
194223 exec_aten::TensorShapeDynamism dynamism =
195224 exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
196- ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
197- auto data = std::vector<T>(std::move (list));
198- const auto raw_data_ptr = data.data ();
199- auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
200225 return make_tensor_impl_ptr (
201226 std::move (sizes),
202- raw_data_ptr ,
227+ std::vector<T>( std::move (list)) ,
203228 std::move (dim_order),
204229 std::move (strides),
205230 type,
206- dynamism,
207- [data_ptr = std::move (data_ptr)](void *) {});
231+ dynamism);
208232}
209233
210234/* *
211235 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
212236 * specified properties.
213237 *
214- * This template overload is specialized for cases where the tensor data is
215- * provided as an initializer list. The scalar type is automatically deduced
216- * from the initializer list's data type. The deleter ensures that the data is
217- * properly managed and its lifetime is tied to the TensorImpl.
238+ * This template overload is specialized for cases where tensor data is provided
239+ * as an initializer list. If the specified `type` differs from the deduced type
240+ * of the initializer list's elements, and casting is allowed, the data will be
241+ * cast to the specified `type`. This allows for flexible creation of tensors
242+ * with data initializer list of one type and a different scalar type.
218243 *
219244 * @tparam T The C++ type of the tensor elements, deduced from the initializer
220245 * list.
221- * @param sizes A vector specifying the size of each dimension.
222246 * @param list An initializer list containing the tensor's data.
223- * @param type The scalar type of the tensor elements.
247+ * @param type The scalar type of the tensor elements. If it differs from the
248+ * deduced type, the data will be cast to this type if allowed.
224249 * @param dynamism Specifies the mutability of the tensor's shape.
225250 * @return A TensorImplPtr that manages the newly created TensorImpl.
226251 */
@@ -232,7 +257,6 @@ inline TensorImplPtr make_tensor_impl_ptr(
232257 exec_aten::ScalarType type = deduced_type,
233258 exec_aten::TensorShapeDynamism dynamism =
234259 exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
235- ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
236260 std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (list.size ())};
237261 return make_tensor_impl_ptr (
238262 std::move (sizes), std::move (list), {0 }, {1 }, type, dynamism);
0 commit comments