@@ -17,19 +17,24 @@ using namespace onnxruntime::webgpu;
1717
1818class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
1919 public:
20- CopyKVCacheProgram (const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer)
21- : Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer) {
20+ CopyKVCacheProgram (const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer,
21+ bool prepare_indirect_dispatch = false )
22+ : Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer), prepare_indirect_dispatch_(prepare_indirect_dispatch) {
2223 }
2324
2425 Status GenerateShaderCode (ShaderHelper& sh) const override ;
2526
2627 WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES ({" copy_size" , ProgramUniformVariableDataType::Uint32},
27- {" past_sequence_length" , ProgramUniformVariableDataType::Uint32});
28+ {" total_sequence_length" , ProgramUniformVariableDataType::Uint32},
29+ {" kv_sequence_length" , ProgramUniformVariableDataType::Uint32},
30+ {" tile_size" , ProgramUniformVariableDataType::Uint32},
31+ {" num_heads" , ProgramUniformVariableDataType::Uint32});
2832
2933 private:
3034 bool has_past_;
3135 bool kv_BNSH_;
3236 bool past_present_share_buffer_;
37+ bool prepare_indirect_dispatch_;
3338};
3439
3540class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
@@ -75,8 +80,8 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
7580class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecodeQKTProgram> {
7681 public:
7782 FlashAttentionDecodeQKTProgram (const std::string& kernel_name,
78- bool has_attention_bias, uint32_t tile_size)
79- : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size) {
83+ bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch )
84+ : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
8085 }
8186
8287 Status GenerateShaderCode (ShaderHelper& sh) const override ;
@@ -86,19 +91,19 @@ class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecode
8691 {" alpha" , ProgramUniformVariableDataType::Float32},
8792 {" present_sequence_length" , ProgramUniformVariableDataType::Uint32},
8893 {" n_reps" , ProgramUniformVariableDataType::Uint32},
89- {" num_total_seq_length_tile" , ProgramUniformVariableDataType::Uint32},
9094 {" num_present_sequence_length_tile" , ProgramUniformVariableDataType::Uint32},
9195 {" num_heads" , ProgramUniformVariableDataType::Uint32});
9296
9397 private:
9498 bool has_attention_bias_;
9599 uint32_t tile_size_;
100+ bool use_indirect_dispatch_;
96101};
97102
98103class FlashAttentionDecodeSplitVxProgram final : public Program<FlashAttentionDecodeSplitVxProgram> {
99104 public:
100- FlashAttentionDecodeSplitVxProgram (const std::string& kernel_name, uint32_t tile_size, int head_size_vec)
101- : Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec) {
105+ FlashAttentionDecodeSplitVxProgram (const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch )
106+ : Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch) {
102107 }
103108
104109 Status GenerateShaderCode (ShaderHelper& sh) const override ;
@@ -107,19 +112,19 @@ class FlashAttentionDecodeSplitVxProgram final : public Program<FlashAttentionDe
107112 {" head_size_vec" , ProgramUniformVariableDataType::Uint32},
108113 {" present_sequence_length" , ProgramUniformVariableDataType::Uint32},
109114 {" n_reps" , ProgramUniformVariableDataType::Uint32},
110- {" num_total_seq_length_tile" , ProgramUniformVariableDataType::Uint32},
111115 {" num_present_sequence_length_tile" , ProgramUniformVariableDataType::Uint32},
112116 {" num_heads" , ProgramUniformVariableDataType::Uint32});
113117
114118 private:
115119 uint32_t tile_size_;
116120 int head_size_vec_;
121+ bool use_indirect_dispatch_;
117122};
118123
119124class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionDecodeVxReduceProgram> {
120125 public:
121- FlashAttentionDecodeVxReduceProgram (const std::string& kernel_name, uint32_t tile_size)
122- : Program{kernel_name}, tile_size_(tile_size) {
126+ FlashAttentionDecodeVxReduceProgram (const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch )
127+ : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
123128 }
124129
125130 Status GenerateShaderCode (ShaderHelper& sh) const override ;
@@ -132,11 +137,13 @@ class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionD
132137
133138 private:
134139 uint32_t tile_size_;
140+ uint32_t seq_tile_size_;
141+ bool use_indirect_dispatch_;
135142};
136143
137144Status ApplyFlashAttention (const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
138145 Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
139- const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
146+ const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr );
140147
141148bool CanApplyFlashAttention (const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
142149 const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
0 commit comments