Skip to content

Commit 913573f

Browse files
fs-eiregedoensmax
authored andcommitted
[webgpu] support bool for binary operators (microsoft#25674)
### Description Currently boolean types are not supported as inputs of binary operators in WebGPU. This change adds the support. ### Motivation and Context In WebGPU, `bool` is not a valid type for storage. To make it work with a storage buffer, we have to store a `u32` value representing 4 bool values. To make it work with the existing WebGPU framework, we need to ensure all modes of the binary operator program always use components == 4 for bool.
1 parent 932e175 commit 913573f

File tree

3 files changed

+77
-32
lines changed

3 files changed

+77
-32
lines changed

onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc

Lines changed: 69 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const
1414
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
1515
const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
1616

17+
const bool a_is_bool = Inputs()[0].var_type == ProgramVariableDataType::Boolx4;
18+
const bool b_is_bool = Inputs()[1].var_type == ProgramVariableDataType::Boolx4;
19+
1720
shader.AdditionalImplementation() << additional_impl_;
1821

1922
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size");
@@ -37,58 +40,78 @@ Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const
3740
}
3841
} else {
3942
const auto& c_indices = shader.AddIndices("bcast_indices");
43+
// Use indices helpers to calculate the offset of A and B.
44+
const auto& a_indices = shader.AddIndices("a_indices");
45+
const auto& b_indices = shader.AddIndices("b_indices");
46+
4047
// check whether can use vectorize mode.
4148
// If either last dimension of A or B is divisible by 4, or the shared dimension is divisible by 4, vectorize mode
4249
// can be enabled.
4350
// In vectorize mode, the source data of A and B will be loaded only once to calculate 4 output values.
44-
// Use indices helpers to calculate the offset of A and B.
4551
if (vectorize_) {
46-
const auto& a_indices = shader.AddIndices("a_indices");
47-
const auto& b_indices = shader.AddIndices("b_indices");
48-
4952
shader.MainFunctionBody() << "let outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") << ";\n"
5053
<< "let offset_a = " << a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
5154
<< "let offset_b = " << b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n";
5255
// get A data
53-
if (a.NumComponents() == 4) {
56+
if (is_lhs_use_4_components_) {
5457
shader.MainFunctionBody() << "let a = " << a.GetByOffset("offset_a / 4") << ";\n";
58+
} else if (a_is_bool) {
59+
shader.MainFunctionBody() << "let a = " << a.GetByOffset("offset_a / 4") << "[offset_a % 4];\n";
5560
} else {
5661
shader.MainFunctionBody() << "let a = input_a_value_t(" << a.GetByOffset("offset_a") << ");\n";
5762
}
5863

5964
// get B data
60-
if (b.NumComponents() == 4) {
65+
if (is_rhs_use_4_components_) {
6166
shader.MainFunctionBody() << "let b = " << b.GetByOffset("offset_b / 4") << ";\n";
67+
} else if (b_is_bool) {
68+
shader.MainFunctionBody() << "let b = " << b.GetByOffset("offset_b / 4") << "[offset_b % 4];\n";
6269
} else {
6370
shader.MainFunctionBody() << "let b = input_b_value_t(" << b.GetByOffset("offset_b") << ");\n";
6471
}
6572
} else {
6673
// In broadcast mode, each element of the vec4 value of A and B will be loaded separately to calculate the output value.
6774
shader.MainFunctionBody() << "var outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") << ";\n"
68-
<< "let offset_a0 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
69-
<< "let offset_b0 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
75+
<< "let offset_a0 = " << a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
76+
<< "let offset_b0 = " << b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
7077
<< "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 1") << ";\n"
71-
<< "let offset_a1 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
72-
<< "let offset_b1 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
78+
<< "let offset_a1 = " << a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
79+
<< "let offset_b1 = " << b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
7380
<< "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 2") << ";\n"
74-
<< "let offset_a2 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
75-
<< "let offset_b2 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
81+
<< "let offset_a2 = " << a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
82+
<< "let offset_b2 = " << b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
7683
<< "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 3") << ";\n"
77-
<< "let offset_a3 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
78-
<< "let offset_b3 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n";
84+
<< "let offset_a3 = " << a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"
85+
<< "let offset_b3 = " << b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n";
7986

