Skip to content

Commit 9102aae

Browse files
[Native WebGPU] Add Conv, ConTranspose and FusedConv (microsoft#24186)
### Description Add Conv, ConvTranspose, and FusedConv to the WebGPU execution provider. ### Motivation and Context Required for operator coverage.
1 parent a4976e3 commit 9102aae

32 files changed

+1747
-183
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/shader_helper.h"
5+
#include "core/providers/webgpu/webgpu_supported_types.h"
6+
#include "core/providers/webgpu/nn/conv.h"
7+
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
8+
#include "core/providers/webgpu/nn/fuse_utils.h"
9+
10+
namespace onnxruntime {
11+
namespace contrib {
12+
namespace webgpu {
13+
using onnxruntime::webgpu::Conv;
14+
template <bool is_channels_last>
15+
class FusedConv final : public Conv<is_channels_last, true> {
16+
public:
17+
FusedConv(const OpKernelInfo& info) : Conv<is_channels_last, true>(info) {
18+
ORT_ENFORCE(GetFusedActivationAttr(info, Conv<is_channels_last, true>::activation_).IsOK());
19+
}
20+
};
21+
22+
ONNX_OPERATOR_KERNEL_EX(
23+
FusedConv,
24+
kMSDomain,
25+
1,
26+
kWebGpuExecutionProvider,
27+
(*KernelDefBuilder::Create())
28+
.TypeConstraint("T", onnxruntime::webgpu::WebGpuSupportedFloatTypes()),
29+
FusedConv<false>);
30+
31+
} // namespace webgpu
32+
} // namespace contrib
33+
} // namespace onnxruntime

onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
4040
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
4141
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
4242
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
43-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
43+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
4444
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
4545
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
4646
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,

onnxruntime/core/optimizer/conv_activation_fusion.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class ConvActivationSelector : public NodeSelector {
121121
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) {
122122
return std::nullopt;
123123
}
124-
} else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider) {
124+
} else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider || node_ep == kWebGpuExecutionProvider) {
125125
if (!is_supported_non_cuda_rocm_ep_activation(*next_node) &&
126126
!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "HardSigmoid", {6})) {
127127
return std::nullopt;

onnxruntime/core/optimizer/graph_transformer_utils.cc

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -296,17 +296,19 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
296296
onnxruntime::kCudaExecutionProvider,
297297
onnxruntime::kRocmExecutionProvider,
298298
onnxruntime::kDmlExecutionProvider};
299-
const InlinedHashSet<std::string_view> cpu_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider,
300-
onnxruntime::kRocmExecutionProvider,
301-
onnxruntime::kAclExecutionProvider,
302-
onnxruntime::kArmNNExecutionProvider,
303-
onnxruntime::kJsExecutionProvider};
304-
const InlinedHashSet<std::string_view> cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider,
305-
onnxruntime::kCudaExecutionProvider,
306-
onnxruntime::kRocmExecutionProvider,
307-
onnxruntime::kAclExecutionProvider,
308-
onnxruntime::kArmNNExecutionProvider,
309-
onnxruntime::kJsExecutionProvider};
299+
const InlinedHashSet<std::string_view> cpu_rocm_acl_armnn_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider,
300+
onnxruntime::kRocmExecutionProvider,
301+
onnxruntime::kAclExecutionProvider,
302+
onnxruntime::kArmNNExecutionProvider,
303+
onnxruntime::kJsExecutionProvider,
304+
onnxruntime::kWebGpuExecutionProvider};
305+
const InlinedHashSet<std::string_view> cpu_cuda_rocm_acl_armnn_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider,
306+
onnxruntime::kCudaExecutionProvider,
307+
onnxruntime::kRocmExecutionProvider,
308+
onnxruntime::kAclExecutionProvider,
309+
onnxruntime::kArmNNExecutionProvider,
310+
onnxruntime::kJsExecutionProvider,
311+
onnxruntime::kWebGpuExecutionProvider};
310312
const InlinedHashSet<std::string_view> cpu_dml_acl_eps = {onnxruntime::kCpuExecutionProvider,
311313
onnxruntime::kDmlExecutionProvider,
312314
onnxruntime::kAclExecutionProvider};
@@ -338,7 +340,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
338340
transformers.emplace_back(std::make_unique<MatMulIntegerToFloatFusion>(cpu_dml_acl_eps));
339341
transformers.emplace_back(std::make_unique<DynamicQuantizeMatMulFusion>(cpu_acl_eps));
340342

