Skip to content

Commit 2132530

Browse files
authored
[webgpu] Unify the present_sequence_length in flash attention (#25945)
### Description This PR unifies the present_sequence_length in flash attention and removes the dependency on total_sequence_length. This is preparation to support graph capture. #25868
1 parent 5d17734 commit 2132530

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader)
164164

165165
Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
166166
const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata,
167-
const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size) {
167+
const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile,
168+
uint32_t num_present_sequence_length_tile, uint32_t tile_size,
169+
uint32_t present_sequence_length) {
168170
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
169171
: parameters.scale_;
170172

@@ -187,8 +189,7 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
187189
.AddUniformVariables({{static_cast<uint32_t>(vectorized_head_size)},
188190
{static_cast<uint32_t>(parameters.total_sequence_length_)},
189191
{static_cast<float>(alpha)},
190-
// present_sequence_length is used to index into the KV cache, for static kv cache it is the max sequence length.
191-
{static_cast<uint32_t>(parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_)},
192+
present_sequence_length,
192193
{static_cast<uint32_t>(parameters.n_reps)},
193194
{num_total_seq_length_tile},
194195
{num_present_sequence_length_tile},
@@ -220,7 +221,8 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
220221
const WebgpuAttentionParameters& parameters,
221222
uint32_t num_total_seq_length_tile,
222223
uint32_t num_present_sequence_length_tile,
223-
uint32_t tile_size) {
224+
uint32_t tile_size,
225+
uint32_t present_sequence_length) {
224226
const int components = 4;
225227
int head_size_vec = parameters.v_head_size_ / components;
226228
FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec};
@@ -233,7 +235,7 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
233235
.SetWorkgroupSize(64)
234236
.AddUniformVariables({{static_cast<uint32_t>(parameters.total_sequence_length_)},
235237
{static_cast<uint32_t>(head_size_vec)},
236-
{static_cast<uint32_t>(parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_)},
238+
present_sequence_length,
237239
{static_cast<uint32_t>(parameters.n_reps)},
238240
num_total_seq_length_tile,
239241
num_present_sequence_length_tile,
@@ -279,7 +281,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
279281
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
280282
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
281283
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value));
282-
const int present_sequence_length = parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_;
284+
285+
// Extract present_sequence_length directly from present_key tensor shape:
286+
// (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
287+
const uint32_t present_sequence_length = static_cast<uint32_t>(present_key->Shape()[2]);
283288
if (parameters.sequence_length_ > 1) {
284289
const uint32_t tile_size = 64;
285290
bool has_attention_bias = attention_bias != nullptr;
@@ -332,12 +337,14 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
332337
const TensorShape metadata_shape(metadata_dims);
333338
Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType<float>(), metadata_shape);
334339
ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata,
335-
parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size));
340+
parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size,
341+
present_sequence_length));
336342

337343
const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_, num_present_sequence_length_tile, parameters.head_size_});
338344
const TensorShape out_split_vx_shape(out_split_vx_dims);
339345
Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape);
340-
ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value, parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size));
346+
ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value, parameters,
347+
num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, present_sequence_length));
341348
ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, parameters, num_total_seq_length_tile, num_present_sequence_length_tile));
342349

343350
return Status::OK();

0 commit comments

Comments
 (0)