Skip to content

Commit d530b29

Browse files
authored
Fix Attention GQA implementation on CPU (#25966)
### Description Attention on CPU is following ONNX specifications. This change replicates the changes introduced by onnx/onnx#7274.
1 parent 2132530 commit d530b29

File tree

5 files changed

+129
-57
lines changed

5 files changed

+129
-57
lines changed

onnxruntime/core/providers/cpu/llm/attention.cc

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ void make_copy<MLFloat16, MLFloat16>(MLFloat16* mask_data, const MLFloat16* mask
6262
template <>
6363
void make_copy<float, bool>(float* mask_data, const bool* mask_index, size_t size) {
6464
for (size_t i = 0; i < size; ++i) {
65-
mask_data[i] = mask_index[i] ? 0.0f : std::numeric_limits<float>::lowest();
65+
mask_data[i] = mask_index[i] ? 0.0f : negative_infinity<float>();
6666
}
6767
}
6868

6969
template <>
7070
void make_copy<MLFloat16, bool>(MLFloat16* mask_data, const bool* mask_index, size_t size) {
7171
for (size_t i = 0; i < size; ++i) {
72-
mask_data[i] = mask_index[i] ? MLFloat16(0.f) : std::numeric_limits<MLFloat16>::lowest();
72+
mask_data[i] = mask_index[i] ? MLFloat16(0.f) : negative_infinity<MLFloat16>();
7373
}
7474
}
7575

@@ -251,7 +251,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
251251
mask_data = static_cast<T*>(allocated_ptr);
252252
for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) {
253253
for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) {
254-
mask_data[s_i * parameters.total_sequence_length + m_i] = std::numeric_limits<T>::lowest();
254+
mask_data[s_i * parameters.total_sequence_length + m_i] = negative_infinity<T>();
255255
}
256256
}
257257
delete_mask_data = true;
@@ -277,7 +277,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
277277
for (int i = 0; i < n_iter; ++i) {
278278
for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) {
279279
for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) {
280-
mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = std::numeric_limits<T>::lowest();
280+
mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = negative_infinity<T>();
281281
}
282282
}
283283
}
@@ -332,7 +332,8 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
332332
}
333333

334334
// handling GQA
335-
std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads;
335+
std::ptrdiff_t head_ki = head_i * parameters.kv_num_heads / parameters.q_num_heads;
336+
std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_ki;
336337
const T* k = K + k_input_chunk_length * ki;
337338

338339
if (nullptr != present_key) {
@@ -362,7 +363,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
362363
alpha,
363364
Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size,
364365
parameters.head_size * parameters.q_num_heads, // lda
365-
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + (head_i % parameters.kv_num_heads) * parameters.head_size : k,
366+
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k,
366367
transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb
367368
beta,
368369
output,
@@ -568,7 +569,8 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
568569
// handling GQA
569570
std::ptrdiff_t batch_i = i / num_heads;
570571
std::ptrdiff_t head_i = i % num_heads;
571-
std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads;
572+
std::ptrdiff_t head_vi = head_i * kv_num_heads / num_heads;
573+
std::ptrdiff_t vi = batch_i * kv_num_heads + head_vi;
572574
const T* v = V + v_input_chunk_length * vi;
573575

574576
if (nullptr != present_value) {
@@ -592,16 +594,15 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
592594
// V is transposed but not QK. We use GemmEx with a different value for ldb.
593595
math::GemmEx<T, ThreadPool>(CblasNoTrans,
594596
CblasNoTrans,
595-
sequence_length, // M
596-
v_head_size, // N
597-
total_sequence_length, // K
598-
1.f, // alpha
599-
attention_probs + attention_probs_offset, // QK
600-
total_sequence_length, // lda
601-
transposed_v ? V + (head_i % kv_num_heads) * v_head_size + v_input_chunk_length * kv_num_heads * batch_i
602-
: v,
603-
transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
604-
0.f, // beta
597+
sequence_length, // M
598+
v_head_size, // N
599+
total_sequence_length, // K
600+
1.f, // alpha
601+
attention_probs + attention_probs_offset, // QK
602+
total_sequence_length, // lda
603+
transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
604+
transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
605+
0.f, // beta
605606
output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size),
606607
v_head_size * num_heads, // ldc
607608
nullptr);

onnxruntime/core/providers/cpu/llm/attention.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@
99

1010
namespace onnxruntime {
1111

12+
template <typename T>
13+
inline T negative_infinity() {
14+
return -std::numeric_limits<T>::infinity();
15+
}
16+
17+
template <>
18+
inline MLFloat16 negative_infinity() {
19+
return MLFloat16(-std::numeric_limits<float>::infinity());
20+
}
21+
1222
template <typename T>
1323
class AttentionBase : public OpKernel {
1424
public:

onnxruntime/test/onnx/main.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,24 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
795795
// Please make no more changes to the list
796796
static const ORTCHAR_T* immutable_broken_tests[] =
797797
{
798+
// pending ONNX update
799+
ORT_TSTR("attention_3d_gqa"),
800+
ORT_TSTR("attention_3d_gqa_attn_mask"),
801+
ORT_TSTR("attention_3d_gqa_causal"),
802+
ORT_TSTR("attention_3d_gqa_scaled"),
803+
ORT_TSTR("attention_3d_gqa_softcap"),
804+
ORT_TSTR("attention_3d_gqa_with_past_and_present"),
805+
ORT_TSTR("attention_4d_gqa"),
806+
ORT_TSTR("attention_4d_gqa_attn_mask"),
807+
ORT_TSTR("attention_4d_gqa_causal"),
808+
ORT_TSTR("attention_4d_gqa_scaled"),
809+
ORT_TSTR("attention_4d_gqa_softcap"),
810+
ORT_TSTR("attention_4d_gqa_with_past_and_present"),
811+
ORT_TSTR("attention_4d_diff_heads_mask4d_padded_kv"),
812+
ORT_TSTR("attention_4d_gqa_with_past_and_present_fp16"),
813+
ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal"),
814+
ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal"),
815+
// unsupported case
798816
ORT_TSTR("AvgPool1d"),
799817
ORT_TSTR("AvgPool1d_stride"),
800818
ORT_TSTR("AvgPool2d"),

0 commit comments

Comments
 (0)