From cb37e187916dda84cdf5331c3c2af04e686aa83c Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 30 Sep 2025 15:39:35 +0800 Subject: [PATCH 1/2] Enable indirect dispatch for flash attention --- .../webgpu/bert/flash_attention.cc | 172 ++++++++++++++---- .../contrib_ops/webgpu/bert/flash_attention.h | 30 +-- .../flash_attention_decode_qkt.wgsl.template | 14 +- ...sh_attention_decode_split_vx.wgsl.template | 18 +- ...h_attention_decode_vx_reduce.wgsl.template | 11 +- .../webgpu/bert/group_query_attention.cc | 2 +- .../core/providers/webgpu/compute_context.h | 5 +- 7 files changed, 188 insertions(+), 64 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index b5c1f73d1678d..d52e29eac96e9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -31,6 +31,11 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& present_key = shader.AddOutput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const auto& present_value = shader.AddOutput("present_value", ShaderUsage::UseUniform); const auto& copy_kv_shape = shader.AddIndices("copy_kv_shape"); + // If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output + if (prepare_indirect_dispatch_) { + shader.AddInput("seqlen_k", ShaderUsage::None); + shader.AddOutput("indirect_buffer", ShaderUsage::None); + } shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.copy_size") << " let output_indices = " << copy_kv_shape.OffsetToIndices("global_idx") << ";\n" @@ -38,8 +43,25 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { " let sequence_id = output_indices[2];\n" " let num_head_id = output_indices[1];\n" " let batch = output_indices[0];\n"; + if (prepare_indirect_dispatch_) { + shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[0u]) + 1u;\n"; + } else { + shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n"; + } + + // Add indirect dispatch logic for thread 0 + if (prepare_indirect_dispatch_) { + shader.MainFunctionBody() << " // Prepare indirect dispatch buffer for thread 0\n" + << " if (global_idx == 0u) {\n" + << " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" + << " indirect_buffer[0] = num_total_seq_length_tile;\n" + << " indirect_buffer[1] = uniforms.num_heads;\n" + << " indirect_buffer[2] = 1u;\n" + << " }\n\n"; + } + if (has_past_) { - shader.MainFunctionBody() << "let past_sequence_length = uniforms.past_sequence_length;\n"; + shader.MainFunctionBody() << "let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n"; if (past_present_share_buffer_) { shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n" << " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n" @@ -70,10 +92,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, const Tensor* K, const Tensor* past_key, Tensor* present_key, - const Tensor* V, const Tensor* past_value, Tensor* present_value) { + const Tensor* V, const Tensor* past_value, Tensor* present_value, + uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer) { // CopyKVCache takes past key/value and current key/value and copies them to present key and value. // This makes it so that FlashAttention only needs to look at present key and value, and saves // number of input buffers in the shader, which we run out of (<=8) without this optimization. + // If indirect_buffer is provided, also prepare indirect dispatch buffer for flash attention. const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); bool has_past = (parameters.total_sequence_length_ - parameters.kv_sequence_length_) > 0; // parameters.total_sequence_length_ is past_sequence_length + kv_sequence_length. @@ -83,7 +107,12 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt int copy_sequence_length = has_past && parameters.past_present_share_buffer_ ? parameters.kv_sequence_length_ : parameters.total_sequence_length_; TensorShape copy_kv_shape{parameters.batch_size_, num_heads, copy_sequence_length, parameters.head_size_ / components}; int64_t copy_size = copy_kv_shape.Size(); - CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH, parameters.past_present_share_buffer_}; + + // Determine if we need to prepare indirect dispatch + bool prepare_indirect_dispatch = (indirect_buffer != nullptr); + + CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH, parameters.past_present_share_buffer_, + prepare_indirect_dispatch}; if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); @@ -94,20 +123,31 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, reshaped_KV_shape, components}, {V, ProgramTensorMetadataDependency::TypeAndRank, reshaped_KV_shape, components}}); } + + if (prepare_indirect_dispatch) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); + } + if (has_past && !parameters.past_present_share_buffer_) { program.AddInputs({{past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, {past_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); } program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components}, - {present_value, ProgramTensorMetadataDependency::Rank, components}}) - .AddIndices(std::move(copy_kv_shape)); + {present_value, ProgramTensorMetadataDependency::Rank, components}}); + + if (prepare_indirect_dispatch) { + program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None}); + } + + program.AddIndices(std::move(copy_kv_shape)); program.SetDispatchGroupSize(static_cast((copy_size + 63) / 64)) .SetWorkgroupSize(64) - .CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_) + .CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_, prepare_indirect_dispatch) .AddUniformVariables({{static_cast(copy_size)}, - // Note that when parameters.past_present_share_buffer_ is true, parameters.past_sequence_length_ will become to - // max_sequence_length. To get a valid past_sequence_length, we use total_sequence_length - kv_sequence_length. - {static_cast(parameters.total_sequence_length_ - parameters.kv_sequence_length_)}}); + {static_cast(parameters.total_sequence_length_)}, + {static_cast(parameters.kv_sequence_length_)}, + {tile_size}, + {static_cast(parameters.num_heads_)}}); return context.RunProgram(program); } @@ -147,6 +187,9 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (use_indirect_dispatch_) { + shader.AddInput("seqlens_k", ShaderUsage::None); + } if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } @@ -159,23 +202,25 @@ Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_), WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec)); + WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), + WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); } Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, - const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, - const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, - uint32_t num_present_sequence_length_tile, uint32_t tile_size, - uint32_t present_sequence_length) { + const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, const Tensor* seqlen_k, + const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; const bool has_attention_bias = attention_bias != nullptr; const int components = 4; - FlashAttentionDecodeQKTProgram program{"FlashAttentionDecodeQKT", has_attention_bias, tile_size}; + FlashAttentionDecodeQKTProgram program{"FlashAttentionDecodeQKT", has_attention_bias, tile_size, use_indirect_dispatch}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}}); + if (use_indirect_dispatch) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); + } if (has_attention_bias) { program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } @@ -183,15 +228,18 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte {metadata, ProgramTensorMetadataDependency::Rank, 2}}); const uint32_t vectorized_head_size = parameters.head_size_ / components; - program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile) - .SetWorkgroupSize(64) - .CacheHint(tile_size, has_attention_bias) + if (use_indirect_dispatch) { + program.SetIndirectDispatchTensor(indirect_buffer); + } else { + program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile); + } + program.SetWorkgroupSize(64) + .CacheHint(tile_size, has_attention_bias, use_indirect_dispatch) .AddUniformVariables({{static_cast(vectorized_head_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(alpha)}, present_sequence_length, {static_cast(parameters.n_reps)}, - {num_total_seq_length_tile}, {num_present_sequence_length_tile}, {static_cast(parameters.num_heads_)}}); @@ -202,6 +250,9 @@ Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shad shader.AddInput("metadata", ShaderUsage::UseUniform); shader.AddInput("qk", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + if (use_indirect_dispatch_) { + shader.AddInput("seqlens_k", ShaderUsage::None); + } shader.AddOutput("out_split_vx", ShaderUsage::UseUniform); const uint32_t tile_size_k_vec = 8u; @@ -210,7 +261,8 @@ Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shad WGSL_TEMPLATE_PARAMETER(head_size_vec, head_size_vec_), WGSL_TEMPLATE_PARAMETER(sub_tile_count, WorkgroupSizeX() / tile_size_k_vec), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec)); + WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), + WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); } Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeContext& context, @@ -218,26 +270,33 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte const Tensor* qk, Tensor* out_split_vx, Tensor* present_value, + const Tensor* seqlen_k, const WebgpuAttentionParameters& parameters, + const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, + bool use_indirect_dispatch, uint32_t present_sequence_length) { const int components = 4; int head_size_vec = parameters.v_head_size_ / components; - FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec}; + FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec, use_indirect_dispatch}; program.AddInputs({{metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}, {qk, ProgramTensorMetadataDependency::TypeAndRank}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size] - program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile) - .CacheHint(tile_size, head_size_vec) + if (use_indirect_dispatch) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}) + .SetIndirectDispatchTensor(indirect_buffer); + } else { + program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile); + } + program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch) .SetWorkgroupSize(64) .AddUniformVariables({{static_cast(parameters.total_sequence_length_)}, {static_cast(head_size_vec)}, present_sequence_length, {static_cast(parameters.n_reps)}, - num_total_seq_length_tile, num_present_sequence_length_tile, {static_cast(parameters.num_heads_)}}); @@ -246,27 +305,36 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("input", ShaderUsage::UseUniform); + if (use_indirect_dispatch_) { + shader.AddInput("seqlens_k", ShaderUsage::None); + } shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_vx_reduce.wgsl.template", - WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_)); + WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), + WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); } Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& context, const Tensor* out_split_vx, Tensor* output, + const Tensor* seqlen_k, const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, - uint32_t num_present_sequence_length_tile) { + uint32_t num_present_sequence_length_tile, + bool use_indirect_dispatch) { const int components = 4; constexpr int tile_size = 8; int tile_head_size = tile_size * components; - FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size}; + FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, use_indirect_dispatch}; program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); + if (use_indirect_dispatch) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); + } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}}); const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); program.SetDispatchGroupSize(parameters.num_heads_ * num_head_size_tile) - .CacheHint(tile_size) + .CacheHint(tile_size, use_indirect_dispatch) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, @@ -279,14 +347,15 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, - const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value)); - + const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { // Extract present_sequence_length directly from present_key tensor shape: // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size) const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]); + if (parameters.sequence_length_ > 1) { const uint32_t tile_size = 64; + // For encode path, use the original CopyKVCache without indirect dispatch preparation + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr)); bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; @@ -323,7 +392,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return context.RunProgram(program); } - // Use present_sequence_length instead of total_sequence_length to make sure the |qk| buffer is static when static qv cache is enabled. + // For decode path (sequence_length == 1) const TensorShapeVector qk_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, present_sequence_length}); const TensorShape qk_shape(qk_dims); @@ -331,21 +400,48 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co constexpr uint32_t tile_size = 64; const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size; const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size; + + // Determine if we should use indirect dispatch + const bool use_indirect_dispatch = parameters.past_present_share_buffer_ && + seqlen_k != nullptr && + context.IsGraphCaptureEnabled(); + + // Create indirect dispatch buffer if using indirect dispatch + Tensor* indirect_buffer_ptr = nullptr; + Tensor indirect_buffer; + if (use_indirect_dispatch) { + const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions + indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType(), indirect_buffer_shape); + indirect_buffer_ptr = &indirect_buffer; + // Use the fused CopyKVCache that also prepares the indirect dispatch buffer + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr)); + } else { + // Use the original CopyKVCache without indirect dispatch preparation + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr)); + } + // The metadata is used to store the max and sum of each tile. const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_, num_present_sequence_length_tile, 2}); const TensorShape metadata_shape(metadata_dims); Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType(), metadata_shape); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata, - parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata, seqlen_k, + parameters, indirect_buffer_ptr, num_total_seq_length_tile, + num_present_sequence_length_tile, tile_size, use_indirect_dispatch, present_sequence_length)); - const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_, num_present_sequence_length_tile, parameters.head_size_}); + const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_, + num_present_sequence_length_tile, parameters.head_size_}); const TensorShape out_split_vx_shape(out_split_vx_dims); Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value, parameters, - num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, present_sequence_length)); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, parameters, num_total_seq_length_tile, num_present_sequence_length_tile)); + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value, + seqlen_k, parameters, indirect_buffer_ptr, + num_total_seq_length_tile, + num_present_sequence_length_tile, tile_size, + use_indirect_dispatch, present_sequence_length)); + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, seqlen_k, parameters, + num_total_seq_length_tile, + num_present_sequence_length_tile, use_indirect_dispatch)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index c75494df253c1..6481f50ab34f6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -17,19 +17,24 @@ using namespace onnxruntime::webgpu; class CopyKVCacheProgram final : public Program { public: - CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer) - : Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer) { + CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer, + bool prepare_indirect_dispatch = false) + : Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer), prepare_indirect_dispatch_(prepare_indirect_dispatch) { } Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"copy_size", ProgramUniformVariableDataType::Uint32}, - {"past_sequence_length", ProgramUniformVariableDataType::Uint32}); + {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"tile_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}); private: bool has_past_; bool kv_BNSH_; bool past_present_share_buffer_; + bool prepare_indirect_dispatch_; }; class FlashAttentionProgram final : public Program { @@ -75,8 +80,8 @@ class FlashAttentionProgram final : public Program { class FlashAttentionDecodeQKTProgram final : public Program { public: FlashAttentionDecodeQKTProgram(const std::string& kernel_name, - bool has_attention_bias, uint32_t tile_size) - : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size) { + bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch) + : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -86,19 +91,19 @@ class FlashAttentionDecodeQKTProgram final : public Program { public: - FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec) - : Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec) { + FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch) + : Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -107,19 +112,19 @@ class FlashAttentionDecodeSplitVxProgram final : public Program { public: - FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size) - : Program{kernel_name}, tile_size_(tile_size) { + FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, bool use_indirect_dispatch) + : Program{kernel_name}, tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -132,11 +137,12 @@ class FlashAttentionDecodeVxReduceProgram final : public Program tile_qk: array; $MAIN { let local_row = u32(local_idx / tile_size_k_vec); let local_col = local_idx % tile_size_k_vec; - let total_seq_offset = (workgroup_idx % uniforms.num_total_seq_length_tile) * tile_size; - let head_idx = u32(workgroup_idx / uniforms.num_total_seq_length_tile); +#if use_indirect_dispatch + let total_sequence_length = u32(seqlens_k[0]) + 1u; +#else + let total_sequence_length = uniforms.total_sequence_length; +#endif + let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; + let head_idx = u32(workgroup_idx / num_total_seq_length_tile); let q_offset = head_idx * uniforms.head_size_vec; - var total_sequence_length = uniforms.total_sequence_length; let present_offset = u32(head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.head_size_vec; for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) { @@ -95,7 +101,7 @@ $MAIN { for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { l_sum += exp(f32(tile_qk[i]) - l_max); } - let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % uniforms.num_total_seq_length_tile; + let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; metadata[meta_offset] = metadata_value_t(l_max, l_sum); } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template index c7593af311ce2..37cf7e8f11b1f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template @@ -5,6 +5,7 @@ #param head_size_vec #param tile_size_k_vec #param sub_tile_count +#param use_indirect_dispatch // Note that this shader adopts similar algorithm with dp4a generation shader. // @@ -40,9 +41,14 @@ var qkv_values: array, $MAIN { let local_row = u32(local_idx / tile_size_k_vec); let local_col = local_idx % tile_size_k_vec; - let total_seq_offset = (workgroup_idx % uniforms.num_total_seq_length_tile) * tile_size; - let head_idx = u32(workgroup_idx / uniforms.num_total_seq_length_tile); - var total_sequence_length = uniforms.total_sequence_length; + #if use_indirect_dispatch + let total_sequence_length = u32(seqlens_k[0]) + 1u; + #else + let total_sequence_length = uniforms.total_sequence_length; + #endif + let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; + let head_idx = u32(workgroup_idx / num_total_seq_length_tile); let present_offset = u32(head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length; // Calculate the global max and sum in qk. @@ -50,12 +56,12 @@ $MAIN { { var g_max = f32(-3.402823e+38f); var g_sum = f32(0); - for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++) + for (var i = 0u; i < num_total_seq_length_tile; i++) { let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i; g_max = max(g_max, metadata[meta_offset].x); } - for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++) + for (var i = 0u; i < num_total_seq_length_tile; i++) { let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i; let m_value = metadata[meta_offset]; @@ -95,7 +101,7 @@ $MAIN { } for (var i = local_idx; i < head_size_vec; i += workgroup_size_x) { - let out_offset = head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % uniforms.num_total_seq_length_tile) * head_size_vec + i; + let out_offset = head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i; out_split_vx[out_offset] = tile_output[i]; } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index a4381baa638ce..08ed21810337a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -2,6 +2,7 @@ // Licensed under the MIT License. #param tile_size +#param use_indirect_dispatch // Inputs are splits of the GQA output, split into num_total_seq_length_tiles // rows. This shader needs to add these splits across the row dimension to @@ -23,10 +24,16 @@ $MAIN { var value = output_value_t(0); let local_row = u32(local_idx / tile_size); let local_col = local_idx % tile_size; + #if use_indirect_dispatch + let total_sequence_length = u32(seqlens_k[0]) + 1u; + let num_total_seq_length_tile = (total_sequence_length + 63u) / 64u; + #else + let num_total_seq_length_tile = uniforms.num_total_seq_length_tile; + #endif if (head_size_offset + local_col < uniforms.head_size_vec) { - for (var r = 0u; r < uniforms.num_total_seq_length_tile; r += tile_size) { - if (r + local_row < uniforms.num_total_seq_length_tile) { + for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) { + if (r + local_row < num_total_seq_length_tile) { value += input[in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col]; } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 8b7b257dd2852..cb845061404f3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -206,7 +206,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& !use_sliding_window && CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context); + present_value, parameters, context, seqlen_k); } Tensor qSplit; diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index fe95917e4e906..6bf7df74ea861 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -8,6 +8,7 @@ #include #include "core/framework/execution_provider.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/webgpu_context.h" @@ -16,7 +17,6 @@ namespace onnxruntime { class Tensor; -class WebGpuExecutionProvider; namespace webgpu { @@ -42,6 +42,9 @@ class ComputeContext { inline bool HasFeature(wgpu::FeatureName feature) const { return webgpu_context_.DeviceHasFeature(feature); } + inline bool IsGraphCaptureEnabled() const { + return ep_.IsGraphCaptureEnabled(); + } #if !defined(__wasm__) inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const { return webgpu_context_.SubgroupMatrixConfigs(); From 0a94dcbbdcc180cd09d5e06ee4f1ca5531ade0d7 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 30 Sep 2025 17:04:00 +0800 Subject: [PATCH 2/2] TODOs and code clean up --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 9 ++++++--- onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | 5 +++-- .../bert/flash_attention_decode_vx_reduce.wgsl.template | 3 ++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index d52e29eac96e9..a9bd4afc5cd09 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -51,6 +51,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { // Add indirect dispatch logic for thread 0 if (prepare_indirect_dispatch_) { + // TODO: Add NormalizeDispatchGroupSize logic here to avoid exceeding max dispatch size. shader.MainFunctionBody() << " // Prepare indirect dispatch buffer for thread 0\n" << " if (global_idx == 0u) {\n" << " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" @@ -311,6 +312,7 @@ Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& sha shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_vx_reduce.wgsl.template", + WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); } @@ -322,11 +324,12 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, + uint32_t seq_tile_size, bool use_indirect_dispatch) { const int components = 4; constexpr int tile_size = 8; int tile_head_size = tile_size * components; - FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, use_indirect_dispatch}; + FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch}; program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (use_indirect_dispatch) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); @@ -334,7 +337,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}}); const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); program.SetDispatchGroupSize(parameters.num_heads_ * num_head_size_tile) - .CacheHint(tile_size, use_indirect_dispatch) + .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, @@ -441,7 +444,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co use_indirect_dispatch, present_sequence_length)); ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, seqlen_k, parameters, num_total_seq_length_tile, - num_present_sequence_length_tile, use_indirect_dispatch)); + num_present_sequence_length_tile, tile_size, use_indirect_dispatch)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 6481f50ab34f6..7d71dc0f4d42d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -123,8 +123,8 @@ class FlashAttentionDecodeSplitVxProgram final : public Program { public: - FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, bool use_indirect_dispatch) - : Program{kernel_name}, tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) { + FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch) + : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -137,6 +137,7 @@ class FlashAttentionDecodeVxReduceProgram final : public Program