Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions onnxruntime/core/providers/webgpu/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/webgpu/nn/conv.h"
#include "core/providers/webgpu/nn/conv2d_mm.h"
#include "core/providers/webgpu/nn/im2col_matmul.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/tensor/transpose.h"
Expand Down Expand Up @@ -99,10 +100,34 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
modified_input_output_shapes.push_back(bias->Shape());
}
modified_input_output_shapes.push_back(TensorShape(output_shape_vector));

const auto input_height = input_shape[is_channels_last ? 1 : 2];
const auto input_width = input_shape[is_channels_last ? 2 : 3];
const auto input_channels = input_shape[is_channels_last ? 3 : 1];
const auto kernel_height = kernel_shape[2];
const auto kernel_width = kernel_shape[3];
const auto output_height = output_shape_vector[is_channels_last ? 1 : 2];
const auto output_width = output_shape_vector[is_channels_last ? 2 : 3];

uint32_t auto_pad_adjust = conv_attrs_.auto_pad == AutoPadType::SAME_LOWER ? 1 : 0;
auto pad0 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[0] : (pads[0] + pads[2] + auto_pad_adjust) / 2;
auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2;
std::vector<uint32_t> updated_pads{pad0, pad1};

if (CanApplyIm2ColMatMulProgram(context,
is_channels_last,
activation_.activation_kind_,
kernel_shape,
conv_attrs_.auto_pad,
onnxruntime::narrow<uint32_t>(conv_attrs_.group))) {
return ApplyIm2ColMatMulProgram(context,
is_channels_last,
dilations,
pads,
strides,
output);
}

if (conv_attrs_.group > 1) {
Tensor transposed_kernel;
if (is_channels_last) {
Expand All @@ -128,13 +153,6 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
}
return context.RunProgram(program);
}
const auto input_height = input_shape[is_channels_last ? 1 : 2];
const auto input_width = input_shape[is_channels_last ? 2 : 3];
const auto input_channels = input_shape[is_channels_last ? 3 : 1];
const auto kernel_height = kernel_shape[2];
const auto kernel_width = kernel_shape[3];
const auto output_height = output_shape_vector[is_channels_last ? 1 : 2];
const auto output_width = output_shape_vector[is_channels_last ? 2 : 3];

const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0;
if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) {
Expand Down
223 changes: 223 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <string>
#include <vector>

#include "core/providers/webgpu/webgpu_utils.h"
#include "core/providers/webgpu/nn/im2col_matmul.h"
#include "core/providers/webgpu/nn/activation_util.h"

namespace onnxruntime {
namespace webgpu {

namespace {

// TODO: move to common header.

Check warning on line 15 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc:15: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
template <typename T>
inline T ceil_div(T numerator, T denominator) {
return (numerator + denominator - 1) / denominator;
}

// Chooses the optimal tile size (M, N) for the im2col operation.
// This tile size is performance-tuned and varies depending on the target device.
std::pair<uint32_t, uint32_t> ChooseTileSize(uint32_t im2col_m, uint32_t im2col_n) {
// Define a list of preferred (tile_m, tile_n) pairs in descending order of preference.
const std::vector<std::pair<uint32_t, uint32_t>> kTileSizes = {
std::make_pair(32, 64),
std::make_pair(16, 64),

Check warning on line 27 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for make_pair [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc:27: Add #include <utility> for make_pair [build/include_what_you_use] [4]
};

for (const auto& tile_pair : kTileSizes) {
const uint32_t tile_m = tile_pair.first;
const uint32_t tile_n = tile_pair.second;

const uint32_t dispatch_m = ceil_div(im2col_m, tile_m);
const uint32_t dispatch_n = ceil_div(im2col_n, tile_n);
const uint32_t dispatch = dispatch_m * dispatch_n;

if (dispatch >= 128) {
return tile_pair;
}
}

// If none of the tile sizes meet the dispatch >=128 requirement,
return kTileSizes.back();
}

// Add support for more devices.
bool IsDeviceSupported(ComputeContext& context) {
const wgpu::AdapterInfo& adapter_info = context.AdapterInfo();

if (adapter_info.vendor == std::string_view("intel")) {
if (adapter_info.architecture == std::string_view("xe-2lpg")) {
return true;
}
}

return false;
}

} // namespace

Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template",
WGSL_TEMPLATE_VARIABLE(output, output),
WGSL_TEMPLATE_VARIABLE(src, src));
}

Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
if (has_bias_) {
shader.AddInput("bias", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
}
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

ORT_ENFORCE(tile_m_ == 16 || tile_m_ == 32, "tile_m must be 16 or 32.");
ORT_ENFORCE(tile_n_ == 64, "tile_n must be 64.");

return WGSL_TEMPLATE_APPLY(shader, "nn/im2col_matmul.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_),
WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_),
WGSL_TEMPLATE_PARAMETER(use_subgroup, use_subgroup_),
WGSL_TEMPLATE_VARIABLE(output, output),
WGSL_TEMPLATE_VARIABLE(src, src),
WGSL_TEMPLATE_VARIABLE(weight, weight));
}

Status ApplyIm2ColMatMulProgram(ComputeContext& context,
bool is_channels_last,
const std::vector<uint32_t>& dilations,
const std::vector<uint32_t>& pads,
const std::vector<uint32_t>& strides,
Tensor* output) {
const auto* input = context.Input<Tensor>(0);
const auto* kernel = context.Input<Tensor>(1);
const bool has_bias = context.InputCount() > 2;
const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr;

// Transpose OIHW Weight to OHWI
TensorShape kernel_shape = kernel->Shape();
const uint32_t channel_output = onnxruntime::narrow<uint32_t>(kernel_shape[0]);
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(kernel_shape[1]);
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(kernel_shape[2]);
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(kernel_shape[3]);

TensorShape nhwc_kernel_shape{channel_output, kernel_height, kernel_width, channel_input};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TensorShape nhwc_kernel_shape{channel_output, kernel_height, kernel_width, channel_input};
TensorShape ohwi_kernel_shape{channel_output, kernel_height, kernel_width, channel_input};

Tensor nhwc_kernel = context.CreateGPUTensor(kernel->DataType(), nhwc_kernel_shape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Tensor nhwc_kernel = context.CreateGPUTensor(kernel->DataType(), nhwc_kernel_shape);
Tensor ohwi_kernel = context.CreateGPUTensor(kernel->DataType(), nhwc_kernel_shape);

OIHW2OHWIProgram transpose_program{};
transpose_program.SetWorkgroupSize(64);

const uint32_t Ci_tiles = ceil_div(channel_input, 64u);
transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles);

transpose_program.AddInput({kernel,
ProgramTensorMetadataDependency::TypeAndRank});
transpose_program.AddOutput({&nhwc_kernel,
ProgramTensorMetadataDependency::TypeAndRank});
transpose_program.AddUniformVariables({{channel_output},
{channel_input},
{kernel_height},
{kernel_width},
{Ci_tiles},
{ceil_div(kernel_height * kernel_height, 4u)}});
ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program));

