Skip to content

Commit 0acb048

Browse files
authored
[webgpu] optimize SkipLayerNormalization operator (microsoft#24164)
If the sizes of batch_size and sequence_length are ones, split the hidden_size to improve parallelism. ### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 34abb8b commit 0acb048

File tree

2 files changed

+121
-57
lines changed

2 files changed

+121
-57
lines changed

onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc

Lines changed: 118 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include "core/providers/webgpu/shader_helper.h"
5+
#include "core/providers/webgpu/string_macros.h"
56
#include "core/providers/webgpu/webgpu_utils.h"
67
#include "core/providers/webgpu/webgpu_supported_types.h"
78
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
@@ -12,7 +13,7 @@ namespace contrib {
1213
namespace webgpu {
1314

1415
Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
15-
const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
16+
const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
1617
shader.AddInput("skip", ShaderUsage::UseUniform);
1718
shader.AddInput("gamma", ShaderUsage::UseUniform);
1819
if (hasBeta_) {
@@ -26,57 +27,112 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
2627
shader.AddOutput("input_skip_bias_sum", ShaderUsage::UseUniform);
2728
}
2829

29-
int components = x.NumComponents();
30-
31-
std::string bias = (hasBias_) ? " + bias[offset1d + i] " : "";
3230
std::string simpl1 = (simplified_) ? "" : "- mean * mean ";
33-
std::string simpl2 = (simplified_) ? "" : "- element_t(mean) ";
34-
std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : "";
35-
std::string input_skip_bias_sum = (has_input_skip_bias_sum_) ? "input_skip_bias_sum[offset + i] = value;\n" : "";
36-
37-
shader.AdditionalImplementation()
38-
<< "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n")
39-
<< "alias f32_val_t = " << (components == 4 ? "vec4<f32>" : (components == 2 ? "vec2<f32>" : "f32")) << ";\n"
40-
<< "var<workgroup> sum_shared : array<f32_val_t, workgroup_size_x>;\n"
41-
<< "var<workgroup> sum_squared_shared : array<f32_val_t, workgroup_size_x>;\n";
42-
43-
shader.MainFunctionBody()
44-
<< "let ix = local_idx;\n"
45-
<< "let iy = global_idx / workgroup_size_x;\n"
46-
<< "let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;\n"
47-
<< "var stride = hidden_size_vectorized / workgroup_size_x;\n"
48-
<< "let offset = ix * stride + iy * hidden_size_vectorized;\n"
49-
<< "let offset1d = stride * ix;\n"
50-
<< "if (ix == workgroup_size_x - 1) {\n"
51-
<< " stride = hidden_size_vectorized - stride * ix;\n"
52-
<< "}\n"
53-
<< "for (var i: u32 = 0; i < stride; i++) {\n"
54-
<< " let skip_value = skip[offset + i];\n"
55-
<< " let input_value = x[offset + i];\n"
56-
<< " let value = input_value + skip_value" << bias << ";\n"
57-
<< " output[offset + i] = value;\n"
58-
<< input_skip_bias_sum
59-
<< " let f32_value = f32_val_t(value);\n"
60-
<< " sum_shared[ix] += f32_value;\n"
61-
<< " sum_squared_shared[ix] += f32_value * f32_value;\n"
62-
<< "}\n"
63-
<< "workgroupBarrier();\n"
64-
<< "var reduce_size : u32 = workgroup_size_x;\n"
65-
<< "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n"
66-
<< " reduce_size = curr_size + (reduce_size & 1);\n"
67-
<< " if (ix < curr_size) {\n"
68-
<< " sum_shared[ix] += sum_shared[ix + reduce_size];\n"
69-
<< " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n"
70-
<< " }\n"
71-
<< " workgroupBarrier();\n"
72-
<< "}\n"
73-
<< "let sum = sum_shared[0];\n"
74-
<< "let square_sum = sum_squared_shared[0];\n"
75-
<< "let mean = " << SumVector("sum", components) << " / f32(uniforms.hidden_size);\n"
76-
<< "let inv_std_dev = inverseSqrt(" << SumVector("square_sum", components) << " / f32(uniforms.hidden_size) " << simpl1 << "+ uniforms.epsilon);\n"
77-
<< "for (var i: u32 = 0; i < stride; i++) {\n"
78-
<< " output[offset + i] = (output[offset + i] " << simpl2 << ") * element_t(inv_std_dev) * gamma[offset1d + i]" << beta << ";\n"
79-
<< "};\n";
31+
std::string simpl2 = (simplified_) ? "" : "- x_element_t(mean) ";
32+
if (split_hidden_dim_) {
33+
shader.AdditionalImplementation()
34+
<< "var<workgroup> sum_shared : array<f32, workgroup_size_x>;\n"
35+
<< "var<workgroup> sum_squared_shared : array<f32, workgroup_size_x>;\n";
36+
37+
SS(input_skip_bias_sum_ss, 512);
38+
if (has_input_skip_bias_sum_) {
39+
input_skip_bias_sum_ss
40+
<< " let workgroup_half_idx = uniforms.hidden_size / (workgroup_size_x * 4);\n"
41+
<< " if (workgroup_idx >= workgroup_half_idx) {\n"
42+
<< " offset = (workgroup_idx - workgroup_half_idx) * workgroup_size_x + local_idx;\n"
43+
<< " let skip_value = skip[offset];\n"
44+
<< " let input_value = x[offset];\n"
45+
<< " let value = input_value + skip_value" << (hasBias_ ? " + bias[offset]" : "") << ";\n"
46+
<< " input_skip_bias_sum[offset] = value;\n"
47+
<< " return;\n"
48+
<< " }\n";
49+
}
50+
51+
shader.MainFunctionBody()
52+
<< " var offset: u32 = 0;\n"
53+
<< (has_input_skip_bias_sum_ ? SS_GET(input_skip_bias_sum_ss) : "")
54+
<< " var sum_vec4 = vec4<f32>(0);\n"
55+
<< " var sum_squared_vec4 = vec4<f32>(0);\n"
56+
<< " var cur_input_skip_bias_sum = x_value_t(0);\n"
57+
<< " for (var i: u32 = 0; i < uniforms.hidden_size / (workgroup_size_x * 4); i++) {\n"
58+
<< " let input_offset = i * workgroup_size_x + local_idx;\n"
59+
<< " let skip_value = skip[input_offset];\n"
60+
<< " let input_value = x[input_offset];\n"
61+
<< " let value = input_value + skip_value" << (hasBias_ ? " + bias[input_offset]" : "") << ";\n"
62+
<< " if (i == workgroup_idx) {\n"
63+
<< " cur_input_skip_bias_sum = value;\n"
64+
<< " }\n"
65+
<< " let f32_value = vec4<f32>(value);\n"
66+
<< " sum_vec4 += f32_value;\n"
67+
<< " sum_squared_vec4 += f32_value * f32_value;\n"
68+
<< " }\n"
69+
<< " var sum = " << SumVector("sum_vec4", 4) << ";\n"
70+
<< " var sum_squared = " << SumVector("sum_squared_vec4", 4) << ";\n"
71+
<< " sum_shared[local_idx] = sum;\n"
72+
<< " sum_squared_shared[local_idx] = sum_squared;\n"
73+
<< " workgroupBarrier();\n"
74+
<< " var reduce_size : u32 = workgroup_size_x;\n"
75+
<< " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n"
76+
<< " reduce_size = curr_size + (reduce_size & 1);\n"
77+
<< " if (local_idx < curr_size) {\n"
78+
<< " sum_shared[local_idx] += sum_shared[local_idx + reduce_size];\n"
79+
<< " sum_squared_shared[local_idx] += sum_squared_shared[local_idx + reduce_size];\n"
80+
<< " }\n"
81+
<< " workgroupBarrier();\n"
82+
<< " }\n"
83+
<< " let mean = sum_shared[0] / f32(uniforms.hidden_size);\n"
84+
<< " let inv_std_dev = inverseSqrt(sum_squared_shared[0] / f32(uniforms.hidden_size) " << simpl1 << "+ uniforms.epsilon);\n"
85+
<< " offset = workgroup_idx * workgroup_size_x + local_idx;\n"
86+
<< " output[offset] = ((cur_input_skip_bias_sum " << simpl2 << ") * x_element_t(inv_std_dev) * gamma[offset]" << (hasBeta_ ? " + beta[offset] " : "") << ");\n";
87+
} else {
88+
int components = x.NumComponents();
89+
std::string bias = (hasBias_) ? " + bias[offset1d + i] " : "";
90+
std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : "";
91+
std::string input_skip_bias_sum = (has_input_skip_bias_sum_) ? "input_skip_bias_sum[offset + i] = value;\n" : "";
92+
93+
shader.AdditionalImplementation()
94+
<< "alias f32_val_t = " << (components == 4 ? "vec4<f32>" : (components == 2 ? "vec2<f32>" : "f32")) << ";\n"
95+
<< "var<workgroup> sum_shared : array<f32_val_t, workgroup_size_x>;\n"
96+
<< "var<workgroup> sum_squared_shared : array<f32_val_t, workgroup_size_x>;\n";
97+
98+
shader.MainFunctionBody()
99+
<< "let ix = local_idx;\n"
100+
<< "let iy = global_idx / workgroup_size_x;\n"
101+
<< "let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;\n"
102+
<< "var stride = hidden_size_vectorized / workgroup_size_x;\n"
103+
<< "let offset = ix * stride + iy * hidden_size_vectorized;\n"
104+
<< "let offset1d = stride * ix;\n"
105+
<< "if (ix == workgroup_size_x - 1) {\n"
106+
<< " stride = hidden_size_vectorized - stride * ix;\n"
107+
<< "}\n"
108+
<< "for (var i: u32 = 0; i < stride; i++) {\n"
109+
<< " let skip_value = skip[offset + i];\n"
110+
<< " let input_value = x[offset + i];\n"
111+
<< " let value = input_value + skip_value" << bias << ";\n"
112+
<< " output[offset + i] = value;\n"
113+
<< input_skip_bias_sum
114+
<< " let f32_value = f32_val_t(value);\n"
115+
<< " sum_shared[ix] += f32_value;\n"
116+
<< " sum_squared_shared[ix] += f32_value * f32_value;\n"
117+
<< "}\n"
118+
<< "workgroupBarrier();\n"
119+
<< "var reduce_size : u32 = workgroup_size_x;\n"
120+
<< "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n"
121+
<< " reduce_size = curr_size + (reduce_size & 1);\n"
122+
<< " if (ix < curr_size) {\n"
123+
<< " sum_shared[ix] += sum_shared[ix + reduce_size];\n"
124+
<< " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n"
125+
<< " }\n"
126+
<< " workgroupBarrier();\n"
127+
<< "}\n"
128+
<< "let sum = sum_shared[0];\n"
129+
<< "let square_sum = sum_squared_shared[0];\n"
130+
<< "let mean = " << SumVector("sum", components) << " / f32(uniforms.hidden_size);\n"
131+
<< "let inv_std_dev = inverseSqrt(" << SumVector("square_sum", components) << " / f32(uniforms.hidden_size) " << simpl1 << "+ uniforms.epsilon);\n"
132+
<< "for (var i: u32 = 0; i < stride; i++) {\n"
133+
<< " output[offset + i] = (output[offset + i] " << simpl2 << ") * x_element_t(inv_std_dev) * gamma[offset1d + i]" << beta << ";\n"
134+
<< "};\n";
135+
}
80136

81137
return Status::OK();
82138
}
@@ -100,14 +156,15 @@ Status SkipLayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeCo
100156
return Status::OK();
101157
}
102158

