Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Introduce a static switch utility for compile-time conditions and various tensor operation utilities. Refine backward kernel launch parameters for improved performance on specific architectures. Refactor backward pass computations and mask processor initialization for efficiency. Update test configurations to accommodate additional dimensions.

- Introduced `static_switch.h` with BOOL_SWITCH, EVENK_SWITCH, SOFTCAP_SWITCH, FP16_SWITCH, and HEADDIM_SWITCH macros for compile-time conditional execution.
- Added `utils.h` containing various utility functions for tensor operations, including relu, max, sum, and GEMM implementations.
- Implemented specialized relu functions for half and bfloat16 types using inline PTX assembly for performance optimization.
- Enhanced tensor layout conversion functions to support different configurations for GEMM operations.
- Included support for asynchronous copy operations and softmax calculations within the FLASH_NAMESPACE.
…ditional head dimensions and adjust keep_window_size
…dout parameter and use sum() for gradient computation
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces tensor operation utilities and performance enhancements to the flash attention implementation. The changes reorganize source file paths to a dedicated flash_dmattn directory, refine backward kernel launch parameters for better performance on specific GPU architectures, and streamline mask processing while expanding test coverage.

Key changes:

  • Reorganizes CUDA source files from csrc/ to csrc/flash_dmattn/ subdirectory
  • Optimizes backward pass kernel configurations for different GPU architectures (sm86/sm89 vs A100/H100)
  • Simplifies mask processor initialization by removing template parameter dependency

Reviewed Changes

Copilot reviewed 7 out of 92 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
setup.py Updates source file paths to new flash_dmattn subdirectory structure
csrc/flash_dmattn/src/mask.h Removes template parameter from Mask struct declaration
csrc/flash_dmattn/src/flash_fwd_kernel.h Updates mask processor initialization calls to remove template parameter
csrc/flash_dmattn/src/flash_bwd_launch_template.h Refines kernel launch parameters and memory usage comments for different architectures
benchmarks/forward_equivalence.py Removes duplicate test configurations and adds head_dim 192/256 support, increases keep_window_size
benchmarks/backward_performance.py Simplifies backward pass testing by using sum() instead of custom gradient tensors
benchmarks/backward_equivalence.py Updates backward pass testing approach and removes gradient tensor parameters

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

} else { // sm86 and sm89
// 96KB, 2 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
// 96KB, 1 CTAs in sm86 and sm 89.
Copy link

Copilot AI Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states '96KB' but based on the kernel traits and the context of other comments, this should be updated to reflect the actual shared memory usage. The comment appears inconsistent with the pattern of other memory usage comments in the file.

Suggested change
// 96KB, 1 CTAs in sm86 and sm 89.
// 88KB, 1 CTAs in sm86 and sm 89.

Copilot uses AI. Check for mistakes.
} else { // sm86 and sm89
// 72KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
// 88KB, 1 CTAs in sm86 and sm 89.
Copy link

Copilot AI Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states '88KB' but the original comment was '72KB'. If the memory usage calculation has changed due to the kernel trait modifications, the comment should accurately reflect the new memory requirements.

Suggested change
// 88KB, 1 CTAs in sm86 and sm 89.
// 72KB, 1 CTAs in sm86 and sm 89.

Copilot uses AI. Check for mistakes.
} else { // sm86 and sm89
// 92KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
// 76KB, 1 CTAs in sm86 and sm 89.
Copy link

Copilot AI Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states '76KB' but the original was '92KB'. This significant reduction in shared memory usage should be verified to ensure the comment accurately reflects the actual memory requirements of the new kernel configuration.

Suggested change
// 76KB, 1 CTAs in sm86 and sm 89.
// 92KB, 1 CTAs in sm86 and sm 89.

Copilot uses AI. Check for mistakes.
} else { // sm86 and sm89
// 88KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_causal>(params, stream);
// 80KB, 1 CTAs in sm86 and sm 89.
Copy link

Copilot AI Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states '80KB' but the original was '88KB'. Please verify this memory usage calculation matches the actual shared memory requirements of the modified kernel traits configuration.

Suggested change
// 80KB, 1 CTAs in sm86 and sm 89.
// 88KB, 1 CTAs in sm86 and sm 89.

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit ab06c18 into main Sep 11, 2025
1 check passed
@LoserCheems LoserCheems deleted the fix-bwd-launch-templates branch November 13, 2025 04:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants