22// Licensed under the MIT License.
33
44#include " core/providers/webgpu/shader_helper.h"
5+ #include " core/providers/webgpu/string_macros.h"
56#include " core/providers/webgpu/webgpu_utils.h"
67#include " core/providers/webgpu/webgpu_supported_types.h"
78#include " contrib_ops/webgpu/webgpu_contrib_kernels.h"
@@ -12,7 +13,7 @@ namespace contrib {
1213namespace webgpu {
1314
1415Status SkipLayerNormProgram::GenerateShaderCode (ShaderHelper& shader) const {
15- const auto & x = shader.AddInput (" x" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
16+ const auto & x = shader.AddInput (" x" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias );
1617 shader.AddInput (" skip" , ShaderUsage::UseUniform);
1718 shader.AddInput (" gamma" , ShaderUsage::UseUniform);
1819 if (hasBeta_) {
@@ -26,57 +27,112 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
2627 shader.AddOutput (" input_skip_bias_sum" , ShaderUsage::UseUniform);
2728 }
2829
29- int components = x.NumComponents ();
30-
31- std::string bias = (hasBias_) ? " + bias[offset1d + i] " : " " ;
3230 std::string simpl1 = (simplified_) ? " " : " - mean * mean " ;
33- std::string simpl2 = (simplified_) ? " " : " - element_t(mean) " ;
34- std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : " " ;
35- std::string input_skip_bias_sum = (has_input_skip_bias_sum_) ? " input_skip_bias_sum[offset + i] = value;\n " : " " ;
36-
37- shader.AdditionalImplementation ()
38- << " alias element_t = " << (is_fp16_ ? " f16;\n " : " f32;\n " )
39- << " alias f32_val_t = " << (components == 4 ? " vec4<f32>" : (components == 2 ? " vec2<f32>" : " f32" )) << " ;\n "
40- << " var<workgroup> sum_shared : array<f32_val_t, workgroup_size_x>;\n "
41- << " var<workgroup> sum_squared_shared : array<f32_val_t, workgroup_size_x>;\n " ;
42-
43- shader.MainFunctionBody ()
44- << " let ix = local_idx;\n "
45- << " let iy = global_idx / workgroup_size_x;\n "
46- << " let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;\n "
47- << " var stride = hidden_size_vectorized / workgroup_size_x;\n "
48- << " let offset = ix * stride + iy * hidden_size_vectorized;\n "
49- << " let offset1d = stride * ix;\n "
50- << " if (ix == workgroup_size_x - 1) {\n "
51- << " stride = hidden_size_vectorized - stride * ix;\n "
52- << " }\n "
53- << " for (var i: u32 = 0; i < stride; i++) {\n "
54- << " let skip_value = skip[offset + i];\n "
55- << " let input_value = x[offset + i];\n "
56- << " let value = input_value + skip_value" << bias << " ;\n "
57- << " output[offset + i] = value;\n "
58- << input_skip_bias_sum
59- << " let f32_value = f32_val_t(value);\n "
60- << " sum_shared[ix] += f32_value;\n "
61- << " sum_squared_shared[ix] += f32_value * f32_value;\n "
62- << " }\n "
63- << " workgroupBarrier();\n "
64- << " var reduce_size : u32 = workgroup_size_x;\n "
65- << " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n "
66- << " reduce_size = curr_size + (reduce_size & 1);\n "
67- << " if (ix < curr_size) {\n "
68- << " sum_shared[ix] += sum_shared[ix + reduce_size];\n "
69- << " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n "
70- << " }\n "
71- << " workgroupBarrier();\n "
72- << " }\n "
73- << " let sum = sum_shared[0];\n "
74- << " let square_sum = sum_squared_shared[0];\n "
75- << " let mean = " << SumVector (" sum" , components) << " / f32(uniforms.hidden_size);\n "
76- << " let inv_std_dev = inverseSqrt(" << SumVector (" square_sum" , components) << " / f32(uniforms.hidden_size) " << simpl1 << " + uniforms.epsilon);\n "
77- << " for (var i: u32 = 0; i < stride; i++) {\n "
78- << " output[offset + i] = (output[offset + i] " << simpl2 << " ) * element_t(inv_std_dev) * gamma[offset1d + i]" << beta << " ;\n "
79- << " };\n " ;
31+ std::string simpl2 = (simplified_) ? " " : " - x_element_t(mean) " ;
32+ if (split_hidden_dim_) {
33+ shader.AdditionalImplementation ()
34+ << " var<workgroup> sum_shared : array<f32, workgroup_size_x>;\n "
35+ << " var<workgroup> sum_squared_shared : array<f32, workgroup_size_x>;\n " ;
36+
37+ SS (input_skip_bias_sum_ss, 512 );
38+ if (has_input_skip_bias_sum_) {
39+ input_skip_bias_sum_ss
40+ << " let workgroup_half_idx = uniforms.hidden_size / (workgroup_size_x * 4);\n "
41+ << " if (workgroup_idx >= workgroup_half_idx) {\n "
42+ << " offset = (workgroup_idx - workgroup_half_idx) * workgroup_size_x + local_idx;\n "
43+ << " let skip_value = skip[offset];\n "
44+ << " let input_value = x[offset];\n "
45+ << " let value = input_value + skip_value" << (hasBias_ ? " + bias[offset]" : " " ) << " ;\n "
46+ << " input_skip_bias_sum[offset] = value;\n "
47+ << " return;\n "
48+ << " }\n " ;
49+ }
50+
51+ shader.MainFunctionBody ()
52+ << " var offset: u32 = 0;\n "
53+ << (has_input_skip_bias_sum_ ? SS_GET (input_skip_bias_sum_ss) : " " )
54+ << " var sum_vec4 = vec4<f32>(0);\n "
55+ << " var sum_squared_vec4 = vec4<f32>(0);\n "
56+ << " var cur_input_skip_bias_sum = x_value_t(0);\n "
57+ << " for (var i: u32 = 0; i < uniforms.hidden_size / (workgroup_size_x * 4); i++) {\n "
58+ << " let input_offset = i * workgroup_size_x + local_idx;\n "
59+ << " let skip_value = skip[input_offset];\n "
60+ << " let input_value = x[input_offset];\n "
61+ << " let value = input_value + skip_value" << (hasBias_ ? " + bias[input_offset]" : " " ) << " ;\n "
62+ << " if (i == workgroup_idx) {\n "
63+ << " cur_input_skip_bias_sum = value;\n "
64+ << " }\n "
65+ << " let f32_value = vec4<f32>(value);\n "
66+ << " sum_vec4 += f32_value;\n "
67+ << " sum_squared_vec4 += f32_value * f32_value;\n "
68+ << " }\n "
69+ << " var sum = " << SumVector (" sum_vec4" , 4 ) << " ;\n "
70+ << " var sum_squared = " << SumVector (" sum_squared_vec4" , 4 ) << " ;\n "
71+ << " sum_shared[local_idx] = sum;\n "
72+ << " sum_squared_shared[local_idx] = sum_squared;\n "
73+ << " workgroupBarrier();\n "
74+ << " var reduce_size : u32 = workgroup_size_x;\n "
75+ << " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n "
76+ << " reduce_size = curr_size + (reduce_size & 1);\n "
77+ << " if (local_idx < curr_size) {\n "
78+ << " sum_shared[local_idx] += sum_shared[local_idx + reduce_size];\n "
79+ << " sum_squared_shared[local_idx] += sum_squared_shared[local_idx + reduce_size];\n "
80+ << " }\n "
81+ << " workgroupBarrier();\n "
82+ << " }\n "
83+ << " let mean = sum_shared[0] / f32(uniforms.hidden_size);\n "
84+ << " let inv_std_dev = inverseSqrt(sum_squared_shared[0] / f32(uniforms.hidden_size) " << simpl1 << " + uniforms.epsilon);\n "
85+ << " offset = workgroup_idx * workgroup_size_x + local_idx;\n "
86+ << " output[offset] = ((cur_input_skip_bias_sum " << simpl2 << " ) * x_element_t(inv_std_dev) * gamma[offset]" << (hasBeta_ ? " + beta[offset] " : " " ) << " );\n " ;
87+ } else {
88+ int components = x.NumComponents ();
89+ std::string bias = (hasBias_) ? " + bias[offset1d + i] " : " " ;
90+ std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : " " ;
91+ std::string input_skip_bias_sum = (has_input_skip_bias_sum_) ? " input_skip_bias_sum[offset + i] = value;\n " : " " ;
92+
93+ shader.AdditionalImplementation ()
94+ << " alias f32_val_t = " << (components == 4 ? " vec4<f32>" : (components == 2 ? " vec2<f32>" : " f32" )) << " ;\n "
95+ << " var<workgroup> sum_shared : array<f32_val_t, workgroup_size_x>;\n "
96+ << " var<workgroup> sum_squared_shared : array<f32_val_t, workgroup_size_x>;\n " ;
97+
98+ shader.MainFunctionBody ()
99+ << " let ix = local_idx;\n "
100+ << " let iy = global_idx / workgroup_size_x;\n "
101+ << " let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;\n "
102+ << " var stride = hidden_size_vectorized / workgroup_size_x;\n "
103+ << " let offset = ix * stride + iy * hidden_size_vectorized;\n "
104+ << " let offset1d = stride * ix;\n "
105+ << " if (ix == workgroup_size_x - 1) {\n "
106+ << " stride = hidden_size_vectorized - stride * ix;\n "
107+ << " }\n "
108+ << " for (var i: u32 = 0; i < stride; i++) {\n "
109+ << " let skip_value = skip[offset + i];\n "
110+ << " let input_value = x[offset + i];\n "
111+ << " let value = input_value + skip_value" << bias << " ;\n "
112+ << " output[offset + i] = value;\n "
113+ << input_skip_bias_sum
114+ << " let f32_value = f32_val_t(value);\n "
115+ << " sum_shared[ix] += f32_value;\n "
116+ << " sum_squared_shared[ix] += f32_value * f32_value;\n "
117+ << " }\n "
118+ << " workgroupBarrier();\n "
119+ << " var reduce_size : u32 = workgroup_size_x;\n "
120+ << " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n "
121+ << " reduce_size = curr_size + (reduce_size & 1);\n "
122+ << " if (ix < curr_size) {\n "
123+ << " sum_shared[ix] += sum_shared[ix + reduce_size];\n "
124+ << " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n "
125+ << " }\n "
126+ << " workgroupBarrier();\n "
127+ << " }\n "
128+ << " let sum = sum_shared[0];\n "
129+ << " let square_sum = sum_squared_shared[0];\n "
130+ << " let mean = " << SumVector (" sum" , components) << " / f32(uniforms.hidden_size);\n "
131+ << " let inv_std_dev = inverseSqrt(" << SumVector (" square_sum" , components) << " / f32(uniforms.hidden_size) " << simpl1 << " + uniforms.epsilon);\n "
132+ << " for (var i: u32 = 0; i < stride; i++) {\n "
133+ << " output[offset + i] = (output[offset + i] " << simpl2 << " ) * x_element_t(inv_std_dev) * gamma[offset1d + i]" << beta << " ;\n "
134+ << " };\n " ;
135+ }
80136
81137 return Status::OK ();
82138}
@@ -100,14 +156,15 @@ Status SkipLayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeCo
100156 return Status::OK ();
101157 }
102158
103- const bool is_fp16 = x->GetElementType () == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
104159 const uint32_t hidden_size = onnxruntime::narrow<uint32_t >(x_shape[x_shape.NumDimensions () - 1 ]);
105160 const int components = GetMaxComponents (hidden_size);
106161 const bool has_input_skip_bias_sum = input_skip_bias_sum != nullptr ;
162+ const uint32_t norm_count = onnxruntime::narrow<uint32_t >(x_shape.SizeToDimension (x_shape.NumDimensions () - 1 ));
163+ const bool split_hidden_dim = hidden_size % 512 == 0 && norm_count == 1 ;
107164
108- SkipLayerNormProgram program{beta != nullptr , bias != nullptr , epsilon_, hidden_size, has_input_skip_bias_sum, is_fp16, simplified };
165+ SkipLayerNormProgram program{beta != nullptr , bias != nullptr , epsilon_, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim };
109166 program
110- .CacheHint (simplified, has_input_skip_bias_sum)
167+ .CacheHint (simplified, has_input_skip_bias_sum, split_hidden_dim )
111168 .AddInputs ({{x, ProgramTensorMetadataDependency::Type, components}})
112169 .AddInputs ({{skip, ProgramTensorMetadataDependency::Type, components}})
113170 .AddInputs ({{gamma, ProgramTensorMetadataDependency::Type, components}})
@@ -123,6 +180,13 @@ Status SkipLayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeCo
123180 {static_cast <float >(epsilon_)},
124181 });
125182
183+ if (split_hidden_dim) {
184+ const uint32_t workgroup_size_x = 128 ;
185+ const uint32_t dispatch_size_x = (has_input_skip_bias_sum ? 2 : 1 ) * hidden_size / (workgroup_size_x * components);
186+ program.SetDispatchGroupSize (dispatch_size_x, 1 , 1 )
187+ .SetWorkgroupSize (workgroup_size_x);
188+ }
189+
126190 if (beta != nullptr ) {
127191 program.AddInput ({beta, ProgramTensorMetadataDependency::Type, components});
128192 }
0 commit comments