Skip to content

Commit e9bb150

Browse files
authored
webgpu: fix InstanceNorm errors (microsoft#24514)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent ef77435 commit e9bb150

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

onnxruntime/core/providers/webgpu/nn/instance_norm.cc

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,25 @@ namespace onnxruntime {
1313
namespace webgpu {
1414

1515
Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader) const {
16-
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseIndicesTypeAlias);
16+
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
1717
const auto& scale = shader.AddInput("scale", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
1818
const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
19-
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
19+
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
2020

21-
shader.AdditionalImplementation() << "var<workgroup> workgroup_shared_sum : array<x_value_t, " << workgroup_size_ << ">;\n"
22-
<< "var<workgroup> workgroup_shared_squared_sum : array<x_value_t, " << workgroup_size_ << ">;\n"
21+
shader.AdditionalImplementation() << "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n"
22+
<< "var<workgroup> workgroup_shared_sum : array<f32_val_t, " << workgroup_size_ << ">;\n"
23+
<< "var<workgroup> workgroup_shared_squared_sum : array<f32_val_t, " << workgroup_size_ << ">;\n"
2324
<< "const workgroup_size = " << workgroup_size_ << ";\n";
25+
2426
shader.MainFunctionBody() << " let batch = workgroup_idx / uniforms.x_shape[1];\n"
2527
<< " let channel = workgroup_idx % uniforms.x_shape[1];\n"
2628
<< " let hight = uniforms.x_shape[2];\n"
2729
<< " // initialize workgroup memory<< \n"
28-
<< " var sum = x_value_t(0);\n"
29-
<< " var squared_sum = x_value_t(0);\n"
30+
<< " var sum = f32_val_t(0);\n"
31+
<< " var squared_sum = f32_val_t(0);\n"
3032
<< " for (var h = local_idx; h < hight; h += workgroup_size) {\n"
3133
<< " let indices = x_indices_t(batch, channel, h);\n"
32-
<< " let value =" << input.GetByIndices("indices") << ";\n"
34+
<< " let value = f32_val_t(" << input.GetByIndices("indices") << ");\n"
3335
<< " sum += value;\n"
3436
<< " squared_sum += value * value;\n"
3537
<< " }\n"
@@ -44,12 +46,12 @@ Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader)
4446
<< " workgroupBarrier();\n"
4547
<< " }\n"
4648
<< " if (local_idx == 0) {\n"
47-
<< " let sum_final = " << SumVector("workgroup_shared_sum[0]", components_) << " / x_element_t(hight * " << components_ << ");\n"
48-
<< " let squared_sum_final = " << SumVector("workgroup_shared_squared_sum[0]", components_) << " / x_element_t(hight * " << components_ << ");\n"
49-
<< " let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + x_element_t(" << std::to_string(epsilon_) << "));\n"
50-
<< " let channel_scale = inv_std_dev * " << scale.GetByOffset("channel") << ";\n"
51-
<< " let channel_shift = " << bias.GetByOffset("channel") << " - sum_final * channel_scale;\n"
52-
<< " " << output.SetByOffset("workgroup_idx", "output_value_t(channel_scale, channel_shift)") << ";\n"
49+
<< " let sum_final = " << SumVector("workgroup_shared_sum[0]", components_) << " / f32(hight * " << components_ << ");\n"
50+
<< " let squared_sum_final = " << SumVector("workgroup_shared_squared_sum[0]", components_) << " / f32(hight * " << components_ << ");\n"
51+
<< " let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + f32(" << std::to_string(epsilon_) << "));\n"
52+
<< " let channel_scale = inv_std_dev * f32(" << scale.GetByOffset("channel") << ");\n"
53+
<< " let channel_shift = f32(" << bias.GetByOffset("channel") << ") - sum_final * channel_scale;\n"
54+
<< " " << output.SetByOffset("workgroup_idx", "output_value_t(output_element_t(channel_scale), output_element_t(channel_shift))") << ";\n"
5355
<< " }\n";
5456
return Status::OK();
5557
}
@@ -110,7 +112,7 @@ Status InstanceNormProgramNHWC::GenerateShaderCode(ShaderHelper& shader) const {
110112
<< "let input_value = " << input.GetByOffset("global_idx") << ";\n";
111113
if (components_ > 1) {
112114
shader.MainFunctionBody() << "for (var i : u32 = 0; i < uniforms.components; i = i + 1) {\n"
113-
<< " let scale_sift = " << channel_scale_shift.GetByOffset("scale_offset + i") << ";\n"
115+
<< " let scale_sift = " << channel_scale_shift.GetByOffset("uniforms.components * scale_offset + i") << ";\n"
114116
<< " scale[i] = input_element_t(scale_sift.x);\n"
115117
<< " shift[i] = input_element_t(scale_sift.y);\n"
116118
<< "}\n";

0 commit comments

Comments
 (0)