Commit 6f80d5c
authored
[API Compatibility] Add paddle.compat.nn.functional.sdpa (PaddlePaddle#76446)
* Implement paddle.nn.functional.sdpa
* Enable flash attention test and disable test_compat_attention on Windows
* Refactor sdpa
* check dtype for mem_efficient_attention, support 3d attn_mask, refine tests
* fix test_flash_attn error
* feat: refactor GQA implementation and improve tensor handling
- Move GQA logic from compat module to main scaled_dot_product_attention function
- Add enable_gqa parameter to function signature with proper documentation
- Simplify tensor dimension handling using is_batched check
- Remove duplicate GQA validation and expansion code from compat module
- Improve code organization by centralizing GQA functionality in main implementation
* feat(test): update attention test shape for better alignment
Change the shape parameter in TestSDPAttentionWithScale from (2, 32, 8, 32) to (2, 8, 8, 32) to improve test alignment and ensure proper attention mechanism validation with more realistic tensor dimensions.
* feat(attention): update documentation and mask handling logic
- Update scaled_dot_product_attention documentation to clarify dtype support and remove GQA mode mention
- Simplify mask padding logic in MultiheadAttention to always use input dtype
- Add tensor shape comments for better code readability
- Refactor attention mask generation logic to improve efficiency
- Remove unused device capability checking functions
These changes improve code clarity and maintainability while ensuring consistent behavior across different input types.
* feat(transformer): initialize bias parameters with None and conditionally create bias parameters
Initialize all bias parameters (in_proj_bias, q_proj_bias, k_proj_bias, v_proj_bias) to None at class initialization. Conditionally create bias parameters only when bias=True, moving the bias parameter creation logic to the appropriate conditional branches. This improves code clarity by ensuring bias parameters are always defined and only created when needed.
* feat(nn): remove __all__ from compat nn module
* feat: fix CUDA availability check in scaled dot product attention
Change `paddle.device.is_available()` to `paddle.cuda.is_available()` in the CUDA availability check function. This ensures proper detection of CUDA availability specifically for GPU operations in the scaled dot product attention implementation.
* feat: update shape output format in docstrings and rename attention module
- Change shape output format from list to paddle.Size in AvgPool1D, AvgPool2D, AvgPool3D, and Unfold docstrings
- Rename attention.py to sdpa.py and update import paths
- Remove debug parameter from check_all_tensors_on_device function
- Replace debug warning with info logging for tensor device placement checks
- Update MultiheadAttention documentation regarding optimized implementation conditions
* feat: reduce log verbosity in attention validation functions
Changed logger calls from info to debug level in SDPA validation functions to reduce noise in production logs. This maintains the same validation logic but only shows detailed validation messages when debug logging is enabled.
* feat: add bfloat16 support check for MHA tests on CUDA
Add paddle.device.is_bf16_supported() check to ensure bfloat16 tests
only run on CUDA devices that support bfloat16. This prevents test
failures on CUDA devices without bfloat16 support by falling back to
float32 dtype in those cases.
* feat: add runtime flags for attention backends and fix bf16 support check
- Add FLAGS_memory_efficient_attention_available and FLAGS_flash_attention_available
to conditionally enable attention backends at runtime
- Update SDPA backend selection to use runtime flags instead of hardcoded values
- Fix bf16 support detection in multihead attention tests by checking CUDA compute capability
- Remove redundant scale check in flash attention constraints
- Improve test coverage by using consistent bf16 capability checks
* feat: add global flags for attention kernel availability
Add global boolean flags `memory_efficient_attention_available` and `flash_attention_available` to centralize availability checks for memory efficient and flash attention kernels. Move flag definitions from individual kernel files to flags.cc for better maintainability and to avoid code duplication. The flags automatically set to true when corresponding compilation macros (PADDLE_WITH_MEMORY_EFFICIENT_ATTENTION and PADDLE_WITH_FLASHATTN) are defined, allowing runtime detection of available attention implementations.
* Fix compile error on windows
* Fix build error
* feat(nn): use safe dict get for attention backend flags
Replace direct dictionary access with get() method to handle missing flags gracefully. This prevents KeyError exceptions when the global flags dictionary doesn't contain the expected flash attention and memory efficient attention availability flags, providing default False values instead.1 parent 55e2134 commit 6f80d5c
File tree
18 files changed
+1626
-489
lines changed- ci
- paddle/phi/core
- python/paddle
- compat/nn
- functional
- device
- nn
- attention
- functional
- test/legacy_test
- tools/windows
18 files changed
+1626
-489
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
161 | 161 | | |
162 | 162 | | |
163 | 163 | | |
164 | | - | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
165 | 168 | | |
166 | 169 | | |
167 | 170 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
400 | 400 | | |
401 | 401 | | |
402 | 402 | | |
403 | | - | |
404 | | - | |
405 | | - | |
406 | 403 | | |
407 | 404 | | |
408 | 405 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
25 | 25 | | |
26 | 26 | | |
27 | 27 | | |
| 28 | + | |
| 29 | + | |
28 | 30 | | |
29 | 31 | | |
30 | 32 | | |
| |||
39 | 41 | | |
40 | 42 | | |
41 | 43 | | |
42 | | - | |
| 44 | + | |
43 | 45 | | |
44 | 46 | | |
45 | 47 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
0 commit comments