Skip to content

Commit d543210

Browse files
authored
[webgpu] add util functions for creating tensor view (microsoft#24566)
### Description The added util functions can be used in 2 ways: - create a reshaped tensor from an existing one. - create a reinterpret view of a different type (will be useful in (u)int4/(u)int8 operator implementation)
1 parent 9d22547 commit d543210

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

onnxruntime/core/providers/webgpu/compute_context.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class ComputeContext {
8888
//
8989
// Create CPU tensor.
9090
//
91+
// This method creates a tensor of the given data type and shape, using the CPU allocator.
92+
// The tensor owns the underlying CPU memory buffer.
93+
//
9194
template <typename TensorShapeType>
9295
Tensor CreateCPUTensor(MLDataType data_type, TensorShapeType&& shape) {
9396
AllocatorPtr allocator;
@@ -98,6 +101,9 @@ class ComputeContext {
98101
//
99102
// Create GPU tensor.
100103
//
104+
// This method creates a tensor of the given data type and shape, using the WebGPU allocator.
105+
// The tensor owns the underlying WebGPU storage buffer.
106+
//
101107
template <typename TensorShapeType>
102108
Tensor CreateGPUTensor(MLDataType data_type, TensorShapeType&& shape) {
103109
AllocatorPtr allocator;

onnxruntime/core/providers/webgpu/webgpu_utils.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55

66
#include <cstdint>
77
#include "core/common/common.h"
8+
#include "core/framework/tensor.h"
89
#include "core/framework/tensor_shape.h"
910

1011
namespace onnxruntime {
1112
namespace webgpu {
1213

14+
/**
15+
* Returns the maximum number of components `N` to be used as `vecN` for the given size.
16+
*/
1317
inline int GetMaxComponents(int64_t size) {
1418
if (size % 4 == 0) {
1519
return 4;
@@ -19,6 +23,11 @@ inline int GetMaxComponents(int64_t size) {
1923
return 1;
2024
}
2125

26+
/**
27+
* Returns a string representing a WGSL expression that sums the components of a value T.
28+
*
29+
* T can be a scalar S, vec2<S> or vec4<S>.
30+
*/
2231
inline std::string SumVector(std::string x, int components) {
2332
switch (components) {
2433
case 1:
@@ -49,5 +58,36 @@ inline std::string MakeScalarOrVectorType(int components, std::string_view data_
4958

5059
TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components);
5160

61+
/**
62+
* Create a reshaped tensor from an existing tensor.
63+
*
64+
* The specified new shape must have the same number of elements as the original tensor.
65+
*
66+
* The new tensor is a "view" of the original tensor. It uses the same data of the original tensor.
67+
* The new tensor does not take or share ownership of the underlying data. The original tensor must outlive the new tensor.
68+
*/
69+
inline Tensor CreateTensorView(const Tensor& tensor, const TensorShape& new_shape) {
70+
ORT_ENFORCE(tensor.Shape().Size() == new_shape.Size(), "Cannot reshape tensor ", tensor.Shape().ToString(), " to ", new_shape.ToString());
71+
return {tensor.DataType(), new_shape, const_cast<void*>(tensor.DataRaw()), tensor.Location()};
72+
}
73+
74+
/**
75+
* Create a reinterpreted tensor from an existing tensor with a new data type and shape.
76+
*
77+
* The new data type and shape must match the original tensor's storage size.
78+
*
79+
* The new tensor is a "view" of the original tensor. It uses the same data of the original tensor.
80+
* The new tensor does not take or share ownership of the underlying data. The original tensor must outlive the new tensor.
81+
*/
82+
inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, const TensorShape& new_shape) {
83+
auto byte_size = Tensor::CalculateTensorStorageSize(tensor.DataType(), tensor.Shape());
84+
auto new_byte_size = Tensor::CalculateTensorStorageSize(new_data_type, new_shape);
85+
ORT_ENFORCE(byte_size == new_byte_size,
86+
"Cannot reshape tensor ", tensor.Shape().ToString(), " to ", new_shape.ToString(),
87+
" with data type ", DataTypeImpl::ToString(new_data_type), ". The byte size of the original tensor is ",
88+
byte_size, " and the byte size of the new tensor is ", new_byte_size);
89+
return {new_data_type, new_shape, const_cast<void*>(tensor.DataRaw()), tensor.Location()};
90+
}
91+
5292
} // namespace webgpu
5393
} // namespace onnxruntime

0 commit comments

Comments
 (0)