Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 137 additions & 38 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Large diffs are not rendered by default.

31 changes: 19 additions & 12 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,24 @@

class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
public:
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer)
: Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer) {
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer,
bool prepare_indirect_dispatch = false)
: Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer), prepare_indirect_dispatch_(prepare_indirect_dispatch) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"copy_size", ProgramUniformVariableDataType::Uint32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32});
{"total_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"tile_size", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32});

private:
bool has_past_;
bool kv_BNSH_;
bool past_present_share_buffer_;
bool prepare_indirect_dispatch_;
};

class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
Expand Down Expand Up @@ -75,8 +80,8 @@
class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecodeQKTProgram> {
public:
FlashAttentionDecodeQKTProgram(const std::string& kernel_name,
bool has_attention_bias, uint32_t tile_size)
: Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size) {
bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch)
: Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -86,19 +91,19 @@
{"alpha", ProgramUniformVariableDataType::Float32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32},
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32});

private:
bool has_attention_bias_;
uint32_t tile_size_;
bool use_indirect_dispatch_;
};

class FlashAttentionDecodeSplitVxProgram final : public Program<FlashAttentionDecodeSplitVxProgram> {
public:
FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec)
: Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec) {
FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch)
: Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -107,19 +112,19 @@
{"head_size_vec", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32},
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32});

private:
uint32_t tile_size_;
int head_size_vec_;
bool use_indirect_dispatch_;
};

class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionDecodeVxReduceProgram> {
public:
FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size)
: Program{kernel_name}, tile_size_(tile_size) {
FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch)

Check warning on line 126 in onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/flash_attention.h:126: Add #include <string> for string [build/include_what_you_use] [4]
: Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -132,11 +137,13 @@

private:
uint32_t tile_size_;
uint32_t seq_tile_size_;
bool use_indirect_dispatch_;
};

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);

bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#param tile_size
#param tile_size_k_vec
#param sub_tile_count
#param use_indirect_dispatch

// Note that this shader adopts similar algorithm with dp4a generation shader.
//
Expand Down Expand Up @@ -48,10 +49,15 @@ var<workgroup> tile_qk: array<q_element_t, tile_size>;
$MAIN {
let local_row = u32(local_idx / tile_size_k_vec);
let local_col = local_idx % tile_size_k_vec;
let total_seq_offset = (workgroup_idx % uniforms.num_total_seq_length_tile) * tile_size;
let head_idx = u32(workgroup_idx / uniforms.num_total_seq_length_tile);
#if use_indirect_dispatch
let total_sequence_length = u32(seqlens_k[0]) + 1u;
#else
let total_sequence_length = uniforms.total_sequence_length;
#endif
let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size;
let head_idx = u32(workgroup_idx / num_total_seq_length_tile);
let q_offset = head_idx * uniforms.head_size_vec;
var total_sequence_length = uniforms.total_sequence_length;
let present_offset = u32(head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.head_size_vec;
for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) {
if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) {
Expand Down Expand Up @@ -95,7 +101,7 @@ $MAIN {
for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
l_sum += exp(f32(tile_qk[i]) - l_max);
}
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % uniforms.num_total_seq_length_tile;
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile;
metadata[meta_offset] = metadata_value_t(l_max, l_sum);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#param head_size_vec
#param tile_size_k_vec
#param sub_tile_count
#param use_indirect_dispatch

// Note that this shader adopts similar algorithm with dp4a generation shader.
//
Expand Down Expand Up @@ -40,22 +41,27 @@ var<workgroup> qkv_values: array<array<present_value_value_t, tile_size_k_vec>,
$MAIN {
let local_row = u32(local_idx / tile_size_k_vec);
let local_col = local_idx % tile_size_k_vec;
let total_seq_offset = (workgroup_idx % uniforms.num_total_seq_length_tile) * tile_size;
let head_idx = u32(workgroup_idx / uniforms.num_total_seq_length_tile);
var total_sequence_length = uniforms.total_sequence_length;
#if use_indirect_dispatch
let total_sequence_length = u32(seqlens_k[0]) + 1u;
#else
let total_sequence_length = uniforms.total_sequence_length;
#endif
let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size;
let head_idx = u32(workgroup_idx / num_total_seq_length_tile);
let present_offset = u32(head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length;

// Calculate the global max and sum in qk.
if (head_idx < uniforms.num_heads)
{
var g_max = f32(-3.402823e+38f);
var g_sum = f32(0);
for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++)
for (var i = 0u; i < num_total_seq_length_tile; i++)
{
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i;
g_max = max(g_max, metadata[meta_offset].x);
}
for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++)
for (var i = 0u; i < num_total_seq_length_tile; i++)
{
let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i;
let m_value = metadata[meta_offset];
Expand Down Expand Up @@ -95,7 +101,7 @@ $MAIN {
}

for (var i = local_idx; i < head_size_vec; i += workgroup_size_x) {
let out_offset = head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % uniforms.num_total_seq_length_tile) * head_size_vec + i;
let out_offset = head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i;
out_split_vx[out_offset] = tile_output[i];
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#param seq_tile_size
#param tile_size
#param use_indirect_dispatch

// Inputs are splits of the GQA output, split into num_total_seq_length_tiles
// rows. This shader needs to add these splits across the row dimension to
Expand All @@ -23,10 +25,16 @@ $MAIN {
var value = output_value_t(0);
let local_row = u32(local_idx / tile_size);
let local_col = local_idx % tile_size;
#if use_indirect_dispatch
let total_sequence_length = u32(seqlens_k[0]) + 1u;
let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size;
#else
let num_total_seq_length_tile = uniforms.num_total_seq_length_tile;
#endif

if (head_size_offset + local_col < uniforms.head_size_vec) {
for (var r = 0u; r < uniforms.num_total_seq_length_tile; r += tile_size) {
if (r + local_row < uniforms.num_total_seq_length_tile) {
for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) {
if (r + local_row < num_total_seq_length_tile) {
value += input[in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col];
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
!use_sliding_window &&
CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context);
present_value, parameters, context, seqlen_k);
}

Tensor qSplit;
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/webgpu/compute_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <utility>

#include "core/framework/execution_provider.h"
#include "core/providers/webgpu/webgpu_execution_provider.h"

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_context.h"
Expand All @@ -16,7 +17,6 @@
namespace onnxruntime {

class Tensor;
class WebGpuExecutionProvider;

namespace webgpu {

Expand All @@ -42,6 +42,9 @@ class ComputeContext {
inline bool HasFeature(wgpu::FeatureName feature) const {
return webgpu_context_.DeviceHasFeature(feature);
}
inline bool IsGraphCaptureEnabled() const {
return ep_.IsGraphCaptureEnabled();
}
#if !defined(__wasm__)
inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const {
return webgpu_context_.SubgroupMatrixConfigs();
Expand Down
Loading