103-
const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
104159
const uint32_t hidden_size = onnxruntime::narrow<uint32_t>(x_shape[x_shape.NumDimensions() - 1]);
105160
const int components = GetMaxComponents(hidden_size);
106161
const bool has_input_skip_bias_sum = input_skip_bias_sum != nullptr;
162+
const uint32_t norm_count = onnxruntime::narrow<uint32_t>(x_shape.SizeToDimension(x_shape.NumDimensions() - 1));
163+
const bool split_hidden_dim = hidden_size % 512 == 0 && norm_count == 1;
107164

108-
SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, is_fp16, simplified};
165+
SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim};
109166
program
110-
.CacheHint(simplified, has_input_skip_bias_sum)
167+
.CacheHint(simplified, has_input_skip_bias_sum, split_hidden_dim)
111168
.AddInputs({{x, ProgramTensorMetadataDependency::Type, components}})
112169
.AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}})
113170
.AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}})
@@ -123,6 +180,13 @@ Status SkipLayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeCo
123180
{static_cast<float>(epsilon_)},
124181
});
125182

183+
if (split_hidden_dim) {
184+
const uint32_t workgroup_size_x = 128;
185+
const uint32_t dispatch_size_x = (has_input_skip_bias_sum ? 2 : 1) * hidden_size / (workgroup_size_x * components);
186+
program.SetDispatchGroupSize(dispatch_size_x, 1, 1)
187+
.SetWorkgroupSize(workgroup_size_x);
188+
}
189+
126190
if (beta != nullptr) {
127191
program.AddInput({beta, ProgramTensorMetadataDependency::Type, components});
128192
}

onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ using onnxruntime::webgpu::ComputeContext;
1515

1616
class SkipLayerNormProgram final : public Program<SkipLayerNormProgram> {
1717
public:
18-
SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, uint32_t hidden_size, bool has_input_skip_bias_sum, bool is_fp16, bool simplified) : Program{"SkipLayerNorm"} {
18+
SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, uint32_t hidden_size, bool has_input_skip_bias_sum, bool simplified, bool split_hidden_dim) : Program{"SkipLayerNorm"} {
1919
epsilon_ = epsilon;
2020
hasBeta_ = hasBeta;
2121
hasBias_ = hasBias;
2222
epsilon_ = epsilon;
2323
hidden_size_ = hidden_size;
2424
has_input_skip_bias_sum_ = has_input_skip_bias_sum;
2525
simplified_ = simplified;
26-
is_fp16_ = is_fp16;
26+
split_hidden_dim_ = split_hidden_dim;
2727
}
2828

2929
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -39,8 +39,8 @@ class SkipLayerNormProgram final : public Program<SkipLayerNormProgram> {
3939
float epsilon_;
4040
uint32_t hidden_size_;
4141
bool has_input_skip_bias_sum_;
42-
bool is_fp16_;
4342
bool simplified_;
43+
bool split_hidden_dim_;
4444
};
4545

4646
template <bool simplified>

0 commit comments

Comments
 (0)