@@ -13,23 +13,25 @@ namespace onnxruntime {
1313namespace webgpu {
1414
1515Status ComputeChannelScaleShiftProgram::GenerateShaderCode (ShaderHelper& shader) const {
16- const auto & input = shader.AddInput (" x" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseIndicesTypeAlias );
16+ const auto & input = shader.AddInput (" x" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
1717 const auto & scale = shader.AddInput (" scale" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
1818 const auto & bias = shader.AddInput (" bias" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
19- const ShaderVariableHelper& output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
19+ const ShaderVariableHelper& output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias );
2020
21- shader.AdditionalImplementation () << " var<workgroup> workgroup_shared_sum : array<x_value_t, " << workgroup_size_ << " >;\n "
22- << " var<workgroup> workgroup_shared_squared_sum : array<x_value_t, " << workgroup_size_ << " >;\n "
21+ shader.AdditionalImplementation () << " alias f32_val_t = " << (components_ == 4 ? " vec4<f32>" : (components_ == 2 ? " vec2<f32>" : " f32" )) << " ;\n "
22+ << " var<workgroup> workgroup_shared_sum : array<f32_val_t, " << workgroup_size_ << " >;\n "
23+ << " var<workgroup> workgroup_shared_squared_sum : array<f32_val_t, " << workgroup_size_ << " >;\n "
2324 << " const workgroup_size = " << workgroup_size_ << " ;\n " ;
25+
2426 shader.MainFunctionBody () << " let batch = workgroup_idx / uniforms.x_shape[1];\n "
2527 << " let channel = workgroup_idx % uniforms.x_shape[1];\n "
2628 << " let hight = uniforms.x_shape[2];\n "
2729 << " // initialize workgroup memory<< \n "
28- << " var sum = x_value_t (0);\n "
29- << " var squared_sum = x_value_t (0);\n "
30+ << " var sum = f32_val_t (0);\n "
31+ << " var squared_sum = f32_val_t (0);\n "
3032 << " for (var h = local_idx; h < hight; h += workgroup_size) {\n "
3133 << " let indices = x_indices_t(batch, channel, h);\n "
32- << " let value =" << input.GetByIndices (" indices" ) << " ;\n "
34+ << " let value = f32_val_t( " << input.GetByIndices (" indices" ) << " ) ;\n "
3335 << " sum += value;\n "
3436 << " squared_sum += value * value;\n "
3537 << " }\n "
@@ -44,12 +46,12 @@ Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader)
4446 << " workgroupBarrier();\n "
4547 << " }\n "
4648 << " if (local_idx == 0) {\n "
47- << " let sum_final = " << SumVector (" workgroup_shared_sum[0]" , components_) << " / x_element_t (hight * " << components_ << " );\n "
48- << " let squared_sum_final = " << SumVector (" workgroup_shared_squared_sum[0]" , components_) << " / x_element_t (hight * " << components_ << " );\n "
49- << " let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + x_element_t (" << std::to_string (epsilon_) << " ));\n "
50- << " let channel_scale = inv_std_dev * " << scale.GetByOffset (" channel" ) << " ;\n "
51- << " let channel_shift = " << bias.GetByOffset (" channel" ) << " - sum_final * channel_scale;\n "
52- << " " << output.SetByOffset (" workgroup_idx" , " output_value_t(channel_scale, channel_shift)" ) << " ;\n "
49+ << " let sum_final = " << SumVector (" workgroup_shared_sum[0]" , components_) << " / f32 (hight * " << components_ << " );\n "
50+ << " let squared_sum_final = " << SumVector (" workgroup_shared_squared_sum[0]" , components_) << " / f32 (hight * " << components_ << " );\n "
51+ << " let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + f32 (" << std::to_string (epsilon_) << " ));\n "
52+ << " let channel_scale = inv_std_dev * f32( " << scale.GetByOffset (" channel" ) << " ) ;\n "
53+ << " let channel_shift = f32( " << bias.GetByOffset (" channel" ) << " ) - sum_final * channel_scale;\n "
54+ << " " << output.SetByOffset (" workgroup_idx" , " output_value_t(output_element_t( channel_scale), output_element_t( channel_shift) )" ) << " ;\n "
5355 << " }\n " ;
5456 return Status::OK ();
5557}
@@ -110,7 +112,7 @@ Status InstanceNormProgramNHWC::GenerateShaderCode(ShaderHelper& shader) const {
110112 << " let input_value = " << input.GetByOffset (" global_idx" ) << " ;\n " ;
111113 if (components_ > 1 ) {
112114 shader.MainFunctionBody () << " for (var i : u32 = 0; i < uniforms.components; i = i + 1) {\n "
113- << " let scale_sift = " << channel_scale_shift.GetByOffset (" scale_offset + i" ) << " ;\n "
115+ << " let scale_sift = " << channel_scale_shift.GetByOffset (" uniforms.components * scale_offset + i" ) << " ;\n "
114116 << " scale[i] = input_element_t(scale_sift.x);\n "
115117 << " shift[i] = input_element_t(scale_sift.y);\n "
116118 << " }\n " ;
0 commit comments