Skip to content

Commit 5c014e2

Browse files
authored
[webgpu] Fix bug in 1D dispatch workgroups (microsoft#24519)
Fixed the bug in microsoft#24228 which causes the incorrect result for phi models when flash attention is disabled.
1 parent 0e407d7 commit 5c014e2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
155155

156156
shader.MainFunctionBody() << "if (m + local_id.y < uniforms.M && n + local_id.x < total_sequence_length) {\n"
157157
<< " let headOffset = batch_head_idx * uniforms.M * uniforms.N;\n"
158-
<< " let outputIdx = headOffset + m + local_id.y * uniforms.N + n + local_id.x;\n"
158+
<< " let outputIdx = headOffset + (m + local_id.y) * uniforms.N + n + local_id.x;\n"
159159
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";
160160

161161
shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)";

0 commit comments

Comments
 (0)