Skip to content

Commit 7afb83c

Browse files
authored
feat: add deterministic feature for flash attention (PaddlePaddle#76496)
* feat: add deterministic feature for flash attention * Merge branch 'develop' into flash_determ * support deterministic for flashattn & flashmask * update format * add assert blockmask * support deterministic for flashattn & flashmask
1 parent 50a842e commit 7afb83c

File tree

5 files changed

+79
-69
lines changed

5 files changed

+79
-69
lines changed

paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -625,11 +625,10 @@ void FlashAttnGradBaseKernel(
625625
const float softmax_scale = 1.0f / std::sqrt(head_size);
626626
const float softmax_unscale = std::sqrt(head_size);
627627

628-
int version =
629-
FLAGS_flash_attn_version == 3 && !FLAGS_cudnn_deterministic &&
630-
(head_size == 64 || head_size == 128 || head_size == 256)
631-
? FLAGS_flash_attn_version
632-
: 2;
628+
int version = FLAGS_flash_attn_version == 3 && FLAGS_cudnn_deterministic &&
629+
head_size > 128
630+
? 2
631+
: FLAGS_flash_attn_version;
633632
FlashAttnBwdParamsV2 params =
634633
FlashAttnBwdParamsV2(dev_ctx,
635634
version,

paddle/phi/kernels/gpu/flash_attn_kernel.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,11 +378,10 @@ void FlashAttnBaseKernel(
378378
const float softmax_scale = 1.0f / std::sqrt(head_size);
379379
const float softmax_unscale = std::sqrt(head_size);
380380

381-
int version =
382-
FLAGS_flash_attn_version == 3 && !FLAGS_cudnn_deterministic &&
383-
(head_size == 64 || head_size == 128 || head_size == 256)
384-
? FLAGS_flash_attn_version
385-
: 2;
381+
int version = FLAGS_flash_attn_version == 3 && FLAGS_cudnn_deterministic &&
382+
head_size > 128
383+
? 2
384+
: FLAGS_flash_attn_version;
386385
FlashAttnFwdParamsV2<T> params = FlashAttnFwdParamsV2<T>(dev_ctx,
387386
version,
388387
batch_size,

paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu

Lines changed: 51 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -511,13 +511,16 @@ void FlashAttnV3GradBaseKernel(
511511
dev_ctx, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads});
512512
dynload::fa3_bwd_params_set_dq_semaphore(params_handle,
513513
dq_semaphore.data<int>());
514+
DenseTensor dk_semaphore = phi::Empty<int32_t>(
515+
dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k});
516+
DenseTensor dv_semaphore = phi::Empty<int32_t>(
517+
dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k});
514518
if (num_heads_k != num_heads &&
515519
dynload::fa3_bwd_params_get_deterministic(params_handle)) {
516-
// TODO(tridao): do we need to zero them out?
517-
DenseTensor dk_semaphore = phi::Empty<int32_t>(
518-
dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k});
519-
DenseTensor dv_semaphore = phi::Empty<int32_t>(
520-
dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k});
520+
phi::funcs::SetConstant<Context, int32_t> set_zero_dk;
521+
set_zero_dk(dev_ctx, &dk_semaphore, static_cast<int32_t>(0));
522+
phi::funcs::SetConstant<Context, int32_t> set_zero_dv;
523+
set_zero_dv(dev_ctx, &dv_semaphore, static_cast<int32_t>(0));
521524
dynload::fa3_bwd_params_set_dk_semaphore(params_handle,
522525
dk_semaphore.data<int>());
523526
dynload::fa3_bwd_params_set_dv_semaphore(params_handle,
@@ -599,11 +602,6 @@ void FlashAttnV3GradKernel(const Context &dev_ctx,
599602
0,
600603
common::errors::InvalidArgument(
601604
"sm_margin is not supported, please set sm_margin to 0"));
602-
PADDLE_ENFORCE_EQ(FLAGS_cudnn_deterministic,
603-
false,
604-
common::errors::InvalidArgument(
605-
"deterministic is not supported in flash attention 3, "
606-
"please set FLAGS_cudnn_deterministic to false"));
607605
// umiswing: fake grad tensor for FlashAttnV3GradBaseKernel
608606
DenseTensor softmax_d;
609607
DenseTensor softmax_lse_log2;
@@ -737,11 +735,6 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx,
737735
0,
738736
common::errors::InvalidArgument(
739737
"sm_margin is not supported, please set sm_margin to 0"));
740-
PADDLE_ENFORCE_EQ(FLAGS_cudnn_deterministic,
741-
false,
742-
common::errors::InvalidArgument(
743-
"deterministic is not supported in flash attention 3, "
744-
"please set FLAGS_cudnn_deterministic to false"));
745738

746739
PADDLE_ENFORCE_EQ(
747740
q.dims()[q.dims().size() - 1],
@@ -1437,13 +1430,17 @@ void FlashMaskV2GradBaseKernel(
14371430
dev_ctx, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads});
14381431
dynload::flashmaskv2_bwd_params_set_dq_semaphore(params_handle,
14391432
dq_semaphore.data<int>());
1433+
DenseTensor dk_semaphore = phi::Empty<int32_t>(
1434+
dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k});
1435+
DenseTensor dv_semaphore = phi::Empty<int32_t>(
1436+
dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k});
14401437
if (num_heads_k != num_heads &&
14411438
dynload::flashmaskv2_bwd_params_get_deterministic(params_handle)) {
1442-
// TODO(tridao): do we need to zero them out?
1443-
DenseTensor dk_semaphore = phi::Empty<int32_t>(
1444-
dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k});
1445-
DenseTensor dv_semaphore = phi::Empty<int32_t>(
1446-
dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k});
1439+
// xiangrui: we need to zero them out
1440+
phi::funcs::SetConstant<Context, int32_t> set_zero_dk;
1441+
set_zero_dk(dev_ctx, &dk_semaphore, static_cast<int32_t>(0));
1442+
phi::funcs::SetConstant<Context, int32_t> set_zero_dv;
1443+
set_zero_dv(dev_ctx, &dv_semaphore, static_cast<int32_t>(0));
14471444
dynload::flashmaskv2_bwd_params_set_dk_semaphore(params_handle,
14481445
dk_semaphore.data<int>());
14491446
dynload::flashmaskv2_bwd_params_set_dv_semaphore(params_handle,
@@ -1573,39 +1570,40 @@ void FlashMaskV2GradKernel(
15731570
DenseTensor dq_accum;
15741571
DenseTensor dk_accum;
15751572
DenseTensor dv_accum;
1576-
FlashMaskV2GradBaseKernel<T, Context>(dev_ctx,
1577-
out_grad,
1578-
q,
1579-
k,
1580-
v,
1581-
out,
1582-
softmax_lse,
1583-
paddle::none, // dq_
1584-
paddle::none, // dk_
1585-
paddle::none, // dv_
1586-
paddle::none,
1587-
paddle::none,
1588-
paddle::none,
1589-
paddle::none,
1590-
startend_row_indices,
1591-
block_mask,
1592-
0, // max_seqlen_q,
1593-
0, // max_seqlen_k,
1594-
softmax_scale,
1595-
is_causal,
1596-
-1, // window_size_left,
1597-
-1, // window_size_right,
1598-
0, // softcap,
1599-
false, // deterministic,
1600-
0, // sm_margin,
1601-
dq,
1602-
dk,
1603-
dv,
1604-
&softmax_d,
1605-
&softmax_lse_log2,
1606-
&dq_accum,
1607-
&dk_accum,
1608-
&dv_accum);
1573+
FlashMaskV2GradBaseKernel<T, Context>(
1574+
dev_ctx,
1575+
out_grad,
1576+
q,
1577+
k,
1578+
v,
1579+
out,
1580+
softmax_lse,
1581+
paddle::none, // dq_
1582+
paddle::none, // dk_
1583+
paddle::none, // dv_
1584+
paddle::none,
1585+
paddle::none,
1586+
paddle::none,
1587+
paddle::none,
1588+
startend_row_indices,
1589+
block_mask,
1590+
0, // max_seqlen_q,
1591+
0, // max_seqlen_k,
1592+
softmax_scale,
1593+
is_causal,
1594+
-1, // window_size_left,
1595+
-1, // window_size_right,
1596+
0, // softcap,
1597+
FLAGS_cudnn_deterministic, // deterministic,
1598+
0, // sm_margin,
1599+
dq,
1600+
dk,
1601+
dv,
1602+
&softmax_d,
1603+
&softmax_lse_log2,
1604+
&dq_accum,
1605+
&dk_accum,
1606+
&dv_accum);
16091607

16101608
// umiswing: some branch in upstream fa3 could have padded the head dimension
16111609
PADDLE_ENFORCE_EQ(

python/paddle/nn/functional/flash_attention.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,14 +2318,28 @@ def flashmask_attention(
23182318
f"Invalid shape of startend_row_indices, when causal is False, the last dimension should be either 2 or 4 but got {startend_row_indices.shape[-1]}"
23192319
)
23202320

2321-
if "xpu" in paddle.get_device():
2322-
fa_version = 2
2323-
elif paddle.get_flags(["FLAGS_cudnn_deterministic"])[
2324-
"FLAGS_cudnn_deterministic"
2325-
]:
2321+
if (
2322+
"xpu" not in paddle.get_device()
2323+
and paddle.get_flags(["FLAGS_cudnn_deterministic"])[
2324+
"FLAGS_cudnn_deterministic"
2325+
]
2326+
):
23262327
assert block_mask is None, (
23272328
" blockmask attention no supports deterministic now ."
23282329
)
2330+
2331+
if "xpu" in paddle.get_device():
2332+
fa_version = 2
2333+
elif (
2334+
paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])[
2335+
"FLAGS_flash_attn_version"
2336+
]
2337+
== 3
2338+
and paddle.base.framework.get_flags(["FLAGS_cudnn_deterministic"])[
2339+
"FLAGS_cudnn_deterministic"
2340+
]
2341+
and query.shape[3] > 128
2342+
):
23292343
fa_version = 2
23302344
else:
23312345
fa_version = paddle.base.framework.get_flags(

0 commit comments

Comments
 (0)