Skip to content

Commit cfb0a72

Browse files
[WebGPU EP] introduce BiasAdd contrib op (#23861)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 834adde commit cfb0a72

File tree

3 files changed

+114
-2
lines changed

3 files changed

+114
-2
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/shader_helper.h"
5+
#include "core/providers/webgpu/webgpu_supported_types.h"
6+
#include "contrib_ops/webgpu/bert/bias_add.h"
7+
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
8+
9+
namespace onnxruntime {
10+
namespace contrib {
11+
namespace webgpu {
12+
13+
ONNX_OPERATOR_KERNEL_EX(
14+
BiasAdd,
15+
kMSDomain,
16+
1,
17+
kWebGpuExecutionProvider,
18+
(*KernelDefBuilder::Create())
19+
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
20+
BiasAdd);
21+
22+
Status BiasAddProgram::GenerateShaderCode(ShaderHelper& shader) const {
23+
const ShaderVariableHelper& input = shader.AddInput("input");
24+
const ShaderVariableHelper& bias = shader.AddInput("bias");
25+
const ShaderVariableHelper& residual = shader.AddInput("residual");
26+
const ShaderVariableHelper& output = shader.AddOutput("output");
27+
28+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
29+
<< "let value = " << input.GetByOffset("global_idx")
30+
<< " + " << bias.GetByOffset("global_idx % uniforms.channels")
31+
<< " + " << residual.GetByOffset("global_idx") << ";\n"
32+
<< output.SetByOffset("global_idx", "value");
33+
34+
return Status::OK();
35+
}
36+
37+
static int64_t GetMaxComponents(int64_t size) {
38+
if (size % 4 == 0) {
39+
return 4;
40+
} else if (size % 2 == 0) {
41+
return 2;
42+
}
43+
return 1;
44+
}
45+
46+
Status BiasAdd::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
47+
const auto* input = context.Input(0);
48+
const auto* bias = context.Input(1);
49+
const auto* residual = context.Input(2);
50+
51+
TensorShape input_shape = input->Shape();
52+
53+
if (input_shape.NumDimensions() != 3) {
54+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasAdd input should have 3 dimensions.");
55+
}
56+
57+
int64_t channels = input_shape[2];
58+
int64_t components = GetMaxComponents(channels);
59+
channels /= components;
60+
61+
TensorShape bias_shape = bias->Shape();
62+
if (bias_shape.NumDimensions() != 1 || bias_shape[0] != channels) {
63+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasAdd bias should have 1 dimension with size equal to the number of channels.");
64+
}
65+
66+
auto* output = context.Output(0, input_shape);
67+
int64_t output_size = output->Shape().Size() / components;
68+
69+
BiasAddProgram program{};
70+
program.AddInputs({{input}, {bias}, {residual}})
71+
.AddOutput({output})
72+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
73+
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
74+
{static_cast<uint32_t>(channels)}});
75+
return context.RunProgram(program);
76+
}
77+
78+
} // namespace webgpu
79+
} // namespace contrib
80+
} // namespace onnxruntime
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/program.h"
7+
#include "core/providers/webgpu/webgpu_kernel.h"
8+
9+
namespace onnxruntime {
10+
namespace contrib {
11+
namespace webgpu {
12+
13+
using namespace onnxruntime::webgpu;
14+
using onnxruntime::webgpu::ComputeContext;
15+
16+
class BiasAddProgram final : public Program<BiasAddProgram> {
17+
public:
18+
BiasAddProgram() : Program{"BiasAdd"} {}
19+
Status GenerateShaderCode(ShaderHelper& sh) const override;
20+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
21+
{"channels", ProgramUniformVariableDataType::Uint32});
22+
};
23+
24+
class BiasAdd final : public WebGpuKernel {
25+
public:
26+
BiasAdd(const OpKernelInfo& info) : WebGpuKernel(info) {}
27+
Status ComputeInternal(ComputeContext& context) const override;
28+
};
29+
30+
} // namespace webgpu
31+
} // namespace contrib
32+
} // namespace onnxruntime

onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
3737
static const BuildKernelCreateInfoFn function_table[] = {
3838
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
3939
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention)>,
40-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
41-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
40+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
41+
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
4242
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
4343
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
4444
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,

0 commit comments

Comments
 (0)