Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions onnxruntime/core/providers/webgpu/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/webgpu/nn/conv.h"
#include "core/providers/webgpu/nn/conv2d_mm.h"
#include "core/providers/webgpu/nn/im2col_matmul.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/tensor/transpose.h"
Expand Down Expand Up @@ -99,10 +100,34 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
modified_input_output_shapes.push_back(bias->Shape());
}
modified_input_output_shapes.push_back(TensorShape(output_shape_vector));

const auto input_height = input_shape[is_channels_last ? 1 : 2];
const auto input_width = input_shape[is_channels_last ? 2 : 3];
const auto input_channels = input_shape[is_channels_last ? 3 : 1];
const auto kernel_height = kernel_shape[2];
const auto kernel_width = kernel_shape[3];
const auto output_height = output_shape_vector[is_channels_last ? 1 : 2];
const auto output_width = output_shape_vector[is_channels_last ? 2 : 3];

uint32_t auto_pad_adjust = conv_attrs_.auto_pad == AutoPadType::SAME_LOWER ? 1 : 0;
auto pad0 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[0] : (pads[0] + pads[2] + auto_pad_adjust) / 2;
auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2;
std::vector<uint32_t> updated_pads{pad0, pad1};

if (CanApplyIm2ColMatMulProgram(context,
is_channels_last,
activation_.activation_kind_,
kernel_shape,
conv_attrs_.auto_pad,
onnxruntime::narrow<uint32_t>(conv_attrs_.group))) {
return ApplyIm2ColMatMulProgram(context,
is_channels_last,
dilations,
pads,
strides,
output);
}

if (conv_attrs_.group > 1) {
Tensor transposed_kernel;
if (is_channels_last) {
Expand All @@ -128,13 +153,6 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
}
return context.RunProgram(program);
}
const auto input_height = input_shape[is_channels_last ? 1 : 2];
const auto input_width = input_shape[is_channels_last ? 2 : 3];
const auto input_channels = input_shape[is_channels_last ? 3 : 1];
const auto kernel_height = kernel_shape[2];
const auto kernel_width = kernel_shape[3];
const auto output_height = output_shape_vector[is_channels_last ? 1 : 2];
const auto output_width = output_shape_vector[is_channels_last ? 2 : 3];

const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0;
if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) {
Expand Down
230 changes: 230 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <string>
#include <utility>
#include <vector>

#include "core/providers/webgpu/webgpu_utils.h"
#include "core/providers/webgpu/nn/im2col_matmul.h"
#include "core/providers/webgpu/nn/activation_util.h"

namespace onnxruntime {
namespace webgpu {

namespace {

// TODO: move to common header.
template <typename T>
inline T ceil_div(T numerator, T denominator) {
return (numerator + denominator - 1) / denominator;
}

// Chooses the optimal tile size (M, N) for the im2col operation.
// This tile size is performance-tuned and varies depending on the target device.
std::pair<uint32_t, uint32_t> ChooseTileSize(uint32_t im2col_m, uint32_t im2col_n) {
// Define a list of preferred (tile_m, tile_n) pairs in descending order of preference.
const std::vector<std::pair<uint32_t, uint32_t>> kTileSizes = {
std::make_pair(32, 64),
std::make_pair(16, 64),
};

for (const auto& tile_pair : kTileSizes) {
const uint32_t tile_m = tile_pair.first;
const uint32_t tile_n = tile_pair.second;

const uint32_t dispatch_m = ceil_div(im2col_m, tile_m);
const uint32_t dispatch_n = ceil_div(im2col_n, tile_n);
const uint32_t dispatch = dispatch_m * dispatch_n;

if (dispatch >= 128) {
return tile_pair;
}
}

// If none of the tile sizes meet the dispatch >=128 requirement,
return kTileSizes.back();
}

// Add support for more devices.
bool IsDeviceSupported(ComputeContext& context) {
const wgpu::AdapterInfo& adapter_info = context.AdapterInfo();

if (adapter_info.vendor == std::string_view("intel")) {
if (adapter_info.architecture == std::string_view("xe-2lpg")) {
return true;
}
}

return false;
}

} // namespace

Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template",
WGSL_TEMPLATE_VARIABLE(output, output),
WGSL_TEMPLATE_VARIABLE(src, src));
}

Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
if (has_bias_) {
shader.AddInput("bias", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
}
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

ORT_ENFORCE(tile_m_ == 16 || tile_m_ == 32, "tile_m must be 16 or 32.");
ORT_ENFORCE(tile_n_ == 64, "tile_n must be 64.");

return WGSL_TEMPLATE_APPLY(shader, "nn/im2col_matmul.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_),
WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_),
WGSL_TEMPLATE_PARAMETER(use_subgroup, use_subgroup_),
WGSL_TEMPLATE_VARIABLE(output, output),
WGSL_TEMPLATE_VARIABLE(src, src),
WGSL_TEMPLATE_VARIABLE(weight, weight));
}

