22// Licensed under the MIT License.
33
44#include " core/providers/webgpu/math/gemm.h"
5- #include " core/providers/webgpu/math/gemm_vec4 .h"
5+ #include " core/providers/webgpu/math/gemm_packed .h"
66
77#include < vector>
88
@@ -38,130 +38,52 @@ WEBGPU_GEMM_VERSIONED_KERNEL(9, 10)
3838WEBGPU_GEMM_VERSIONED_KERNEL(11 , 12 )
3939WEBGPU_GEMM_KERNEL(13 )
4040
41- Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
42- const uint32_t TILE_SIZE = 16 ;
43-
44- // Add shared memory arrays
45- shader.AdditionalImplementation () << " var<workgroup> tile_a: array<array<output_value_t, " << TILE_SIZE << " >, " << TILE_SIZE << " >;\n "
46- << " var<workgroup> tile_b: array<array<output_value_t, " << TILE_SIZE << " >, " << TILE_SIZE << " >;\n\n " ;
47-
41+ Status GemmNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
4842 const ShaderVariableHelper& output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
4943
50- shader.MainFunctionBody () << " var value = output_value_t(0);\n\n "
51- << " let tile_col_start = (workgroup_idx % uniforms.num_tile_n) * " << TILE_SIZE << " u;\n "
52- << " let tile_row_start = (workgroup_idx / uniforms.num_tile_n) * " << TILE_SIZE << " u;\n " ;
44+ shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" )
45+ << " let m = global_idx / uniforms.N;\n "
46+ << " let n = global_idx % uniforms.N;\n "
47+ << " var value = output_value_t(0);\n "
48+ << " \n " ;
5349
5450 // When A or B is empty, we don't bind A and B. Because WebGPU doesn't support binding a zero-sized buffer.
5551 if (need_handle_matmul_) {
5652 const ShaderVariableHelper& A = shader.AddInput (" A" , ShaderUsage::UseUniform);
5753 const ShaderVariableHelper& B = shader.AddInput (" B" , ShaderUsage::UseUniform);
5854
59- shader.MainFunctionBody ()
60- << " let num_tiles = (uniforms.K - 1u) / " << TILE_SIZE << " u + 1u;\n "
61- << " var k_start = 0u;\n "
62- << " for (var t = 0u; t < num_tiles; t = t + 1u) {\n " ;
63-
64- // Fill workgroup shared memory
65- if (transA_ && transB_) {
66- shader.MainFunctionBody () << " var col = tile_row_start + local_id.x;\n "
67- << " var row = k_start + local_id.y;\n "
68- << " if (col < uniforms.M && row < uniforms.K) {\n "
69- << " tile_a[local_id.y][local_id.x] = " << A.GetByOffset (" row * uniforms.M + col" ) << " ;\n "
70- << " } else {\n "
71- << " tile_a[local_id.y][local_id.x] = output_value_t(0);\n "
72- << " }\n\n "
73- << " col = k_start + local_id.x;\n "
74- << " row = tile_col_start + local_id.y;\n "
75- << " if (col < uniforms.K && row < uniforms.N) {\n "
76- << " tile_b[local_id.y][local_id.x] = " << B.GetByOffset (" row * uniforms.K + col" ) << " ;\n "
77- << " } else {\n "
78- << " tile_b[local_id.y][local_id.x] = output_value_t(0);\n "
79- << " }\n " ;
80- } else if (transA_ && !transB_) {
81- shader.MainFunctionBody () << " var col = tile_row_start + local_id.x;\n "
82- << " var row = k_start + local_id.y;\n "
83- << " if (col < uniforms.M && row < uniforms.K) {\n "
84- << " tile_a[local_id.y][local_id.x] = " << A.GetByOffset (" row * uniforms.M + col" ) << " ;\n "
85- << " } else {\n "
86- << " tile_a[local_id.y][local_id.x] = output_value_t(0);\n "
87- << " }\n\n "
88- << " col = tile_col_start + local_id.x;\n "
89- << " row = k_start + local_id.y;\n "
90- << " if (col < uniforms.N && row < uniforms.K) {\n "
91- << " tile_b[local_id.y][local_id.x] = " << B.GetByOffset (" row * uniforms.N + col" ) << " ;\n "
92- << " } else {\n "
93- << " tile_b[local_id.y][local_id.x] = output_value_t(0);\n "
94- << " }\n " ;
95- } else if (!transA_ && transB_) {
96- shader.MainFunctionBody () << " var col = k_start + local_id.x;\n "
97- << " var row = tile_row_start + local_id.y;\n "
98- << " if (col < uniforms.K && row < uniforms.M) {\n "
99- << " tile_a[local_id.y][local_id.x] = " << A.GetByOffset (" row * uniforms.K + col" ) << " ;\n "
100- << " } else {\n "
101- << " tile_a[local_id.y][local_id.x] = output_value_t(0);\n "
102- << " }\n\n "
103- << " col = k_start + local_id.x;\n "
104- << " row = tile_col_start + local_id.y;\n "
105- << " if (col < uniforms.K && row < uniforms.N) {\n "
106- << " tile_b[local_id.y][local_id.x] = " << B.GetByOffset (" row * uniforms.K + col" ) << " ;\n "
107- << " } else {\n "
108- << " tile_b[local_id.y][local_id.x] = output_value_t(0);\n "
109- << " }\n " ;
110- } else {
111- shader.MainFunctionBody () << " var col = k_start + local_id.x;\n "
112- << " var row = tile_row_start + local_id.y;\n "
113- << " if (col < uniforms.K && row < uniforms.M) {\n "
114- << " tile_a[local_id.y][local_id.x] = " << A.GetByOffset (" row * uniforms.K + col" ) << " ;\n "
115- << " } else {\n "
116- << " tile_a[local_id.y][local_id.x] = output_value_t(0);\n "
117- << " }\n\n "
118- << " col = tile_col_start + local_id.x;\n "
119- << " row = k_start + local_id.y;\n "
120- << " if (col < uniforms.N && row < uniforms.K) {\n "
121- << " tile_b[local_id.y][local_id.x] = " << B.GetByOffset (" row * uniforms.N + col" ) << " ;\n "
122- << " } else {\n "
123- << " tile_b[local_id.y][local_id.x] = output_value_t(0);\n "
124- << " }\n " ;
125- }
126-
127- shader.MainFunctionBody () << " k_start = k_start + " << TILE_SIZE << " u;\n "
128- << " workgroupBarrier();\n\n "
129- << " for (var k = 0u; k < " << TILE_SIZE << " u; k = k + 1u) {\n " ;
55+ shader.MainFunctionBody () << " for (var k = 0u; k < uniforms.K; k = k + 1u) {\n " ;
13056
13157 if (transA_ && transB_) {
132- shader.MainFunctionBody () << " value = value + tile_a[k][local_id.y] * tile_b[local_id.x][k];\n " ;
58+ shader.MainFunctionBody () << " value = value + " << A.GetByOffset (" k * uniforms.M + m" )
59+ << " * " << B.GetByOffset (" n * uniforms.K + k" ) << " ;\n " ;
13360 } else if (transA_ && !transB_) {
134- shader.MainFunctionBody () << " value = value + tile_a[k][local_id.y] * tile_b[k][local_id.x];\n " ;
61+ shader.MainFunctionBody () << " value = value + " << A.GetByOffset (" k * uniforms.M + m" )
62+ << " * " << B.GetByOffset (" k * uniforms.N + n" ) << " ;\n " ;
13563 } else if (!transA_ && transB_) {
136- shader.MainFunctionBody () << " value = value + tile_a[local_id.y][k] * tile_b[local_id.x][k];\n " ;
64+ shader.MainFunctionBody () << " value = value + " << A.GetByOffset (" m * uniforms.K + k" )
65+ << " * " << B.GetByOffset (" n * uniforms.K + k" ) << " ;\n " ;
13766 } else {
138- shader.MainFunctionBody () << " value = value + tile_a[local_id.y][k] * tile_b[k][local_id.x];\n " ;
67+ shader.MainFunctionBody () << " value = value + " << A.GetByOffset (" m * uniforms.K + k" )
68+ << " * " << B.GetByOffset (" k * uniforms.N + n" ) << " ;\n " ;
13969 }
140-
141- shader.MainFunctionBody () << " }\n "
142- << " workgroupBarrier();\n "
143- << " }\n\n " ;
70+ shader.MainFunctionBody () << " }\n "
71+ << " \n " ;
14472 }
14573
14674 // Calculate Alpha
14775 if (alpha_) {
14876 shader.MainFunctionBody () << " value = value * output_value_t(uniforms.alpha);\n " ;
14977 }
15078
151- shader.MainFunctionBody () << " let m = tile_row_start + local_id.y;\n "
152- << " let n = tile_col_start + local_id.x;\n " ;
153-
15479 // Calculate Bias
15580 if (need_handle_bias_) {
15681 const ShaderVariableHelper& C = shader.AddInput (" C" , ShaderUsage::UseUniform);
15782 shader.MainFunctionBody () << " value = value + output_value_t(uniforms.beta) * "
15883 << C.GetByOffset (C.BroadcastedIndicesToOffset (" vec2(m, n)" , output)) << " ;\n " ;
15984 }
16085
161- // Write output
162- shader.MainFunctionBody () << " if (m < uniforms.M && n < uniforms.N) {\n "
163- << " " << output.SetByOffset (" m * uniforms.N + n" , " value" ) << " \n "
164- << " }\n " ;
86+ shader.MainFunctionBody () << output.SetByOffset (" global_idx" , " value" ) << " \n " ;
16587
16688 return Status::OK ();
16789}
@@ -182,14 +104,14 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
182104 return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input tensors A and B must be 2 dimensional." );
183105 }
184106
185- uint32_t M = onnxruntime::narrow<uint32_t >(transA_ ? A_shape[1 ] : A_shape[0 ]);
186- uint32_t K = onnxruntime::narrow<uint32_t >(transA_ ? A_shape[0 ] : A_shape[1 ]);
187- uint32_t N = onnxruntime::narrow<uint32_t >(transB_ ? B_shape[0 ] : B_shape[1 ]);
188-
189107 if ((transA_ ? A_shape[0 ] : A_shape[1 ]) != (transB_ ? B_shape[1 ] : B_shape[0 ])) {
190108 return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Inner dimensions of A and B must match." );
191109 }
192110
111+ int64_t M = transA_ ? A_shape[1 ] : A_shape[0 ];
112+ int64_t K = transA_ ? A_shape[0 ] : A_shape[1 ];
113+ int64_t N = transB_ ? B_shape[0 ] : B_shape[1 ];
114+
193115 std::vector<int64_t > output_dims{M, N};
194116 auto * Y = context.Output (0 , output_dims);
195117 int64_t output_size = Y->Shape ().Size ();
@@ -198,42 +120,36 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
198120 return Status::OK ();
199121 }
200122
201- // First try vec4 optimization if possible
202- if (CanApplyGemmVec4 (A, B)) {
203- return ApplyGemmVec4 (A, B, C, transA_, transB_, alpha_, beta_, context, Y);
204- }
205-
206123 // WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty.
207124 bool need_handle_matmul = A_shape.Size () > 0 && B_shape.Size () > 0 ;
208125 bool need_handle_bias = C && beta_;
209126
210- GemmProgram program{transA_, transB_, alpha_, need_handle_bias, need_handle_matmul};
127+ if (M <= 8 && N <= 8 && K <= 8 ) {
128+ // Use naive implementation for small matrices
129+ GemmNaiveProgram program{transA_, transB_, alpha_, need_handle_bias, need_handle_matmul};
130+ if (need_handle_matmul) {
131+ program.AddInputs ({{A, ProgramTensorMetadataDependency::Type},
132+ {B, ProgramTensorMetadataDependency::Type}});
133+ }
211134
212- if (need_handle_matmul) {
213- program.AddInputs ({{A, ProgramTensorMetadataDependency::Type},
214- {B, ProgramTensorMetadataDependency::Type}});
215- }
135+ if (need_handle_bias) {
136+ program.AddInput ({C, ProgramTensorMetadataDependency::Rank});
137+ }
216138
217- if (need_handle_bias) {
218- program.AddInput ({C, ProgramTensorMetadataDependency::Rank});
139+ program.CacheHint (alpha_, transA_, transB_)
140+ .AddOutputs ({{Y, ProgramTensorMetadataDependency::Type}})
141+ .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
142+ .SetWorkgroupSize (WORKGROUP_SIZE)
143+ .AddUniformVariables ({{static_cast <uint32_t >(output_size)},
144+ {static_cast <uint32_t >(M)},
145+ {static_cast <uint32_t >(N)},
146+ {static_cast <uint32_t >(K)},
147+ {alpha_},
148+ {beta_}});
149+ return context.RunProgram (program);
219150 }
220151
221- const uint32_t TILE_SIZE = 16 ;
222- const uint32_t num_tile_n = (N + TILE_SIZE - 1 ) / TILE_SIZE;
223- const uint32_t num_tile_m = (M + TILE_SIZE - 1 ) / TILE_SIZE;
224-
225- program.CacheHint (alpha_, transA_, transB_)
226- .AddOutputs ({{Y, ProgramTensorMetadataDependency::Type}})
227- .SetDispatchGroupSize (num_tile_n * num_tile_m)
228- .SetWorkgroupSize (TILE_SIZE, TILE_SIZE)
229- .AddUniformVariables ({{num_tile_n},
230- {M},
231- {N},
232- {K},
233- {alpha_},
234- {beta_}});
235-
236- return context.RunProgram (program);
152+ return ApplyGemmPacked (A, B, C, transA_, transB_, alpha_, beta_, context);
237153}
238154
239155} // namespace webgpu
0 commit comments