Skip to content

Commit 156d521

Browse files
authored
optimize gqa cpu (#20598)
### Description <!-- Describe your changes. --> optimize the GQA implementation on CPU. Mainly optimization are: 1. compute attention on real total sequence length instead of maximum sequence length in case past/present share same buffer 2. remove the mask 3. remove the transpose after attention x value It improve the phi3 model https://github.com/microsoft/onnxruntime-genai/blob/main/examples/python/phi3-qa.py with max sequence length 2k/4k from 10 tps to 20 tps. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 1f50921 commit 156d521

File tree

3 files changed

+46
-97
lines changed

3 files changed

+46
-97
lines changed

onnxruntime/contrib_ops/cpu/bert/attention_helper.h

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -153,49 +153,6 @@ void PrepareMask(const int32_t* mask_index,
153153
}
154154
}
155155

156-
// Applies causal mask and past seqlens right pad to the mask_data buffer.
157-
template <typename T>
158-
void PrepareMaskGQA(T* mask_data,
159-
int batch_size,
160-
int sequence_length,
161-
int buffer_sequence_length,
162-
int local_window_size,
163-
const int32_t* seqlens_k) {
164-
// mask_data has been filled with 0, and its shape is BxSxT
165-
T* p_mask = mask_data;
166-
// TODO: parallelize this
167-
for (int b_i = 0; b_i < batch_size; b_i++) {
168-
if (sequence_length > 1) {
169-
// Apply causal/local mask for prompt case.
170-
for (int s_i = 0; s_i < sequence_length; s_i++) {
171-
for (int m_i = s_i + 1; m_i < buffer_sequence_length; m_i++) {
172-
p_mask[s_i * buffer_sequence_length + m_i] = std::numeric_limits<T>::lowest();
173-
}
174-
// Apply local mask.
175-
if (local_window_size > 0) {
176-
for (int m_i = 0; m_i < s_i - local_window_size; m_i++) {
177-
p_mask[s_i * buffer_sequence_length + m_i] = std::numeric_limits<T>::lowest();
178-
}
179-
}
180-
}
181-
} else if (sequence_length == 1) {
182-
// Apply right padding to mask for token gen case.
183-
int total_seqlen = seqlens_k[b_i] + 1;
184-
for (int m_i = total_seqlen; m_i < buffer_sequence_length; m_i++) {
185-
p_mask[m_i] = std::numeric_limits<T>::lowest();
186-
}
187-
// Apply local mask.
188-
if (local_window_size > 0) {
189-
for (int m_i = 0; m_i < total_seqlen - local_window_size - 1; m_i++) {
190-
p_mask[m_i] = std::numeric_limits<T>::lowest();
191-
}
192-
}
193-
}
194-
ptrdiff_t mask_to_advance = SafeInt<ptrdiff_t>(sequence_length) * buffer_sequence_length;
195-
p_mask += mask_to_advance;
196-
}
197-
}
198-
199156
// Concatenate a past state chunk PxH with input state chunk LxH into present state chunk TxH
200157
// Returns a pointer to the start of present state chunk.
201158
template <typename T>

onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,6 @@ class GQAAttentionBase : public AttentionBase {
5555
auto attention_probs = allocator->Alloc(bytes);
5656
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
5757

58-
void* mask_data = nullptr;
59-
size_t mask_data_bytes = SafeInt<size_t>(batch_size) * sequence_length * seqlen_present_kv_cache * sizeof(T);
60-
mask_data = allocator->Alloc(mask_data_bytes);
61-
memset(mask_data, 0, mask_data_bytes);
62-
BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator));
63-
6458
const T* past_key_data = past_key != nullptr ? past_key->Data<T>() : nullptr;
6559
T* present_key_data = present_key != nullptr ? present_key->MutableData<T>() : nullptr;
6660
const T* past_value_data = past_value != nullptr ? past_value->Data<T>() : nullptr;
@@ -70,17 +64,13 @@ class GQAAttentionBase : public AttentionBase {
7064

7165
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
7266
ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, k,
73-
seqlens_k->Data<int32_t>(), static_cast<T*>(mask_data),
67+
seqlens_k->Data<int32_t>(),
7468
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache,
7569
head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, tp);
7670

77-
// Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
78-
auto out_tmp_data =
79-
allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * head_size * sizeof(T));
80-
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator));
81-
71+
// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
8272
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
83-
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(out_tmp_data), static_cast<T*>(attention_probs),
73+
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(attention_probs),
8474
v, seqlens_k->Data<int32_t>(), batch_size, sequence_length, seqlen_past_kv_cache,
8575
seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data,
8676
past_present_share_buffer, packed_qkv, tp);
@@ -90,15 +80,13 @@ class GQAAttentionBase : public AttentionBase {
9080

9181
private:
9282
// Helper function to compute the attention probs. It does 2 things:
93-
// attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) +
94-
// 1 x mask_data(B, N, S, T)
83+
// attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T)
9584
// attention_probs(B, N, S, T) = Softmax(attention_probs)
9685
template <typename T>
9786
void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
9887
const T* Q, // Q data. Its size is BxNxSxH
9988
const T* K, // k data. Its size is BxNxLxH
10089
const int32_t* seqlens_k, // past sequence lengths tensor
101-
T* mask_data, // buffer for mask data.
10290
int batch_size, // batch size of self-attention
10391
int sequence_length, // sequence length of self-attention (S)
10492
int past_buffer_sequence_length, // sequence length of past state
@@ -117,7 +105,6 @@ class GQAAttentionBase : public AttentionBase {
117105
const size_t past_buff_chunk_length = static_cast<size_t>(past_buffer_sequence_length) * head_size; // L x H
118106
const size_t present_buff_chunk_length = static_cast<size_t>(present_buffer_sequence_length) * head_size; // T x H
119107

120-
PrepareMaskGQA(mask_data, batch_size, sequence_length, present_buffer_sequence_length, local_window_size_, seqlens_k);
121108
if (!past_present_share_buffer) {
122109
memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
123110
}
@@ -146,16 +133,11 @@ class GQAAttentionBase : public AttentionBase {
146133
const int head_index = static_cast<int>(i) % num_heads_;
147134
const int past_seqlen = sequence_length == 1 ? static_cast<int>(seqlens_k[batch_index]) : past_buffer_sequence_length;
148135
const size_t past_chunk_length = static_cast<size_t>(past_seqlen) * head_size;
136+
const int total_seqlen = seqlens_k[batch_index] + 1;
149137

150138
const int output_offset = static_cast<int>(i) * sequence_length * present_buffer_sequence_length;
151-
const int mask_offset = batch_index * sequence_length * present_buffer_sequence_length;
152139
T* output = attention_probs + output_offset;
153140

154-
// Broadcast mask data: (Bx)SxT -> (BxNx)SxT
155-
memcpy(output,
156-
mask_data + mask_offset,
157-
probs_matrix_bytes);
158-
159141
const T* k;
160142
if (packed_qkv) {
161143
k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
@@ -168,7 +150,6 @@ class GQAAttentionBase : public AttentionBase {
168150
i / kv_num_heads_factor);
169151
}
170152

171-
// TODO: CblasTrans stuff what do?
172153
// Compute Q*K' + AttentionMask
173154
// original transposed each iteration
174155
// A: Q (B x N x) S x H (B x N x) S x H S x H
@@ -180,20 +161,38 @@ class GQAAttentionBase : public AttentionBase {
180161
} else {
181162
q = Q + q_input_chunk_length * i;
182163
}
183-
math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, present_buffer_sequence_length, head_size, alpha,
184-
q, k, mask_data != nullptr ? 1.0f : 0.0f, output, nullptr);
164+
math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasTrans,
165+
sequence_length, total_seqlen, head_size, alpha,
166+
q, head_size, k, head_size,
167+
0.0f /*bata*/,
168+
output, present_buffer_sequence_length, nullptr);
169+
170+
// compute Softmax
171+
T* output_softmax = output;
172+
for (int seq = 0; seq < sequence_length; seq++) {
173+
int seq_causal_length = sequence_length == 1 ? total_seqlen : seq + 1;
174+
if (local_window_size_ > 0 && seq_causal_length > local_window_size_ + 1) {
175+
for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) {
176+
output_softmax[total_seq_id] = 0.f;
177+
}
178+
ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, local_window_size_ + 1, nullptr);
179+
} else {
180+
ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr);
181+
}
182+
183+
// set causal [seq_causal_length, total_seqlen) to 0.f
184+
for (int total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) {
185+
output_softmax[total_seq_id] = 0.f;
186+
}
187+
188+
output_softmax += present_buffer_sequence_length;
189+
}
185190
}
186191
});
187-
188-
// attention_probs(B, N, S, T) = Softmax(attention_probs)
189-
const int N = batch_size * num_heads_ * sequence_length;
190-
const int D = present_buffer_sequence_length;
191-
ComputeAttentionSoftmaxInplace(attention_probs, N, D, tp);
192192
}
193193

