[FEATURE] Enhance forward combine kernel and split attention#227
[FEATURE] Enhance forward combine kernel and split attention#227LoserCheems merged 5 commits intomainfrom
Conversation
Improves split attention merging with stable normalization Supports variable-length sequences and autotuning
…d optimize output handling
There was a problem hiding this comment.
Pull request overview
This PR introduces a forward combine kernel for efficiently merging split attention outputs and enhances the FlashDecoding mechanism by adding support for KV-split parallelization. The changes enable better GPU utilization through parallel processing of attention across the KV sequence dimension, particularly beneficial for long-context scenarios.
Changes:
- Introduces a new combine kernel for stable numerically-stable merging of split attention outputs using log-sum-exp normalization
- Adds autotuning configurations and heuristics for determining optimal KV split counts
- Refactors forward attention kernels to support split KV mechanism with intermediate float32 accumulation for numerical precision
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| flash_sparse_attn/ops/triton/utils.py | Adds num_splits_heuristic function, FWD_COMBINE_AUTOTUNE_KEYS, get_fwd_combine_autotune_configs, updates get_fwd_base_grid to support num_splits parameter, adds get_fwd_combine_grid, and extends input validation for num_splits |
| flash_sparse_attn/ops/triton/flash_fwd_combine.py | New file implementing the combine kernel for merging split attention outputs with stable softmax normalization across splits |
| flash_sparse_attn/ops/triton/flash_fwd.py | Modifies _fwd_base_kernel and forward functions to support split KV mechanism, including stride calculations, tensor allocation for partial outputs, and integration with the combine kernel |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if arch == "cuda:sm80": | ||
| return [ | ||
| triton.Config( | ||
| {"TILE_M": 32, "TILE_N": 128}, | ||
| num_warps=4, | ||
| num_stages=1, | ||
| ) | ||
| ] | ||
| elif arch == "cuda:sm90": | ||
| return [ | ||
| triton.Config( | ||
| {"TILE_M": 32, "TILE_N": 128}, | ||
| num_warps=4, | ||
| num_stages=1, | ||
| ) | ||
| ] | ||
| elif arch == "cuda:sm100": | ||
| return [ | ||
| triton.Config( | ||
| {"TILE_M": 32, "TILE_N": 128}, | ||
| num_warps=4, | ||
| num_stages=1, | ||
| ) | ||
| ] | ||
| elif arch == "cuda:sm120": | ||
| return [ | ||
| triton.Config( | ||
| {"TILE_M": 32, "TILE_N": 128}, | ||
| num_warps=4, | ||
| num_stages=1, | ||
| ) | ||
| ] |
There was a problem hiding this comment.
The non-autotune configurations specify "TILE_N" in the config dictionary (lines 191, 199, 207, 215), but the autotune configurations use "TILE_K" (line 237). Since the combine kernel only uses TILE_M and TILE_K parameters (as seen in the kernel signature), the TILE_N in the non-autotune configs should be renamed to TILE_K for consistency. This mismatch could cause the kernel to fail when autotune is disabled.
…rward combine kernel
Summary
Root Cause
Changes
Reproduction
Tests
Compatibility
Checklist