Skip to content

Commit a1186f6

Browse files
authored
[webgpu] Use 1D dispatch groups for attention (microsoft#24228)
This PR uses 1d disptach group size and uses workgroup_idx instead of workgroup.x|workgroup.y in case they are normalized.
1 parent d6df4f2 commit a1186f6

File tree

2 files changed

+44
-34
lines changed

2 files changed

+44
-34
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -99,31 +99,32 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
9999
<< "var<workgroup> tileK: array<key_value_t, " << tile_size_ * tile_size_ << ">;\n"
100100
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
101101
shader.MainFunctionBody() << "// x holds the N and y holds the M\n"
102-
<< "let m = workgroup_id.y * TILE_SIZE;\n"
103-
<< "let n = workgroup_id.x * TILE_SIZE;\n"
104-
<< "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
105-
<< "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n"
102+
<< "let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;\n"
103+
<< "let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;\n"
104+
<< "let batch_head_idx = u32(workgroup_idx / (uniforms.num_total_seq_length_tile * uniforms.num_seq_length_tile));\n"
105+
<< "let batch_idx = batch_head_idx / uniforms.num_heads;\n"
106+
<< "let qOffset = batch_head_idx * uniforms.M * uniforms.K + m * uniforms.K;\n"
106107
<< "let sequence_length = uniforms.M;\n"
107108
<< "var total_sequence_length = uniforms.N;\n";
108109
std::ostringstream oss;
109110
InitVarStub(oss, seqlen_k_);
110111
shader.MainFunctionBody() << oss.str();
111-
shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n";
112+
shader.MainFunctionBody() << "let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n";
112113
if (has_present_key_) {
113-
shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;\n";
114+
shader.MainFunctionBody() << "let presentKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;\n";
114115
}
115116

116117
shader.MainFunctionBody() << "var value = f32_val_t(0);\n"
117118
"for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
118-
" if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n"
119+
" if (m + local_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n"
119120
" tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n"
120121
" }\n"
121122
" if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n"
122123
" var idx = TILE_SIZE * local_id.y + local_id.x;\n";
123124

124125
if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) {
125126
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n"
126-
<< " let pastKeyOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;\n"
127+
<< " let pastKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;\n"
127128
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
128129
<< " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
129130
<< " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
@@ -152,9 +153,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
152153
<< " workgroupBarrier();\n"
153154
<< "}\n";
154155

155-
shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n"
156-
<< " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n"
157-
<< " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n"
156+
shader.MainFunctionBody() << "if (m + local_id.y < uniforms.M && n + local_id.x < total_sequence_length) {\n"
157+
<< " let headOffset = batch_head_idx * uniforms.M * uniforms.N;\n"
158+
<< " let outputIdx = headOffset + m + local_id.y * uniforms.N + n + local_id.x;\n"
158159
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";
159160

160161
shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)";
@@ -199,9 +200,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
199200
}
200201

201202
const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components;
202-
program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size,
203-
(parameters.sequence_length_ + tile_size - 1) / tile_size,
204-
parameters.batch_size_ * parameters.num_heads_)
203+
const uint32_t num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
204+
const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
205+
program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_length_tile * num_total_seq_length_tile)
205206
.SetWorkgroupSize(tile_size, tile_size)
206207
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_)
207208
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
@@ -214,7 +215,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
214215
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
215216
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
216217
{static_cast<uint32_t>(parameters.n_reps)},
217-
{static_cast<uint32_t>(parameters.is_first_prompt_ ? 1 : 0)}})
218+
{static_cast<uint32_t>(parameters.is_first_prompt_ ? 1 : 0)},
219+
{num_total_seq_length_tile},
220+
{num_seq_length_tile}})
218221
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
219222

