Skip to content

Commit 34abb8b

Browse files
authored
[webgpu] fix bias-add (microsoft#24336)
"channels" should be validated before divided by "components". "components" should be passed to program inputs and outputs.
1 parent 2265613 commit 34abb8b

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

onnxruntime/contrib_ops/webgpu/bert/bias_add.cc

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ ONNX_OPERATOR_KERNEL_EX(
2121
BiasAdd);
2222

2323
Status BiasAddProgram::GenerateShaderCode(ShaderHelper& shader) const {
24-
const ShaderVariableHelper& input = shader.AddInput("input");
25-
const ShaderVariableHelper& bias = shader.AddInput("bias");
26-
const ShaderVariableHelper& residual = shader.AddInput("residual");
27-
const ShaderVariableHelper& output = shader.AddOutput("output");
24+
const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform);
25+
const ShaderVariableHelper& bias = shader.AddInput("bias", ShaderUsage::UseUniform);
26+
const ShaderVariableHelper& residual = shader.AddInput("residual", ShaderUsage::UseUniform);
27+
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform);
2828

2929
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
30-
<< "let value = " << input.GetByOffset("global_idx")
30+
<< " let value = " << input.GetByOffset("global_idx")
3131
<< " + " << bias.GetByOffset("global_idx % uniforms.channels")
3232
<< " + " << residual.GetByOffset("global_idx") << ";\n"
33-
<< output.SetByOffset("global_idx", "value");
33+
<< " " + output.SetByOffset("global_idx", "value");
3434

3535
return Status::OK();
3636
}
@@ -47,23 +47,26 @@ Status BiasAdd::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) co
4747
}
4848

4949
int64_t channels = input_shape[2];
50-
int64_t components = GetMaxComponents(channels);
51-
channels /= components;
52-
5350
TensorShape bias_shape = bias->Shape();
5451
if (bias_shape.NumDimensions() != 1 || bias_shape[0] != channels) {
55-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasAdd bias should have 1 dimension with size equal to the number of channels.");
52+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
53+
"BiasAdd bias should have 1 dimension with size equal to the number of channels.");
5654
}
5755

56+
int components = GetMaxComponents(channels);
57+
channels /= components;
58+
5859
auto* output = context.Output(0, input_shape);
5960
int64_t output_size = output->Shape().Size() / components;
6061

6162
BiasAddProgram program{};
62-
program.AddInputs({{input}, {bias}, {residual}})
63-
.AddOutput({output})
63+
program
64+
.AddInputs({{input, ProgramTensorMetadataDependency::None, components},
65+
{bias, ProgramTensorMetadataDependency::None, components},
66+
{residual, ProgramTensorMetadataDependency::None, components}})
67+
.AddOutput({output, ProgramTensorMetadataDependency::None, components})
6468
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
65-
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
66-
{static_cast<uint32_t>(channels)}});
69+
.AddUniformVariables({{static_cast<uint32_t>(output_size)}, {static_cast<uint32_t>(channels)}});
6770
return context.RunProgram(program);
6871
}
6972

0 commit comments

Comments
 (0)