Skip to content

[FEATURE] Enhance forward combine kernel and split attention#227

Merged
LoserCheems merged 5 commits intomainfrom
optime-triton-kernels
Feb 27, 2026
Merged

[FEATURE] Enhance forward combine kernel and split attention#227
LoserCheems merged 5 commits intomainfrom
optime-triton-kernels

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • Introduces a forward combine kernel for split outputs and improves split attention merging with stable normalization.

Root Cause

  • The need for better performance and flexibility in handling variable-length sequences and autotuning configurations.

Changes

  • Added forward combine kernel, autotuning configurations, and optimized KV splits for FlashDecoding. Refactored forward attention functions to support the split KV mechanism.

Reproduction

  • Implement the new forward combine kernel and autotuning configurations in relevant scenarios.

Tests

  • Validated changes through existing and new tests for the forward combine kernel and split attention mechanisms.

Compatibility

  • No backward compatibility issues identified.

Checklist

Copilot AI review requested due to automatic review settings February 27, 2026 15:03
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a forward combine kernel for efficiently merging split attention outputs and enhances the FlashDecoding mechanism by adding support for KV-split parallelization. The changes enable better GPU utilization through parallel processing of attention across the KV sequence dimension, particularly beneficial for long-context scenarios.

Changes:

  • Introduces a new combine kernel for stable numerically-stable merging of split attention outputs using log-sum-exp normalization
  • Adds autotuning configurations and heuristics for determining optimal KV split counts
  • Refactors forward attention kernels to support split KV mechanism with intermediate float32 accumulation for numerical precision

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
flash_sparse_attn/ops/triton/utils.py Adds num_splits_heuristic function, FWD_COMBINE_AUTOTUNE_KEYS, get_fwd_combine_autotune_configs, updates get_fwd_base_grid to support num_splits parameter, adds get_fwd_combine_grid, and extends input validation for num_splits
flash_sparse_attn/ops/triton/flash_fwd_combine.py New file implementing the combine kernel for merging split attention outputs with stable softmax normalization across splits
flash_sparse_attn/ops/triton/flash_fwd.py Modifies _fwd_base_kernel and forward functions to support split KV mechanism, including stride calculations, tensor allocation for partial outputs, and integration with the combine kernel

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +188 to +219
if arch == "cuda:sm80":
return [
triton.Config(
{"TILE_M": 32, "TILE_N": 128},
num_warps=4,
num_stages=1,
)
]
elif arch == "cuda:sm90":
return [
triton.Config(
{"TILE_M": 32, "TILE_N": 128},
num_warps=4,
num_stages=1,
)
]
elif arch == "cuda:sm100":
return [
triton.Config(
{"TILE_M": 32, "TILE_N": 128},
num_warps=4,
num_stages=1,
)
]
elif arch == "cuda:sm120":
return [
triton.Config(
{"TILE_M": 32, "TILE_N": 128},
num_warps=4,
num_stages=1,
)
]
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The non-autotune configurations specify "TILE_N" in the config dictionary (lines 191, 199, 207, 215), but the autotune configurations use "TILE_K" (line 237). Since the combine kernel only uses TILE_M and TILE_K parameters (as seen in the kernel signature), the TILE_N in the non-autotune configs should be renamed to TILE_K for consistency. This mismatch could cause the kernel to fail when autotune is disabled.

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 6ec02a5 into main Feb 27, 2026
@LoserCheems LoserCheems deleted the optime-triton-kernels branch March 1, 2026 03:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants