55
66#include < cstdint>
77#include " core/common/common.h"
8+ #include " core/framework/tensor.h"
89#include " core/framework/tensor_shape.h"
910
1011namespace onnxruntime {
1112namespace webgpu {
1213
14+ /* *
15+ * Returns the maximum number of components `N` to be used as `vecN` for the given size.
16+ */
1317inline int GetMaxComponents (int64_t size) {
1418 if (size % 4 == 0 ) {
1519 return 4 ;
@@ -19,6 +23,11 @@ inline int GetMaxComponents(int64_t size) {
1923 return 1 ;
2024}
2125
26+ /* *
27+ * Returns a string representing a WGSL expression that sums the components of a value T.
28+ *
29+ * T can be a scalar S, vec2<S> or vec4<S>.
30+ */
2231inline std::string SumVector (std::string x, int components) {
2332 switch (components) {
2433 case 1 :
@@ -49,5 +58,36 @@ inline std::string MakeScalarOrVectorType(int components, std::string_view data_
4958
5059TensorShape ReduceShapeByComponents (const TensorShape& shape, int64_t components);
5160
61+ /* *
62+ * Create a reshaped tensor from an existing tensor.
63+ *
64+ * The specified new shape must have the same number of elements as the original tensor.
65+ *
66+ * The new tensor is a "view" of the original tensor. It uses the same data of the original tensor.
67+ * The new tensor does not take or share ownership of the underlying data. The original tensor must outlive the new tensor.
68+ */
69+ inline Tensor CreateTensorView (const Tensor& tensor, const TensorShape& new_shape) {
70+ ORT_ENFORCE (tensor.Shape ().Size () == new_shape.Size (), " Cannot reshape tensor " , tensor.Shape ().ToString (), " to " , new_shape.ToString ());
71+ return {tensor.DataType (), new_shape, const_cast <void *>(tensor.DataRaw ()), tensor.Location ()};
72+ }
73+
74+ /* *
75+ * Create a reinterpreted tensor from an existing tensor with a new data type and shape.
76+ *
77+ * The new data type and shape must match the original tensor's storage size.
78+ *
79+ * The new tensor is a "view" of the original tensor. It uses the same data of the original tensor.
80+ * The new tensor does not take or share ownership of the underlying data. The original tensor must outlive the new tensor.
81+ */
82+ inline Tensor CreateTensorView (const Tensor& tensor, MLDataType new_data_type, const TensorShape& new_shape) {
83+ auto byte_size = Tensor::CalculateTensorStorageSize (tensor.DataType (), tensor.Shape ());
84+ auto new_byte_size = Tensor::CalculateTensorStorageSize (new_data_type, new_shape);
85+ ORT_ENFORCE (byte_size == new_byte_size,
86+ " Cannot reshape tensor " , tensor.Shape ().ToString (), " to " , new_shape.ToString (),
87+ " with data type " , DataTypeImpl::ToString (new_data_type), " . The byte size of the original tensor is " ,
88+ byte_size, " and the byte size of the new tensor is " , new_byte_size);
89+ return {new_data_type, new_shape, const_cast <void *>(tensor.DataRaw ()), tensor.Location ()};
90+ }
91+
5292} // namespace webgpu
5393} // namespace onnxruntime
0 commit comments