@@ -14,6 +14,9 @@ Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const
1414 const auto & b = shader.AddInput (" input_b" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
1515 const auto & c = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
1616
17+ const bool a_is_bool = Inputs ()[0 ].var_type == ProgramVariableDataType::Boolx4;
18+ const bool b_is_bool = Inputs ()[1 ].var_type == ProgramVariableDataType::Boolx4;
19+
1720 shader.AdditionalImplementation () << additional_impl_;
1821
1922 shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.vec_size" );
@@ -37,58 +40,78 @@ Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const
3740 }
3841 } else {
3942 const auto & c_indices = shader.AddIndices (" bcast_indices" );
43+ // Use indices helpers to calculate the offset of A and B.
44+ const auto & a_indices = shader.AddIndices (" a_indices" );
45+ const auto & b_indices = shader.AddIndices (" b_indices" );
46+
4047 // check whether can use vectorize mode.
4148 // If either last dimension of A or B is divisible by 4, or the shared dimension is divisible by 4, vectorize mode
4249 // can be enabled.
4350 // In vectorize mode, the source data of A and B will be loaded only once to calculate 4 output values.
44- // Use indices helpers to calculate the offset of A and B.
4551 if (vectorize_) {
46- const auto & a_indices = shader.AddIndices (" a_indices" );
47- const auto & b_indices = shader.AddIndices (" b_indices" );
48-
4952 shader.MainFunctionBody () << " let outputIndices = " << c_indices.OffsetToIndices (" global_idx * 4" ) << " ;\n "
5053 << " let offset_a = " << a_indices.BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
5154 << " let offset_b = " << b_indices.BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n " ;
5255 // get A data
53- if (a. NumComponents () == 4 ) {
56+ if (is_lhs_use_4_components_ ) {
5457 shader.MainFunctionBody () << " let a = " << a.GetByOffset (" offset_a / 4" ) << " ;\n " ;
58+ } else if (a_is_bool) {
59+ shader.MainFunctionBody () << " let a = " << a.GetByOffset (" offset_a / 4" ) << " [offset_a % 4];\n " ;
5560 } else {
5661 shader.MainFunctionBody () << " let a = input_a_value_t(" << a.GetByOffset (" offset_a" ) << " );\n " ;
5762 }
5863
5964 // get B data
60- if (b. NumComponents () == 4 ) {
65+ if (is_rhs_use_4_components_ ) {
6166 shader.MainFunctionBody () << " let b = " << b.GetByOffset (" offset_b / 4" ) << " ;\n " ;
67+ } else if (b_is_bool) {
68+ shader.MainFunctionBody () << " let b = " << b.GetByOffset (" offset_b / 4" ) << " [offset_b % 4];\n " ;
6269 } else {
6370 shader.MainFunctionBody () << " let b = input_b_value_t(" << b.GetByOffset (" offset_b" ) << " );\n " ;
6471 }
6572 } else {
6673 // In broadcast mode, each element of the vec4 value of A and B will be loaded separately to calculate the output value.
6774 shader.MainFunctionBody () << " var outputIndices = " << c_indices.OffsetToIndices (" global_idx * 4" ) << " ;\n "
68- << " let offset_a0 = " << a .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
69- << " let offset_b0 = " << b .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
75+ << " let offset_a0 = " << a_indices .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
76+ << " let offset_b0 = " << b_indices .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
7077 << " outputIndices = " << c_indices.OffsetToIndices (" global_idx * 4 + 1" ) << " ;\n "
71- << " let offset_a1 = " << a .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
72- << " let offset_b1 = " << b .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
78+ << " let offset_a1 = " << a_indices .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
79+ << " let offset_b1 = " << b_indices .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
7380 << " outputIndices = " << c_indices.OffsetToIndices (" global_idx * 4 + 2" ) << " ;\n "
74- << " let offset_a2 = " << a .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
75- << " let offset_b2 = " << b .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
81+ << " let offset_a2 = " << a_indices .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
82+ << " let offset_b2 = " << b_indices .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
7683 << " outputIndices = " << c_indices.OffsetToIndices (" global_idx * 4 + 3" ) << " ;\n "
77- << " let offset_a3 = " << a .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
78- << " let offset_b3 = " << b .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n " ;
84+ << " let offset_a3 = " << a_indices .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n "
85+ << " let offset_b3 = " << b_indices .BroadcastedIndicesToOffset (" outputIndices" , c_indices) << " ;\n " ;
7986
8087 // get A data
81- shader.MainFunctionBody () << " let a = vec4<input_a_value_t>("
82- << a.GetByOffset (" offset_a0" ) << " , "
83- << a.GetByOffset (" offset_a1" ) << " , "
84- << a.GetByOffset (" offset_a2" ) << " , "
85- << a.GetByOffset (" offset_a3" ) << " );\n " ;
88+ if (a_is_bool) {
89+ shader.MainFunctionBody () << " let a = vec4<bool>("
90+ << a.GetByOffset (" offset_a0 / 4" ) << " [offset_a0 % 4], "
91+ << a.GetByOffset (" offset_a1 / 4" ) << " [offset_a1 % 4], "
92+ << a.GetByOffset (" offset_a2 / 4" ) << " [offset_a2 % 4], "
93+ << a.GetByOffset (" offset_a3 / 4" ) << " [offset_a3 % 4]);\n " ;
94+ } else {
95+ shader.MainFunctionBody () << " let a = vec4<input_a_value_t>("
96+ << a.GetByOffset (" offset_a0" ) << " , "
97+ << a.GetByOffset (" offset_a1" ) << " , "
98+ << a.GetByOffset (" offset_a2" ) << " , "
99+ << a.GetByOffset (" offset_a3" ) << " );\n " ;
100+ }
86101 // get B data
87- shader.MainFunctionBody () << " let b = vec4<input_b_value_t>("
88- << b.GetByOffset (" offset_b0" ) << " , "
89- << b.GetByOffset (" offset_b1" ) << " , "
90- << b.GetByOffset (" offset_b2" ) << " , "
91- << b.GetByOffset (" offset_b3" ) << " );\n " ;
102+ if (b_is_bool) {
103+ shader.MainFunctionBody () << " let b = vec4<bool>("
104+ << b.GetByOffset (" offset_b0 / 4" ) << " [offset_b0 % 4], "
105+ << b.GetByOffset (" offset_b1 / 4" ) << " [offset_b1 % 4], "
106+ << b.GetByOffset (" offset_b2 / 4" ) << " [offset_b2 % 4], "
107+ << b.GetByOffset (" offset_b3 / 4" ) << " [offset_b3 % 4]);\n " ;
108+ } else {
109+ shader.MainFunctionBody () << " let b = vec4<input_b_value_t>("
110+ << b.GetByOffset (" offset_b0" ) << " , "
111+ << b.GetByOffset (" offset_b1" ) << " , "
112+ << b.GetByOffset (" offset_b2" ) << " , "
113+ << b.GetByOffset (" offset_b3" ) << " );\n " ;
114+ }
92115 }
93116 }
94117
@@ -114,6 +137,12 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
114137 bool is_lhs_scalar = lhs_shape.IsScalar ();
115138 bool is_rhs_scalar = rhs_shape.IsScalar ();
116139
140+ // Check if either input is boolean
141+ // For boolean inputs, we need to handle them differently in the shader. This is because `bool` is not a valid type in
142+ // storage buffer. We have to use a `u32` to represent 4 boolean values.
143+ bool is_lhs_bool = lhs_tensor->GetElementType () == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
144+ bool is_rhs_bool = rhs_tensor->GetElementType () == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
145+
117146 bool vectorize = is_lhs_scalar || is_rhs_scalar || !is_broadcast;
118147 bool a_last_dim_divisible_by_4 = false ;
119148 bool b_last_dim_divisible_by_4 = false ;
@@ -157,6 +186,8 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
157186 is_broadcast,
158187 is_lhs_scalar,
159188 is_rhs_scalar,
189+ shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4,
190+ shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4,
160191 vectorize};
161192 program
162193 .SetDispatchGroupSize ((vec_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
@@ -169,8 +200,8 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
169200 // Mode Element-wise
170201 // cache hint: "E{is_a_scalar}{is_b_scalar}"
171202 program
172- .AddInputs ({{lhs_tensor, ProgramTensorMetadataDependency::Type, {is_lhs_scalar ? 1 : vec_size} , 4 },
173- {rhs_tensor, ProgramTensorMetadataDependency::Type, {is_rhs_scalar ? 1 : vec_size} , 4 }})
203+ .AddInputs ({{lhs_tensor, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten , 4 },
204+ {rhs_tensor, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten , 4 }})
174205 .CacheHint (" E" + std::to_string (is_lhs_scalar) + std::to_string (is_rhs_scalar));
175206 } else if (vectorize) {
176207 // reshape the dims to merge the shared dimension if available
@@ -187,13 +218,13 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
187218 reshaped_output_shape[reshaped_output_shape.NumDimensions () - 1 ] = output_shape.SizeFromDimension (output_shape.NumDimensions () - num_shared_dimension);
188219 }
189220
190- if (shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4) {
191- program.AddInput ({lhs_tensor, ProgramTensorMetadataDependency::Type, {(lhs_shape. Size () + 3 ) / 4 } , 4 });
221+ if (shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4 || is_lhs_bool ) {
222+ program.AddInput ({lhs_tensor, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten , 4 });
192223 } else {
193224 program.AddInput ({lhs_tensor, ProgramTensorMetadataDependency::Type});
194225 }
195- if (shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4) {
196- program.AddInput ({rhs_tensor, ProgramTensorMetadataDependency::Type, {(rhs_shape. Size () + 3 ) / 4 } , 4 });
226+ if (shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4 || is_rhs_bool ) {
227+ program.AddInput ({rhs_tensor, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten , 4 });
197228 } else {
198229 program.AddInput ({rhs_tensor, ProgramTensorMetadataDependency::Type});
199230 }
@@ -208,9 +239,11 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
208239 // Mode Broadcast
209240 // cache hint: "B"
210241 program
211- .AddInputs ({{lhs_tensor, ProgramTensorMetadataDependency::TypeAndRank},
212- {rhs_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
242+ .AddInputs ({{lhs_tensor, ProgramTensorMetadataDependency::TypeAndRank, ProgramInput::Flatten, is_lhs_bool ? 4 : 1 },
243+ {rhs_tensor, ProgramTensorMetadataDependency::TypeAndRank, ProgramInput::Flatten, is_rhs_bool ? 4 : 1 }})
213244 .AddIndices (output_tensor->Shape ())
245+ .AddIndices (lhs_tensor->Shape ())
246+ .AddIndices (rhs_tensor->Shape ())
214247 .CacheHint (" B" );
215248 }
216249
@@ -343,5 +376,9 @@ WEBGPU_BINARY_IMPL(LessOrEqual, "vec4<u32>(vec4<input_a_element_t>(a) <= vec4<in
343376WEBGPU_BINARY_VERSIONED_KERNEL(LessOrEqual, 12 , 15 , LessOrEqual, WebGpuSupportedNumberTypes())
344377WEBGPU_BINARY_KERNEL(LessOrEqual, 16 , LessOrEqual, WebGpuSupportedNumberTypes())
345378
379+ // And operator only supports tensor(bool).
380+ WEBGPU_BINARY_IMPL(And, " (vec4<input_a_element_t>(a) & vec4<input_b_element_t>(b))" )
381+ WEBGPU_BINARY_KERNEL(And, 7 , And, DataTypeImpl::GetTensorType<bool >())
382+
346383} // namespace webgpu
347384} // namespace onnxruntime
0 commit comments