// im2col-matmul
const TensorShape input_shape = input->Shape();
const TensorShape output_shape = output->Shape();

const uint32_t batch = onnxruntime::narrow<uint32_t>(input_shape[0]);
const uint32_t input_height = onnxruntime::narrow<uint32_t>(input_shape[is_channels_last ? 1 : 2]);
const uint32_t input_width = onnxruntime::narrow<uint32_t>(input_shape[is_channels_last ? 2 : 3]);
const uint32_t output_height = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 1 : 2]);
const uint32_t output_width = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 2 : 3]);

const uint32_t im2col_m = output_height * output_width;
const uint32_t im2col_k = kernel_height * kernel_width * channel_input;
const uint32_t im2col_n = channel_output;

const auto [tile_m, tile_n] = ChooseTileSize(im2col_m, im2col_n);
const uint32_t workgroup_size = tile_n;
const bool use_subgroup = true;
Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, use_subgroup};
im2col_mm_program.SetWorkgroupSize(workgroup_size);

const uint32_t M_tiles = ceil_div(im2col_m, tile_m);
const uint32_t N_tiles = ceil_div(im2col_n, tile_n);
im2col_mm_program.SetDispatchGroupSize(M_tiles, N_tiles, batch);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about enhancing the current TransposeProgram with shared path instead of adding a new one?
You are doing transpose from perm [0, 1, 2, 3] to perm [0, 2, 3, 1]. It equals that we are transposing from [o, i, hw] to [o, hw, i]. You can simply extend the DoTranspose with shared path to support any shape that only transpose the last two dimensions and keep the previous dimensions unchanged. Currently, the shared path only supports 2d transpose from new shape from perm [0, 1] to new shape with perm [1, 0]. We can extend it to transpose from [0, 1, 2] to [0, 2, 1] if the transpose meets the requirement that only transpose the last two dimensions by reshape it into 3d tensor [d0 * d1*...*dn-3, dn-2, dn-1]


im2col_mm_program.AddInput({input,
ProgramTensorMetadataDependency::TypeAndRank,
4});
im2col_mm_program.AddInput({&nhwc_kernel,
ProgramTensorMetadataDependency::TypeAndRank,
4});
if (has_bias) {
im2col_mm_program.AddInput({bias,
ProgramTensorMetadataDependency::TypeAndRank});
}
im2col_mm_program.AddOutput({output,
ProgramTensorMetadataDependency::TypeAndRank});
im2col_mm_program.AddUniformVariables({{batch},
{input_height},
{input_width},
{channel_input},
{kernel_height},
{kernel_width},
{output_height},
{output_width},
{im2col_m},
{im2col_k},
{im2col_n},
{M_tiles},
{N_tiles},
{ceil_div(ceil_div(im2col_k, 4u), 4u)},
{dilations},
{pads},
{strides}});
im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, use_subgroup);