341-
transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_rocm_acl_armnn_js_eps));
343+
transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_rocm_acl_armnn_js_webgpu_eps));
342344

343345
transformers.emplace_back(std::make_unique<GeluFusion>(cpu_acl_cuda_dml_rocm_eps, level));
344346
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_acl_cuda_dml_rocm_eps, level));

onnxruntime/core/providers/webgpu/math/matmul.cc

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
#include "core/providers/cpu/tensor/utils.h"
77
#include "core/providers/webgpu/shader_helper.h"
88
#include "core/providers/webgpu/webgpu_supported_types.h"
9-
9+
#include "core/providers/webgpu/nn/fuse_utils.h"
1010
#include "core/providers/webgpu/data_transfer.h"
11+
1112
namespace onnxruntime {
1213
namespace webgpu {
1314

@@ -54,11 +55,12 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
5455
std::string process_bias;
5556
if (has_bias_) {
5657
shader.AddInput("bias", ShaderUsage::UseUniform);
57-
process_bias = "value += output_value_t(bias[row + i]);";
58+
process_bias = is_channels_last_ ? "value += output_value_t(bias[col])" : "value += output_value_t(bias[row + i]);";
5859
}
5960

61+
std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t");
6062
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform |
61-
ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
63+
ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
6264
const auto& batch_dims = shader.AddIndices("batch_dims");
6365

6466
int a_components = a.NumComponents();
@@ -90,6 +92,7 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
9092
<< "for (var i = 0u; i < " << output_number_ << "u; i++) {\n"
9193
<< " var value = values[i];\n"
9294
<< process_bias << "\n"
95+
<< apply_activation << "\n"
9396
<< " let cur_indices = output_indices_t(batch, row + i, col/ " << components << ");\n"
9497
<< " let offset = " << output.IndicesToOffset("cur_indices") << ";\n"
9598
<< output.SetByOffset("offset", "value")
@@ -127,7 +130,7 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
127130
const int64_t a_rows = a->Shape().NumDimensions() > 1 ? a->Shape()[a->Shape().NumDimensions() - 2] : 1;
128131
TensorShape output_shape_shader({batch_size, a_rows, helper.N() / components});
129132

130-
MatMulNaiveProgram program{output_rank, output_number, has_bias};
133+
MatMulNaiveProgram program{Activation(), output_rank, output_number, has_bias};
131134

132135
program
133136
.CacheHint(std::to_string(components), std::to_string(a_components), std::to_string(output_number))
@@ -147,11 +150,32 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
147150
return context.RunProgram(program);
148151
}
149152

150-
int64_t batchA = a->Shape().SizeToDimension(a->Shape().NumDimensions() - 2);
151-
int64_t batchB = b->Shape().SizeToDimension(b->Shape().NumDimensions() - 2);
153+
std::vector<const Tensor*> inputs(has_bias ? 3 : 2);
154+
inputs[0] = a;
155+
inputs[1] = b;
156+
if (has_bias) {
157+
const auto* bias = context.Input(2);
158+
inputs.push_back(bias);
159+
}
160+
auto program = CreateMatMulProgram(Activation(), inputs, output_tensor, false);
161+
162+
return context.RunProgram(program);
163+
}
164+
165+
MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output_tensor, bool is_channels_last,
166+
const TensorShape& input_a_reshape,
167+
const TensorShape& input_b_reshape) {
168+
const auto* a = inputs[0];
169+
const auto* b = inputs[1];
170+
bool has_bias = inputs.size() > 2;
171+
TensorShape a_shape = input_a_reshape.NumDimensions() > 0 ? input_a_reshape : a->Shape();
172+
TensorShape b_shape = input_b_reshape.NumDimensions() > 0 ? input_b_reshape : b->Shape();
173+
174+
MatMulComputeHelper helper;
175+
ORT_THROW_IF_ERROR(helper.Compute(a_shape, b_shape));
176+
int64_t batchA = a_shape.SizeToDimension(a_shape.NumDimensions() - 2);
177+
int64_t batchB = b_shape.SizeToDimension(b_shape.NumDimensions() - 2);
152178

153-
TensorShape a_shape = a->Shape();
154-
TensorShape b_shape = b->Shape();
155179
TensorShape output_shape = helper.OutputShape();
156180

157181
const int64_t dim_output_outer = output_shape[output_shape.NumDimensions() - 2];
@@ -184,44 +208,46 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
184208
const int64_t batch_size = outer_dims.Size();
185209

186210
// Get dimensions for matrix multiplication from TensorShape
187-
const int32_t dim_a_outer = narrow<int32_t>(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension
188-
const int32_t dim_inner = narrow<int32_t>(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension
189-
const int32_t dim_b_outer = narrow<int32_t>(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension
211+
const uint32_t dim_a_outer = narrow<uint32_t>(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension
212+
const uint32_t dim_inner = narrow<uint32_t>(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension
213+
const uint32_t dim_b_outer = narrow<uint32_t>(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension
190214

191215
const bool is_vec4 = dim_inner % 4 == 0 && dim_b_outer % 4 == 0;
192216

193217
InlinedVector<int64_t> elements_per_thread = dim_a_outer <= 8
194218
? InlinedVector<int64_t>({4, 1, 1})
195219
: InlinedVector<int64_t>({4, 4, 1});
196220

197-
const uint32_t dispatch_x = narrow<uint32_t>((dim_b_outer + MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) /
198-
(MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0]));
199-
const uint32_t dispatch_y = narrow<uint32_t>((dim_a_outer + MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
200-
(MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
201-
const uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
202-
(MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
221+
const uint32_t dispatch_x = narrow<uint32_t>((dim_b_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) /
222+
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0]));
223+
const uint32_t dispatch_y = narrow<uint32_t>((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
224+
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
225+
const uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
226+
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
203227

204228
const int components = is_vec4 ? 4 : 1;
205229
const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components);
206230
const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components);
207231
const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components});
208232

209-
MatMulProgram program{has_bias, is_vec4, elements_per_thread};
233+
MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last};
210234
program
211-
.CacheHint(absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4))
235+
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4))
212236
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
213237
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
214238
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}})
215239
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}})
216240
.AddIndices(outer_dims)
217241
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
218-
.SetWorkgroupSize(MATMUL_PACKED_WORKGROUP_SIZE_X, MATMUL_PACKED_WORKGROUP_SIZE_Y, MATMUL_PACKED_WORKGROUP_SIZE_Z);
242+
.SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z);
219243