Status ApplyIm2ColMatMulProgram(ComputeContext& context,
bool is_channels_last,
const std::vector<uint32_t>& dilations,
const std::vector<uint32_t>& pads,
const std::vector<uint32_t>& strides,
Tensor* output) {
const auto* input = context.Input<Tensor>(0);
const auto* weight = context.Input<Tensor>(1);
const bool has_bias = context.InputCount() > 2;
const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr;

// Transpose OIHW Weight to OHWI
TensorShape weight_shape = weight->Shape();
const uint32_t channel_output = onnxruntime::narrow<uint32_t>(weight_shape[0]);
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(weight_shape[1]);
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(weight_shape[2]);
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(weight_shape[3]);

TensorShape ohwi_weight_shape{channel_output, kernel_height, kernel_width, channel_input};
Tensor ohwi_weight = context.CreateGPUTensor(weight->DataType(), ohwi_weight_shape);
OIHW2OHWIProgram transpose_program{};
transpose_program.SetWorkgroupSize(64);

const uint32_t Ci_tiles = ceil_div(channel_input, 64u);
transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles);

transpose_program.AddInput({weight,
ProgramTensorMetadataDependency::TypeAndRank});
transpose_program.AddOutput({&ohwi_weight,
ProgramTensorMetadataDependency::TypeAndRank});
transpose_program.AddUniformVariables({{channel_output},
{channel_input},
{kernel_height},
{kernel_width},
{Ci_tiles},
{ceil_div(kernel_height * kernel_height, 4u)}});
ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program));

// im2col-matmul
const TensorShape input_shape = input->Shape();
const TensorShape output_shape = output->Shape();

const uint32_t batch = onnxruntime::narrow<uint32_t>(input_shape[0]);
const uint32_t input_height = onnxruntime::narrow<uint32_t>(input_shape[is_channels_last ? 1 : 2]);
const uint32_t input_width = onnxruntime::narrow<uint32_t>(input_shape[is_channels_last ? 2 : 3]);
const uint32_t output_height = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 1 : 2]);
const uint32_t output_width = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 2 : 3]);

const uint32_t im2col_m = output_height * output_width;
const uint32_t im2col_k = kernel_height * kernel_width * channel_input;
const uint32_t im2col_n = channel_output;

const auto [tile_m, tile_n] = ChooseTileSize(im2col_m, im2col_n);
const uint32_t workgroup_size = tile_n;

// Check the device's subgroup size before shader compilation to avoid potential performance penalties
// associated with conditional checks in the shader runtime.
// Ensure the subgroup size must be greater than or equal to `tile_m` to safely enable `use_subgroup`.
// If this condition is not met, the feature must be disabled.
const bool use_subgroup = true;
Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, use_subgroup};
im2col_mm_program.SetWorkgroupSize(workgroup_size);

const uint32_t M_tiles = ceil_div(im2col_m, tile_m);
const uint32_t N_tiles = ceil_div(im2col_n, tile_n);
im2col_mm_program.SetDispatchGroupSize(M_tiles, N_tiles, batch);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about enhancing the current TransposeProgram with shared path instead of adding a new one?
You are doing transpose from perm [0, 1, 2, 3] to perm [0, 2, 3, 1]. It equals that we are transposing from [o, i, hw] to [o, hw, i]. You can simply extend the DoTranspose with shared path to support any shape that only transpose the last two dimensions and keep the previous dimensions unchanged. Currently, the shared path only supports 2d transpose from new shape from perm [0, 1] to new shape with perm [1, 0]. We can extend it to transpose from [0, 1, 2] to [0, 2, 1] if the transpose meets the requirement that only transpose the last two dimensions by reshape it into 3d tensor [d0 * d1*...*dn-3, dn-2, dn-1]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood.
I intend to improve the current Transpose path discussed in the previous PR #26501.
Could I handle this as a separate task for a following PR?


