@@ -99,31 +99,32 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
9999 << " var<workgroup> tileK: array<key_value_t, " << tile_size_ * tile_size_ << " >;\n "
100100 << " alias f32_val_t = " << (components_ == 4 ? " vec4<f32>" : (components_ == 2 ? " vec2<f32>" : " f32" )) << " ;\n " ;
101101 shader.MainFunctionBody () << " // x holds the N and y holds the M\n "
102- << " let m = workgroup_id.y * TILE_SIZE;\n "
103- << " let n = workgroup_id.x * TILE_SIZE;\n "
104- << " let batch_idx = workgroup_id.z / uniforms.num_heads;\n "
105- << " let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n "
102+ << " let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;\n "
103+ << " let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;\n "
104+ << " let batch_head_idx = u32(workgroup_idx / (uniforms.num_total_seq_length_tile * uniforms.num_seq_length_tile));\n "
105+ << " let batch_idx = batch_head_idx / uniforms.num_heads;\n "
106+ << " let qOffset = batch_head_idx * uniforms.M * uniforms.K + m * uniforms.K;\n "
106107 << " let sequence_length = uniforms.M;\n "
107108 << " var total_sequence_length = uniforms.N;\n " ;
108109 std::ostringstream oss;
109110 InitVarStub (oss, seqlen_k_);
110111 shader.MainFunctionBody () << oss.str ();
111- shader.MainFunctionBody () << " let kOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n " ;
112+ shader.MainFunctionBody () << " let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n " ;
112113 if (has_present_key_) {
113- shader.MainFunctionBody () << " let presentKeyOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;\n " ;
114+ shader.MainFunctionBody () << " let presentKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;\n " ;
114115 }
115116
116117 shader.MainFunctionBody () << " var value = f32_val_t(0);\n "
117118 " for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n "
118- " if (global_id .y < uniforms.M && w + local_id.x < uniforms.K) {\n "
119+ " if (m + local_id .y < uniforms.M && w + local_id.x < uniforms.K) {\n "
119120 " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n "
120121 " }\n "
121122 " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n "
122123 " var idx = TILE_SIZE * local_id.y + local_id.x;\n " ;
123124
124125 if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) {
125126 shader.MainFunctionBody () << " if (n + local_id.y < past_sequence_length) {\n "
126- << " let pastKeyOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;\n "
127+ << " let pastKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;\n "
127128 << " tileK[idx] = " << (past_present_share_buffer_ ? " present_key" : " past_key" ) << " [pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n "
128129 << " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n "
129130 << " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n "
@@ -152,9 +153,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
152153 << " workgroupBarrier();\n "
153154 << " }\n " ;
154155
155- shader.MainFunctionBody () << " if (global_id .y < uniforms.M && global_id .x < total_sequence_length) {\n "
156- << " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n "
157- << " let outputIdx = headOffset + global_id .y * uniforms.N + global_id .x;\n "
156+ shader.MainFunctionBody () << " if (m + local_id .y < uniforms.M && n + local_id .x < total_sequence_length) {\n "
157+ << " let headOffset = batch_head_idx * uniforms.M * uniforms.N;\n "
158+ << " let outputIdx = headOffset + m + local_id .y * uniforms.N + n + local_id .x;\n "
158159 << " var sum: f32 = " << (components_ == 4 ? " value.x + value.y + value.z + value.w" : (components_ == 2 ? " value.x + value.y" : " value" )) << " ;\n " ;
159160
160161 shader.MainFunctionBody () << " output[outputIdx] = output_value_t(sum * uniforms.alpha)" ;
@@ -199,9 +200,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
199200 }
200201
201202 const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1 ) / components;
202- program. SetDispatchGroupSize (( total_sequence_length + tile_size - 1 ) / tile_size,
203- (parameters.sequence_length_ + tile_size - 1 ) / tile_size,
204- parameters.batch_size_ * parameters.num_heads_ )
203+ const uint32_t num_total_seq_length_tile = ( total_sequence_length + tile_size - 1 ) / tile_size;
204+ const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1 ) / tile_size;
205+ program. SetDispatchGroupSize ( parameters.batch_size_ * parameters.num_heads_ * num_seq_length_tile * num_total_seq_length_tile )
205206 .SetWorkgroupSize (tile_size, tile_size)
206207 .CacheHint (std::to_string (tile_size), parameters.past_present_share_buffer_ , feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr , components, parameters.is_first_prompt_ )
207208 .AddUniformVariables ({{static_cast <uint32_t >(parameters.sequence_length_ )},
@@ -214,7 +215,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
214215 {static_cast <uint32_t >(parameters.kv_sequence_length_ )},
215216 {static_cast <uint32_t >(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_ )},
216217 {static_cast <uint32_t >(parameters.n_reps )},
217- {static_cast <uint32_t >(parameters.is_first_prompt_ ? 1 : 0 )}})
218+ {static_cast <uint32_t >(parameters.is_first_prompt_ ? 1 : 0 )},
219+ {num_total_seq_length_tile},
220+ {num_seq_length_tile}})
218221 .SetOverridableConstants ({{static_cast <uint32_t >(tile_size)}});
219222
220223 return context.RunProgram (program);
@@ -228,15 +231,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
228231 shader.AdditionalImplementation () << " var<workgroup> thread_max: array<f32, " << work_group_size_ << " >;\n "
229232 << " var<workgroup> thread_sum: array<f32, " << work_group_size_ << " >;\n "
230233 << " alias f32_val_t = " << (components_ == 4 ? " vec4<f32>" : (components_ == 2 ? " vec2<f32>" : " f32" )) << " ;\n " ;
231- shader.MainFunctionBody () << " let batch_idx = workgroup_id.z / uniforms.num_heads ;\n "
232- << " let sequence_length = uniforms.sequence_length ;\n "
234+ shader.MainFunctionBody () << " let sequence_length = uniforms.sequence_length ;\n "
235+ << " let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads ;\n "
233236 << " var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << " ;\n " ;
234237 std::ostringstream oss;
235238 InitVarStub (oss, seqlen_k_);
236239 shader.MainFunctionBody () << oss.str ()
237240 << " let local_offset = local_idx * uniforms.elements_per_thread;\n "
238- << " let offset = (global_idx / " << work_group_size_ << " ) * uniforms.total_sequence_length_comp + local_offset;\n "
239- << " let seq_causal_length = " << (seqlen_k_ ? " past_sequence_length + workgroup_id.y + 1" : " uniforms.total_sequence_length_comp" ) << " ;\n "
241+ << " let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n "
242+ << " let seq_causal_length = " << (seqlen_k_ ? " past_sequence_length + workgroup_idx % sequence_length + 1" : " uniforms.total_sequence_length_comp" ) << " ;\n "
240243 << " var thread_max_vector = f32_val_t(-3.402823e+38f);\n "
241244 << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n "
242245 << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n "
@@ -292,7 +295,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
292295 }
293296 program.AddOutputs ({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
294297 .CacheHint (work_group_size)
295- .SetDispatchGroupSize (1 , sequence_length, batch_size * num_heads)
298+ .SetDispatchGroupSize (batch_size * num_heads * sequence_length )
296299 .SetWorkgroupSize (work_group_size)
297300 .AddUniformVariables ({{static_cast <uint32_t >(batch_size)},
298301 {static_cast <uint32_t >(num_heads)},
@@ -321,19 +324,20 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
321324
322325 shader.AdditionalImplementation () << " var<workgroup> tileQ: array<probs_value_t, " << tile_size_ * tile_size_ << " >;\n "
323326 << " var<workgroup> tileK: array<v_value_t, " << tile_size_ * tile_size_ << " >;\n " ;
324- shader.MainFunctionBody () << " let head_idx = workgroup_id.z % uniforms.num_heads;\n "
325- << " let batch_idx = workgroup_id.z / uniforms.num_heads;\n "
326- << " let m = global_id.y;\n "
327- << " let n = global_id.x;\n "
328- << " let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n "
327+ shader.MainFunctionBody () << " let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * uniforms.num_seq_length_tile));\n "
328+ << " let head_idx = batch_head_idx % uniforms.num_heads;\n "
329+ << " let batch_idx = batch_head_idx / uniforms.num_heads;\n "
330+ << " let m = (u32(workgroup_idx / uniforms.num_head_size_tile) % uniforms.num_seq_length_tile) * TILE_SIZE + local_id.y;\n "
331+ << " let n = (workgroup_idx % uniforms.num_head_size_tile) * TILE_SIZE + local_id.x;\n "
332+ << " let offsetA = batch_head_idx * (uniforms.M * uniforms.K) + m * uniforms.K;\n "
329333 << " let sequence_length = uniforms.M;\n "
330334 << " var total_sequence_length = uniforms.K;\n " ;
331335 std::ostringstream oss;
332336 InitVarStub (oss, seqlen_k_);
333337 shader.MainFunctionBody () << oss.str ();
334- shader.MainFunctionBody () << " let vOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.kv_sequence_length + n;\n " ;
338+ shader.MainFunctionBody () << " let vOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.kv_sequence_length + n;\n " ;
335339 if (has_present_value_) {
336- shader.MainFunctionBody () << " let presentValueOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.present_sequence_length + n;\n " ;
340+ shader.MainFunctionBody () << " let presentValueOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.present_sequence_length + n;\n " ;
337341 }
338342
339343 shader.MainFunctionBody () << " var value = output_value_t(0);\n "
@@ -346,7 +350,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
346350
347351 if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) {
348352 shader.MainFunctionBody () << " if (w + local_id.y < past_sequence_length) {\n "
349- << " let pastValueOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.past_sequence_length + n;\n "
353+ << " let pastValueOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.past_sequence_length + n;\n "
350354 << " tileK[idx] = " << (past_present_share_buffer_ ? " present_value" : " past_value" ) << " [pastValueOffset + (w + local_id.y) * uniforms.N];\n "
351355 << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n "
352356 << " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n "
@@ -414,9 +418,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
414418 program.AddOutput ({present_value, ProgramTensorMetadataDependency::TypeAndRank, components});
415419 }
416420
417- program. SetDispatchGroupSize (( parameters.v_head_size_ + tile_n_size - 1 ) / tile_n_size,
418- (parameters.sequence_length_ + tile_size - 1 ) / tile_size,
419- parameters.batch_size_ * parameters.num_heads_ )
421+ const uint32_t num_head_size_tile = ( parameters.v_head_size_ + tile_n_size - 1 ) / tile_n_size;
422+ const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1 ) / tile_size;
423+ program. SetDispatchGroupSize ( parameters.batch_size_ * parameters.num_heads_ * num_head_size_tile * num_seq_length_tile )
420424 .CacheHint (std::to_string (tile_size), parameters.past_present_share_buffer_ , feed_past_value, has_present_value, seqlen_k != nullptr , parameters.is_first_prompt_ )
421425 .SetWorkgroupSize (tile_size, tile_size)
422426 .AddUniformVariables ({{static_cast <uint32_t >(parameters.sequence_length_ )},
@@ -429,7 +433,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
429433 {static_cast <uint32_t >(parameters.kv_sequence_length_ )},
430434 {static_cast <uint32_t >(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_ )},
431435 {static_cast <uint32_t >(parameters.n_reps )},
432- {static_cast <uint32_t >(parameters.is_first_prompt_ )}})
436+ {static_cast <uint32_t >(parameters.is_first_prompt_ )},
437+ {num_head_size_tile},
438+ {num_seq_length_tile}})
433439 .SetOverridableConstants ({{static_cast <uint32_t >(tile_size)}});
434440
435441 return context.RunProgram (program);
0 commit comments