194194
template <typename T>
195195
void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH
196-
T* tmp_buffer, // buffer for temp use with size is BxNxSxH
197196
const T* attention_probs, // Attention probs with size BxNxSxT
198197
const T* V, // V value with size BxN_kvxSxH
199198
const int32_t* seqlens_k, // past sequence lengths tensor
@@ -211,8 +210,7 @@ class GQAAttentionBase : public AttentionBase {
211210
const bool is_prompt = sequence_length != 1;
212211
const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0;
213212
const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
214-
const size_t q_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // S x H
215-
const size_t kv_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // L x H
213+
const int kv_input_chunk_length = sequence_length * head_size; // L x H
216214
const size_t past_buff_chunk_length = static_cast<size_t>(past_buffer_sequence_length) * head_size; // L x H
217215
const size_t present_buff_chunk_length = static_cast<size_t>(present_buffer_sequence_length) * head_size; // T x H
218216

@@ -243,6 +241,7 @@ class GQAAttentionBase : public AttentionBase {
243241
const int head_index = static_cast<int>(i % num_heads_);
244242
const int past_seqlen = sequence_length == 1 ? static_cast<int>(seqlens_k[batch_index]) : past_buffer_sequence_length;
245243
const size_t past_chunk_length = static_cast<size_t>(past_seqlen) * head_size;
244+
const int total_seqlen = seqlens_k[batch_index] + 1;
246245

247246
const T* v;
248247
if (packed_qkv) {
@@ -256,21 +255,17 @@ class GQAAttentionBase : public AttentionBase {
256255
i / kv_num_heads_factor);
257256
}
258257

259-
T* current_tmp_data = reinterpret_cast<T*>(tmp_buffer) + q_input_chunk_length * static_cast<int>(i);
260-
const int attention_probs_offset = sequence_length * present_buffer_sequence_length * static_cast<int>(i);
261-
math::MatMul<T>(sequence_length, head_size, present_buffer_sequence_length,
262-
attention_probs + attention_probs_offset,
263-
v, current_tmp_data, nullptr);
264-
265-
// Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v)
266-
T* src = current_tmp_data;
267-
const int dest_offset = (batch_index * sequence_length * num_heads_ + head_index) * head_size;
268-
T* dest = output + dest_offset;
269-
for (int j = 0; j < sequence_length; j++) {
270-
memcpy(dest, src, bytes_to_copy_trans);
271-
src += head_size;
272-
dest += hidden_size;
273-
}
258+
T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
259+
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * present_buffer_sequence_length * i;
260+
261+
math::GemmEx<T, ThreadPool>(CblasNoTrans,
262+
CblasNoTrans,
263+
sequence_length, head_size, total_seqlen,
264+
1.f, /*alpha*/
265+
attention_probs + attention_probs_offset, present_buffer_sequence_length,
266+
v, head_size,
267+
0.0f /*beta*/,
268+
output_current, hidden_size, nullptr);
274269
}
275270
});
276271
}

onnxruntime/test/python/transformers/test_gqa_cpu.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,7 +1294,7 @@ def parity_check_gqa_prompt_no_buff(
12941294
None,
12951295
cos,
12961296
sin,
1297-
cache_seqlens,
1297+
cache_seqlens - 1,
12981298
left_window_size,
12991299
past_format,
13001300
False,
@@ -1310,7 +1310,7 @@ def parity_check_gqa_prompt_no_buff(
13101310
new_v,
13111311
cos,
13121312
sin,
1313-
cache_seqlens,
1313+
cache_seqlens - 1,
13141314
left_window_size,
13151315
past_format,
13161316
False,
@@ -1766,9 +1766,6 @@ def test_gqa_no_past(self):
17661766
seqs = (
17671767
[
17681768
(127, 127),
1769-
(35, 35),
1770-
(2000, 2000),
1771-
(200, 200),
17721769
(240, 240),
17731770
]
17741771
if pipeline_mode

0 commit comments

Comments
 (0)