@@ -164,7 +164,9 @@ Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader)
164
164
165
165
Status ComputeFlashAttentionDecodeQKT (onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
166
166
const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata,
167
- const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size) {
167
+ const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile,
168
+ uint32_t num_present_sequence_length_tile, uint32_t tile_size,
169
+ uint32_t present_sequence_length) {
168
170
const float alpha = parameters.scale_ == 0 .0f ? 1 .f / sqrt (static_cast <float >(parameters.head_size_ ))
169
171
: parameters.scale_ ;
170
172
@@ -187,8 +189,7 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
187
189
.AddUniformVariables ({{static_cast <uint32_t >(vectorized_head_size)},
188
190
{static_cast <uint32_t >(parameters.total_sequence_length_ )},
189
191
{static_cast <float >(alpha)},
190
- // present_sequence_length is used to index into the KV cache, for static kv cache it is the max sequence length.
191
- {static_cast <uint32_t >(parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_ )},
192
+ present_sequence_length,
192
193
{static_cast <uint32_t >(parameters.n_reps )},
193
194
{num_total_seq_length_tile},
194
195
{num_present_sequence_length_tile},
@@ -220,7 +221,8 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
220
221
const WebgpuAttentionParameters& parameters,
221
222
uint32_t num_total_seq_length_tile,
222
223
uint32_t num_present_sequence_length_tile,
223
- uint32_t tile_size) {
224
+ uint32_t tile_size,
225
+ uint32_t present_sequence_length) {
224
226
const int components = 4 ;
225
227
int head_size_vec = parameters.v_head_size_ / components;
226
228
FlashAttentionDecodeSplitVxProgram program{" FlashAttentionDecodeSplitVx" , tile_size, head_size_vec};
@@ -233,7 +235,7 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
233
235
.SetWorkgroupSize (64 )
234
236
.AddUniformVariables ({{static_cast <uint32_t >(parameters.total_sequence_length_ )},
235
237
{static_cast <uint32_t >(head_size_vec)},
236
- { static_cast < uint32_t >(parameters. is_gqa_ ? parameters. seqlen_present_kv_cache_ : parameters. total_sequence_length_ )} ,
238
+ present_sequence_length ,
237
239
{static_cast <uint32_t >(parameters.n_reps )},
238
240
num_total_seq_length_tile,
239
241
num_present_sequence_length_tile,
@@ -279,7 +281,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
279
281
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
280
282
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
281
283
ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value));
282
- const int present_sequence_length = parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_ ;
284
+
285
+ // Extract present_sequence_length directly from present_key tensor shape:
286
+ // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
287
+ const uint32_t present_sequence_length = static_cast <uint32_t >(present_key->Shape ()[2 ]);
283
288
if (parameters.sequence_length_ > 1 ) {
284
289
const uint32_t tile_size = 64 ;
285
290
bool has_attention_bias = attention_bias != nullptr ;
@@ -332,12 +337,14 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
332
337
const TensorShape metadata_shape (metadata_dims);
333
338
Tensor metadata = context.CreateGPUTensor (DataTypeImpl::GetType<float >(), metadata_shape);
334
339
ORT_RETURN_IF_ERROR (ComputeFlashAttentionDecodeQKT (context, Q, attention_bias, &qk, present_key, &metadata,
335
- parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size));
340
+ parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size,
341
+ present_sequence_length));
336
342
337
343
const TensorShapeVector out_split_vx_dims ({parameters.batch_size_ , parameters.num_heads_ , num_present_sequence_length_tile, parameters.head_size_ });
338
344
const TensorShape out_split_vx_shape (out_split_vx_dims);
339
345
Tensor out_split_vx = context.CreateGPUTensor (Q->DataType (), out_split_vx_shape);
340
- ORT_RETURN_IF_ERROR (ComputeFlashAttentionDecodeSplitVxScore (context, &metadata, &qk, &out_split_vx, present_value, parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size));
346
+ ORT_RETURN_IF_ERROR (ComputeFlashAttentionDecodeSplitVxScore (context, &metadata, &qk, &out_split_vx, present_value, parameters,
347
+ num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, present_sequence_length));
341
348
ORT_RETURN_IF_ERROR (ComputeFlashAttentionDecodeVxReduce (context, &out_split_vx, output, parameters, num_total_seq_length_tile, num_present_sequence_length_tile));
342
349
343
350
return Status::OK ();
0 commit comments