@@ -62,14 +62,14 @@ void make_copy<MLFloat16, MLFloat16>(MLFloat16* mask_data, const MLFloat16* mask
62
62
template <>
63
63
void make_copy<float , bool >(float * mask_data, const bool * mask_index, size_t size) {
64
64
for (size_t i = 0 ; i < size; ++i) {
65
- mask_data[i] = mask_index[i] ? 0 .0f : std::numeric_limits <float >:: lowest ();
65
+ mask_data[i] = mask_index[i] ? 0 .0f : negative_infinity <float >();
66
66
}
67
67
}
68
68
69
69
template <>
70
70
void make_copy<MLFloat16, bool >(MLFloat16* mask_data, const bool * mask_index, size_t size) {
71
71
for (size_t i = 0 ; i < size; ++i) {
72
- mask_data[i] = mask_index[i] ? MLFloat16 (0 .f ) : std::numeric_limits <MLFloat16>:: lowest ();
72
+ mask_data[i] = mask_index[i] ? MLFloat16 (0 .f ) : negative_infinity <MLFloat16>();
73
73
}
74
74
}
75
75
@@ -251,7 +251,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
251
251
mask_data = static_cast <T*>(allocated_ptr);
252
252
for (int s_i = 0 ; s_i < parameters.q_sequence_length ; s_i++) {
253
253
for (int m_i = parameters.past_sequence_length + s_i + 1 ; m_i < parameters.total_sequence_length ; m_i++) {
254
- mask_data[s_i * parameters.total_sequence_length + m_i] = std::numeric_limits <T>:: lowest ();
254
+ mask_data[s_i * parameters.total_sequence_length + m_i] = negative_infinity <T>();
255
255
}
256
256
}
257
257
delete_mask_data = true ;
@@ -277,7 +277,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
277
277
for (int i = 0 ; i < n_iter; ++i) {
278
278
for (int s_i = 0 ; s_i < parameters.q_sequence_length ; s_i++) {
279
279
for (int m_i = parameters.past_sequence_length + s_i + 1 ; m_i < parameters.total_sequence_length ; m_i++) {
280
- mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = std::numeric_limits <T>:: lowest ();
280
+ mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = negative_infinity <T>();
281
281
}
282
282
}
283
283
}
@@ -332,7 +332,8 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
332
332
}
333
333
334
334
// handling GQA
335
- std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads ;
335
+ std::ptrdiff_t head_ki = head_i * parameters.kv_num_heads / parameters.q_num_heads ;
336
+ std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_ki;
336
337
const T* k = K + k_input_chunk_length * ki;
337
338
338
339
if (nullptr != present_key) {
@@ -362,7 +363,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
362
363
alpha,
363
364
Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size ,
364
365
parameters.head_size * parameters.q_num_heads , // lda
365
- transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + (head_i % parameters. kv_num_heads ) * parameters.head_size : k,
366
+ transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k,
366
367
transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size , // ldb
367
368
beta,
368
369
output,
@@ -568,7 +569,8 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
568
569
// handling GQA
569
570
std::ptrdiff_t batch_i = i / num_heads;
570
571
std::ptrdiff_t head_i = i % num_heads;
571
- std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads;
572
+ std::ptrdiff_t head_vi = head_i * kv_num_heads / num_heads;
573
+ std::ptrdiff_t vi = batch_i * kv_num_heads + head_vi;
572
574
const T* v = V + v_input_chunk_length * vi;
573
575
574
576
if (nullptr != present_value) {
@@ -592,16 +594,15 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
592
594
// V is transposed but not QK. We use GemmEx with a different value for ldb.
593
595
math::GemmEx<T, ThreadPool>(CblasNoTrans,
594
596
CblasNoTrans,
595
- sequence_length, // M
596
- v_head_size, // N
597
- total_sequence_length, // K
598
- 1 .f , // alpha
599
- attention_probs + attention_probs_offset, // QK
600
- total_sequence_length, // lda
601
- transposed_v ? V + (head_i % kv_num_heads) * v_head_size + v_input_chunk_length * kv_num_heads * batch_i
602
- : v,
603
- transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
604
- 0 .f , // beta
597
+ sequence_length, // M
598
+ v_head_size, // N
599
+ total_sequence_length, // K
600
+ 1 .f , // alpha
601
+ attention_probs + attention_probs_offset, // QK
602
+ total_sequence_length, // lda
603
+ transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
604
+ transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
605
+ 0 .f , // beta
605
606
output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size),
606
607
v_head_size * num_heads, // ldc
607
608
nullptr );
0 commit comments