8087
// get A data
81-
shader.MainFunctionBody() << "let a = vec4<input_a_value_t>("
82-
<< a.GetByOffset("offset_a0") << ", "
83-
<< a.GetByOffset("offset_a1") << ", "
84-
<< a.GetByOffset("offset_a2") << ", "
85-
<< a.GetByOffset("offset_a3") << ");\n";
88+
if (a_is_bool) {
89+
shader.MainFunctionBody() << "let a = vec4<bool>("
90+
<< a.GetByOffset("offset_a0 / 4") << "[offset_a0 % 4], "
91+
<< a.GetByOffset("offset_a1 / 4") << "[offset_a1 % 4], "
92+
<< a.GetByOffset("offset_a2 / 4") << "[offset_a2 % 4], "
93+
<< a.GetByOffset("offset_a3 / 4") << "[offset_a3 % 4]);\n";
94+
} else {
95+
shader.MainFunctionBody() << "let a = vec4<input_a_value_t>("
96+
<< a.GetByOffset("offset_a0") << ", "
97+
<< a.GetByOffset("offset_a1") << ", "
98+
<< a.GetByOffset("offset_a2") << ", "
99+
<< a.GetByOffset("offset_a3") << ");\n";
100+
}
86101
// get B data
87-
shader.MainFunctionBody() << "let b = vec4<input_b_value_t>("
88-
<< b.GetByOffset("offset_b0") << ", "
89-
<< b.GetByOffset("offset_b1") << ", "
90-
<< b.GetByOffset("offset_b2") << ", "
91-
<< b.GetByOffset("offset_b3") << ");\n";
102+
if (b_is_bool) {
103+
shader.MainFunctionBody() << "let b = vec4<bool>("
104+
<< b.GetByOffset("offset_b0 / 4") << "[offset_b0 % 4], "
105+
<< b.GetByOffset("offset_b1 / 4") << "[offset_b1 % 4], "
106+
<< b.GetByOffset("offset_b2 / 4") << "[offset_b2 % 4], "
107+
<< b.GetByOffset("offset_b3 / 4") << "[offset_b3 % 4]);\n";
108+
} else {
109+
shader.MainFunctionBody() << "let b = vec4<input_b_value_t>("
110+
<< b.GetByOffset("offset_b0") << ", "
111+
<< b.GetByOffset("offset_b1") << ", "
112+
<< b.GetByOffset("offset_b2") << ", "
113+
<< b.GetByOffset("offset_b3") << ");\n";
114+
}
92115
}
93116
}
94117

@@ -114,6 +137,12 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
114137
bool is_lhs_scalar = lhs_shape.IsScalar();
115138
bool is_rhs_scalar = rhs_shape.IsScalar();
116139

140+
// Check if either input is boolean
141+
// For boolean inputs, we need to handle them differently in the shader. This is because `bool` is not a valid type in
142+
// storage buffer. We have to use a `u32` to represent 4 boolean values.
143+
bool is_lhs_bool = lhs_tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
144+
bool is_rhs_bool = rhs_tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
145+
117146
bool vectorize = is_lhs_scalar || is_rhs_scalar || !is_broadcast;
118147
bool a_last_dim_divisible_by_4 = false;
119148
bool b_last_dim_divisible_by_4 = false;
@@ -157,6 +186,8 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
157186
is_broadcast,
158187
is_lhs_scalar,
159188
is_rhs_scalar,
189+
shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4,
190+
shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4,
160191
vectorize};
161192
program
162193
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
@@ -169,8 +200,8 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
169200
// Mode Element-wise
170201
// cache hint: "E{is_a_scalar}{is_b_scalar}"
171202
program
172-
.AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::Type, {is_lhs_scalar ? 1 : vec_size}, 4},
173-
{rhs_tensor, ProgramTensorMetadataDependency::Type, {is_rhs_scalar ? 1 : vec_size}, 4}})
203+
.AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4},
204+
{rhs_tensor, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}})
174205
.CacheHint("E" + std::to_string(is_lhs_scalar) + std::to_string(is_rhs_scalar));
175206
} else if (vectorize) {
176207
// reshape the dims to merge the shared dimension if available
@@ -187,13 +218,13 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
187218
reshaped_output_shape[reshaped_output_shape.NumDimensions() - 1] = output_shape.SizeFromDimension(output_shape.NumDimensions() - num_shared_dimension);
188219
}
189220

