66#include " core/providers/cpu/tensor/utils.h"
77#include " core/providers/webgpu/shader_helper.h"
88#include " core/providers/webgpu/webgpu_supported_types.h"
9-
9+ # include " core/providers/webgpu/nn/fuse_utils.h "
1010#include " core/providers/webgpu/data_transfer.h"
11+
1112namespace onnxruntime {
1213namespace webgpu {
1314
@@ -54,11 +55,12 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
5455 std::string process_bias;
5556 if (has_bias_) {
5657 shader.AddInput (" bias" , ShaderUsage::UseUniform);
57- process_bias = " value += output_value_t(bias[row + i]);" ;
58+ process_bias = is_channels_last_ ? " value += output_value_t(bias[col]) " : " value += output_value_t(bias[row + i]);" ;
5859 }
5960
61+ std::string apply_activation = GetActivationSnippet (activation_, " output_value_t" , " output_element_t" );
6062 const auto & output = shader.AddOutput (" output" , ShaderUsage::UseUniform |
61- ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
63+ ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias );
6264 const auto & batch_dims = shader.AddIndices (" batch_dims" );
6365
6466 int a_components = a.NumComponents ();
@@ -90,6 +92,7 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
9092 << " for (var i = 0u; i < " << output_number_ << " u; i++) {\n "
9193 << " var value = values[i];\n "
9294 << process_bias << " \n "
95+ << apply_activation << " \n "
9396 << " let cur_indices = output_indices_t(batch, row + i, col/ " << components << " );\n "
9497 << " let offset = " << output.IndicesToOffset (" cur_indices" ) << " ;\n "
9598 << output.SetByOffset (" offset" , " value" )
@@ -127,7 +130,7 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
127130 const int64_t a_rows = a->Shape ().NumDimensions () > 1 ? a->Shape ()[a->Shape ().NumDimensions () - 2 ] : 1 ;
128131 TensorShape output_shape_shader ({batch_size, a_rows, helper.N () / components});
129132
130- MatMulNaiveProgram program{output_rank, output_number, has_bias};
133+ MatMulNaiveProgram program{Activation (), output_rank, output_number, has_bias};
131134
132135 program
133136 .CacheHint (std::to_string (components), std::to_string (a_components), std::to_string (output_number))
@@ -147,11 +150,32 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
147150 return context.RunProgram (program);
148151 }
149152
150- int64_t batchA = a->Shape ().SizeToDimension (a->Shape ().NumDimensions () - 2 );
151- int64_t batchB = b->Shape ().SizeToDimension (b->Shape ().NumDimensions () - 2 );
153+ std::vector<const Tensor*> inputs (has_bias ? 3 : 2 );
154+ inputs[0 ] = a;
155+ inputs[1 ] = b;
156+ if (has_bias) {
157+ const auto * bias = context.Input (2 );
158+ inputs.push_back (bias);
159+ }
160+ auto program = CreateMatMulProgram (Activation (), inputs, output_tensor, false );
161+
162+ return context.RunProgram (program);
163+ }
164+
165+ MatMulProgram CreateMatMulProgram (const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output_tensor, bool is_channels_last,
166+ const TensorShape& input_a_reshape,
167+ const TensorShape& input_b_reshape) {
168+ const auto * a = inputs[0 ];
169+ const auto * b = inputs[1 ];
170+ bool has_bias = inputs.size () > 2 ;
171+ TensorShape a_shape = input_a_reshape.NumDimensions () > 0 ? input_a_reshape : a->Shape ();
172+ TensorShape b_shape = input_b_reshape.NumDimensions () > 0 ? input_b_reshape : b->Shape ();
173+
174+ MatMulComputeHelper helper;
175+ ORT_THROW_IF_ERROR (helper.Compute (a_shape, b_shape));
176+ int64_t batchA = a_shape.SizeToDimension (a_shape.NumDimensions () - 2 );
177+ int64_t batchB = b_shape.SizeToDimension (b_shape.NumDimensions () - 2 );
152178
153- TensorShape a_shape = a->Shape ();
154- TensorShape b_shape = b->Shape ();
155179 TensorShape output_shape = helper.OutputShape ();
156180
157181 const int64_t dim_output_outer = output_shape[output_shape.NumDimensions () - 2 ];
@@ -184,44 +208,46 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
184208 const int64_t batch_size = outer_dims.Size ();
185209
186210 // Get dimensions for matrix multiplication from TensorShape
187- const int32_t dim_a_outer = narrow<int32_t >(a_shape[a_shape.NumDimensions () - 2 ]); // left matrix second dimension
188- const int32_t dim_inner = narrow<int32_t >(a_shape[a_shape.NumDimensions () - 1 ]); // left matrix first dimension
189- const int32_t dim_b_outer = narrow<int32_t >(b_shape[b_shape.NumDimensions () - 1 ]); // right matrix first dimension
211+ const uint32_t dim_a_outer = narrow<uint32_t >(a_shape[a_shape.NumDimensions () - 2 ]); // left matrix second dimension
212+ const uint32_t dim_inner = narrow<uint32_t >(a_shape[a_shape.NumDimensions () - 1 ]); // left matrix first dimension
213+ const uint32_t dim_b_outer = narrow<uint32_t >(b_shape[b_shape.NumDimensions () - 1 ]); // right matrix first dimension
190214
191215 const bool is_vec4 = dim_inner % 4 == 0 && dim_b_outer % 4 == 0 ;
192216
193217 InlinedVector<int64_t > elements_per_thread = dim_a_outer <= 8
194218 ? InlinedVector<int64_t >({4 , 1 , 1 })
195219 : InlinedVector<int64_t >({4 , 4 , 1 });
196220
197- const uint32_t dispatch_x = narrow<uint32_t >((dim_b_outer + MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0 ] - 1 ) /
198- (MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0 ]));
199- const uint32_t dispatch_y = narrow<uint32_t >((dim_a_outer + MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1 ] - 1 ) /
200- (MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1 ]));
201- const uint32_t dispatch_z = narrow<uint32_t >((static_cast <uint32_t >(batch_size) + MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2 ] - 1 ) /
202- (MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2 ]));
221+ const uint32_t dispatch_x = narrow<uint32_t >((dim_b_outer + MatMul:: MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0 ] - 1 ) /
222+ (MatMul:: MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0 ]));
223+ const uint32_t dispatch_y = narrow<uint32_t >((dim_a_outer + MatMul:: MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1 ] - 1 ) /
224+ (MatMul:: MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1 ]));
225+ const uint32_t dispatch_z = narrow<uint32_t >((static_cast <uint32_t >(batch_size) + MatMul:: MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2 ] - 1 ) /
226+ (MatMul:: MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2 ]));
203227
204228 const int components = is_vec4 ? 4 : 1 ;
205229 const TensorShape a_shape_temp = CreateMatMulIntermediateShape (outer_dims_a, dim_a_outer, dim_inner, components);
206230 const TensorShape b_shape_temp = CreateMatMulIntermediateShape (outer_dims_b, dim_inner, dim_b_outer, components);
207231 const TensorShape output_shape_temp = TensorShape ({batch_size, dim_a_outer, dim_b_outer / components});
208232
209- MatMulProgram program{has_bias, is_vec4, elements_per_thread};
233+ MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last };
210234 program
211- .CacheHint (absl::StrJoin (elements_per_thread, " -" ), std::to_string (is_vec4))
235+ .CacheHint (activation. ToString (), absl::StrJoin (elements_per_thread, " -" ), std::to_string (is_vec4))
212236 .AddInputs ({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
213237 {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
214238 .AddOutputs ({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}})
215239 .AddUniformVariables ({{dim_a_outer}, {dim_b_outer}, {dim_inner}})
216240 .AddIndices (outer_dims)
217241 .SetDispatchGroupSize (dispatch_x, dispatch_y, dispatch_z)
218- .SetWorkgroupSize (MATMUL_PACKED_WORKGROUP_SIZE_X, MATMUL_PACKED_WORKGROUP_SIZE_Y, MATMUL_PACKED_WORKGROUP_SIZE_Z);
242+ .SetWorkgroupSize (MatMul:: MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul:: MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul:: MATMUL_PACKED_WORKGROUP_SIZE_Z);
219243
220244 if (has_bias) {
221- const auto * bias = context.Input (2 );
222- program.AddInput ({bias, ProgramTensorMetadataDependency::Rank, 1 });
245+ auto bias_components = is_channels_last ? components : 1 ;
246+ const auto * bias = inputs[2 ];
247+ TensorShape reduced_bias_shape = ReduceShapeByComponents (bias->Shape (), bias_components);
248+ program.AddInput ({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components});
223249 }
224- return context. RunProgram ( program) ;
250+ return program;
225251}
226252
227253} // namespace webgpu
0 commit comments