im2col_mm_program.AddInput({input,
ProgramTensorMetadataDependency::TypeAndRank,
4});
im2col_mm_program.AddInput({&ohwi_weight,
ProgramTensorMetadataDependency::TypeAndRank,
4});
if (has_bias) {
im2col_mm_program.AddInput({bias,
ProgramTensorMetadataDependency::TypeAndRank});
}
im2col_mm_program.AddOutput({output,
ProgramTensorMetadataDependency::TypeAndRank});
im2col_mm_program.AddUniformVariables({{batch},
{input_height},
{input_width},
{channel_input},
{kernel_height},
{kernel_width},
{output_height},
{output_width},
{im2col_m},
{im2col_k},
{im2col_n},
{M_tiles},
{N_tiles},
{ceil_div(ceil_div(im2col_k, 4u), 4u)},
{dilations},
{pads},
{strides}});
im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, use_subgroup);

return context.RunProgram(im2col_mm_program);
}

bool CanApplyIm2ColMatMulProgram(ComputeContext& context,
const bool is_channels_last,
const ActivationKind activation_kind,
const TensorShape weight_shape,
const AutoPadType auto_pad,
const uint32_t group) {
if (!IsDeviceSupported(context)) {
return false;
}

// TODO: Support !is_channels_last
// TODO: Support fuse
// TODO: Support auto pad
// TODO: Support group conv
if (!is_channels_last || activation_kind != ActivationKind::None || auto_pad != AutoPadType::NOTSET || group != 1) {
return false;
}

// TODO: Support conv1d
// TODO: Support conv2d_1x1
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(weight_shape[2]);
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(weight_shape[3]);
if (kernel_height == 1 || kernel_width == 1) {
return false;
}

// TODO: Support channel input vec1
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(weight_shape[1]);
if (channel_input % 4 != 0) {
return false;
}

return true;
}

} // namespace webgpu
} // namespace onnxruntime
92 changes: 92 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/im2col_matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <vector>

#include "core/framework/tensor_shape.h"
#include "core/framework/tensor.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/nn/conv_attributes.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/nn/fuse_utils.h"

namespace onnxruntime {
namespace webgpu {

// Transpose OIHW Weight to OHWI
class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
public:
OIHW2OHWIProgram() : Program("OIHW2OHWI") {}

Status GenerateShaderCode(ShaderHelper& shader) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"O", ProgramUniformVariableDataType::Uint32},
{"I", ProgramUniformVariableDataType::Uint32},
{"H", ProgramUniformVariableDataType::Uint32},
{"W", ProgramUniformVariableDataType::Uint32},
{"Ci_tiles", ProgramUniformVariableDataType::Uint32},
{"H_W_tiles", ProgramUniformVariableDataType::Uint32});
};

class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> {
public:
Im2ColMatMulProgram(bool has_bias,
uint32_t tile_m,
uint32_t tile_n,
bool use_subgroup) : Program("Im2ColMatMul"),
has_bias_(has_bias),
tile_m_(tile_m),
tile_n_(tile_n),
use_subgroup_(use_subgroup) {}

Status GenerateShaderCode(ShaderHelper& shader) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"batch", ProgramUniformVariableDataType::Uint32},
{"src_h", ProgramUniformVariableDataType::Uint32},
{"src_w", ProgramUniformVariableDataType::Uint32},
{"channel_i", ProgramUniformVariableDataType::Uint32},
{"kernel_h", ProgramUniformVariableDataType::Uint32},
{"kernel_w", ProgramUniformVariableDataType::Uint32},
{"output_h", ProgramUniformVariableDataType::Uint32},
{"output_w", ProgramUniformVariableDataType::Uint32},
{"im2col_m", ProgramUniformVariableDataType::Uint32},
{"im2col_k", ProgramUniformVariableDataType::Uint32},
{"im2col_n", ProgramUniformVariableDataType::Uint32},
{"M_tiles", ProgramUniformVariableDataType::Uint32},
{"N_tiles", ProgramUniformVariableDataType::Uint32},
{"K_tiles", ProgramUniformVariableDataType::Uint32},
{"dilations", ProgramUniformVariableDataType::Uint32},
{"pads", ProgramUniformVariableDataType::Uint32},
{"strides", ProgramUniformVariableDataType::Uint32});

private:
bool has_bias_;

uint32_t tile_m_;
uint32_t tile_n_;
bool use_subgroup_;
};

bool CanApplyIm2ColMatMulProgram(ComputeContext& context,
const bool is_channels_last,
const ActivationKind activation_kind,
const TensorShape kernel_shape,
const AutoPadType auto_pad,
const uint32_t group);

Status ApplyIm2ColMatMulProgram(ComputeContext& context,
const bool is_channels_last,
const std::vector<uint32_t>& dilations,
const std::vector<uint32_t>& pads,
const std::vector<uint32_t>& strides,
Tensor* output);

} // namespace webgpu
} // namespace onnxruntime
Loading