220244
if (has_bias) {
221-
const auto* bias = context.Input(2);
222-
program.AddInput({bias, ProgramTensorMetadataDependency::Rank, 1});
245+
auto bias_components = is_channels_last ? components : 1;
246+
const auto* bias = inputs[2];
247+
TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components);
248+
program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components});
223249
}
224-
return context.RunProgram(program);
250+
return program;
225251
}
226252

227253
} // namespace webgpu

onnxruntime/core/providers/webgpu/math/matmul.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,29 @@
99
#include "core/providers/webgpu/math/matmul_utils.h"
1010
#include "core/providers/webgpu/math/matmul_packed.h"
1111
#include "core/providers/webgpu/webgpu_utils.h"
12+
#include "core/providers/webgpu/nn/fuse_utils.h"
1213

1314
namespace onnxruntime {
1415
namespace webgpu {
1516

17+
MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output, bool is_channels_last,
18+
const TensorShape& input_a_reshape = TensorShape(),
19+
const TensorShape& input_b_reshape = TensorShape());
20+
1621
class MatMul final : public WebGpuKernel {
1722
public:
1823
MatMul(const OpKernelInfo& info) : WebGpuKernel{info} {}
1924

2025
Status ComputeInternal(ComputeContext& context) const override;
21-
2226
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8;
2327
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8;
2428
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Z = 1;
2529
};
2630

2731
class MatMulNaiveProgram final : public Program<MatMulNaiveProgram> {
2832
public:
29-
MatMulNaiveProgram(const size_t output_rank, int64_t output_number, bool has_bias)
30-
: Program{"MatMulNaive"}, output_rank_(output_rank), output_number_(output_number), has_bias_{has_bias} {
33+
MatMulNaiveProgram(const Activation& activation, const size_t output_rank, int64_t output_number, bool has_bias, bool is_channels_last = false)
34+
: Program{"MatMulNaive"}, activation_(activation), output_rank_(output_rank), output_number_(output_number), has_bias_{has_bias}, is_channels_last_(is_channels_last) {
3135
}
3236

3337
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -38,9 +42,11 @@ class MatMulNaiveProgram final : public Program<MatMulNaiveProgram> {
3842
{"K", ProgramUniformVariableDataType::Uint32});
3943

4044
private:
45+
const Activation& activation_;
4146
const size_t output_rank_;
4247
const int64_t output_number_;
4348
const bool has_bias_;
49+
const bool is_channels_last_;
4450
};
4551

4652
} // namespace webgpu

0 commit comments

Comments
 (0)