190-
if (shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4) {
191-
program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type, {(lhs_shape.Size() + 3) / 4}, 4});
221+
if (shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4 || is_lhs_bool) {
222+
program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4});
192223
} else {
193224
program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type});
194225
}
195-
if (shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4) {
196-
program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type, {(rhs_shape.Size() + 3) / 4}, 4});
226+
if (shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4 || is_rhs_bool) {
227+
program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4});
197228
} else {
198229
program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type});
199230
}
@@ -208,9 +239,11 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
208239
// Mode Broadcast
209240
// cache hint: "B"
210241
program
211-
.AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::TypeAndRank},
212-
{rhs_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
242+
.AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::TypeAndRank, ProgramInput::Flatten, is_lhs_bool ? 4 : 1},
243+
{rhs_tensor, ProgramTensorMetadataDependency::TypeAndRank, ProgramInput::Flatten, is_rhs_bool ? 4 : 1}})
213244
.AddIndices(output_tensor->Shape())
245+
.AddIndices(lhs_tensor->Shape())
246+
.AddIndices(rhs_tensor->Shape())
214247
.CacheHint("B");
215248
}
216249

@@ -343,5 +376,9 @@ WEBGPU_BINARY_IMPL(LessOrEqual, "vec4<u32>(vec4<input_a_element_t>(a) <= vec4<in
343376
WEBGPU_BINARY_VERSIONED_KERNEL(LessOrEqual, 12, 15, LessOrEqual, WebGpuSupportedNumberTypes())
344377
WEBGPU_BINARY_KERNEL(LessOrEqual, 16, LessOrEqual, WebGpuSupportedNumberTypes())
345378

379+
// And operator only supports tensor(bool).
380+
WEBGPU_BINARY_IMPL(And, "(vec4<input_a_element_t>(a) & vec4<input_b_element_t>(b))")
381+
WEBGPU_BINARY_KERNEL(And, 7, And, DataTypeImpl::GetTensorType<bool>())
382+
346383
} // namespace webgpu
347384
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,16 @@ class BinaryElementwiseProgram final : public Program<BinaryElementwiseProgram>
1818
const bool is_broadcast,
1919
const bool is_lhs_scalar,
2020
const bool is_rhs_scalar,
21+
const bool is_lhs_use_4_components,
22+
const bool is_rhs_use_4_components,
2123
const bool vectorize) : Program{kernel_name},
2224
expression_{expression},
2325
additional_impl_{additional_impl},
2426
is_broadcast_{is_broadcast},
2527
is_lhs_scalar_{is_lhs_scalar},
2628
is_rhs_scalar_{is_rhs_scalar},
29+
is_lhs_use_4_components_{is_lhs_use_4_components},
30+
is_rhs_use_4_components_{is_rhs_use_4_components},
2731
vectorize_{vectorize} {}
2832

2933
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -36,6 +40,8 @@ class BinaryElementwiseProgram final : public Program<BinaryElementwiseProgram>
3640
bool is_broadcast_;
3741
bool is_lhs_scalar_;
3842
bool is_rhs_scalar_;
43+
bool is_lhs_use_4_components_;
44+
bool is_rhs_use_4_components_;
3945
bool vectorize_;
4046
};
4147

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD
228228
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Less);
229229
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 15, LessOrEqual);
230230
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LessOrEqual);
231+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, And);
231232

232233
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Shape);
233234
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Shape);
@@ -511,6 +512,7 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
511512
KERNEL_CREATE_INFO(13, Less),
512513
KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual),
513514
KERNEL_CREATE_INFO(16, LessOrEqual),
515+
KERNEL_CREATE_INFO(7, And),
514516

515517
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Shape)>,
516518
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Shape)>,

0 commit comments

Comments
 (0)