Skip to content

Commit 7395ca9

Browse files
heyuhhhlfr-0531
andauthored
[None][doc] Add Sparse Attention feature doc (#9648)
Signed-off-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> Co-authored-by: Fanrong Li <[email protected]>
1 parent c059e6c commit 7395ca9

File tree

3 files changed

+255
-0
lines changed

3 files changed

+255
-0
lines changed
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
# Sparse Attention
2+
3+
- [Background and Motivation](#background-and-motivation)
4+
- [Quick Start](#quick-start)
5+
- [Python API](#python-api)
6+
- [Usage with trtllm-bench or trtllm-serve](#usage-with-trtllm-bench-or-trtllm-serve)
7+
- [Developer Guide](#developer-guide)
8+
- [Architecture Overview](#architecture-overview)
9+
- [Framework Implementation](#framework-implementation)
10+
- [Implementing a New Algorithm](#implementing-a-new-algorithm)
11+
- [1. Configuration Class](#1-configuration-class)
12+
- [2. Implement the prediction module in Attention Backend](#2-implement-the-prediction-module-in-attention-backend)
13+
- [3. Manage Auxiliary Memory Pool](#3-manage-auxiliary-memory-pool)
14+
- [4. Registration and Dispatch](#4-registration-and-dispatch)
15+
- [Summary and Future Work](#summary-and-future-work)
16+
- [Current Status](#current-status)
17+
- [Future Work](#future-work)
18+
19+
## Background and Motivation
20+
21+
As Large Language Models (LLMs) are applied to increasingly complex tasks such as long-document summarization, code generation, and autonomous agents, the demand for processing long contexts and extended generation has surged. In Transformer-based models, the attention mechanism's computational complexity and memory usage grow quadratically and linearly with sequence length, respectively. This creates significant bottlenecks in both the **Context (Prefill)** and **Generation (Decode)** phases:
22+
23+
* **Context Phase**: Processing long prompts requires substantial memory bandwidth and computation, affecting time-to-first-token (TTFT). Since the context phase is typically compute-bound, reducing the computational load here is critical.
24+
* **Generation Phase**: The Key-Value (KV) cache grows with every generated token, consuming vast amounts of GPU memory and bandwidth. Since the generation phase is usually memory-bound, reducing the memory footprint directly alleviates memory pressure, improves token-to-token latency (TPOT), and allows for larger batch sizes.
25+
26+
Fortunately, key observations indicate that attention scores naturally exhibit sparsity, meaning not all K/V tokens are necessary for attention computation. To enhance the efficiency of long-sequence LLMs, numerous methods have been proposed to optimize performance by leveraging approximate sparse attention. Among those methods, sparsity can be applied to different dimensions of the attention: head dimension, hidden dimension, and sequence dimension. When applying sparsity to the sequence dimension, those methods selectively compute only the most important query-key pairs. This approach can be referred to as token sparsity. Token sparsity has been widely explored in lots of recent academic works, and it is also a kind of structured sparse method that is friendly for GPU. Currently, TensorRT LLM focuses on the sparse attention methods that leverages token sparsity.
27+
28+
Token sparsity can be applied to two distinct aspects of LLM inference:
29+
* **Sparse Computation**: If a query token does not require the entire history, just skip the computation for irrelevant tokens, thereby reducing attention computational costs.
30+
* **Sparse KV cache**: Evicts KV tokens from the cache that are not required for future generation steps. This reduces GPU memory usage and lowers computation overhead for subsequent steps.
31+
32+
Both methods can be enabled simultaneously to achieve better performance.
33+
34+
To support these emerging techniques, TensorRT LLM has designed a general, extensible and flexible **sparse attention framework** (which is continuously being optimized) to compatibly integrate advanced sparse algorithms. Currently we can support [RocketKV](https://arxiv.org/pdf/2502.14051) and [DSA](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/DeepSeek_V3_2.pdf).
35+
36+
## Quick Start
37+
38+
This section provides a brief guide on enabling sparse attention in TensorRT LLM, using RocketKV as an example. For more details, please refer to [RocketKV sparse attention](../../examples/sparse_attention/RocketKV.md).
39+
40+
### Python API
41+
42+
To use sparse attention, you need to configure a specific `SparseAttentionConfig` (for example, `RocketSparseAttentionConfig`) and pass it to the `LLM` constructor.
43+
44+
```python
45+
from tensorrt_llm import LLM, SamplingParams
46+
from tensorrt_llm.llmapi import RocketSparseAttentionConfig, KvCacheConfig
47+
48+
# 1. Configure Sparse Attention
49+
# Example: RocketKV configuration
50+
rocket_config = RocketSparseAttentionConfig(
51+
prompt_budget=2048,
52+
kt_cache_dtype='float8_e5m2'
53+
)
54+
55+
# 2. Configure KV Cache
56+
# Note: Some sparse algorithms (like RocketKV) may require disabling block reuse
57+
kv_config = KvCacheConfig(enable_block_reuse=False)
58+
59+
# 3. Initialize LLM
60+
llm = LLM(
61+
model="<path_to_model>",
62+
backend='pytorch', # Currently requires the PyTorch backend
63+
sparse_attention_config=rocket_config,
64+
kv_cache_config=kv_config,
65+
)
66+
67+
# 4. Generate
68+
prompts = ["To be or not to be..."]
69+
outputs = llm.generate(prompts, SamplingParams(max_tokens=128))
70+
```
71+
72+
### Usage with `trtllm-bench` or `trtllm-serve`
73+
74+
You can enable sparse attention in benchmarking and serving tools by providing a `sparse_attention_config` in an `extra_config.yaml` file.
75+
76+
**extra_config.yaml:**
77+
```yaml
78+
backend: pytorch
79+
attn_backend: TRTLLM
80+
sparse_attention_config: # RocketKV as an example
81+
algorithm: rocket
82+
kt_cache_dtype: float8_e5m2
83+
prompt_budget: 2048
84+
kv_cache_config:
85+
enable_block_reuse: false
86+
enable_chunked_prefill: false
87+
```
88+
89+
Run the command with the config file:
90+
```bash
91+
trtllm-bench/trtllm-serve --model <model_path> --extra_llm_api_options extra_config.yaml ...
92+
```
93+
94+
For example, users can evaluate a model with trtllm-eval on LongBenchV2 task like this:
95+
96+
```bash
97+
trtllm-eval --model <path_to_model> --extra_llm_api_options extra_config.yaml longbench_v2 --max_output_length 1024 ...
98+
```
99+
100+
## Developer Guide
101+
102+
This section describes the sparse attention framework architecture and guides developers on how to implement new sparse attention algorithms in TensorRT LLM. Unless otherwise specified, this framework primarily targets **MQA/GQA/MLA-based** attention mechanisms.
103+
104+
### Architecture Overview
105+
106+
<div align="center">
107+
<figure>
108+
<img src="https://github.com/NVIDIA/TensorRT-LLM/raw/main/docs/source/media/sparse_attention_framework.png" width="800">
109+
</figure>
110+
</div>
111+
<p align="center"><sub><em>Figure 1: The sparse attention framework in TensorRT LLM.</em></sub></p>
112+
113+
Our goal is to design a general, extensible, and flexible sparse attention framework. In this framework, the attention operator provides the unified APIs to support both **sparse computation** and **sparse KV cache** that leverage token sparsity, while the users/developers can only focus on the algorithm of sparse attentions, i.e. how to accurately identify important query-key pairs.
114+
115+
For the generality, TensorRT LLM abstracts sparse attention into a prediction-based workflow: *a prediction module first identifies the sparse indices (tokens/blocks to keep or attend to), which are then used by the subsequent attention operator*. Currently, for standard attention (MQA/GQA), TensorRT LLM supports **sparse KV cache** in the context phase and **sparse computation** in the generation phase. Different KV heads are allowed to use different sparse indices, while Q heads that map to the same KV head share the same sparse pattern. It does **not** yet support sparse computation in the context phase or sparse KV cache in the generation phase.
116+
117+
For the scalability, figure 1 illustrates the overall design. The architecture is built by inheriting from the existing `AttentionBackend` to define algorithm-specific sparse attention backends. Within these backends, `prediction` methods are implemented to generate the corresponding sparse indices. These indices are then passed as arguments to the `AttentionOp` to perform the sparse attention computation. This approach balances system flexibility with extensibility, allowing new algorithms to be integrated by simply defining their prediction logic **without** modifying the core attention kernels.
118+
119+
TensorRT LLM currently supports the following features:
120+
121+
1. **Context Phase**:
122+
* **sparse computation**: MLA
123+
* **sparse KV cache**: MQA/GQA
124+
125+
2. **Generation Phase**:
126+
* **sparse computation**: MLA/MQA/GQA
127+
* **sparse KV cache**: no support yet
128+
129+
### Framework Implementation
130+
131+
To hide the complexity of sparse algorithms, the main prediction logic is encapsulated within the `tensorrt_llm._torch.attention_backend` module.
132+
133+
We have extended the existing `AttentionBackend` to include a prediction step that retrieves sparse indices before the attention operation. These indices are generated using two prediction methods:
134+
135+
```python
136+
# Predict indices for sparse KV Cache
137+
sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict(
138+
q, k, metadata, **kwargs)
139+
140+
# Predict indices for sparse computation
141+
sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(
142+
q, k, metadata, **kwargs)
143+
```
144+
145+
The specific prediction logic is hidden in the subclasses, where developers implement `sparse_kv_predict` and `sparse_attn_predict`.
146+
147+
The key files located in `tensorrt_llm/_torch/attention_backend/sparse/` are:
148+
149+
* `rocket.py`, `dsa.py`: Implementations of specific algorithms (e.g., RocketKV, DSA).
150+
* `kernel.py`: Custom Triton kernels for importance scoring or selection.
151+
* `utils.py`: Dispatch related logic.
152+
153+
<div align="center">
154+
<figure>
155+
<img src="https://github.com/NVIDIA/TensorRT-LLM/raw/main/docs/source/media/sparse_attention_op.png" width="800">
156+
</figure>
157+
</div>
158+
<p align="center"><sub><em>Figure 2: Sparse attention operator workflow in TensorRT LLM.</em></sub></p>
159+
160+
In `AttentionOp`, currently, the MQA/GQA sparse attention only supports sparse computation at block granularity in the generation phase, where the block size equals to the page size of the KV cache. It means that we can skip the attention computation of those unimportant pages. In addition, we provide a sparse MLA kernel that supports token-level sparse computation in both the context and generation phases.
161+
162+
To support those features, as illustrated in figure 2, we have implemented two kernels for the MQA/GQA path, `updateSparseKvCacheAfterFmha` and `gatherKvPageOffsetsKernel`, applied in the context and generation phases respectively:
163+
164+
* **`updateSparseKvCacheAfterFmha`**: Invoked in the post-processing stage after the context attention computation. It selects the important KV tokens and write those K/V vectors to the KV cache to reduce the KV cache size.
165+
166+
* **`gatherKvPageOffsetsKernel`**: Executed before the attention computation in the generation phase. It converts the input sparse indices (which can be of arbitrary granularity) into page-aligned indices. This means that if a single token is selected, the entire page it is included in the attention computation. After this conversion, we will get a new `kv_page_offsets` and also an updated `kv_len` that is the number of those selected KV tokens. Then these new metadata are fed into the subsequent attention kernel for computation.
167+
168+
For sparse MLA, the kernel supports token sparsity directly, eliminating the need for `gatherKvPageOffsetsKernel`. However, please note that sparse KV cache support is not yet available.
169+
170+
Many sparse attention algorithms also require additional auxiliary memory. In the current system, there are two paths to support this feature:
171+
172+
* Implement a simple, custom CacheManager at the Python level, inheriting from `KVCacheManager`.
173+
174+
* Use `KVCacheManagerCpp` to simultaneously manage both the KV Cache and auxiliary memory.
175+
176+
Each option has its own advantages and disadvantages, please refer to the [Manage Auxiliary Memory Pool](#3-manage-auxiliary-memory-pool) for more details.
177+
178+
### Implementing a New Algorithm
179+
180+
#### 1. Configuration Class
181+
182+
Define a configuration class in `tensorrt_llm/llmapi/llm_args.py` inheriting from `BaseSparseAttentionConfig`. This class should hold user-tunable parameters for your algorithm.
183+
184+
```python
185+
@dataclass
186+
class MySparseAttentionConfig(BaseSparseAttentionConfig):
187+
topk: int = 64
188+
# ... other parameters
189+
```
190+
191+
#### 2. Implement the prediction module in Attention Backend
192+
193+
Create a new class inheriting from `TrtllmAttention` (in `tensorrt_llm/_torch/attention_backend/trtllm.py`). You typically need to override two main prediction methods:
194+
195+
**`sparse_kv_predict(self, q, k, metadata, **kwargs)`**
196+
* **Behavior**: This function performs prediction to return the indices of tokens to be preserved in the KV cache.
197+
* **Output**:
198+
- `sparse_kv_indices`: The token indices of the important tokens on sequence dimension, shape `(nHeads, nTokens)`, where `nHeads` is the number of KV heads and `nTokens` is the total number of selected tokens across all samples in the batch.
199+
- `sparse_kv_offsets`: The offset for the `sparse_kv_indices`, shape `(nBatch + 1)`, where `nBatch` is the number of the batch size. The index for head `h` and sample `n` can be obtained via `sparse_kv_indices[h, sparse_kv_offsets[n]]`.
200+
* **Constraint**: Returned indices must be **sorted** to ensure safe in-place gathering in memory. Note that this post-processing "gather" step introduces some overhead, but significantly improves flexibility, allowing compatibility with features in context like chunked prefill.
201+
202+
**`sparse_attn_predict(self, q, k, metadata, **kwargs)`**
203+
* **Behavior**: For the current query tokens, predict and return the sparse indices for sparse computation.
204+
* **Output**:
205+
- `sparse_attn_indices`: The block indices of the block sparse attention on the KV sequence dimension, shape `(nHeads, nBlocks)`, where `nHeads` is the number of KV heads and `nBlocks` is the total number of selected blocks across all samples in the batch. For block sparse attention, the block size is defined by `sparse_attn_indices_block_size`, which supports arbitrary values.
206+
- `sparse_attn_offsets`: The offset for the `sparse_attn_indices`, shape `(nBatch + 1)`, where `nBatch` is the number of the batch size. The index for head `h` and sample `n` can be obtained via `sparse_attn_indices[h, sparse_attn_offsets[n]]`.
207+
* **Constraint**: The generation phase sparse computation is supported for NVIDIA Blackwell GPUs and newer (SM 100+) using TRTLLM-GEN kernels. However, it is flexible enough to extend to different architectures. Currently, only KV cache's **page-level** granularity is supported for sparse computation.
208+
209+
**Note**: The prediction process can be time-consuming, especially in low-latency scenarios where it might account for a significant portion of the attention time. It is highly recommended to optimize this step using custom kernels.
210+
211+
#### 3. Manage Auxiliary Memory Pool
212+
213+
Many sparse algorithms (like RocketKV or DSA) require auxiliary structures (e.g., a "KT cache" in RocketKV) to select relevant tokens. There are two primary ways to manage this memory in TensorRT LLM:
214+
215+
**Option A: Python-level Custom Manager**
216+
217+
You can implement a custom manager in Python.
218+
* **Use Case**: Algorithms like RocketKV use this approach to store the KT cache (e.g., `RocketKVCacheManager` in `rocket.py`).
219+
* **Implementation**: Create a Python level cache manager that handles the allocation and lifecycle of the auxiliary tensors. It is recommended to use the existing `BlockManager` to manage the auxiliary pools if possible. This allows the auxiliary pool to share block manager logics, reducing implementation overhead.
220+
* **Key Methods to Override**:
221+
* `get_cache_size_per_token` / `get_cache_bytes_per_token`: Update `kv_factor` correctly to include the size of the auxiliary structures so TensorRT LLM allocates sufficient GPU memory.
222+
* `add_dummy_requests` / `prepare_resources`: Ensure the auxiliary pool allocates correct resources/tokens for new requests.
223+
* **Pros**: The custom cache manager is more flexible and easier to implement because it can share the same blocks managed by the `KVCacheManager`.
224+
* **Cons**: This approach operates at the Python level, making it difficult to share features of the KV cache managed at the C++ level (e.g., advanced transmission or kvcache reuse features tied to the C++ manager).
225+
226+
**Option B: C++ Integrated Manager**
227+
228+
For tighter integration, you can manage the auxiliary memory within the C++ `KVCacheManager`.
229+
* **Use Case**: Algorithms like DSA use this approach to store the indexer Kcache.
230+
* **Pros**: Enables compatibility with advanced features such as KV cache reuse and disagg-serving. For example, DSA's low-rank indexer Kcache can be reused or transmitted between context and generation engines.
231+
* **Cons**: Higher implementation complexity. The current C++ `KVCacheManager` is optimized for the standard KV cache pool. Adding custom pools often requires significant modifications or manual implementation of the pool management logic within the C++ level.
232+
233+
**Note**: If your algorithm involves sparse KV cache, standard KV cache block reuse is generally incompatible because eviction modifies the block content uniquely for each request. However, algorithms like DSA that use low-rank approximation without eviction can support block reuse.
234+
235+
#### 4. Registration and Dispatch
236+
237+
* Register your config and backend in `tensorrt_llm/_torch/attention_backend/sparse/utils.py` and `tensorrt_llm/_torch/pyexecutor/_util.py` to ensure the system routes the request to your new backend when the config is present.
238+
* Add initialization logic in `cpp/tensorrt_llm/thop/attentionOp.cpp` and `cpp/tensorrt_llm/kernels/sparseAttentionKernels.h` if new C++ level parameters are required.
239+
240+
## Summary and Future Work
241+
242+
### Current Status
243+
244+
Currently, the status of the sparse attention framework is as follows:
245+
246+
1. **Supported Operations**: The `AttentionOp` currently supports **sparse KV cache** in the context phase and **sparse computation** in the generation phase. Other combinations (for example, sparse computation in the context phase) are not yet supported for MQA/GQA. For MLA, sparse computation is supported in both the context and generation phases.
247+
2. **Algorithm Support**: RocketKV is supported in both the vanilla (PyTorch) backend and the TRTLLM backend, while DSA is supported in the TRTLLM backend. These implementations validate the generality and scalability of the framework.
248+
249+
### Future Work
250+
251+
* **Sparse Computation in Context Phase**: We plan to introduce sparse computation support for the context phase for MQA/GQA, allowing the TensorRT LLM sparse attention framework to cover more scenarios.
252+
* **Dynamic Eviction in Generation Phase**: Dynamically evicting KV cache blocks during the generation phase poses significant challenges to KV cache flexibility. While difficult to implement in the current framework, block-level eviction appears to be a promising compromise and is under further exploration.
253+
* **Unified Auxiliary Memory Management**: We are exploring a unified mechanism to manage auxiliary memory pools. This would allow users to define custom auxiliary spaces more flexibly while automatically inheriting advanced features from the KV cache, such as reuse and offloading.
254+
* **Code Refactoring**: As more sparse attention algorithms are integrated, the framework will undergo refactoring to unify code and improve maintainability.
255+
* **Optimizations**: We are discussing further optimizations, such as improving DSA performance.
291 KB
Loading
253 KB
Loading

0 commit comments

Comments
 (0)