Skip to content

Commit f83e661

Browse files
authored
webgpu support for DequantizeLinear (microsoft#24268)
webgpu support for DequantizeLinear
1 parent 04e0b50 commit f83e661

File tree

8 files changed

+353
-27
lines changed

8 files changed

+353
-27
lines changed

onnxruntime/core/providers/webgpu/program.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ constexpr std::string_view ProgramVariableDataTypeName[] = {
102102
"u8x4", // Uint8x4
103103
"u8x8", // Uint8x8
104104
"u8x16", // Uint8x16
105+
"i8x4", // Int8x4
106+
"i8x8", // Int8x8
107+
"i8x16", // Int8x16
105108
};
106109
std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) {
107110
os << ProgramVariableDataTypeName[std::underlying_type<decltype(type)>::type(type)];
@@ -129,6 +132,7 @@ int NumberOfComponents(ProgramVariableDataType type) {
129132
case ProgramVariableDataType::Float16x4:
130133
case ProgramVariableDataType::Boolx4:
131134
case ProgramVariableDataType::Uint8x4:
135+
case ProgramVariableDataType::Int8x4:
132136
return 4;
133137
case ProgramVariableDataType::Uint8x8:
134138
return 8;
@@ -142,6 +146,10 @@ int NumberOfComponents(ProgramVariableDataType type) {
142146
ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component /* = 1 */) {
143147
if (component == 1) {
144148
switch (element_type) {
149+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
150+
return ProgramVariableDataType::Uint8x4; // shader needs to be aware that only 1 value is valid
151+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
152+
return ProgramVariableDataType::Int8x4; // shader needs to be aware that only 1 value is valid
145153
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
146154
return ProgramVariableDataType::Float32;
147155
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
@@ -174,6 +182,8 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp
174182
switch (element_type) {
175183
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
176184
return ProgramVariableDataType::Uint8x4;
185+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
186+
return ProgramVariableDataType::Int8x4;
177187
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
178188
return ProgramVariableDataType::Float32x4;
179189
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:

onnxruntime/core/providers/webgpu/program.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,10 @@ enum class ProgramVariableDataType {
197197
Boolx4,
198198
Uint8x4,
199199
Uint8x8,
200-
Uint8x16
200+
Uint8x16,
201+
Int8x4,
202+
Int8x8,
203+
Int8x16,
201204
};
202205
#ifndef NDEBUG
203206
std::ostream& operator<<(std::ostream& os, ProgramVariableDataType);
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include <vector>
5+
6+
#include "core/util/math.h"
7+
#include "core/providers/webgpu/quantization/quantize_linear.h"
8+
#include "core/providers/webgpu/shader_helper.h"
9+
#include "core/providers/webgpu/webgpu_supported_types.h"
10+
#include "core/providers/webgpu/webgpu_utils.h"
11+
12+
namespace onnxruntime {
13+
namespace webgpu {
14+
15+
Status DequantizeLinearProgram::GenerateShaderCode(ShaderHelper& shader) const {
16+
const auto& x = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseElementTypeAlias);
17+
const auto& scale = shader.AddInput("scale", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
18+
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias);
19+
20+
shader.MainFunctionBody()
21+
<< shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
22+
<< "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n";
23+
24+
// Get x input
25+
if (packed_) {
26+
std::string unpack = (signed_) ? "unpack4xI8(x)" : "unpack4xU8(x)";
27+
if (output.NumComponents() == 1) {
28+
shader.MainFunctionBody()
29+
<< "let x = " << x.GetByOffset("global_idx / 4") << ";\n"
30+
<< "let x_vec = " << unpack << ";\n"
31+
<< "let x_value = x_vec[global_idx % 4];\n";
32+
} else {
33+
shader.MainFunctionBody()
34+
<< "let x = " << x.GetByOffset("global_idx") << ";\n"
35+
<< "let x_vec = " << unpack << ";\n"
36+
<< "let x_value = x_vec;\n";
37+
}
38+
} else {
39+
shader.MainFunctionBody()
40+
<< "let x_value = " << x.GetByOffset("global_idx") << ";\n";
41+
}
42+
43+
// Get scaler
44+
if (per_layer_) {
45+
// scale input is a scalar ()
46+
shader.MainFunctionBody()
47+
<< "let scale_value = " << scale.GetByOffset("0") << ";\n";
48+
} else if (per_axis_) {
49+
shader.MainFunctionBody()
50+
<< "let scale_index = " << output.IndicesGet("output_indices", "uniforms.axis") << ";\n"
51+
<< "let scale_value = " << scale.GetByOffset("scale_index") << ";\n";
52+
} else {
53+
// Block quantization. Scale input rank is same as input/output rank.
54+
shader.MainFunctionBody()
55+
<< "var scale_indices: scale_indices_t = output_indices;\n"
56+
<< "let index = " << scale.IndicesGet("scale_indices", "uniforms.axis") << "/ uniforms.block_size;\n"
57+
<< scale.IndicesSet("scale_indices", "uniforms.axis", "index") << ";\n"
58+
<< "let scale_value = " << scale.GetByIndices("scale_indices") << ";\n";
59+
}
60+
61+
// Get zero-point
62+
if (has_zeropoint_) {
63+
const auto& zero_point = shader.AddInput("zero_point", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
64+
65+
std::string unpack = (signed_) ? "unpack4xI8(zero_point_input)" : "unpack4xU8(zero_point_input)";
66+
if (per_layer_) {
67+
// zero-point input is a scalar
68+
if (packed_) {
69+
shader.MainFunctionBody()
70+
<< "let zero_point_input = " << zero_point.GetByOffset("0") << ";\n"
71+
<< "let zero_point_vec = " << unpack << ";\n"
72+
<< "let zero_point_value = zero_point_vec[0];\n";
73+
} else {
74+
shader.MainFunctionBody()
75+
<< "let zero_point_value = " << zero_point.GetByOffset("0") << ";\n";
76+
}
77+
} else if (per_axis_) {
78+
// zero-point input is a 1D tensor
79+
if (packed_) {
80+
shader.MainFunctionBody()
81+
<< "let zero_point_index = " << output.IndicesGet("output_indices", "uniforms.axis") << ";\n"
82+
<< "let zero_point_input = " << zero_point.GetByOffset("zero_point_index / 4") << ";\n"
83+
<< "let zero_point_vec = " << unpack << ";\n"
84+
<< "let zero_point_value = zero_point_vec[zero_point_index % 4];\n";
85+
} else {
86+
shader.MainFunctionBody()
87+
<< "let zero_point_index = " << output.IndicesGet("output_indices", "uniforms.axis") << ";\n"
88+
<< "let zero_point_value = " << zero_point.GetByOffset("zero_point_index") << ";\n";
89+
}
90+
} else {
91+
// BlockedQuantization. The zero-point input shape is same as the input shape except along axis.
92+
if (packed_) {
93+
shader.MainFunctionBody()
94+
<< "let zero_point_offset = " << scale.GetByIndices("scale_indices") << ";\n"
95+
<< "let zero_point_input = " << zero_point.GetByOffset("zero_point_offset / 4") << ";\n"
96+
<< "let zero_point_vec = " << unpack << ";\n"
97+
<< "let zero_point_value = zero_point_vec[zero_point_offset % 4];\n";
98+
} else {
99+
shader.MainFunctionBody()
100+
<< "let zero_point_value = " << zero_point.GetByIndices("scale_indices") << ";\n";
101+
}
102+
}
103+
} else {
104+
shader.MainFunctionBody()
105+
<< "let zero_point_value = input_element_t(0);\n";
106+
}
107+
108+
// compute and write output
109+
shader.MainFunctionBody()
110+
<< output.SetByOffset("global_idx", "(output_value_t(x_value) - scale_value_t(zero_point_value)) * scale_value");
111+
112+
return Status::OK();
113+
}
114+
115+
Status DequantizeLinear::ComputeInternal(ComputeContext& context) const {
116+
const auto* x = context.Input(0);
117+
const auto* x_scale = context.Input(1);
118+
const auto* x_zeropoint = context.Input(2);
119+
const auto x_shape = x->Shape();
120+
int64_t x_size = x_shape.Size();
121+
auto* output_tensor = context.Output(0, x_shape);
122+
int64_t x_scale_rank = x_scale->Shape().NumDimensions();
123+
124+
bool packed = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
125+
bool is_signed = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
126+
int64_t axis = (axis_ >= 0) ? axis_ : axis_ + x_shape.NumDimensions();
127+
128+
int max_components = GetMaxComponents(x_size);
129+
if (max_components != 4) {
130+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "DequantizeLinear: components must be 4, but got ", max_components);
131+
}
132+
133+
// scaler - single scaler for all elements
134+
bool per_layer = x_scale_rank == 0 || (x_scale_rank == 1 && x_scale->Shape()[0] == 1);
135+
136+
// 1D tensor - 1 scaler for per axis
137+
bool per_axis = per_layer == false && x_scale_rank == 1;
138+
139+
bool use_components = per_layer && (!packed || max_components == 4);
140+
int components = use_components ? max_components : 1;
141+
int input_component = use_components && !packed ? max_components : 1;
142+
143+
DequantizeLinearProgram program{packed, is_signed, per_layer, per_axis, x_zeropoint != nullptr};
144+
145+
program
146+
.AddInputs({{x, ProgramTensorMetadataDependency::TypeAndRank, input_component}})
147+
.AddInputs({{x_scale, ProgramTensorMetadataDependency::TypeAndRank}})
148+
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None, components})
149+
.SetDispatchGroupSize((x_size / components + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
150+
.AddUniformVariables({{static_cast<uint32_t>(axis)}})
151+
.AddUniformVariables({{static_cast<uint32_t>(block_size_)}})
152+
.AddUniformVariables({{static_cast<uint32_t>(x_size / components)}})
153+
.CacheHint(std::to_string(axis), std::to_string(is_signed), std::to_string(per_layer), std::to_string(per_axis), std::to_string(block_size_));
154+
155+
if (x_zeropoint != nullptr) {
156+
program.AddInputs({{x_zeropoint, ProgramTensorMetadataDependency::TypeAndRank}});
157+
}
158+
159+
return context.RunProgram(program);
160+
}
161+
162+
namespace {
163+
const std::vector<MLDataType>& DequantizeLinearConstraints() {
164+
static std::vector<MLDataType> types{
165+
DataTypeImpl::GetTensorType<int8_t>(),
166+
DataTypeImpl::GetTensorType<uint8_t>(),
167+
DataTypeImpl::GetTensorType<int32_t>()};
168+
return types;
169+
}
170+
} // namespace
171+
172+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
173+
DequantizeLinear,
174+
kOnnxDomain,
175+
10, 12,
176+
kWebGpuExecutionProvider,
177+
(*KernelDefBuilder::Create())
178+
.TypeConstraint("T", DequantizeLinearConstraints()),
179+
DequantizeLinear);
180+
181+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
182+
DequantizeLinear,
183+
kOnnxDomain,
184+
13, 18,
185+
kWebGpuExecutionProvider,
186+
(*KernelDefBuilder::Create())
187+
.TypeConstraint("T", DequantizeLinearConstraints()),
188+
DequantizeLinear);
189+
190+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
191+
DequantizeLinear,
192+
kOnnxDomain,
193+
19, 20,
194+
kWebGpuExecutionProvider,
195+
(*KernelDefBuilder::Create())
196+
.TypeConstraint("T1", DequantizeLinearConstraints())
197+
.TypeConstraint("T2", WebGpuSupportedFloatTypes()),
198+
DequantizeLinear);
199+
200+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
201+
DequantizeLinear,
202+
kOnnxDomain,
203+
21, 22,
204+
kWebGpuExecutionProvider,
205+
(*KernelDefBuilder::Create())
206+
.TypeConstraint("T1", DequantizeLinearConstraints())
207+
.TypeConstraint("T2", WebGpuSupportedFloatTypes()),
208+
DequantizeLinear);
209+
210+
ONNX_OPERATOR_KERNEL_EX(
211+
DequantizeLinear,
212+
kOnnxDomain,
213+
23,
214+
kWebGpuExecutionProvider,
215+
(*KernelDefBuilder::Create())
216+
.TypeConstraint("T1", DequantizeLinearConstraints())
217+
.TypeConstraint("T2", WebGpuSupportedFloatTypes()),
218+
DequantizeLinear);
219+
220+
} // namespace webgpu
221+
} // namespace onnxruntime
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/webgpu_kernel.h"
7+
8+
namespace onnxruntime {
9+
namespace webgpu {
10+
11+
class DequantizeLinearProgram final : public Program<DequantizeLinearProgram> {
12+
public:
13+
DequantizeLinearProgram(const bool packed, const bool issigned, const bool per_layer,
14+
const bool per_axis, bool has_zeropoint) : Program<DequantizeLinearProgram>{"DequantizeLinear"},
15+
packed_{packed},
16+
signed_{issigned},
17+
per_layer_{per_layer},
18+
per_axis_{per_axis},
19+
has_zeropoint_{has_zeropoint} {}
20+
21+
Status GenerateShaderCode(ShaderHelper& sh) const override;
22+
23+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"axis", ProgramUniformVariableDataType::Uint32},
24+
{"block_size", ProgramUniformVariableDataType::Uint32},
25+
{"output_size", ProgramUniformVariableDataType::Uint32});
26+
27+
private:
28+
bool packed_;
29+
bool signed_;
30+
bool per_layer_;
31+
bool per_axis_;
32+
bool has_zeropoint_;
33+
};
34+
35+
class DequantizeLinear final : public WebGpuKernel {
36+
public:
37+
DequantizeLinear(const OpKernelInfo& info) : WebGpuKernel(info) {
38+
axis_ = info.GetAttrOrDefault<int64_t>("axis", 1);
39+
block_size_ = info.GetAttrOrDefault<int64_t>("block_size", 0);
40+
output_dtype_ = info.GetAttrOrDefault<int64_t>("output_dtype", 0);
41+
}
42+
43+
Status ComputeInternal(ComputeContext& context) const override;
44+
45+
private:
46+
int64_t axis_;
47+
int64_t block_size_;
48+
int64_t output_dtype_;
49+
};
50+
51+
} // namespace webgpu
52+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/shader_helper.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va
168168
var_type == ProgramVariableDataType::Uint8x16,
169169
"Unexpected program variable type ", int(var_type), " for uint8 tensor");
170170
break;
171+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
172+
ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int8x4 ||
173+
var_type == ProgramVariableDataType::Int8x8 ||
174+
var_type == ProgramVariableDataType::Int8x16,
175+
"Unexpected program variable type ", int(var_type), " for int8 tensor");
176+
break;
171177
default:
172178
ORT_RETURN_IF(true, "Unsupported data type: ", element_type);
173179
// todo: add int4/uint4

