Skip to content

Commit 99975c1

Browse files
qjia7fs-eire
authored andcommitted
[webgpu] Enable indirect dispatch for flash attention (#26207)
This pull request introduces support for indirect dispatch in the WebGPU FlashAttention implementation, enabling more dynamic and efficient kernel launches based on runtime sequence lengths. The changes add new logic and parameters to propagate sequence length information and indirect dispatch buffers through the attention pipeline, with conditional code paths to maintain compatibility with the existing direct dispatch approach. It's part of the work to enable graph capture in phi4 #25868
1 parent 3153871 commit 99975c1

File tree

7 files changed

+193
-64
lines changed

7 files changed

+193
-64
lines changed

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 137 additions & 38 deletions
Large diffs are not rendered by default.

onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,24 @@ using namespace onnxruntime::webgpu;
1717

1818
class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
1919
public:
20-
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer)
21-
: Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer) {
20+
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer,
21+
bool prepare_indirect_dispatch = false)
22+
: Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer), prepare_indirect_dispatch_(prepare_indirect_dispatch) {
2223
}
2324

2425
Status GenerateShaderCode(ShaderHelper& sh) const override;
2526

2627
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"copy_size", ProgramUniformVariableDataType::Uint32},
27-
{"past_sequence_length", ProgramUniformVariableDataType::Uint32});
28+
{"total_sequence_length", ProgramUniformVariableDataType::Uint32},
29+
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
30+
{"tile_size", ProgramUniformVariableDataType::Uint32},
31+
{"num_heads", ProgramUniformVariableDataType::Uint32});
2832

2933
private:
3034
bool has_past_;
3135
bool kv_BNSH_;
3236
bool past_present_share_buffer_;
37+
bool prepare_indirect_dispatch_;
3338
};
3439

3540
class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
@@ -75,8 +80,8 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
7580
class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecodeQKTProgram> {
7681
public:
7782
FlashAttentionDecodeQKTProgram(const std::string& kernel_name,
78-
bool has_attention_bias, uint32_t tile_size)
79-
: Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size) {
83+
bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch)
84+
: Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
8085
}
8186

8287
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -86,19 +91,19 @@ class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecode
8691
{"alpha", ProgramUniformVariableDataType::Float32},
8792
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
8893
{"n_reps", ProgramUniformVariableDataType::Uint32},
89-
{"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32},
9094
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
9195
{"num_heads", ProgramUniformVariableDataType::Uint32});
9296

9397
private:
9498
bool has_attention_bias_;
9599
uint32_t tile_size_;
100+
bool use_indirect_dispatch_;
96101
};
97102

98103
class FlashAttentionDecodeSplitVxProgram final : public Program<FlashAttentionDecodeSplitVxProgram> {
99104
public:
100-
FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec)
101-
: Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec) {
105+
FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch)
106+
: Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch) {
102107
}
103108

104109
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -107,19 +112,19 @@ class FlashAttentionDecodeSplitVxProgram final : public Program<FlashAttentionDe
107112
{"head_size_vec", ProgramUniformVariableDataType::Uint32},
108113
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
109114
{"n_reps", ProgramUniformVariableDataType::Uint32},
110-
{"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32},
111115
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
112116
{"num_heads", ProgramUniformVariableDataType::Uint32});
113117

114118
private:
115119
uint32_t tile_size_;
116120
int head_size_vec_;
121+
bool use_indirect_dispatch_;
117122
};
118123

119124
class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionDecodeVxReduceProgram> {
120125
public:
121-
FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size)
122-
: Program{kernel_name}, tile_size_(tile_size) {
126+
FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch)
127+
: Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
123128
}
124129

125130
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -132,11 +137,13 @@ class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionD
132137

133138
private:
134139
uint32_t tile_size_;
140+
uint32_t seq_tile_size_;
141+
bool use_indirect_dispatch_;
135142
};
136143

137144
Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
138145
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
139-
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
146+
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);
140147

141148
bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
142149
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);

onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#param tile_size
66
#param tile_size_k_vec
77
#param sub_tile_count
8+
#param use_indirect_dispatch
89

910
// Note that this shader adopts similar algorithm with dp4a generation shader.
1011
//
@@ -48,10 +49,15 @@ var<workgroup> tile_qk: array<q_element_t, tile_size>;
4849
$MAIN {
4950
let local_row = u32(local_idx / tile_size_k_vec);
5051
let local_col = local_idx % tile_size_k_vec;
51-
let total_seq_offset = (workgroup_idx % uniforms.num_total_seq_length_tile) * tile_size;
52-
let head_idx = u32(workgroup_idx / uniforms.num_total_seq_length_tile);
52+
#if use_indirect_dispatch
53+
let total_sequence_length = u32(seqlens_k[0]) + 1u;
54+
#else
55+
let total_sequence_length = uniforms.total_sequence_length;
56+
#endif
57+
let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
58+
let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size;
59+
let head_idx = u32(workgroup_idx / num_total_seq_length_tile);
5360
let q_offset = head_idx * uniforms.head_size_vec;
54-
var total_sequence_length = uniforms.total_sequence_length;
5561
let present_offset = u32(head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.head_size_vec;
5662
for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) {
5763
if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) {
@@ -95,7 +101,7 @@ $MAIN {
95101
for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
96102
l_sum += exp(f32(tile_qk[i]) - l_max);
97103
}
98-
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % uniforms.num_total_seq_length_tile;
104+
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile;
99105
metadata[meta_offset] = metadata_value_t(l_max, l_sum);
100106
}
101107
}

onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#param head_size_vec
66
#param tile_size_k_vec
77
#param sub_tile_count
8+
#param use_indirect_dispatch
89

910
// Note that this shader adopts similar algorithm with dp4a generation shader.
1011
//
@@ -40,22 +41,27 @@ var<workgroup> qkv_values: array<array<present_value_value_t, tile_size_k_vec>,
4041
$MAIN {
4142
let local_row = u32(local_idx / tile_size_k_vec);
4243
let local_col = local_idx % tile_size_k_vec;
43-
let total_seq_offset = (workgroup_idx % uniforms.num_total_seq_length_tile) * tile_size;
44-
let head_idx = u32(workgroup_idx / uniforms.num_total_seq_length_tile);
45-
var total_sequence_length = uniforms.total_sequence_length;
44+
#if use_indirect_dispatch
45+
let total_sequence_length = u32(seqlens_k[0]) + 1u;
46+
#else
47+
let total_sequence_length = uniforms.total_sequence_length;
48+
#endif
49+
let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
50+
let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size;
51+
let head_idx = u32(workgroup_idx / num_total_seq_length_tile);
4652
let present_offset = u32(head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length;
4753

4854
// Calculate the global max and sum in qk.
4955
if (head_idx < uniforms.num_heads)
5056
{
5157
var g_max = f32(-3.402823e+38f);
5258
var g_sum = f32(0);
53-
for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++)
59+
for (var i = 0u; i < num_total_seq_length_tile; i++)
5460
{
5561
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i;
5662
g_max = max(g_max, metadata[meta_offset].x);
5763
}
58-
for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++)
64+
for (var i = 0u; i < num_total_seq_length_tile; i++)
5965
{
6066
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i;
6167
let m_value = metadata[meta_offset];
@@ -95,7 +101,7 @@ $MAIN {
95101
}
96102

97103
for (var i = local_idx; i < head_size_vec; i += workgroup_size_x) {
98-
let out_offset = head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % uniforms.num_total_seq_length_tile) * head_size_vec + i;
104+
let out_offset = head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i;
99105
out_split_vx[out_offset] = tile_output[i];
100106
}
101107
}

onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
#param seq_tile_size
45
#param tile_size
6+
#param use_indirect_dispatch
57

68
// Inputs are splits of the GQA output, split into num_total_seq_length_tiles
79
// rows. This shader needs to add these splits across the row dimension to
@@ -23,10 +25,16 @@ $MAIN {
2325
var value = output_value_t(0);
2426
let local_row = u32(local_idx / tile_size);
2527
let local_col = local_idx % tile_size;
28+
#if use_indirect_dispatch
29+
let total_sequence_length = u32(seqlens_k[0]) + 1u;
30+
let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size;
31+
#else
32+
let num_total_seq_length_tile = uniforms.num_total_seq_length_tile;
33+
#endif
2634

2735
if (head_size_offset + local_col < uniforms.head_size_vec) {
28-
for (var r = 0u; r < uniforms.num_total_seq_length_tile; r += tile_size) {
29-
if (r + local_row < uniforms.num_total_seq_length_tile) {
36+
for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) {
37+
if (r + local_row < num_total_seq_length_tile) {
3038
value += input[in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col];
3139
}
3240
}

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
206206
!use_sliding_window &&
207207
CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) {
208208
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
209-
present_value, parameters, context);
209+
present_value, parameters, context, seqlen_k);
210210
}
211211

212212
Tensor qSplit;

onnxruntime/core/providers/webgpu/compute_context.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <utility>
99

1010
#include "core/framework/execution_provider.h"
11+
#include "core/providers/webgpu/webgpu_execution_provider.h"
1112

1213
#include "core/providers/webgpu/program.h"
1314
#include "core/providers/webgpu/webgpu_context.h"
@@ -16,7 +17,6 @@
1617
namespace onnxruntime {
1718

1819
class Tensor;
19-
class WebGpuExecutionProvider;
2020

2121
namespace webgpu {
2222

@@ -42,6 +42,9 @@ class ComputeContext {
4242
inline bool HasFeature(wgpu::FeatureName feature) const {
4343
return webgpu_context_.DeviceHasFeature(feature);
4444
}
45+
inline bool IsGraphCaptureEnabled() const {
46+
return ep_.IsGraphCaptureEnabled();
47+
}
4548
#if !defined(__wasm__)
4649
inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const {
4750
return webgpu_context_.SubgroupMatrixConfigs();

0 commit comments

Comments
 (0)