Skip to content

Commit 20c2ec9

Browse files
Valentine233pytorchmergebot
authored andcommitted
[CPU] Add flash attention mask version (pytorch#115913)
Add a masked-version flash attention for CPU. Pull Request resolved: pytorch#115913 Approved by: https://github.com/jgong5, https://github.com/drisspg
1 parent b847290 commit 20c2ec9

File tree

16 files changed

+1140
-137
lines changed

16 files changed

+1140
-137
lines changed

aten/src/ATen/native/cpu/FlashAttentionKernel.cpp

Lines changed: 125 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -138,23 +138,41 @@ inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) {
138138
}
139139
}
140140

141+
void reshape_attn_mask_to_4d(
142+
Tensor& attn_mask,
143+
int64_t batchSize,
144+
int64_t num_head,
145+
int64_t qSize,
146+
int64_t kvSize) {
147+
// Support mask shapes:
148+
// 2d: ({Q_seq_len, 1} x {KV_seq_len, 1})
149+
// 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})
150+
// Guaranteed in check_attn_mask_shape
151+
int64_t attn_mask_size_0 = 1;
152+
int64_t attn_mask_size_1 = 1;
153+
if (attn_mask.dim() == 4) {
154+
if (attn_mask.size(0) == batchSize) {
155+
attn_mask_size_0 = batchSize;
156+
}
157+
if (attn_mask.size(1) == num_head) {
158+
attn_mask_size_1 = num_head;
159+
}
160+
}
161+
attn_mask = attn_mask
162+
.view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), attn_mask.size(-1)})
163+
.expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize});
164+
}
165+
141166
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
142167
void cpu_flash_attention(
143168
const Tensor& output,
144169
const Tensor& logsumexp,
145-
const Tensor& cum_seq_q,
146-
const Tensor& cum_seq_k,
147-
int64_t& max_q,
148-
int64_t& max_k,
149-
const Tensor& philox_seed,
150-
const Tensor& philox_offset,
151-
const Tensor& debug_attn_mask,
152170
const at::Tensor& q,
153171
const at::Tensor& k,
154172
const at::Tensor& v,
155173
double dropout_p,
156174
bool is_causal,
157-
bool return_debug_mask,
175+
c10::optional<Tensor> attn_mask,
158176
c10::optional<double> scale) {
159177
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
160178
// -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
@@ -181,6 +199,14 @@ void cpu_flash_attention(
181199
int64_t num_head = query.size(2);
182200
int64_t headSize = query.size(3);
183201

202+
bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
203+
if (has_attn_mask) {
204+
if (is_reduced_type) {
205+
attn_mask.value() = attn_mask.value().to(at::kFloat);
206+
}
207+
reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize);
208+
}
209+
184210
// Strides
185211
int64_t qStrideB = query.stride(0);
186212
int64_t qStrideM = query.stride(1);
@@ -197,6 +223,16 @@ void cpu_flash_attention(
197223
int64_t lStrideB = logsumexp.stride(0);
198224
int64_t lStrideM = logsumexp.stride(1);
199225
int64_t lStrideH = logsumexp.stride(2);
226+
int64_t mStrideB =
227+
(has_attn_mask && attn_mask.value().size(0) > 1)
228+
? attn_mask.value().stride(0)
229+
: 0;
230+
int64_t mStrideH =
231+
(has_attn_mask && attn_mask.value().size(1) > 1)
232+
? attn_mask.value().stride(1)
233+
: 0;
234+
int64_t mStrideM =
235+
has_attn_mask ? attn_mask.value().stride(2) : 0;
200236

201237
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
202238
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
@@ -220,6 +256,9 @@ void cpu_flash_attention(
220256
scalar_t* q_data = query.data_ptr<scalar_t>();
221257
scalar_t* k_data = key.data_ptr<scalar_t>();
222258
scalar_t* v_data = value.data_ptr<scalar_t>();
259+
accum_t* mask_data = has_attn_mask
260+
? attn_mask.value().data_ptr<accum_t>()
261+
: nullptr;
223262
scalar_t* out_data = output.data_ptr<scalar_t>();
224263
accum_t* lse_data = logsumexp.data_ptr<accum_t>();
225264
accum_t* buf_data = buf.data_ptr<accum_t>();
@@ -275,13 +314,41 @@ void cpu_flash_attention(
275314
kvBlockSize - last_col - 1);
276315
}
277316
}
317+
// Update attention weights with attention mask
318+
// And apply scaling factor
319+
// qk <- qk * scaling + attn_mask
320+
if (has_attn_mask) {
321+
for (int64_t row = 0; row < qBlockSize; ++row) {
322+
at::vec::map2<accum_t>(
323+
[scaling_factor](Vec x, Vec y) {
324+
return x * Vec(scaling_factor) + y;
325+
},
326+
qk_data + row * kvBlockSize,
327+
qk_data + row * kvBlockSize,
328+
mask_data + i * mStrideB + j * mStrideH +
329+
(m + row) * mStrideM + n,
330+
kvBlockSize);
331+
}
332+
}
278333
// Update coefficients with Softmax
279334
accum_t tmp_max = 0, tmp_sum = 0, sum_old = 0, exp_tmp = 0;
280335
for (int64_t row = 0; row < qBlockSize; ++row) {
281336
sum_old = qk_sum_data[row];
282-
// scale and max per row
283-
_mul_reduce_max_fusion_kernel(qk_data + row * kvBlockSize, scaling_factor, kvBlockSize,
284-
qk_data + row * kvBlockSize, tmp_max);
337+
if (has_attn_mask) {
338+
// max per row
339+
tmp_max = at::vec::reduce_all<accum_t>(
340+
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
341+
qk_data + row * kvBlockSize,
342+
kvBlockSize);
343+
} else {
344+
// apply scaling factor and max per row in fusion
345+
_mul_reduce_max_fusion_kernel(
346+
qk_data + row * kvBlockSize,
347+
scaling_factor,
348+
kvBlockSize,
349+
qk_data + row * kvBlockSize,
350+
tmp_max);
351+
}
285352
tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
286353
// qk <- exp(qk - max) and sum per row
287354
tmp_sum = tmp_max;
@@ -354,14 +421,9 @@ void cpu_flash_attention_backward(
354421
const at::Tensor& value,
355422
const at::Tensor& out,
356423
const at::Tensor& logsumexp,
357-
const Tensor& cumulative_sequence_length_q,
358-
const Tensor& cumulative_sequence_length_k,
359-
const int64_t max_seqlen_batch_q,
360-
const int64_t max_seqlen_batch_k,
361424
double dropout_p,
362425
bool is_causal,
363-
const at::Tensor& philox_seed,
364-
const at::Tensor& philox_offset,
426+
c10::optional<Tensor> attn_mask,
365427
c10::optional<double> scale) {
366428
constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
367429
using accum_t = at::opmath_type<scalar_t>;
@@ -381,6 +443,14 @@ void cpu_flash_attention_backward(
381443
int64_t num_head = query.size(2);
382444
int64_t headSize = query.size(3);
383445

446+
bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
447+
if (has_attn_mask) {
448+
if (is_reduced_type) {
449+
attn_mask.value() = attn_mask.value().to(at::kFloat);
450+
}
451+
reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize);
452+
}
453+
384454
// Strides
385455
int64_t qStrideB = query.stride(0);
386456
int64_t qStrideM = query.stride(1);
@@ -397,6 +467,16 @@ void cpu_flash_attention_backward(
397467
int64_t lStrideB = logsumexp.stride(0);
398468
int64_t lStrideM = logsumexp.stride(1);
399469
int64_t lStrideH = logsumexp.stride(2);
470+
int64_t mStrideB =
471+
(has_attn_mask && attn_mask.value().size(0) > 1)
472+
? attn_mask.value().stride(0)
473+
: 0;
474+
int64_t mStrideH =
475+
(has_attn_mask && attn_mask.value().size(1) > 1)
476+
? attn_mask.value().stride(1)
477+
: 0;
478+
int64_t mStrideM =
479+
has_attn_mask ? attn_mask.value().stride(2) : 0;
400480

401481
int64_t grad_qStrideB = grad_q.stride(0);
402482
int64_t grad_qStrideM = grad_q.stride(1);
@@ -440,6 +520,9 @@ void cpu_flash_attention_backward(
440520
scalar_t* q_data = query.data_ptr<scalar_t>();
441521
scalar_t* k_data = key.data_ptr<scalar_t>();
442522
scalar_t* v_data = value.data_ptr<scalar_t>();
523+
accum_t* mask_data = has_attn_mask
524+
? attn_mask.value().data_ptr<accum_t>()
525+
: nullptr;
443526
scalar_t* out_data = out.data_ptr<scalar_t>();
444527
accum_t* lse_data = logsumexp.data_ptr<accum_t>();
445528
accum_t* buf_data = buf.data_ptr<accum_t>();
@@ -492,6 +575,20 @@ void cpu_flash_attention_backward(
492575
static_cast<accum_t>(0),
493576
attn_data,
494577
kvBlockSize);
578+
// attn <- attn + mask
579+
if (has_attn_mask) {
580+
for (const auto row : c10::irange(qBlockSize)) {
581+
at::vec::map2<accum_t>(
582+
[](Vec x, Vec y) {
583+
return x + y;
584+
},
585+
attn_data + row * kvBlockSize,
586+
attn_data + row * kvBlockSize,
587+
mask_data + i * mStrideB + j * mStrideH +
588+
(m + row) * mStrideM + n,
589+
kvBlockSize);
590+
}
591+
}
495592
// restore self attention after softmax from logsumexp
496593
// attn <- exp(attn - normalizer)
497594
for (const auto row : c10::irange(qBlockSize)) {
@@ -615,38 +712,28 @@ void cpu_flash_attention_backward(
615712
void flash_attention_kernel_impl(
616713
const Tensor& output,
617714
const Tensor& logsumexp,
618-
const Tensor& cum_seq_q,
619-
const Tensor& cum_seq_k,
620-
int64_t& max_q,
621-
int64_t& max_k,
622-
const Tensor& philox_seed,
623-
const Tensor& philox_offset,
624-
const Tensor& debug_attn_mask,
625715
const at::Tensor& query,
626716
const at::Tensor& key,
627717
const at::Tensor& value,
628718
double dropout_p,
629719
bool is_causal,
630-
bool return_debug_mask,
720+
c10::optional<Tensor> attn_mask,
631721
c10::optional<double> scale) {
632722
auto q_seq_len = query.size(2);
633723

634724
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, query.scalar_type(), "flash_attention", [&] {
635725
if (q_seq_len >= 768) {
636726
cpu_flash_attention<scalar_t, 256, 512>(
637-
output, logsumexp, cum_seq_q, cum_seq_k,
638-
max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
639-
query, key, value, dropout_p, is_causal, return_debug_mask, scale);
727+
output, logsumexp, query, key, value,
728+
dropout_p, is_causal, attn_mask, scale);
640729
} else if (q_seq_len >= 192) {
641730
cpu_flash_attention<scalar_t, 64, 512>(
642-
output, logsumexp, cum_seq_q, cum_seq_k,
643-
max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
644-
query, key, value, dropout_p, is_causal, return_debug_mask, scale);
731+
output, logsumexp, query, key, value,
732+
dropout_p, is_causal, attn_mask, scale);
645733
} else {
646734
cpu_flash_attention<scalar_t, 32, 512>(
647-
output, logsumexp, cum_seq_q, cum_seq_k,
648-
max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
649-
query, key, value, dropout_p, is_causal, return_debug_mask, scale);
735+
output, logsumexp, query, key, value,
736+
dropout_p, is_causal, attn_mask, scale);
650737
}
651738
});
652739
}
@@ -661,14 +748,9 @@ void flash_attention_backward_kernel_impl(
661748
const at::Tensor& value,
662749
const at::Tensor& out,
663750
const at::Tensor& logsumexp,
664-
const Tensor& cum_seq_q,
665-
const Tensor& cum_seq_k,
666-
const int64_t max_q,
667-
const int64_t max_k,
668751
double dropout_p,
669752
bool is_causal,
670-
const at::Tensor& philox_seed,
671-
const at::Tensor& philox_offset,
753+
c10::optional<Tensor> attn_mask,
672754
c10::optional<double> scale) {
673755
// make sure grad_out has no zero strides (broadcasted dimensions)
674756
// since we are going to call gemm next
@@ -681,20 +763,17 @@ void flash_attention_backward_kernel_impl(
681763
cpu_flash_attention_backward<scalar_t, 256, 512>(
682764
grad_q, grad_k, grad_v, grad_out_contig,
683765
query, key, value, out, logsumexp,
684-
cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
685-
is_causal, philox_seed, philox_offset, scale);
766+
dropout_p, is_causal, attn_mask, scale);
686767
} else if (q_seq_len >= 192) {
687768
cpu_flash_attention_backward<scalar_t, 64, 512>(
688769
grad_q, grad_k, grad_v, grad_out_contig,
689770
query, key, value, out, logsumexp,
690-
cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
691-
is_causal, philox_seed, philox_offset, scale);
771+
dropout_p, is_causal, attn_mask, scale);
692772
} else {
693773
cpu_flash_attention_backward<scalar_t, 32, 512>(
694774
grad_q, grad_k, grad_v, grad_out_contig,
695775
query, key, value, out, logsumexp,
696-
cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
697-
is_causal, philox_seed, philox_offset, scale);
776+
dropout_p, is_causal, attn_mask, scale);
698777
}
699778
});
700779
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14476,19 +14476,28 @@
1447614476

1447714477
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
1447814478
dispatch:
14479-
CPU: _scaled_dot_product_flash_attention_cpu
1448014479
CUDA: _scaled_dot_product_flash_attention_cuda
1448114480
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
1448214481
tags: nondeterministic_seeded
1448314482

14483+
- func: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)
14484+
dispatch:
14485+
CPU: _scaled_dot_product_flash_attention_cpu
14486+
tags: nondeterministic_seeded
14487+
1448414488
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
1448514489
device_check: NoCheck
1448614490
variants: function
1448714491
dispatch:
14488-
CPU: _scaled_dot_product_flash_attention_backward_cpu
1448914492
CUDA: _scaled_dot_product_flash_attention_backward_cuda
1449014493
NestedTensorCUDA: _scaled_dot_product_flash_attention_backward_nested
1449114494

14495+
- func: _scaled_dot_product_flash_attention_for_cpu_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, float dropout_p, bool is_causal, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
14496+
device_check: NoCheck
14497+
variants: function
14498+
dispatch:
14499+
CPU: _scaled_dot_product_flash_attention_cpu_backward
14500+
1449214501
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
1449314502
dispatch:
1449414503
CUDA: _scaled_dot_product_efficient_attention_cuda

0 commit comments

Comments
 (0)