return context.RunProgram(im2col_mm_program);
}

bool CanApplyIm2ColMatMulProgram(ComputeContext& context,
const bool is_channels_last,
const ActivationKind activation_kind,
const TensorShape kernel_shape,
const AutoPadType auto_pad,
const uint32_t group) {
if (!IsDeviceSupported(context)) {
return false;
}

// TODO: Support !is_channels_last

Check warning on line 198 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc:198: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// TODO: Support fuse

Check warning on line 199 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc:199: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// TODO: Support auto pad

Check warning on line 200 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc:200: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// TODO: Support group conv

Check warning on line 201 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc:201: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
if (!is_channels_last || activation_kind != ActivationKind::None || auto_pad != AutoPadType::NOTSET || group != 1) {
return false;
}

// TODO: Support conv2d_1x1

Check warning on line 206 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc:206: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(kernel_shape[2]);
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(kernel_shape[3]);
if (kernel_height == 1 && kernel_width == 1) {
return false;
}

// TODO: Support channel input vec1

Check warning on line 213 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc:213: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(kernel_shape[1]);
if (channel_input % 4 != 0) {
return false;
}

return true;
}

} // namespace webgpu
} // namespace onnxruntime
90 changes: 90 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/im2col_matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/framework/tensor_shape.h"
#include "core/framework/tensor.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/nn/conv_attributes.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/nn/fuse_utils.h"

namespace onnxruntime {
namespace webgpu {

// Transpose OIHW Weight to OHWI
class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
public:
OIHW2OHWIProgram() : Program("OIHW2OHWI") {}

Status GenerateShaderCode(ShaderHelper& shader) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"O", ProgramUniformVariableDataType::Uint32},
{"I", ProgramUniformVariableDataType::Uint32},
{"H", ProgramUniformVariableDataType::Uint32},
{"W", ProgramUniformVariableDataType::Uint32},
{"Ci_tiles", ProgramUniformVariableDataType::Uint32},
{"H_W_tiles", ProgramUniformVariableDataType::Uint32});
};

class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> {
public:
Im2ColMatMulProgram(bool has_bias,
uint32_t tile_m,
uint32_t tile_n,
bool use_subgroup) : Program("Im2ColMatMul"),
has_bias_(has_bias),
tile_m_(tile_m),
tile_n_(tile_n),
use_subgroup_(use_subgroup) {}

Status GenerateShaderCode(ShaderHelper& shader) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"batch", ProgramUniformVariableDataType::Uint32},
{"src_h", ProgramUniformVariableDataType::Uint32},
{"src_w", ProgramUniformVariableDataType::Uint32},
{"channel_i", ProgramUniformVariableDataType::Uint32},
{"kernel_h", ProgramUniformVariableDataType::Uint32},
{"kernel_w", ProgramUniformVariableDataType::Uint32},
{"output_h", ProgramUniformVariableDataType::Uint32},
{"output_w", ProgramUniformVariableDataType::Uint32},
{"im2col_m", ProgramUniformVariableDataType::Uint32},
{"im2col_k", ProgramUniformVariableDataType::Uint32},
{"im2col_n", ProgramUniformVariableDataType::Uint32},
{"M_tiles", ProgramUniformVariableDataType::Uint32},
{"N_tiles", ProgramUniformVariableDataType::Uint32},
{"K_tiles", ProgramUniformVariableDataType::Uint32},
{"dilations", ProgramUniformVariableDataType::Uint32},
{"pads", ProgramUniformVariableDataType::Uint32},
{"strides", ProgramUniformVariableDataType::Uint32});

private:
bool has_bias_;

uint32_t tile_m_;
uint32_t tile_n_;
bool use_subgroup_;
};

bool CanApplyIm2ColMatMulProgram(ComputeContext& context,
const bool is_channels_last,
const ActivationKind activation_kind,
const TensorShape kernel_shape,
const AutoPadType auto_pad,
const uint32_t group);

Status ApplyIm2ColMatMulProgram(ComputeContext& context,
const bool is_channels_last,
const std::vector<uint32_t>& dilations,
const std::vector<uint32_t>& pads,
const std::vector<uint32_t>& strides,

Check warning on line 86 in onnxruntime/core/providers/webgpu/nn/im2col_matmul.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/nn/im2col_matmul.h:86: Add #include <vector> for vector<> [build/include_what_you_use] [4]
Tensor* output);

} // namespace webgpu
} // namespace onnxruntime
Loading
Loading