@@ -21,16 +21,16 @@ ONNX_OPERATOR_KERNEL_EX(
2121 BiasAdd);
2222
2323Status BiasAddProgram::GenerateShaderCode (ShaderHelper& shader) const {
24- const ShaderVariableHelper& input = shader.AddInput (" input" );
25- const ShaderVariableHelper& bias = shader.AddInput (" bias" );
26- const ShaderVariableHelper& residual = shader.AddInput (" residual" );
27- const ShaderVariableHelper& output = shader.AddOutput (" output" );
24+ const ShaderVariableHelper& input = shader.AddInput (" input" , ShaderUsage::UseUniform );
25+ const ShaderVariableHelper& bias = shader.AddInput (" bias" , ShaderUsage::UseUniform );
26+ const ShaderVariableHelper& residual = shader.AddInput (" residual" , ShaderUsage::UseUniform );
27+ const ShaderVariableHelper& output = shader.AddOutput (" output" , ShaderUsage::UseUniform );
2828
2929 shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" )
30- << " let value = " << input.GetByOffset (" global_idx" )
30+ << " let value = " << input.GetByOffset (" global_idx" )
3131 << " + " << bias.GetByOffset (" global_idx % uniforms.channels" )
3232 << " + " << residual.GetByOffset (" global_idx" ) << " ;\n "
33- << output.SetByOffset (" global_idx" , " value" );
33+ << " " + output.SetByOffset (" global_idx" , " value" );
3434
3535 return Status::OK ();
3636}
@@ -47,23 +47,26 @@ Status BiasAdd::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) co
4747 }
4848
4949 int64_t channels = input_shape[2 ];
50- int64_t components = GetMaxComponents (channels);
51- channels /= components;
52-
5350 TensorShape bias_shape = bias->Shape ();
5451 if (bias_shape.NumDimensions () != 1 || bias_shape[0 ] != channels) {
55- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " BiasAdd bias should have 1 dimension with size equal to the number of channels." );
52+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
53+ " BiasAdd bias should have 1 dimension with size equal to the number of channels." );
5654 }
5755
56+ int components = GetMaxComponents (channels);
57+ channels /= components;
58+
5859 auto * output = context.Output (0 , input_shape);
5960 int64_t output_size = output->Shape ().Size () / components;
6061
6162 BiasAddProgram program{};
62- program.AddInputs ({{input}, {bias}, {residual}})
63- .AddOutput ({output})
63+ program
64+ .AddInputs ({{input, ProgramTensorMetadataDependency::None, components},
65+ {bias, ProgramTensorMetadataDependency::None, components},
66+ {residual, ProgramTensorMetadataDependency::None, components}})
67+ .AddOutput ({output, ProgramTensorMetadataDependency::None, components})
6468 .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
65- .AddUniformVariables ({{static_cast <uint32_t >(output_size)},
66- {static_cast <uint32_t >(channels)}});
69+ .AddUniformVariables ({{static_cast <uint32_t >(output_size)}, {static_cast <uint32_t >(channels)}});
6770 return context.RunProgram (program);
6871}
6972
0 commit comments