-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[webgpu] Optimize Conv by im2col-matmul #26603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,223 @@ | ||||||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||||||
| // Licensed under the MIT License. | ||||||
| #include <string> | ||||||
| #include <vector> | ||||||
|
|
||||||
| #include "core/providers/webgpu/webgpu_utils.h" | ||||||
| #include "core/providers/webgpu/nn/im2col_matmul.h" | ||||||
| #include "core/providers/webgpu/nn/activation_util.h" | ||||||
|
|
||||||
| namespace onnxruntime { | ||||||
| namespace webgpu { | ||||||
|
|
||||||
| namespace { | ||||||
|
|
||||||
| // TODO: move to common header. | ||||||
|
Check warning on line 15 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc
|
||||||
| template <typename T> | ||||||
| inline T ceil_div(T numerator, T denominator) { | ||||||
| return (numerator + denominator - 1) / denominator; | ||||||
| } | ||||||
|
|
||||||
| // Chooses the optimal tile size (M, N) for the im2col operation. | ||||||
| // This tile size is performance-tuned and varies depending on the target device. | ||||||
| std::pair<uint32_t, uint32_t> ChooseTileSize(uint32_t im2col_m, uint32_t im2col_n) { | ||||||
| // Define a list of preferred (tile_m, tile_n) pairs in descending order of preference. | ||||||
| const std::vector<std::pair<uint32_t, uint32_t>> kTileSizes = { | ||||||
| std::make_pair(32, 64), | ||||||
| std::make_pair(16, 64), | ||||||
|
Check warning on line 27 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc
|
||||||
| }; | ||||||
|
|
||||||
| for (const auto& tile_pair : kTileSizes) { | ||||||
| const uint32_t tile_m = tile_pair.first; | ||||||
| const uint32_t tile_n = tile_pair.second; | ||||||
|
|
||||||
| const uint32_t dispatch_m = ceil_div(im2col_m, tile_m); | ||||||
| const uint32_t dispatch_n = ceil_div(im2col_n, tile_n); | ||||||
| const uint32_t dispatch = dispatch_m * dispatch_n; | ||||||
|
|
||||||
| if (dispatch >= 128) { | ||||||
| return tile_pair; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // If none of the tile sizes meet the dispatch >=128 requirement, | ||||||
| return kTileSizes.back(); | ||||||
| } | ||||||
|
|
||||||
| // Add support for more devices. | ||||||
| bool IsDeviceSupported(ComputeContext& context) { | ||||||
| const wgpu::AdapterInfo& adapter_info = context.AdapterInfo(); | ||||||
|
|
||||||
| if (adapter_info.vendor == std::string_view("intel")) { | ||||||
| if (adapter_info.architecture == std::string_view("xe-2lpg")) { | ||||||
| return true; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| } // namespace | ||||||
|
|
||||||
| Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||||||
| const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||||||
| const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||||||
|
|
||||||
| return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template", | ||||||
| WGSL_TEMPLATE_VARIABLE(output, output), | ||||||
| WGSL_TEMPLATE_VARIABLE(src, src)); | ||||||
| } | ||||||
|
|
||||||
| Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||||||
| const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||||||
| const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||||||
| if (has_bias_) { | ||||||
| shader.AddInput("bias", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||||||
| } | ||||||
| const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||||||
|
|
||||||
| ORT_ENFORCE(tile_m_ == 16 || tile_m_ == 32, "tile_m must be 16 or 32."); | ||||||
| ORT_ENFORCE(tile_n_ == 64, "tile_n must be 64."); | ||||||
|
|
||||||
| return WGSL_TEMPLATE_APPLY(shader, "nn/im2col_matmul.wgsl.template", | ||||||
| WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), | ||||||
| WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_), | ||||||
| WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_), | ||||||
| WGSL_TEMPLATE_PARAMETER(use_subgroup, use_subgroup_), | ||||||
| WGSL_TEMPLATE_VARIABLE(output, output), | ||||||
| WGSL_TEMPLATE_VARIABLE(src, src), | ||||||
| WGSL_TEMPLATE_VARIABLE(weight, weight)); | ||||||
| } | ||||||
|
|
||||||
| Status ApplyIm2ColMatMulProgram(ComputeContext& context, | ||||||
| bool is_channels_last, | ||||||
| const std::vector<uint32_t>& dilations, | ||||||
| const std::vector<uint32_t>& pads, | ||||||
| const std::vector<uint32_t>& strides, | ||||||
| Tensor* output) { | ||||||
| const auto* input = context.Input<Tensor>(0); | ||||||
| const auto* kernel = context.Input<Tensor>(1); | ||||||
| const bool has_bias = context.InputCount() > 2; | ||||||
| const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr; | ||||||
|
|
||||||
| // Transpose OIHW Weight to OHWI | ||||||
| TensorShape kernel_shape = kernel->Shape(); | ||||||
| const uint32_t channel_output = onnxruntime::narrow<uint32_t>(kernel_shape[0]); | ||||||
| const uint32_t channel_input = onnxruntime::narrow<uint32_t>(kernel_shape[1]); | ||||||
| const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(kernel_shape[2]); | ||||||
| const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(kernel_shape[3]); | ||||||
|
|
||||||
| TensorShape nhwc_kernel_shape{channel_output, kernel_height, kernel_width, channel_input}; | ||||||
| Tensor nhwc_kernel = context.CreateGPUTensor(kernel->DataType(), nhwc_kernel_shape); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| OIHW2OHWIProgram transpose_program{}; | ||||||
| transpose_program.SetWorkgroupSize(64); | ||||||
|
|
||||||
| const uint32_t Ci_tiles = ceil_div(channel_input, 64u); | ||||||
| transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles); | ||||||
|
|
||||||
| transpose_program.AddInput({kernel, | ||||||
| ProgramTensorMetadataDependency::TypeAndRank}); | ||||||
| transpose_program.AddOutput({&nhwc_kernel, | ||||||
| ProgramTensorMetadataDependency::TypeAndRank}); | ||||||
| transpose_program.AddUniformVariables({{channel_output}, | ||||||
| {channel_input}, | ||||||
| {kernel_height}, | ||||||
| {kernel_width}, | ||||||
| {Ci_tiles}, | ||||||
| {ceil_div(kernel_height * kernel_height, 4u)}}); | ||||||
| ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program)); | ||||||
|
|
||||||
| // im2col-matmul | ||||||
| const TensorShape input_shape = input->Shape(); | ||||||
| const TensorShape output_shape = output->Shape(); | ||||||
|
|
||||||
| const uint32_t batch = onnxruntime::narrow<uint32_t>(input_shape[0]); | ||||||
| const uint32_t input_height = onnxruntime::narrow<uint32_t>(input_shape[is_channels_last ? 1 : 2]); | ||||||
| const uint32_t input_width = onnxruntime::narrow<uint32_t>(input_shape[is_channels_last ? 2 : 3]); | ||||||
| const uint32_t output_height = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 1 : 2]); | ||||||
| const uint32_t output_width = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 2 : 3]); | ||||||
|
|
||||||
| const uint32_t im2col_m = output_height * output_width; | ||||||
| const uint32_t im2col_k = kernel_height * kernel_width * channel_input; | ||||||
| const uint32_t im2col_n = channel_output; | ||||||
|
|
||||||
| const auto [tile_m, tile_n] = ChooseTileSize(im2col_m, im2col_n); | ||||||
| const uint32_t workgroup_size = tile_n; | ||||||
| const bool use_subgroup = true; | ||||||
| Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, use_subgroup}; | ||||||
| im2col_mm_program.SetWorkgroupSize(workgroup_size); | ||||||
|
|
||||||
| const uint32_t M_tiles = ceil_div(im2col_m, tile_m); | ||||||
| const uint32_t N_tiles = ceil_div(im2col_n, tile_n); | ||||||
| im2col_mm_program.SetDispatchGroupSize(M_tiles, N_tiles, batch); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about enhancing the current TransposeProgram with shared path instead of adding a new one? |
||||||
|
|
||||||
| im2col_mm_program.AddInput({input, | ||||||
| ProgramTensorMetadataDependency::TypeAndRank, | ||||||
| 4}); | ||||||
| im2col_mm_program.AddInput({&nhwc_kernel, | ||||||
| ProgramTensorMetadataDependency::TypeAndRank, | ||||||
| 4}); | ||||||
| if (has_bias) { | ||||||
| im2col_mm_program.AddInput({bias, | ||||||
| ProgramTensorMetadataDependency::TypeAndRank}); | ||||||
| } | ||||||
| im2col_mm_program.AddOutput({output, | ||||||
| ProgramTensorMetadataDependency::TypeAndRank}); | ||||||
| im2col_mm_program.AddUniformVariables({{batch}, | ||||||
| {input_height}, | ||||||
| {input_width}, | ||||||
| {channel_input}, | ||||||
| {kernel_height}, | ||||||
| {kernel_width}, | ||||||
| {output_height}, | ||||||
| {output_width}, | ||||||
| {im2col_m}, | ||||||
| {im2col_k}, | ||||||
| {im2col_n}, | ||||||
| {M_tiles}, | ||||||
| {N_tiles}, | ||||||
| {ceil_div(ceil_div(im2col_k, 4u), 4u)}, | ||||||
| {dilations}, | ||||||
| {pads}, | ||||||
| {strides}}); | ||||||
| im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, use_subgroup); | ||||||
|
|
||||||
| return context.RunProgram(im2col_mm_program); | ||||||
| } | ||||||
|
|
||||||
| bool CanApplyIm2ColMatMulProgram(ComputeContext& context, | ||||||
| const bool is_channels_last, | ||||||
| const ActivationKind activation_kind, | ||||||
| const TensorShape kernel_shape, | ||||||
| const AutoPadType auto_pad, | ||||||
| const uint32_t group) { | ||||||
| if (!IsDeviceSupported(context)) { | ||||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| // TODO: Support !is_channels_last | ||||||
|
Check warning on line 198 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc
|
||||||
| // TODO: Support fuse | ||||||
|
Check warning on line 199 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc
|
||||||
| // TODO: Support auto pad | ||||||
|
Check warning on line 200 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc
|
||||||
| // TODO: Support group conv | ||||||
|
Check warning on line 201 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc
|
||||||
| if (!is_channels_last || activation_kind != ActivationKind::None || auto_pad != AutoPadType::NOTSET || group != 1) { | ||||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| // TODO: Support conv2d_1x1 | ||||||
|
Check warning on line 206 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc
|
||||||
| const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(kernel_shape[2]); | ||||||
| const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(kernel_shape[3]); | ||||||
| if (kernel_height == 1 && kernel_width == 1) { | ||||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| // TODO: Support channel input vec1 | ||||||
|
Check warning on line 213 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc
|
||||||
| const uint32_t channel_input = onnxruntime::narrow<uint32_t>(kernel_shape[1]); | ||||||
| if (channel_input % 4 != 0) { | ||||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| return true; | ||||||
| } | ||||||
|
|
||||||
| } // namespace webgpu | ||||||
| } // namespace onnxruntime | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "core/framework/tensor_shape.h" | ||
| #include "core/framework/tensor.h" | ||
| #include "core/framework/op_kernel.h" | ||
| #include "core/providers/cpu/nn/conv_attributes.h" | ||
| #include "core/providers/webgpu/program.h" | ||
| #include "core/providers/webgpu/webgpu_supported_types.h" | ||
| #include "core/providers/webgpu/shader_helper.h" | ||
| #include "core/providers/webgpu/webgpu_kernel.h" | ||
| #include "core/providers/webgpu/nn/fuse_utils.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace webgpu { | ||
|
|
||
| // Transpose OIHW Weight to OHWI | ||
| class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> { | ||
| public: | ||
| OIHW2OHWIProgram() : Program("OIHW2OHWI") {} | ||
|
|
||
| Status GenerateShaderCode(ShaderHelper& shader) const override; | ||
|
|
||
| WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( | ||
| {"O", ProgramUniformVariableDataType::Uint32}, | ||
| {"I", ProgramUniformVariableDataType::Uint32}, | ||
| {"H", ProgramUniformVariableDataType::Uint32}, | ||
| {"W", ProgramUniformVariableDataType::Uint32}, | ||
| {"Ci_tiles", ProgramUniformVariableDataType::Uint32}, | ||
| {"H_W_tiles", ProgramUniformVariableDataType::Uint32}); | ||
| }; | ||
|
|
||
| class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> { | ||
| public: | ||
| Im2ColMatMulProgram(bool has_bias, | ||
| uint32_t tile_m, | ||
| uint32_t tile_n, | ||
| bool use_subgroup) : Program("Im2ColMatMul"), | ||
| has_bias_(has_bias), | ||
| tile_m_(tile_m), | ||
| tile_n_(tile_n), | ||
| use_subgroup_(use_subgroup) {} | ||
|
|
||
| Status GenerateShaderCode(ShaderHelper& shader) const override; | ||
|
|
||
| WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( | ||
| {"batch", ProgramUniformVariableDataType::Uint32}, | ||
| {"src_h", ProgramUniformVariableDataType::Uint32}, | ||
| {"src_w", ProgramUniformVariableDataType::Uint32}, | ||
| {"channel_i", ProgramUniformVariableDataType::Uint32}, | ||
| {"kernel_h", ProgramUniformVariableDataType::Uint32}, | ||
| {"kernel_w", ProgramUniformVariableDataType::Uint32}, | ||
| {"output_h", ProgramUniformVariableDataType::Uint32}, | ||
| {"output_w", ProgramUniformVariableDataType::Uint32}, | ||
| {"im2col_m", ProgramUniformVariableDataType::Uint32}, | ||
| {"im2col_k", ProgramUniformVariableDataType::Uint32}, | ||
| {"im2col_n", ProgramUniformVariableDataType::Uint32}, | ||
| {"M_tiles", ProgramUniformVariableDataType::Uint32}, | ||
| {"N_tiles", ProgramUniformVariableDataType::Uint32}, | ||
| {"K_tiles", ProgramUniformVariableDataType::Uint32}, | ||
| {"dilations", ProgramUniformVariableDataType::Uint32}, | ||
| {"pads", ProgramUniformVariableDataType::Uint32}, | ||
| {"strides", ProgramUniformVariableDataType::Uint32}); | ||
|
|
||
| private: | ||
| bool has_bias_; | ||
|
|
||
| uint32_t tile_m_; | ||
| uint32_t tile_n_; | ||
| bool use_subgroup_; | ||
| }; | ||
|
|
||
| bool CanApplyIm2ColMatMulProgram(ComputeContext& context, | ||
| const bool is_channels_last, | ||
| const ActivationKind activation_kind, | ||
| const TensorShape kernel_shape, | ||
| const AutoPadType auto_pad, | ||
| const uint32_t group); | ||
|
|
||
| Status ApplyIm2ColMatMulProgram(ComputeContext& context, | ||
| const bool is_channels_last, | ||
| const std::vector<uint32_t>& dilations, | ||
| const std::vector<uint32_t>& pads, | ||
| const std::vector<uint32_t>& strides, | ||
|
Check warning on line 86 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.h
|
||
| Tensor* output); | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace onnxruntime | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.