220223
return context.RunProgram(program);
@@ -228,15 +231,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
228231
shader.AdditionalImplementation() << "var<workgroup> thread_max: array<f32, " << work_group_size_ << ">;\n"
229232
<< "var<workgroup> thread_sum: array<f32, " << work_group_size_ << ">;\n"
230233
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
231-
shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
232-
<< "let sequence_length = uniforms.sequence_length;\n"
234+
shader.MainFunctionBody() << "let sequence_length = uniforms.sequence_length;\n"
235+
<< "let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;\n"
233236
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
234237
std::ostringstream oss;
235238
InitVarStub(oss, seqlen_k_);
236239
shader.MainFunctionBody() << oss.str()
237240
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
238-
<< "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n"
239-
<< "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n"
241+
<< "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n"
242+
<< "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n"
240243
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
241244
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
242245
<< " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n"
@@ -292,7 +295,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
292295
}
293296
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
294297
.CacheHint(work_group_size)
295-
.SetDispatchGroupSize(1, sequence_length, batch_size * num_heads)
298+
.SetDispatchGroupSize(batch_size * num_heads * sequence_length)
296299
.SetWorkgroupSize(work_group_size)
297300
.AddUniformVariables({{static_cast<uint32_t>(batch_size)},
298301
{static_cast<uint32_t>(num_heads)},
@@ -321,19 +324,20 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
321324

322325
shader.AdditionalImplementation() << "var<workgroup> tileQ: array<probs_value_t, " << tile_size_ * tile_size_ << ">;\n"
323326
<< "var<workgroup> tileK: array<v_value_t, " << tile_size_ * tile_size_ << ">;\n";
324-
shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n"
325-
<< "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
326-
<< "let m = global_id.y;\n"
327-
<< "let n = global_id.x;\n"
328-
<< "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n"
327+
shader.MainFunctionBody() << "let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * uniforms.num_seq_length_tile));\n"
328+
<< "let head_idx = batch_head_idx % uniforms.num_heads;\n"
329+
<< "let batch_idx = batch_head_idx / uniforms.num_heads;\n"
330+
<< "let m = (u32(workgroup_idx / uniforms.num_head_size_tile) % uniforms.num_seq_length_tile) * TILE_SIZE + local_id.y;\n"
331+
<< "let n = (workgroup_idx % uniforms.num_head_size_tile) * TILE_SIZE + local_id.x;\n"
332+
<< "let offsetA = batch_head_idx * (uniforms.M * uniforms.K) + m * uniforms.K;\n"
329333
<< "let sequence_length = uniforms.M;\n"
330334
<< "var total_sequence_length = uniforms.K;\n";
331335
std::ostringstream oss;
332336
InitVarStub(oss, seqlen_k_);
333337
shader.MainFunctionBody() << oss.str();
334-
shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.kv_sequence_length + n;\n";
338+
shader.MainFunctionBody() << "let vOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.kv_sequence_length + n;\n";
335339
if (has_present_value_) {
336-
shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.present_sequence_length + n;\n";
340+
shader.MainFunctionBody() << "let presentValueOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.present_sequence_length + n;\n";
337341
}
338342

339343
shader.MainFunctionBody() << "var value = output_value_t(0);\n"
@@ -346,7 +350,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
346350

347351
if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) {
348352
shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n"
349-
<< " let pastValueOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.past_sequence_length + n;\n"
353+
<< " let pastValueOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.past_sequence_length + n;\n"
350354
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n"
351355
<< " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
352356
<< " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n"
@@ -414,9 +418,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
414418
program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank, components});
415419
}
416420

417-
program.SetDispatchGroupSize((parameters.v_head_size_ + tile_n_size - 1) / tile_n_size,
418-
(parameters.sequence_length_ + tile_size - 1) / tile_size,
419-
parameters.batch_size_ * parameters.num_heads_)
421+
const uint32_t num_head_size_tile = (parameters.v_head_size_ + tile_n_size - 1) / tile_n_size;
422+
const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
423+
program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_head_size_tile * num_seq_length_tile)
420424
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_)
421425
.SetWorkgroupSize(tile_size, tile_size)
422426
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
@@ -429,7 +433,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
429433
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
430434
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
431435
{static_cast<uint32_t>(parameters.n_reps)},
432-
{static_cast<uint32_t>(parameters.is_first_prompt_)}})
436+
{static_cast<uint32_t>(parameters.is_first_prompt_)},
437+
{num_head_size_tile},
438+
{num_seq_length_tile}})
433439
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
434440

435441
return context.RunProgram(program);

onnxruntime/contrib_ops/webgpu/bert/attention.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
5050
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
5151
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
5252
{"n_reps", ProgramUniformVariableDataType::Uint32},
53-
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});
53+
{"is_first_prompt", ProgramUniformVariableDataType::Uint32},
54+
{"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32},
55+
{"num_seq_length_tile", ProgramUniformVariableDataType::Uint32});
5456

5557
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
5658

@@ -105,7 +107,9 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
105107
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
106108
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
107109
{"n_reps", ProgramUniformVariableDataType::Uint32},
108-
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});
110+
{"is_first_prompt", ProgramUniformVariableDataType::Uint32},
111+
{"num_head_size_tile", ProgramUniformVariableDataType::Uint32},
112+
{"num_seq_length_tile", ProgramUniformVariableDataType::Uint32});
109113

110114
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
111115

0 commit comments

Comments
 (0)