Skip to content

Commit 81a8920

Browse files
authored
[webgpu] Use 1d dispatch group size (microsoft#24084)
This PR uses 1d disptach group size and uses workgroup_idx instead of workgroup.x|workgroup.y in case they are normalized.
1 parent 8d21bf7 commit 81a8920

File tree

4 files changed

+17
-12
lines changed

4 files changed

+17
-12
lines changed

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,12 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
224224
// Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / workgroup_size_x, 1)
225225
// Each lane/thread is responsible for a single q.
226226
shader.MainFunctionBody() << R"MAIN_FN(
227-
let head_idx = workgroup_id.x;
227+
let head_idx = u32(workgroup_idx / uniforms.num_seq_tile);
228228
let capped_sg_id = min(sg_id, max_k_step);
229229
let capped_sg_size = min(sg_size, max_k_step);
230230
231231
// Load Q
232-
let q_idx_global = workgroup_id.y * workgroup_size_x + local_idx;
232+
let q_idx_global = (workgroup_idx % uniforms.num_seq_tile) * workgroup_size_x + local_idx;
233233
let valid_q = q_idx_global < uniforms.new_sequence_length;
234234
if (valid_q)
235235
{
@@ -445,7 +445,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
445445
std::string cache_hint = std::to_string(has_attention_bias) +
446446
std::to_string(parameters.head_size_) +
447447
std::to_string(parameters.num_heads_);
448-
program.SetDispatchGroupSize(parameters.num_heads_, (parameters.sequence_length_ + tile_size - 1) / tile_size, 1)
448+
const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
449+
program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile)
449450
.SetWorkgroupSize(tile_size)
450451
.CacheHint(cache_hint)
451452
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
@@ -454,7 +455,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
454455
{static_cast<uint32_t>(parameters.total_sequence_length_ - parameters.kv_sequence_length_)},
455456
{static_cast<uint32_t>(parameters.is_gqa_ ? 1 : 0)},
456457
{static_cast<uint32_t>(parameters.n_reps)},
457-
{alpha}});
458+
{alpha},
459+
{num_seq_tile}});
458460

459461
return context.RunProgram(program);
460462
}

onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
5252
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
5353
{"is_gqa", ProgramUniformVariableDataType::Uint32},
5454
{"n_reps", ProgramUniformVariableDataType::Uint32},
55-
{"alpha", ProgramUniformVariableDataType::Float32});
55+
{"alpha", ProgramUniformVariableDataType::Float32},
56+
{"num_seq_tile", ProgramUniformVariableDataType::Uint32});
5657

5758
private:
5859
bool has_attention_bias_;

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
138138
shader.MainFunctionBody() << R"MAIN_FN(
139139
// During the load phase we use all 256 threads to load 64 rows of A/B.
140140
// For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K.
141-
let a_global_base = workgroup_id.x * tile_size;
142-
let b_global_base = workgroup_id.y * tile_size;
141+
let a_global_base = u32(workgroup_idx / uniforms.num_N_tile) * tile_size;
142+
let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size;
143143
let load_AorB = u32(local_idx/128);
144144
let load_row = u32((local_idx%128)/2);
145145
let load_col = u32(local_idx%2);
@@ -275,11 +275,11 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
275275

276276
constexpr uint32_t kTileSize = 64;
277277
TensorShape reshaped_y_shape{1, M, N / kVec4Components};
278+
uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize;
279+
uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize;
278280
DP4AMatMulNBitsProgram mul_program{block_size};
279281
mul_program.SetWorkgroupSize(256);
280-
mul_program.SetDispatchGroupSize(
281-
(M + kTileSize - 1) / kTileSize,
282-
(N + kTileSize - 1) / kTileSize, 1);
282+
mul_program.SetDispatchGroupSize(num_M_tile * num_N_tile);
283283
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
284284
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1},
285285
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec2Components * kU32Components)},
@@ -288,7 +288,8 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
288288
{static_cast<uint32_t>(N)},
289289
{static_cast<uint32_t>(K)},
290290
{static_cast<uint32_t>(K / 8)},
291-
{static_cast<uint32_t>(K / 16)}})
291+
{static_cast<uint32_t>(K / 16)},
292+
{num_N_tile}})
292293
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast<int>(kVec4Components)})
293294
.CacheHint("Block" + std::to_string(block_size));
294295
return context.RunProgram(mul_program);

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
2727
{"N", ProgramUniformVariableDataType::Uint32},
2828
{"K", ProgramUniformVariableDataType::Uint32},
2929
{"K8", ProgramUniformVariableDataType::Uint32},
30-
{"K16", ProgramUniformVariableDataType::Uint32});
30+
{"K16", ProgramUniformVariableDataType::Uint32},
31+
{"num_N_tile", ProgramUniformVariableDataType::Uint32});
3132

3233
private:
3334
uint32_t block_size_;

0 commit comments

Comments
 (0)