-
Notifications
You must be signed in to change notification settings - Fork 40
Add tensor operation utilities and performance enhancements #165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Introduced `static_switch.h` with BOOL_SWITCH, EVENK_SWITCH, SOFTCAP_SWITCH, FP16_SWITCH, and HEADDIM_SWITCH macros for compile-time conditional execution. - Added `utils.h` containing various utility functions for tensor operations, including relu, max, sum, and GEMM implementations. - Implemented specialized relu functions for half and bfloat16 types using inline PTX assembly for performance optimization. - Enhanced tensor layout conversion functions to support different configurations for GEMM operations. - Included support for asynchronous copy operations and softmax calculations within the FLASH_NAMESPACE.
…sm86 and sm89 architectures
… directory structure
…ditional head dimensions and adjust keep_window_size
…dout parameter and use sum() for gradient computation
…compute_attn functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces tensor operation utilities and performance enhancements to the flash attention implementation. The changes reorganize source file paths to a dedicated flash_dmattn directory, refine backward kernel launch parameters for better performance on specific GPU architectures, and streamline mask processing while expanding test coverage.
Key changes:
- Reorganizes CUDA source files from
csrc/tocsrc/flash_dmattn/subdirectory - Optimizes backward pass kernel configurations for different GPU architectures (sm86/sm89 vs A100/H100)
- Simplifies mask processor initialization by removing template parameter dependency
Reviewed Changes
Copilot reviewed 7 out of 92 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| setup.py | Updates source file paths to new flash_dmattn subdirectory structure |
| csrc/flash_dmattn/src/mask.h | Removes template parameter from Mask struct declaration |
| csrc/flash_dmattn/src/flash_fwd_kernel.h | Updates mask processor initialization calls to remove template parameter |
| csrc/flash_dmattn/src/flash_bwd_launch_template.h | Refines kernel launch parameters and memory usage comments for different architectures |
| benchmarks/forward_equivalence.py | Removes duplicate test configurations and adds head_dim 192/256 support, increases keep_window_size |
| benchmarks/backward_performance.py | Simplifies backward pass testing by using sum() instead of custom gradient tensors |
| benchmarks/backward_equivalence.py | Updates backward pass testing approach and removes gradient tensor parameters |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| } else { // sm86 and sm89 | ||
| // 96KB, 2 CTAs in sm86 and sm 89. | ||
| run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream); | ||
| // 96KB, 1 CTAs in sm86 and sm 89. |
Copilot
AI
Sep 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment states '96KB' but based on the kernel traits and the context of other comments, this should be updated to reflect the actual shared memory usage. The comment appears inconsistent with the pattern of other memory usage comments in the file.
| // 96KB, 1 CTAs in sm86 and sm 89. | |
| // 88KB, 1 CTAs in sm86 and sm 89. |
| } else { // sm86 and sm89 | ||
| // 72KB, 1 CTAs in sm86 and sm 89. | ||
| run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream); | ||
| // 88KB, 1 CTAs in sm86 and sm 89. |
Copilot
AI
Sep 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment states '88KB' but the original comment was '72KB'. If the memory usage calculation has changed due to the kernel trait modifications, the comment should accurately reflect the new memory requirements.
| // 88KB, 1 CTAs in sm86 and sm 89. | |
| // 72KB, 1 CTAs in sm86 and sm 89. |
| } else { // sm86 and sm89 | ||
| // 92KB, 1 CTAs in sm86 and sm 89. | ||
| run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream); | ||
| // 76KB, 1 CTAs in sm86 and sm 89. |
Copilot
AI
Sep 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment states '76KB' but the original was '92KB'. This significant reduction in shared memory usage should be verified to ensure the comment accurately reflects the actual memory requirements of the new kernel configuration.
| // 76KB, 1 CTAs in sm86 and sm 89. | |
| // 92KB, 1 CTAs in sm86 and sm 89. |
| } else { // sm86 and sm89 | ||
| // 88KB, 1 CTAs in sm86 and sm 89. | ||
| run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_causal>(params, stream); | ||
| // 80KB, 1 CTAs in sm86 and sm 89. |
Copilot
AI
Sep 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment states '80KB' but the original was '88KB'. Please verify this memory usage calculation matches the actual shared memory requirements of the modified kernel traits configuration.
| // 80KB, 1 CTAs in sm86 and sm 89. | |
| // 88KB, 1 CTAs in sm86 and sm 89. |
Introduce a static switch utility for compile-time conditions and various tensor operation utilities. Refine backward kernel launch parameters for improved performance on specific architectures. Refactor backward pass computations and mask processor initialization for efficiency. Update test configurations to accommodate additional dimensions.