onnxruntime/core/providers/webgpu/shader_variable.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ constexpr static const std::string_view STORAGE_TYPE_ARRAY[] = {
3232
"u32", // Uint8x4
3333
"vec2<u32>", // Uint8x8
3434
"vec4<u32>", // Uint8x16
35+
"u32", // Int8x4
3536
};
3637
constexpr static const auto STORAGE_TYPE = details::_to_std_array(STORAGE_TYPE_ARRAY);
3738

@@ -54,6 +55,7 @@ constexpr static const std::string_view VALUE_TYPE_ARRAY[] = {
5455
"u32", // Uint8x4 (u32 as 4 elements of uint8)
5556
"vec2<u32>", // Uint8x8 (vec2<u32> as 2x4 elements of uint8)
5657
"vec4<u32>", // Uint8x16 (vec4<u32> as 4x4 elements of uint8)
58+
"i32", // Int8x4
5759
};
5860
constexpr static const auto VALUE_TYPE = details::_to_std_array(VALUE_TYPE_ARRAY);
5961

@@ -76,6 +78,9 @@ constexpr static const std::string_view ELEMENT_TYPE_ARRAY[] = {
7678
"u32", // Uint8x4
7779
"u32", // Uint8x8
7880
"u32", // Uint8x16
81+
"i32", // Int8x4
82+
"i32", // Int8x8
83+
"i32", // Int8x16
7984
};
8085
constexpr static const auto ELEMENT_TYPE = details::_to_std_array(ELEMENT_TYPE_ARRAY);
8186

0 commit comments

Comments
 (0)