-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[webgpu] Optimize Conv by im2col-matmul #26603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
daijh
wants to merge
7
commits into
microsoft:main
Choose a base branch
from
daijh:im2col-matmul
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
b1e5290
[webgpu] im2col matmul
daijh 2f7487e
Update
daijh 4efeff4
Update
daijh e6d48e4
Fix comment
daijh 6a4bede
Distinguish between `weight` and `kernel` in naming.
daijh 07073e1
Do not support conv_1d; Can not use `bias_element_t` if no bias.
daijh 749c6e5
Address comment of reviewdog
daijh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,230 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
| #include <string> | ||
| #include <utility> | ||
| #include <vector> | ||
|
|
||
| #include "core/providers/webgpu/webgpu_utils.h" | ||
| #include "core/providers/webgpu/nn/im2col_matmul.h" | ||
| #include "core/providers/webgpu/nn/activation_util.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace webgpu { | ||
|
|
||
| namespace { | ||
|
|
||
| // TODO: move to common header. | ||
| template <typename T> | ||
| inline T ceil_div(T numerator, T denominator) { | ||
| return (numerator + denominator - 1) / denominator; | ||
| } | ||
|
|
||
| // Chooses the optimal tile size (M, N) for the im2col operation. | ||
| // This tile size is performance-tuned and varies depending on the target device. | ||
| std::pair<uint32_t, uint32_t> ChooseTileSize(uint32_t im2col_m, uint32_t im2col_n) { | ||
| // Define a list of preferred (tile_m, tile_n) pairs in descending order of preference. | ||
| const std::vector<std::pair<uint32_t, uint32_t>> kTileSizes = { | ||
| std::make_pair(32, 64), | ||
| std::make_pair(16, 64), | ||
| }; | ||
|
|
||
| for (const auto& tile_pair : kTileSizes) { | ||
| const uint32_t tile_m = tile_pair.first; | ||
| const uint32_t tile_n = tile_pair.second; | ||
|
|
||
| const uint32_t dispatch_m = ceil_div(im2col_m, tile_m); | ||
| const uint32_t dispatch_n = ceil_div(im2col_n, tile_n); | ||
| const uint32_t dispatch = dispatch_m * dispatch_n; | ||
|
|
||
| if (dispatch >= 128) { | ||
| return tile_pair; | ||
| } | ||
| } | ||
|
|
||
| // If none of the tile sizes meet the dispatch >=128 requirement, | ||
| return kTileSizes.back(); | ||
| } | ||
|
|
||
| // Add support for more devices. | ||
| bool IsDeviceSupported(ComputeContext& context) { | ||
| const wgpu::AdapterInfo& adapter_info = context.AdapterInfo(); | ||
|
|
||
| if (adapter_info.vendor == std::string_view("intel")) { | ||
| if (adapter_info.architecture == std::string_view("xe-2lpg")) { | ||
| return true; | ||
| } | ||
| } | ||
|
|
||
| return false; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
| const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
| const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
|
|
||
| return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template", | ||
| WGSL_TEMPLATE_VARIABLE(output, output), | ||
| WGSL_TEMPLATE_VARIABLE(src, src)); | ||
| } | ||
|
|
||
| Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
| const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
| const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
| if (has_bias_) { | ||
| shader.AddInput("bias", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
| } | ||
| const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
|
|
||
| ORT_ENFORCE(tile_m_ == 16 || tile_m_ == 32, "tile_m must be 16 or 32."); | ||
| ORT_ENFORCE(tile_n_ == 64, "tile_n must be 64."); | ||
|
|
||
| return WGSL_TEMPLATE_APPLY(shader, "nn/im2col_matmul.wgsl.template", | ||
| WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), | ||
| WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_), | ||
| WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_), | ||
| WGSL_TEMPLATE_PARAMETER(use_subgroup, use_subgroup_), | ||
| WGSL_TEMPLATE_VARIABLE(output, output), | ||
| WGSL_TEMPLATE_VARIABLE(src, src), | ||
| WGSL_TEMPLATE_VARIABLE(weight, weight)); | ||
| } | ||
|
|
||
| Status ApplyIm2ColMatMulProgram(ComputeContext& context, | ||
| bool is_channels_last, | ||
| const std::vector<uint32_t>& dilations, | ||
| const std::vector<uint32_t>& pads, | ||
| const std::vector<uint32_t>& strides, | ||
| Tensor* output) { | ||
| const auto* input = context.Input<Tensor>(0); | ||
| const auto* weight = context.Input<Tensor>(1); | ||
| const bool has_bias = context.InputCount() > 2; | ||
| const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr; | ||
|
|
||
| // Transpose OIHW Weight to OHWI | ||
| TensorShape weight_shape = weight->Shape(); | ||
| const uint32_t channel_output = onnxruntime::narrow<uint32_t>(weight_shape[0]); | ||
| const uint32_t channel_input = onnxruntime::narrow<uint32_t>(weight_shape[1]); | ||
| const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(weight_shape[2]); | ||
| const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(weight_shape[3]); | ||
|
|
||
| TensorShape ohwi_weight_shape{channel_output, kernel_height, kernel_width, channel_input}; | ||
| Tensor ohwi_weight = context.CreateGPUTensor(weight->DataType(), ohwi_weight_shape); | ||
| OIHW2OHWIProgram transpose_program{}; | ||
| transpose_program.SetWorkgroupSize(64); | ||
|
|
||
| const uint32_t Ci_tiles = ceil_div(channel_input, 64u); | ||
| transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles); | ||
|
|
||
| transpose_program.AddInput({weight, | ||
| ProgramTensorMetadataDependency::TypeAndRank}); | ||
| transpose_program.AddOutput({&ohwi_weight, | ||
| ProgramTensorMetadataDependency::TypeAndRank}); | ||
| transpose_program.AddUniformVariables({{channel_output}, | ||
| {channel_input}, | ||
| {kernel_height}, | ||
| {kernel_width}, | ||
| {Ci_tiles}, | ||
| {ceil_div(kernel_height * kernel_height, 4u)}}); | ||
| ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program)); | ||
|
|
||
| // im2col-matmul | ||
| const TensorShape input_shape = input->Shape(); | ||
| const TensorShape output_shape = output->Shape(); | ||
|
|
||
| const uint32_t batch = onnxruntime::narrow<uint32_t>(input_shape[0]); | ||
| const uint32_t input_height = onnxruntime::narrow<uint32_t>(input_shape[is_channels_last ? 1 : 2]); | ||
| const uint32_t input_width = onnxruntime::narrow<uint32_t>(input_shape[is_channels_last ? 2 : 3]); | ||
| const uint32_t output_height = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 1 : 2]); | ||
| const uint32_t output_width = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 2 : 3]); | ||
|
|
||
| const uint32_t im2col_m = output_height * output_width; | ||
| const uint32_t im2col_k = kernel_height * kernel_width * channel_input; | ||
| const uint32_t im2col_n = channel_output; | ||
|
|
||
| const auto [tile_m, tile_n] = ChooseTileSize(im2col_m, im2col_n); | ||
| const uint32_t workgroup_size = tile_n; | ||
|
|
||
| // Check the device's subgroup size before shader compilation to avoid potential performance penalties | ||
| // associated with conditional checks in the shader runtime. | ||
| // Ensure the subgroup size must be greater than or equal to `tile_m` to safely enable `use_subgroup`. | ||
| // If this condition is not met, the feature must be disabled. | ||
| const bool use_subgroup = true; | ||
| Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, use_subgroup}; | ||
| im2col_mm_program.SetWorkgroupSize(workgroup_size); | ||
|
|
||
| const uint32_t M_tiles = ceil_div(im2col_m, tile_m); | ||
| const uint32_t N_tiles = ceil_div(im2col_n, tile_n); | ||
| im2col_mm_program.SetDispatchGroupSize(M_tiles, N_tiles, batch); | ||
|
|
||
| im2col_mm_program.AddInput({input, | ||
| ProgramTensorMetadataDependency::TypeAndRank, | ||
| 4}); | ||
| im2col_mm_program.AddInput({&ohwi_weight, | ||
| ProgramTensorMetadataDependency::TypeAndRank, | ||
| 4}); | ||
| if (has_bias) { | ||
| im2col_mm_program.AddInput({bias, | ||
| ProgramTensorMetadataDependency::TypeAndRank}); | ||
| } | ||
| im2col_mm_program.AddOutput({output, | ||
| ProgramTensorMetadataDependency::TypeAndRank}); | ||
| im2col_mm_program.AddUniformVariables({{batch}, | ||
| {input_height}, | ||
| {input_width}, | ||
| {channel_input}, | ||
| {kernel_height}, | ||
| {kernel_width}, | ||
| {output_height}, | ||
| {output_width}, | ||
| {im2col_m}, | ||
| {im2col_k}, | ||
| {im2col_n}, | ||
| {M_tiles}, | ||
| {N_tiles}, | ||
| {ceil_div(ceil_div(im2col_k, 4u), 4u)}, | ||
| {dilations}, | ||
| {pads}, | ||
| {strides}}); | ||
| im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, use_subgroup); | ||
|
|
||
| return context.RunProgram(im2col_mm_program); | ||
| } | ||
|
|
||
| bool CanApplyIm2ColMatMulProgram(ComputeContext& context, | ||
| const bool is_channels_last, | ||
| const ActivationKind activation_kind, | ||
| const TensorShape weight_shape, | ||
| const AutoPadType auto_pad, | ||
| const uint32_t group) { | ||
| if (!IsDeviceSupported(context)) { | ||
| return false; | ||
| } | ||
|
|
||
| // TODO: Support !is_channels_last | ||
| // TODO: Support fuse | ||
| // TODO: Support auto pad | ||
| // TODO: Support group conv | ||
| if (!is_channels_last || activation_kind != ActivationKind::None || auto_pad != AutoPadType::NOTSET || group != 1) { | ||
| return false; | ||
| } | ||
|
|
||
| // TODO: Support conv1d | ||
| // TODO: Support conv2d_1x1 | ||
| const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(weight_shape[2]); | ||
| const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(weight_shape[3]); | ||
| if (kernel_height == 1 || kernel_width == 1) { | ||
| return false; | ||
| } | ||
|
|
||
| // TODO: Support channel input vec1 | ||
| const uint32_t channel_input = onnxruntime::narrow<uint32_t>(weight_shape[1]); | ||
| if (channel_input % 4 != 0) { | ||
| return false; | ||
| } | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <vector> | ||
|
|
||
| #include "core/framework/tensor_shape.h" | ||
| #include "core/framework/tensor.h" | ||
| #include "core/framework/op_kernel.h" | ||
| #include "core/providers/cpu/nn/conv_attributes.h" | ||
| #include "core/providers/webgpu/program.h" | ||
| #include "core/providers/webgpu/webgpu_supported_types.h" | ||
| #include "core/providers/webgpu/shader_helper.h" | ||
| #include "core/providers/webgpu/webgpu_kernel.h" | ||
| #include "core/providers/webgpu/nn/fuse_utils.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace webgpu { | ||
|
|
||
| // Transpose OIHW Weight to OHWI | ||
| class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> { | ||
| public: | ||
| OIHW2OHWIProgram() : Program("OIHW2OHWI") {} | ||
|
|
||
| Status GenerateShaderCode(ShaderHelper& shader) const override; | ||
|
|
||
| WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( | ||
| {"O", ProgramUniformVariableDataType::Uint32}, | ||
| {"I", ProgramUniformVariableDataType::Uint32}, | ||
| {"H", ProgramUniformVariableDataType::Uint32}, | ||
| {"W", ProgramUniformVariableDataType::Uint32}, | ||
| {"Ci_tiles", ProgramUniformVariableDataType::Uint32}, | ||
| {"H_W_tiles", ProgramUniformVariableDataType::Uint32}); | ||
| }; | ||
|
|
||
| class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> { | ||
| public: | ||
| Im2ColMatMulProgram(bool has_bias, | ||
| uint32_t tile_m, | ||
| uint32_t tile_n, | ||
| bool use_subgroup) : Program("Im2ColMatMul"), | ||
| has_bias_(has_bias), | ||
| tile_m_(tile_m), | ||
| tile_n_(tile_n), | ||
| use_subgroup_(use_subgroup) {} | ||
|
|
||
| Status GenerateShaderCode(ShaderHelper& shader) const override; | ||
|
|
||
| WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( | ||
| {"batch", ProgramUniformVariableDataType::Uint32}, | ||
| {"src_h", ProgramUniformVariableDataType::Uint32}, | ||
| {"src_w", ProgramUniformVariableDataType::Uint32}, | ||
| {"channel_i", ProgramUniformVariableDataType::Uint32}, | ||
| {"kernel_h", ProgramUniformVariableDataType::Uint32}, | ||
| {"kernel_w", ProgramUniformVariableDataType::Uint32}, | ||
| {"output_h", ProgramUniformVariableDataType::Uint32}, | ||
| {"output_w", ProgramUniformVariableDataType::Uint32}, | ||
| {"im2col_m", ProgramUniformVariableDataType::Uint32}, | ||
| {"im2col_k", ProgramUniformVariableDataType::Uint32}, | ||
| {"im2col_n", ProgramUniformVariableDataType::Uint32}, | ||
| {"M_tiles", ProgramUniformVariableDataType::Uint32}, | ||
| {"N_tiles", ProgramUniformVariableDataType::Uint32}, | ||
| {"K_tiles", ProgramUniformVariableDataType::Uint32}, | ||
| {"dilations", ProgramUniformVariableDataType::Uint32}, | ||
| {"pads", ProgramUniformVariableDataType::Uint32}, | ||
| {"strides", ProgramUniformVariableDataType::Uint32}); | ||
|
|
||
| private: | ||
| bool has_bias_; | ||
|
|
||
| uint32_t tile_m_; | ||
| uint32_t tile_n_; | ||
| bool use_subgroup_; | ||
| }; | ||
|
|
||
| bool CanApplyIm2ColMatMulProgram(ComputeContext& context, | ||
| const bool is_channels_last, | ||
| const ActivationKind activation_kind, | ||
| const TensorShape kernel_shape, | ||
| const AutoPadType auto_pad, | ||
| const uint32_t group); | ||
|
|
||
| Status ApplyIm2ColMatMulProgram(ComputeContext& context, | ||
| const bool is_channels_last, | ||
| const std::vector<uint32_t>& dilations, | ||
| const std::vector<uint32_t>& pads, | ||
| const std::vector<uint32_t>& strides, | ||
| Tensor* output); | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace onnxruntime |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about enhancing the current TransposeProgram with shared path instead of adding a new one?
You are doing transpose from perm [0, 1, 2, 3] to perm [0, 2, 3, 1]. It equals that we are transposing from [o, i, hw] to [o, hw, i]. You can simply extend the DoTranspose with shared path to support any shape that only transpose the last two dimensions and keep the previous dimensions unchanged. Currently, the shared path only supports 2d transpose from new shape from perm [0, 1] to new shape with perm [1, 0]. We can extend it to transpose from [0, 1, 2] to [0, 2, 1] if the transpose meets the requirement that only transpose the last two dimensions by reshape it into 3d tensor [d0 * d1*...*dn-3, dn-2, dn-1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood.
I intend to improve the current Transpose path discussed in the previous PR #26501.
Could I handle this as a separate task for a following PR?