Skip to content

Commit 2265613

Browse files
authored
[webgpu] Fix bias_split_gelu (microsoft#24342)
"channels" should be validated before divided by "components". "components" should be passed to program inputs and outputs. Rename "input" to "x" to match "ErfImpl". Correct the last dimension of output shape.
1 parent 4edada6 commit 2265613

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

onnxruntime/contrib_ops/webgpu/bert/bias_split_gelu.cc

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,21 @@ ONNX_OPERATOR_KERNEL_EX(
2222
BiasSplitGelu);
2323

2424
Status BiasSplitGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
25-
const ShaderVariableHelper& input = shader.AddInput("input");
26-
const ShaderVariableHelper& bias = shader.AddInput("bias");
27-
const ShaderVariableHelper& output = shader.AddOutput("output");
25+
const ShaderVariableHelper& x =
26+
shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
27+
const ShaderVariableHelper& bias = shader.AddInput("bias", ShaderUsage::UseUniform);
28+
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform);
2829

2930
shader.AdditionalImplementation() << ErfImpl;
3031

3132
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
3233
<< "const M_SQRT2: f32 = sqrt(2.0);\n"
33-
<< "const halfChannels = uniforms.channels / 2u;\n"
34+
<< "let halfChannels = uniforms.channels / 2u;\n"
3435
<< "let biasIdx = global_idx % halfChannels;\n"
3536
<< "let batchIndex = global_idx / halfChannels;\n"
3637
<< "let inputOffset = biasIdx + batchIndex * halfChannels * 2;\n"
37-
<< "let valueLeft = " << input.GetByOffset("inputOffset") << " + " << bias.GetByOffset("biasIdx") << ";\n"
38-
<< "let valueRight = " << input.GetByOffset("inputOffset + halfChannels") << " + " << bias.GetByOffset("biasIdx + halfChannels") << ";\n"
38+
<< "let valueLeft = " << x.GetByOffset("inputOffset") << " + " << bias.GetByOffset("biasIdx") << ";\n"
39+
<< "let valueRight = " << x.GetByOffset("inputOffset + halfChannels") << " + " << bias.GetByOffset("biasIdx + halfChannels") << ";\n"
3940
<< "let geluRight = valueRight * 0.5 * (erf_v(valueRight / M_SQRT2) + 1);\n"
4041
<< output.SetByOffset("global_idx", "valueLeft * geluRight");
4142

@@ -53,25 +54,27 @@ Status BiasSplitGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& conte
5354
}
5455

5556
int64_t channels = input_shape[2];
56-
int64_t components = GetMaxComponents(channels);
57-
channels /= components;
5857
input_shape[2] = channels / 2; // for output shape calculation (N,S,D) -> (N,S,D/2)
5958

6059
TensorShape bias_shape = bias->Shape();
6160
if (bias_shape.NumDimensions() != 1 || bias_shape[0] != channels) {
62-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasSplitGelu bias should have 1 dimension with size equal to the number of channels.");
61+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
62+
"BiasSplitGelu bias should have 1 dimension with size equal to the number of channels.");
6363
}
6464

65+
int components = GetMaxComponents(channels);
66+
channels /= components;
67+
6568
auto* output = context.Output(0, input_shape);
6669
int64_t output_size = output->Shape().Size() / components;
6770

6871
BiasSplitGeluProgram program{};
69-
program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
70-
{bias}})
71-
.AddOutput({output})
72+
program
73+
.AddInputs({{input, ProgramTensorMetadataDependency::None, components},
74+
{bias, ProgramTensorMetadataDependency::None, components}})
75+
.AddOutput({output, ProgramTensorMetadataDependency::None, components})
7276
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
73-
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
74-
{static_cast<uint32_t>(channels)}});
77+
.AddUniformVariables({{static_cast<uint32_t>(output_size)}, {static_cast<uint32_t>(channels)}});
7578
return context.RunProgram(program);
7679
}
7780

0 commit comments

Comments
 (0)