@@ -108,9 +108,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
108108 std::ostringstream oss;
109109 InitVarStub (oss, seqlen_k_);
110110 shader.MainFunctionBody () << oss.str ();
111- shader.MainFunctionBody () << " let kOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.kv_sequence_length * uniforms.K;\n " ;
111+ shader.MainFunctionBody () << " let kOffset = (workgroup_id.z / uniforms.n_reps ) * uniforms.kv_sequence_length * uniforms.K;\n " ;
112112 if (has_present_key_) {
113- shader.MainFunctionBody () << " let presentKeyOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.present_sequence_length * uniforms.K;\n " ;
113+ shader.MainFunctionBody () << " let presentKeyOffset = (workgroup_id.z / uniforms.n_reps ) * uniforms.present_sequence_length * uniforms.K;\n " ;
114114 }
115115
116116 shader.MainFunctionBody () << " var value = f32_val_t(0);\n "
@@ -123,7 +123,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
123123
124124 if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) {
125125 shader.MainFunctionBody () << " if (n + local_id.y < past_sequence_length) {\n "
126- << " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.past_sequence_length * uniforms.K;\n "
126+ << " let pastKeyOffset = (workgroup_id.z / uniforms.n_reps ) * uniforms.past_sequence_length * uniforms.K;\n "
127127 << " tileK[idx] = " << (past_present_share_buffer_ ? " present_key" : " past_key" ) << " [pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n "
128128 << " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n "
129129 << " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n "
@@ -181,7 +181,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
181181 const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1 );
182182
183183 AttentionProbsProgram program{" AttentionProbs" , feed_past_key, has_present_key, has_attention_bias, tile_size,
184- components, parameters.is_first_prompt_ , parameters. n_reps , seqlen_k, parameters.past_present_share_buffer_ };
184+ components, parameters.is_first_prompt_ , seqlen_k, parameters.past_present_share_buffer_ };
185185 program.AddInputs ({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
186186 {K, ProgramTensorMetadataDependency::TypeAndRank, components}});
187187 if (feed_past_key) {
@@ -331,9 +331,9 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
331331 std::ostringstream oss;
332332 InitVarStub (oss, seqlen_k_);
333333 shader.MainFunctionBody () << oss.str ();
334- shader.MainFunctionBody () << " let vOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.N * uniforms.kv_sequence_length + n;\n " ;
334+ shader.MainFunctionBody () << " let vOffset = (workgroup_id.z / uniforms.n_reps ) * uniforms.N * uniforms.kv_sequence_length + n;\n " ;
335335 if (has_present_value_) {
336- shader.MainFunctionBody () << " let presentValueOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.N * uniforms.present_sequence_length + n;\n " ;
336+ shader.MainFunctionBody () << " let presentValueOffset = (workgroup_id.z / uniforms.n_reps ) * uniforms.N * uniforms.present_sequence_length + n;\n " ;
337337 }
338338
339339 shader.MainFunctionBody () << " var value = output_value_t(0);\n "
@@ -346,7 +346,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
346346
347347 if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) {
348348 shader.MainFunctionBody () << " if (w + local_id.y < past_sequence_length) {\n "
349- << " let pastValueOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.N * uniforms.past_sequence_length + n;\n "
349+ << " let pastValueOffset = (workgroup_id.z / uniforms.n_reps ) * uniforms.N * uniforms.past_sequence_length + n;\n "
350350 << " tileK[idx] = " << (past_present_share_buffer_ ? " present_value" : " past_value" ) << " [pastValueOffset + (w + local_id.y) * uniforms.N];\n "
351351 << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n "
352352 << " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n "
@@ -400,7 +400,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
400400 const int components = parameters.v_head_size_ % 4 == 0 ? 4 : (parameters.v_head_size_ % 2 == 0 ? 2 : 1 );
401401 constexpr int tile_size = 12 ;
402402 int tile_n_size = tile_size * components;
403- VxAttentionScoreProgram program{" VxAttentionScore" , feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_ , parameters. n_reps , seqlen_k, parameters.past_present_share_buffer_ };
403+ VxAttentionScoreProgram program{" VxAttentionScore" , feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_ , seqlen_k, parameters.past_present_share_buffer_ };
404404 program.AddInputs ({{probs, ProgramTensorMetadataDependency::TypeAndRank},
405405 {V, ProgramTensorMetadataDependency::TypeAndRank, components}});
406406 if (feed_past_value) {
0 commit comments