@@ -138,23 +138,41 @@ inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) {
138138 }
139139}
140140
141+ void reshape_attn_mask_to_4d (
142+ Tensor& attn_mask,
143+ int64_t batchSize,
144+ int64_t num_head,
145+ int64_t qSize,
146+ int64_t kvSize) {
147+ // Support mask shapes:
148+ // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1})
149+ // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})
150+ // Guaranteed in check_attn_mask_shape
151+ int64_t attn_mask_size_0 = 1 ;
152+ int64_t attn_mask_size_1 = 1 ;
153+ if (attn_mask.dim () == 4 ) {
154+ if (attn_mask.size (0 ) == batchSize) {
155+ attn_mask_size_0 = batchSize;
156+ }
157+ if (attn_mask.size (1 ) == num_head) {
158+ attn_mask_size_1 = num_head;
159+ }
160+ }
161+ attn_mask = attn_mask
162+ .view ({attn_mask_size_0, attn_mask_size_1, attn_mask.size (-2 ), attn_mask.size (-1 )})
163+ .expand ({attn_mask_size_0, attn_mask_size_1, qSize, kvSize});
164+ }
165+
141166template <typename scalar_t , int64_t q_split_size, int64_t kv_split_size>
142167void cpu_flash_attention (
143168 const Tensor& output,
144169 const Tensor& logsumexp,
145- const Tensor& cum_seq_q,
146- const Tensor& cum_seq_k,
147- int64_t & max_q,
148- int64_t & max_k,
149- const Tensor& philox_seed,
150- const Tensor& philox_offset,
151- const Tensor& debug_attn_mask,
152170 const at::Tensor& q,
153171 const at::Tensor& k,
154172 const at::Tensor& v,
155173 double dropout_p,
156174 bool is_causal,
157- bool return_debug_mask ,
175+ c10::optional<Tensor> attn_mask ,
158176 c10::optional<double > scale) {
159177 // Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
160178 // -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
@@ -181,6 +199,14 @@ void cpu_flash_attention(
181199 int64_t num_head = query.size (2 );
182200 int64_t headSize = query.size (3 );
183201
202+ bool has_attn_mask = attn_mask.has_value () && attn_mask.value ().numel ();
203+ if (has_attn_mask) {
204+ if (is_reduced_type) {
205+ attn_mask.value () = attn_mask.value ().to (at::kFloat );
206+ }
207+ reshape_attn_mask_to_4d (attn_mask.value (), batchSize, num_head, qSize, kvSize);
208+ }
209+
184210 // Strides
185211 int64_t qStrideB = query.stride (0 );
186212 int64_t qStrideM = query.stride (1 );
@@ -197,6 +223,16 @@ void cpu_flash_attention(
197223 int64_t lStrideB = logsumexp.stride (0 );
198224 int64_t lStrideM = logsumexp.stride (1 );
199225 int64_t lStrideH = logsumexp.stride (2 );
226+ int64_t mStrideB =
227+ (has_attn_mask && attn_mask.value ().size (0 ) > 1 )
228+ ? attn_mask.value ().stride (0 )
229+ : 0 ;
230+ int64_t mStrideH =
231+ (has_attn_mask && attn_mask.value ().size (1 ) > 1 )
232+ ? attn_mask.value ().stride (1 )
233+ : 0 ;
234+ int64_t mStrideM =
235+ has_attn_mask ? attn_mask.value ().stride (2 ) : 0 ;
200236
201237 int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
202238 int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
@@ -220,6 +256,9 @@ void cpu_flash_attention(
220256 scalar_t * q_data = query.data_ptr <scalar_t >();
221257 scalar_t * k_data = key.data_ptr <scalar_t >();
222258 scalar_t * v_data = value.data_ptr <scalar_t >();
259+ accum_t * mask_data = has_attn_mask
260+ ? attn_mask.value ().data_ptr <accum_t >()
261+ : nullptr ;
223262 scalar_t * out_data = output.data_ptr <scalar_t >();
224263 accum_t * lse_data = logsumexp.data_ptr <accum_t >();
225264 accum_t * buf_data = buf.data_ptr <accum_t >();
@@ -275,13 +314,41 @@ void cpu_flash_attention(
275314 kvBlockSize - last_col - 1 );
276315 }
277316 }
317+ // Update attention weights with attention mask
318+ // And apply scaling factor
319+ // qk <- qk * scaling + attn_mask
320+ if (has_attn_mask) {
321+ for (int64_t row = 0 ; row < qBlockSize; ++row) {
322+ at::vec::map2<accum_t >(
323+ [scaling_factor](Vec x, Vec y) {
324+ return x * Vec (scaling_factor) + y;
325+ },
326+ qk_data + row * kvBlockSize,
327+ qk_data + row * kvBlockSize,
328+ mask_data + i * mStrideB + j * mStrideH +
329+ (m + row) * mStrideM + n,
330+ kvBlockSize);
331+ }
332+ }
278333 // Update coefficients with Softmax
279334 accum_t tmp_max = 0 , tmp_sum = 0 , sum_old = 0 , exp_tmp = 0 ;
280335 for (int64_t row = 0 ; row < qBlockSize; ++row) {
281336 sum_old = qk_sum_data[row];
282- // scale and max per row
283- _mul_reduce_max_fusion_kernel (qk_data + row * kvBlockSize, scaling_factor, kvBlockSize,
284- qk_data + row * kvBlockSize, tmp_max);
337+ if (has_attn_mask) {
338+ // max per row
339+ tmp_max = at::vec::reduce_all<accum_t >(
340+ [](Vec& x, Vec& y) { return at::vec::maximum (x, y); },
341+ qk_data + row * kvBlockSize,
342+ kvBlockSize);
343+ } else {
344+ // apply scaling factor and max per row in fusion
345+ _mul_reduce_max_fusion_kernel (
346+ qk_data + row * kvBlockSize,
347+ scaling_factor,
348+ kvBlockSize,
349+ qk_data + row * kvBlockSize,
350+ tmp_max);
351+ }
285352 tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
286353 // qk <- exp(qk - max) and sum per row
287354 tmp_sum = tmp_max;
@@ -354,14 +421,9 @@ void cpu_flash_attention_backward(
354421 const at::Tensor& value,
355422 const at::Tensor& out,
356423 const at::Tensor& logsumexp,
357- const Tensor& cumulative_sequence_length_q,
358- const Tensor& cumulative_sequence_length_k,
359- const int64_t max_seqlen_batch_q,
360- const int64_t max_seqlen_batch_k,
361424 double dropout_p,
362425 bool is_causal,
363- const at::Tensor& philox_seed,
364- const at::Tensor& philox_offset,
426+ c10::optional<Tensor> attn_mask,
365427 c10::optional<double > scale) {
366428 constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t >;
367429 using accum_t = at::opmath_type<scalar_t >;
@@ -381,6 +443,14 @@ void cpu_flash_attention_backward(
381443 int64_t num_head = query.size (2 );
382444 int64_t headSize = query.size (3 );
383445
446+ bool has_attn_mask = attn_mask.has_value () && attn_mask.value ().numel ();
447+ if (has_attn_mask) {
448+ if (is_reduced_type) {
449+ attn_mask.value () = attn_mask.value ().to (at::kFloat );
450+ }
451+ reshape_attn_mask_to_4d (attn_mask.value (), batchSize, num_head, qSize, kvSize);
452+ }
453+
384454 // Strides
385455 int64_t qStrideB = query.stride (0 );
386456 int64_t qStrideM = query.stride (1 );
@@ -397,6 +467,16 @@ void cpu_flash_attention_backward(
397467 int64_t lStrideB = logsumexp.stride (0 );
398468 int64_t lStrideM = logsumexp.stride (1 );
399469 int64_t lStrideH = logsumexp.stride (2 );
470+ int64_t mStrideB =
471+ (has_attn_mask && attn_mask.value ().size (0 ) > 1 )
472+ ? attn_mask.value ().stride (0 )
473+ : 0 ;
474+ int64_t mStrideH =
475+ (has_attn_mask && attn_mask.value ().size (1 ) > 1 )
476+ ? attn_mask.value ().stride (1 )
477+ : 0 ;
478+ int64_t mStrideM =
479+ has_attn_mask ? attn_mask.value ().stride (2 ) : 0 ;
400480
401481 int64_t grad_qStrideB = grad_q.stride (0 );
402482 int64_t grad_qStrideM = grad_q.stride (1 );
@@ -440,6 +520,9 @@ void cpu_flash_attention_backward(
440520 scalar_t * q_data = query.data_ptr <scalar_t >();
441521 scalar_t * k_data = key.data_ptr <scalar_t >();
442522 scalar_t * v_data = value.data_ptr <scalar_t >();
523+ accum_t * mask_data = has_attn_mask
524+ ? attn_mask.value ().data_ptr <accum_t >()
525+ : nullptr ;
443526 scalar_t * out_data = out.data_ptr <scalar_t >();
444527 accum_t * lse_data = logsumexp.data_ptr <accum_t >();
445528 accum_t * buf_data = buf.data_ptr <accum_t >();
@@ -492,6 +575,20 @@ void cpu_flash_attention_backward(
492575 static_cast <accum_t >(0 ),
493576 attn_data,
494577 kvBlockSize);
578+ // attn <- attn + mask
579+ if (has_attn_mask) {
580+ for (const auto row : c10::irange (qBlockSize)) {
581+ at::vec::map2<accum_t >(
582+ [](Vec x, Vec y) {
583+ return x + y;
584+ },
585+ attn_data + row * kvBlockSize,
586+ attn_data + row * kvBlockSize,
587+ mask_data + i * mStrideB + j * mStrideH +
588+ (m + row) * mStrideM + n,
589+ kvBlockSize);
590+ }
591+ }
495592 // restore self attention after softmax from logsumexp
496593 // attn <- exp(attn - normalizer)
497594 for (const auto row : c10::irange (qBlockSize)) {
@@ -615,38 +712,28 @@ void cpu_flash_attention_backward(
615712void flash_attention_kernel_impl (
616713 const Tensor& output,
617714 const Tensor& logsumexp,
618- const Tensor& cum_seq_q,
619- const Tensor& cum_seq_k,
620- int64_t & max_q,
621- int64_t & max_k,
622- const Tensor& philox_seed,
623- const Tensor& philox_offset,
624- const Tensor& debug_attn_mask,
625715 const at::Tensor& query,
626716 const at::Tensor& key,
627717 const at::Tensor& value,
628718 double dropout_p,
629719 bool is_causal,
630- bool return_debug_mask ,
720+ c10::optional<Tensor> attn_mask ,
631721 c10::optional<double > scale) {
632722 auto q_seq_len = query.size (2 );
633723
634724 AT_DISPATCH_FLOATING_TYPES_AND (kBFloat16 , query.scalar_type (), " flash_attention" , [&] {
635725 if (q_seq_len >= 768 ) {
636726 cpu_flash_attention<scalar_t , 256 , 512 >(
637- output, logsumexp, cum_seq_q, cum_seq_k,
638- max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
639- query, key, value, dropout_p, is_causal, return_debug_mask, scale);
727+ output, logsumexp, query, key, value,
728+ dropout_p, is_causal, attn_mask, scale);
640729 } else if (q_seq_len >= 192 ) {
641730 cpu_flash_attention<scalar_t , 64 , 512 >(
642- output, logsumexp, cum_seq_q, cum_seq_k,
643- max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
644- query, key, value, dropout_p, is_causal, return_debug_mask, scale);
731+ output, logsumexp, query, key, value,
732+ dropout_p, is_causal, attn_mask, scale);
645733 } else {
646734 cpu_flash_attention<scalar_t , 32 , 512 >(
647- output, logsumexp, cum_seq_q, cum_seq_k,
648- max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
649- query, key, value, dropout_p, is_causal, return_debug_mask, scale);
735+ output, logsumexp, query, key, value,
736+ dropout_p, is_causal, attn_mask, scale);
650737 }
651738 });
652739}
@@ -661,14 +748,9 @@ void flash_attention_backward_kernel_impl(
661748 const at::Tensor& value,
662749 const at::Tensor& out,
663750 const at::Tensor& logsumexp,
664- const Tensor& cum_seq_q,
665- const Tensor& cum_seq_k,
666- const int64_t max_q,
667- const int64_t max_k,
668751 double dropout_p,
669752 bool is_causal,
670- const at::Tensor& philox_seed,
671- const at::Tensor& philox_offset,
753+ c10::optional<Tensor> attn_mask,
672754 c10::optional<double > scale) {
673755 // make sure grad_out has no zero strides (broadcasted dimensions)
674756 // since we are going to call gemm next
@@ -681,20 +763,17 @@ void flash_attention_backward_kernel_impl(
681763 cpu_flash_attention_backward<scalar_t , 256 , 512 >(
682764 grad_q, grad_k, grad_v, grad_out_contig,
683765 query, key, value, out, logsumexp,
684- cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
685- is_causal, philox_seed, philox_offset, scale);
766+ dropout_p, is_causal, attn_mask, scale);
686767 } else if (q_seq_len >= 192 ) {
687768 cpu_flash_attention_backward<scalar_t , 64 , 512 >(
688769 grad_q, grad_k, grad_v, grad_out_contig,
689770 query, key, value, out, logsumexp,
690- cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
691- is_causal, philox_seed, philox_offset, scale);
771+ dropout_p, is_causal, attn_mask, scale);
692772 } else {
693773 cpu_flash_attention_backward<scalar_t , 32 , 512 >(
694774 grad_q, grad_k, grad_v, grad_out_contig,
695775 query, key, value, out, logsumexp,
696- cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
697- is_causal, philox_seed, philox_offset, scale);
776+ dropout_p, is_causal, attn_mask, scale);
698777 }
699778 });
700779}
0 commit comments