Skip to content

Commit 6f80d5c

Browse files
[API Compatibility] Add paddle.compat.nn.functional.sdpa (PaddlePaddle#76446)
* Implement paddle.nn.functional.sdpa * Enable flash attention test and disable test_compat_attention on Windows * Refactor sdpa * check dtype for mem_efficient_attention, support 3d attn_mask, refine tests * fix test_flash_attn error * feat: refactor GQA implementation and improve tensor handling - Move GQA logic from compat module to main scaled_dot_product_attention function - Add enable_gqa parameter to function signature with proper documentation - Simplify tensor dimension handling using is_batched check - Remove duplicate GQA validation and expansion code from compat module - Improve code organization by centralizing GQA functionality in main implementation * feat(test): update attention test shape for better alignment Change the shape parameter in TestSDPAttentionWithScale from (2, 32, 8, 32) to (2, 8, 8, 32) to improve test alignment and ensure proper attention mechanism validation with more realistic tensor dimensions. * feat(attention): update documentation and mask handling logic - Update scaled_dot_product_attention documentation to clarify dtype support and remove GQA mode mention - Simplify mask padding logic in MultiheadAttention to always use input dtype - Add tensor shape comments for better code readability - Refactor attention mask generation logic to improve efficiency - Remove unused device capability checking functions These changes improve code clarity and maintainability while ensuring consistent behavior across different input types. * feat(transformer): initialize bias parameters with None and conditionally create bias parameters Initialize all bias parameters (in_proj_bias, q_proj_bias, k_proj_bias, v_proj_bias) to None at class initialization. Conditionally create bias parameters only when bias=True, moving the bias parameter creation logic to the appropriate conditional branches. This improves code clarity by ensuring bias parameters are always defined and only created when needed. * feat(nn): remove __all__ from compat nn module * feat: fix CUDA availability check in scaled dot product attention Change `paddle.device.is_available()` to `paddle.cuda.is_available()` in the CUDA availability check function. This ensures proper detection of CUDA availability specifically for GPU operations in the scaled dot product attention implementation. * feat: update shape output format in docstrings and rename attention module - Change shape output format from list to paddle.Size in AvgPool1D, AvgPool2D, AvgPool3D, and Unfold docstrings - Rename attention.py to sdpa.py and update import paths - Remove debug parameter from check_all_tensors_on_device function - Replace debug warning with info logging for tensor device placement checks - Update MultiheadAttention documentation regarding optimized implementation conditions * feat: reduce log verbosity in attention validation functions Changed logger calls from info to debug level in SDPA validation functions to reduce noise in production logs. This maintains the same validation logic but only shows detailed validation messages when debug logging is enabled. * feat: add bfloat16 support check for MHA tests on CUDA Add paddle.device.is_bf16_supported() check to ensure bfloat16 tests only run on CUDA devices that support bfloat16. This prevents test failures on CUDA devices without bfloat16 support by falling back to float32 dtype in those cases. * feat: add runtime flags for attention backends and fix bf16 support check - Add FLAGS_memory_efficient_attention_available and FLAGS_flash_attention_available to conditionally enable attention backends at runtime - Update SDPA backend selection to use runtime flags instead of hardcoded values - Fix bf16 support detection in multihead attention tests by checking CUDA compute capability - Remove redundant scale check in flash attention constraints - Improve test coverage by using consistent bf16 capability checks * feat: add global flags for attention kernel availability Add global boolean flags `memory_efficient_attention_available` and `flash_attention_available` to centralize availability checks for memory efficient and flash attention kernels. Move flag definitions from individual kernel files to flags.cc for better maintainability and to avoid code duplication. The flags automatically set to true when corresponding compilation macros (PADDLE_WITH_MEMORY_EFFICIENT_ATTENTION and PADDLE_WITH_FLASHATTN) are defined, allowing runtime detection of available attention implementations. * Fix compile error on windows * Fix build error * feat(nn): use safe dict get for attention backend flags Replace direct dictionary access with get() method to handle missing flags gracefully. This prevents KeyError exceptions when the global flags dictionary doesn't contain the expected flash attention and memory efficient attention availability flags, providing default False values instead.
1 parent 55e2134 commit 6f80d5c

File tree

18 files changed

+1626
-489
lines changed

18 files changed

+1626
-489
lines changed

ci/h-test.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ concurrency_list="^test_fp8_deep_gemm$|\
161161
^test_dist_fuse_gemm_epilogue_pass$|\
162162
^test_fuse_allreduce_split_to_reducescatter_pass$|\
163163
^test_ps_server_pass$|\
164-
^test_white_lists$"
164+
^test_white_lists$|\
165+
^test_scaled_dot_product_attention$|\
166+
^test_compat_scaled_dot_product_attention$|\
167+
^test_flash_attention$"
165168

166169
cd ${work_dir}/build
167170
tmp_dir=`mktemp -d`

paddle/phi/core/flags.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,25 @@ PHI_DEFINE_EXPORTED_int64(conv_workspace_size_limit,
3232
phi::backends::gpu::kDefaultConvWorkspaceSizeLimitMB,
3333
"cuDNN convolution workspace limit in MB unit.");
3434
#endif
35+
36+
#ifdef PADDLE_WITH_MEMORY_EFFICIENT_ATTENTION
37+
static const bool kMemEffAttnDefault = true;
38+
#else
39+
static const bool kMemEffAttnDefault = false;
40+
#endif
41+
42+
PHI_DEFINE_EXPORTED_bool(
43+
mem_efficient_attn_available,
44+
kMemEffAttnDefault,
45+
"Weather memory efficient attention is available on the current device.");
46+
47+
#ifdef PADDLE_WITH_FLASHATTN
48+
static const bool kFlashAttnDefault = true;
49+
#else
50+
static const bool kFlashAttnDefault = false;
51+
#endif
52+
53+
PHI_DEFINE_EXPORTED_bool(
54+
flash_attn_available,
55+
kFlashAttnDefault,
56+
"Weather flash attention is available on the current device.");

python/paddle/compat/nn/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,6 @@ def __setstate__(self, state):
400400
self.__dict__.setdefault("count_include_pad", True)
401401

402402

403-
__all__ = ['Unfold', 'Linear', 'MultiheadAttention']
404-
405-
406403
class Unfold(nn.Unfold):
407404
"""
408405
A compatible version of paddle.nn.Unfold:

python/paddle/compat/nn/functional/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from paddle.tensor import softmax
2626
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
2727

28+
from .sdpa import scaled_dot_product_attention
29+
2830
if TYPE_CHECKING:
2931
from typing_extensions import TypeAlias
3032

@@ -39,7 +41,7 @@
3941
]
4042

4143

42-
__all__ = ['pad', 'softmax', 'linear', 'unfold']
44+
__all__ = ['pad', 'softmax', 'linear', 'scaled_dot_product_attention', 'unfold']
4345

4446

4547
def _check_valid_pad_len(pad_len, x_dim, is_constant):
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
19+
import paddle.nn.functional as F
20+
21+
if TYPE_CHECKING:
22+
from paddle import Tensor
23+
24+
25+
def scaled_dot_product_attention(
26+
query: Tensor,
27+
key: Tensor,
28+
value: Tensor,
29+
attn_mask: Tensor | None = None,
30+
dropout_p: float = 0.0,
31+
is_causal: bool = False,
32+
scale: float | None = None,
33+
enable_gqa: bool = False,
34+
) -> Tensor:
35+
r"""
36+
The equation is:
37+
38+
.. math::
39+
40+
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
41+
42+
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
43+
The dimensions of the three parameters are the same.
44+
``d`` represents the size of the last dimension of the three parameters.
45+
46+
47+
Warning:
48+
This API only verifies inputs with dtype float16 and bfloat16, other dtypes may fall back to math
49+
implementation, which is less optimized.
50+
51+
Note:
52+
This API differs from :ref:`api_paddle_nn_functional_scaled_dot_product_attention` in that:
53+
The QKV layout of this API is [batch_size, num_heads, seq_len, head_dim] or [num_heads, seq_len, head_dim].
54+
55+
Args:
56+
query(Tensor): The query tensor in the Attention module.
57+
4-D tensor with shape:
58+
[batch_size, num_heads, seq_len, head_dim].
59+
3-D tensor with shape:
60+
[num_heads, seq_len, head_dim].
61+
The dtype can be float16 or bfloat16.
62+
key(Tensor): The key tensor in the Attention module.
63+
4-D tensor with shape:
64+
[batch_size, num_heads, seq_len, head_dim].
65+
3-D tensor with shape:
66+
[num_heads, seq_len, head_dim].
67+
The dtype can be float16 or bfloat16.
68+
value(Tensor): The value tensor in the Attention module.
69+
4-D tensor with shape:
70+
[batch_size, num_heads, seq_len, head_dim].
71+
3-D tensor with shape:
72+
[num_heads, seq_len, head_dim].
73+
The dtype can be float16 or bfloat16.
74+
attn_mask(Tensor, optional): The attention mask tensor. The shape should be broadcastable to
75+
[batch_size, num_heads, seq_len_key, seq_len_query]. The dtype can be bool
76+
or same type of query. The bool mask indicates the positions should take part
77+
in attention. The non-bool mask will be added to attention score.
78+
79+
is_causal(bool, optional): Whether enable causal mode. If True, the attention masking is a lower
80+
triangular matrix when the mask is a square matrix. The attention masking has the
81+
form of the upper left causal bias when the mask is a non-square matrix.
82+
An error is thrown if both attn_mask and is_causal are set.
83+
scale(float, optional): The scaling factor used in the calculation of attention weights.
84+
If None, scale = 1 / sqrt(head_dim).
85+
enable_gqa(bool, optional): Whether enable GQA mode. Default False.
86+
87+
Returns:
88+
out(Tensor): The attention tensor.
89+
4-D tensor with shape: [batch_size, num_heads, seq_len, head_dim].
90+
3-D tensor with shape: [num_heads, seq_len, head_dim].
91+
The dtype can be float16 or bfloat16.
92+
93+
Examples:
94+
.. code-block:: python
95+
96+
>>> # doctest: +SKIP('bfloat need V100 compile')
97+
>>> import paddle
98+
>>> q = paddle.rand((1, 2, 128, 16), dtype=paddle.bfloat16)
99+
>>> output = paddle.compat.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False)
100+
>>> print(output)
101+
>>> # doctest: -SKIP
102+
"""
103+
if is_causal and attn_mask is not None:
104+
raise RuntimeError(
105+
"Explicit attn_mask should not be set when is_causal=True"
106+
)
107+
108+
query, key, value = (
109+
query.swapaxes(-3, -2),
110+
key.swapaxes(-3, -2),
111+
value.swapaxes(-3, -2),
112+
)
113+
out = F.scaled_dot_product_attention(
114+
query,
115+
key,
116+
value,
117+
attn_mask,
118+
dropout_p,
119+
is_causal,
120+
True, # training
121+
None, # backend
122+
scale,
123+
enable_gqa,
124+
None, # name
125+
)
126+
return out.swapaxes(-3, -2)

0 commit comments

Comments
 (0)