@@ -22,20 +22,21 @@ ONNX_OPERATOR_KERNEL_EX(
2222 BiasSplitGelu);
2323
2424Status BiasSplitGeluProgram::GenerateShaderCode (ShaderHelper& shader) const {
25- const ShaderVariableHelper& input = shader.AddInput (" input" );
26- const ShaderVariableHelper& bias = shader.AddInput (" bias" );
27- const ShaderVariableHelper& output = shader.AddOutput (" output" );
25+ const ShaderVariableHelper& x =
26+ shader.AddInput (" x" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
27+ const ShaderVariableHelper& bias = shader.AddInput (" bias" , ShaderUsage::UseUniform);
28+ const ShaderVariableHelper& output = shader.AddOutput (" output" , ShaderUsage::UseUniform);
2829
2930 shader.AdditionalImplementation () << ErfImpl;
3031
3132 shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" )
3233 << " const M_SQRT2: f32 = sqrt(2.0);\n "
33- << " const halfChannels = uniforms.channels / 2u;\n "
34+ << " let halfChannels = uniforms.channels / 2u;\n "
3435 << " let biasIdx = global_idx % halfChannels;\n "
3536 << " let batchIndex = global_idx / halfChannels;\n "
3637 << " let inputOffset = biasIdx + batchIndex * halfChannels * 2;\n "
37- << " let valueLeft = " << input .GetByOffset (" inputOffset" ) << " + " << bias.GetByOffset (" biasIdx" ) << " ;\n "
38- << " let valueRight = " << input .GetByOffset (" inputOffset + halfChannels" ) << " + " << bias.GetByOffset (" biasIdx + halfChannels" ) << " ;\n "
38+ << " let valueLeft = " << x .GetByOffset (" inputOffset" ) << " + " << bias.GetByOffset (" biasIdx" ) << " ;\n "
39+ << " let valueRight = " << x .GetByOffset (" inputOffset + halfChannels" ) << " + " << bias.GetByOffset (" biasIdx + halfChannels" ) << " ;\n "
3940 << " let geluRight = valueRight * 0.5 * (erf_v(valueRight / M_SQRT2) + 1);\n "
4041 << output.SetByOffset (" global_idx" , " valueLeft * geluRight" );
4142
@@ -53,25 +54,27 @@ Status BiasSplitGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& conte
5354 }
5455
5556 int64_t channels = input_shape[2 ];
56- int64_t components = GetMaxComponents (channels);
57- channels /= components;
5857 input_shape[2 ] = channels / 2 ; // for output shape calculation (N,S,D) -> (N,S,D/2)
5958
6059 TensorShape bias_shape = bias->Shape ();
6160 if (bias_shape.NumDimensions () != 1 || bias_shape[0 ] != channels) {
62- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " BiasSplitGelu bias should have 1 dimension with size equal to the number of channels." );
61+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
62+ " BiasSplitGelu bias should have 1 dimension with size equal to the number of channels." );
6363 }
6464
65+ int components = GetMaxComponents (channels);
66+ channels /= components;
67+
6568 auto * output = context.Output (0 , input_shape);
6669 int64_t output_size = output->Shape ().Size () / components;
6770
6871 BiasSplitGeluProgram program{};
69- program.AddInputs ({{input, ProgramTensorMetadataDependency::TypeAndRank},
70- {bias}})
71- .AddOutput ({output})
72+ program
73+ .AddInputs ({{input, ProgramTensorMetadataDependency::None, components},
74+ {bias, ProgramTensorMetadataDependency::None, components}})
75+ .AddOutput ({output, ProgramTensorMetadataDependency::None, components})
7276 .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
73- .AddUniformVariables ({{static_cast <uint32_t >(output_size)},
74- {static_cast <uint32_t >(channels)}});
77+ .AddUniformVariables ({{static_cast <uint32_t >(output_size)}, {static_cast <uint32_t >(channels)}});
7578 return context.RunProgram (program);
7679}
7780
0 commit comments