@@ -55,12 +55,6 @@ class GQAAttentionBase : public AttentionBase {
5555 auto attention_probs = allocator->Alloc (bytes);
5656 BufferUniquePtr scratch_buffer (attention_probs, BufferDeleter (allocator));
5757
58- void * mask_data = nullptr ;
59- size_t mask_data_bytes = SafeInt<size_t >(batch_size) * sequence_length * seqlen_present_kv_cache * sizeof (T);
60- mask_data = allocator->Alloc (mask_data_bytes);
61- memset (mask_data, 0 , mask_data_bytes);
62- BufferUniquePtr mask_data_buffer (mask_data, BufferDeleter (allocator));
63-
6458 const T* past_key_data = past_key != nullptr ? past_key->Data <T>() : nullptr ;
6559 T* present_key_data = present_key != nullptr ? present_key->MutableData <T>() : nullptr ;
6660 const T* past_value_data = past_value != nullptr ? past_value->Data <T>() : nullptr ;
@@ -70,17 +64,13 @@ class GQAAttentionBase : public AttentionBase {
7064
7165 const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
7266 ComputeAttentionProbs<T>(static_cast <T*>(attention_probs), Q, k,
73- seqlens_k->Data <int32_t >(), static_cast <T*>(mask_data),
67+ seqlens_k->Data <int32_t >(),
7468 batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache,
7569 head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, tp);
7670
77- // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
78- auto out_tmp_data =
79- allocator->Alloc (SafeInt<size_t >(batch_size) * num_heads_ * sequence_length * head_size * sizeof (T));
80- BufferUniquePtr out_tmp_buffer (out_tmp_data, BufferDeleter (allocator));
81-
71+ // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
8272 const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
83- ComputeVxAttentionScore (output->MutableData <T>(), static_cast <T*>(out_tmp_data), static_cast <T*>( attention_probs),
73+ ComputeVxAttentionScore (output->MutableData <T>(), static_cast <T*>(attention_probs),
8474 v, seqlens_k->Data <int32_t >(), batch_size, sequence_length, seqlen_past_kv_cache,
8575 seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data,
8676 past_present_share_buffer, packed_qkv, tp);
@@ -90,15 +80,13 @@ class GQAAttentionBase : public AttentionBase {
9080
9181 private:
9282 // Helper function to compute the attention probs. It does 2 things:
93- // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) +
94- // 1 x mask_data(B, N, S, T)
83+ // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T)
9584 // attention_probs(B, N, S, T) = Softmax(attention_probs)
9685 template <typename T>
9786 void ComputeAttentionProbs (T* attention_probs, // output buffer with size BxNxSxT
9887 const T* Q, // Q data. Its size is BxNxSxH
9988 const T* K, // k data. Its size is BxNxLxH
10089 const int32_t * seqlens_k, // past sequence lengths tensor
101- T* mask_data, // buffer for mask data.
10290 int batch_size, // batch size of self-attention
10391 int sequence_length, // sequence length of self-attention (S)
10492 int past_buffer_sequence_length, // sequence length of past state
@@ -117,7 +105,6 @@ class GQAAttentionBase : public AttentionBase {
117105 const size_t past_buff_chunk_length = static_cast <size_t >(past_buffer_sequence_length) * head_size; // L x H
118106 const size_t present_buff_chunk_length = static_cast <size_t >(present_buffer_sequence_length) * head_size; // T x H
119107
120- PrepareMaskGQA (mask_data, batch_size, sequence_length, present_buffer_sequence_length, local_window_size_, seqlens_k);
121108 if (!past_present_share_buffer) {
122109 memset (present_key, 0 , batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof (T));
123110 }
@@ -146,16 +133,11 @@ class GQAAttentionBase : public AttentionBase {
146133 const int head_index = static_cast <int >(i) % num_heads_;
147134 const int past_seqlen = sequence_length == 1 ? static_cast <int >(seqlens_k[batch_index]) : past_buffer_sequence_length;
148135 const size_t past_chunk_length = static_cast <size_t >(past_seqlen) * head_size;
136+ const int total_seqlen = seqlens_k[batch_index] + 1 ;
149137
150138 const int output_offset = static_cast <int >(i) * sequence_length * present_buffer_sequence_length;
151- const int mask_offset = batch_index * sequence_length * present_buffer_sequence_length;
152139 T* output = attention_probs + output_offset;
153140
154- // Broadcast mask data: (Bx)SxT -> (BxNx)SxT
155- memcpy (output,
156- mask_data + mask_offset,
157- probs_matrix_bytes);
158-
159141 const T* k;
160142 if (packed_qkv) {
161143 k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
@@ -168,7 +150,6 @@ class GQAAttentionBase : public AttentionBase {
168150 i / kv_num_heads_factor);
169151 }
170152
171- // TODO: CblasTrans stuff what do?
172153 // Compute Q*K' + AttentionMask
173154 // original transposed each iteration
174155 // A: Q (B x N x) S x H (B x N x) S x H S x H
@@ -180,20 +161,38 @@ class GQAAttentionBase : public AttentionBase {
180161 } else {
181162 q = Q + q_input_chunk_length * i;
182163 }
183- math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, present_buffer_sequence_length, head_size, alpha,
184- q, k, mask_data != nullptr ? 1 .0f : 0 .0f , output, nullptr );
164+ math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasTrans,
165+ sequence_length, total_seqlen, head_size, alpha,
166+ q, head_size, k, head_size,
167+ 0 .0f /* bata*/ ,
168+ output, present_buffer_sequence_length, nullptr );
169+
170+ // compute Softmax
171+ T* output_softmax = output;
172+ for (int seq = 0 ; seq < sequence_length; seq++) {
173+ int seq_causal_length = sequence_length == 1 ? total_seqlen : seq + 1 ;
174+ if (local_window_size_ > 0 && seq_causal_length > local_window_size_ + 1 ) {
175+ for (int total_seq_id = 0 ; total_seq_id < seq_causal_length - local_window_size_ - 1 ; total_seq_id++) {
176+ output_softmax[total_seq_id] = 0 .f ;
177+ }
178+ ComputeAttentionSoftmaxInplace (output_softmax + seq_causal_length - local_window_size_ - 1 , 1 , local_window_size_ + 1 , nullptr );
179+ } else {
180+ ComputeAttentionSoftmaxInplace (output_softmax, 1 , seq_causal_length, nullptr );
181+ }
182+
183+ // set causal [seq_causal_length, total_seqlen) to 0.f
184+ for (int total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) {
185+ output_softmax[total_seq_id] = 0 .f ;
186+ }
187+
188+ output_softmax += present_buffer_sequence_length;
189+ }
185190 }
186191 });
187-
188- // attention_probs(B, N, S, T) = Softmax(attention_probs)
189- const int N = batch_size * num_heads_ * sequence_length;
190- const int D = present_buffer_sequence_length;
191- ComputeAttentionSoftmaxInplace (attention_probs, N, D, tp);
192192 }
193193
194194 template <typename T>
195195 void ComputeVxAttentionScore (T* output, // buffer for the result with size BxSxNxH
196- T* tmp_buffer, // buffer for temp use with size is BxNxSxH
197196 const T* attention_probs, // Attention probs with size BxNxSxT
198197 const T* V, // V value with size BxN_kvxSxH
199198 const int32_t * seqlens_k, // past sequence lengths tensor
@@ -211,8 +210,7 @@ class GQAAttentionBase : public AttentionBase {
211210 const bool is_prompt = sequence_length != 1 ;
212211 const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0 ;
213212 const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
214- const size_t q_input_chunk_length = static_cast <size_t >(sequence_length) * head_size; // S x H
215- const size_t kv_input_chunk_length = static_cast <size_t >(sequence_length) * head_size; // L x H
213+ const int kv_input_chunk_length = sequence_length * head_size; // L x H
216214 const size_t past_buff_chunk_length = static_cast <size_t >(past_buffer_sequence_length) * head_size; // L x H
217215 const size_t present_buff_chunk_length = static_cast <size_t >(present_buffer_sequence_length) * head_size; // T x H
218216
@@ -243,6 +241,7 @@ class GQAAttentionBase : public AttentionBase {
243241 const int head_index = static_cast <int >(i % num_heads_);
244242 const int past_seqlen = sequence_length == 1 ? static_cast <int >(seqlens_k[batch_index]) : past_buffer_sequence_length;
245243 const size_t past_chunk_length = static_cast <size_t >(past_seqlen) * head_size;
244+ const int total_seqlen = seqlens_k[batch_index] + 1 ;
246245
247246 const T* v;
248247 if (packed_qkv) {
@@ -256,21 +255,17 @@ class GQAAttentionBase : public AttentionBase {
256255 i / kv_num_heads_factor);
257256 }
258257
259- T* current_tmp_data = reinterpret_cast <T*>(tmp_buffer) + q_input_chunk_length * static_cast <int >(i);
260- const int attention_probs_offset = sequence_length * present_buffer_sequence_length * static_cast <int >(i);
261- math::MatMul<T>(sequence_length, head_size, present_buffer_sequence_length,
262- attention_probs + attention_probs_offset,
263- v, current_tmp_data, nullptr );
264-
265- // Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v)
266- T* src = current_tmp_data;
267- const int dest_offset = (batch_index * sequence_length * num_heads_ + head_index) * head_size;
268- T* dest = output + dest_offset;
269- for (int j = 0 ; j < sequence_length; j++) {
270- memcpy (dest, src, bytes_to_copy_trans);
271- src += head_size;
272- dest += hidden_size;
273- }
258+ T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
259+ ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t >(sequence_length) * present_buffer_sequence_length * i;
260+
261+ math::GemmEx<T, ThreadPool>(CblasNoTrans,
262+ CblasNoTrans,
263+ sequence_length, head_size, total_seqlen,
264+ 1 .f , /* alpha*/
265+ attention_probs + attention_probs_offset, present_buffer_sequence_length,
266+ v, head_size,
267+ 0 .0f /* beta*/ ,
268+ output_current, hidden_size, nullptr );
274269 }
275270 });
276271 }
0 commit comments