Skip to content

Commit ef77435

Browse files
[Native WebGPU] Fixed Conv2dMM and MatMul issues related indexing, hint, etc. (microsoft#24527)
### Description Fixed a few issues related to Conv2dMM and MatMul in the Native WebGPU backend. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 5c014e2 commit ef77435

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

onnxruntime/core/providers/webgpu/math/matmul.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<cons
232232

233233
MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last};
234234
program
235-
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4))
235+
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last)
236236
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
237237
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
238238
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}})

onnxruntime/core/providers/webgpu/math/matmul_packed.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,15 +185,16 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader,
185185
<< " " << (inner_elements_size == 3 ? "" : "acc[i] = BCached3 * ACached.w + acc[i];") << "\n"
186186
<< " }\n";
187187
}
188-
shader.MainFunctionBody() << " workgroupBarrier();\n"
189-
<< " }\n"; // main for loop
188+
shader.MainFunctionBody()
189+
<< " }\n"
190+
<< " workgroupBarrier();\n"
191+
<< " }\n"; // main for loop
190192

191193
// Write the results to the output buffer
192194
shader.MainFunctionBody()
193195
<< " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n"
194196
<< " mm_write(batch, globalRow + innerRow, globalCol, acc[innerRow]);\n"
195-
<< " }\n"
196-
<< "}\n";
197+
<< " }\n";
197198

198199
return Status::OK();
199200
}
@@ -217,8 +218,8 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader,
217218

218219
const auto tile_a_outer = workgroup_size_y * elements_per_thread_y;
219220
const auto tile_b_outer = workgroup_size_x * elements_per_thread_x;
220-
const auto tile_a_width = tile_inner;
221-
const auto tile_a_height = tile_a_outer;
221+
const auto tile_a_width = transpose_a ? tile_a_outer : tile_inner;
222+
const auto tile_a_height = transpose_a ? tile_inner : tile_a_outer;
222223

223224
if (!(tile_a_height % workgroup_size_y == 0 && tile_a_width % workgroup_size_x == 0 && tile_inner % workgroup_size_y == 0)) {
224225
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
@@ -243,7 +244,7 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader,
243244
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "")
244245
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
245246
<< " var kStart = 0;\n"
246-
<< " var acc: array<vec4<" << data_type << ">, rowPerThread>;\n";
247+
<< " var acc: array<array<" << data_type << ", colPerThread>, rowPerThread>;\n";
247248

248249
if (sequentially_access_by_threads) {
249250
shader.MainFunctionBody() << "let localRow = i32(local_id.y);\n"
@@ -277,7 +278,7 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader,
277278
<< " BCached[inner] = mm_Bsub[k][localCol + inner * " << workgroup_size_x << "];\n"
278279
<< " }\n"
279280
<< " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n"
280-
<< " let ACached = " << (transpose_a ? "mm_Asub[k][localCol + innerRow * " + std::to_string(workgroup_size_y) + "];" : "mm_Asub[localRow + innerRow * " + std::to_string(workgroup_size_y) + "][k];") << "\n"
281+
<< " let ACached = " << (transpose_a ? "mm_Asub[k][localRow + innerRow * " + std::to_string(workgroup_size_y) + "];" : "mm_Asub[localRow + innerRow * " + std::to_string(workgroup_size_y) + "][k];") << "\n"
281282
<< " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n"
282283
<< " acc[innerRow][innerCol] = acc[innerRow][innerCol] +\n"
283284
<< " ACached * BCached[innerCol];\n"

onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ Status Conv2dMMProgram::GenerateShaderCode(ShaderHelper& shader) const {
159159
<< declaration_functions.str()
160160
<< Conv2dCommonSnippet(x, w, activation_, "x_element_t", element_size_[0], element_size_[1], element_size_[2]);
161161
std::string data_type = "x_element_t";
162-
return is_vec4_ ? MatMulProgram::MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, /* transpose_a = */ !is_channels_last_, tile_inner_) : MatMulProgram::MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, false, tile_inner_, false, 0, sequentially_access_by_threads_);
162+
return is_vec4_ ? MatMulProgram::MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, /* transpose_a = */ !is_channels_last_, tile_inner_) : MatMulProgram::MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, !is_channels_last_, tile_inner_, /* split_t = */ false, 0, sequentially_access_by_threads_);
163163
}
164164

165165
Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::vector<const Tensor*>& inputs, const std::vector<uint32_t>& pads, const std::vector<uint32_t>& strides, const std::vector<uint32_t>& dilations, Tensor* output, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner, bool is_channels_last, bool sequentially_access_by_threads, const std::vector<TensorShape>& input_output_shapes) {

0 commit comments

Comments
 (0)