|
| 1 | +# Sparse Attention for Large Language Models |
| 2 | + |
| 3 | +This example demonstrates how to apply sparse attention optimization to Large Language Models (LLMs) using TensorRT-Model-Optimizer's attention sparsity module. |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +Sparse attention reduces the computational complexity of attention mechanisms by selectively computing only the most important attention scores. This can significantly speed up inference and reduce memory usage, especially for long sequences. |
| 8 | + |
| 9 | +## Features |
| 10 | + |
| 11 | +- **Sparse Attention Method**: |
| 12 | + - Softmax Skip: Threshold-based masking for efficient attention computation |
| 13 | + - Extensible architecture: Easy to add new sparse attention methods in the future |
| 14 | +- **Calibration Support**: Automatically find optimal sparsity parameters |
| 15 | +- **HuggingFace Integration**: Works with any HuggingFace transformer model |
| 16 | +- **Composable**: Can be combined with quantization and other optimizations |
| 17 | + |
| 18 | +## Installation |
| 19 | + |
| 20 | +```bash |
| 21 | +pip install nvidia-modelopt transformers torch |
| 22 | +``` |
| 23 | + |
| 24 | +## Quick Start |
| 25 | + |
| 26 | +```python |
| 27 | +import modelopt.torch.sparsity as mts # Similar to mtq for quantization |
| 28 | +from transformers import AutoModelForCausalLM |
| 29 | + |
| 30 | +# Load model |
| 31 | +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-8B") |
| 32 | + |
| 33 | +# Define sparse attention config |
| 34 | +sparse_config = { |
| 35 | + "method": "softmax_skip", |
| 36 | + "sparse_cfg": { |
| 37 | + "*attn*": {"threshold": 1e-4, "enable": True}, |
| 38 | + "default": {"enable": False} |
| 39 | + } |
| 40 | +} |
| 41 | + |
| 42 | +# Apply sparse attention |
| 43 | +sparse_model = mts.attention_sparsity.sparsify(model, config=sparse_config) |
| 44 | + |
| 45 | +# Use the model as usual |
| 46 | +output = sparse_model.generate(input_ids, max_new_tokens=100) |
| 47 | +``` |
| 48 | + |
| 49 | +### Command Line Usage |
| 50 | + |
| 51 | +The `hf_spar_attn.py` script supports two modes: |
| 52 | + |
| 53 | +```bash |
| 54 | +# Basic mode: Apply sparse attention and test generation |
| 55 | +python hf_spar_attn.py --mode basic --model_name Qwen/Qwen3-8B --show_stats |
| 56 | + |
| 57 | +# Export mode: Export to unified HuggingFace checkpoint |
| 58 | +python hf_spar_attn.py --mode export --model_name Qwen/Qwen3-8B --export_dir ./sparse_model |
| 59 | +``` |
| 60 | + |
| 61 | +## Examples |
| 62 | + |
| 63 | +### 1. Basic Usage (--mode basic) |
| 64 | + |
| 65 | +Apply sparse attention to a model and test generation quality: |
| 66 | + |
| 67 | +```bash |
| 68 | +python hf_spar_attn.py --mode basic \ |
| 69 | + --model_name Qwen/Qwen3-8B \ |
| 70 | + --sparse_attn skip_softmax \ |
| 71 | + --show_stats \ |
| 72 | + --benchmark |
| 73 | +``` |
| 74 | + |
| 75 | +Options for basic mode: |
| 76 | + |
| 77 | +- `--show_stats`: Display sparsity statistics for each attention layer |
| 78 | +- `--benchmark`: Compare performance before and after sparse attention |
| 79 | +- `--show_memory`: Display GPU memory usage |
| 80 | + |
| 81 | +### 2. Export Model (--mode export) |
| 82 | + |
| 83 | +Export the sparse attention model to unified HuggingFace checkpoint format for deployment: |
| 84 | + |
| 85 | +```bash |
| 86 | +python hf_spar_attn.py --mode export \ |
| 87 | + --model_name Qwen/Qwen3-8B \ |
| 88 | + --sparse_attn skip_softmax \ |
| 89 | + --export_dir ./sparse_model |
| 90 | +``` |
| 91 | + |
| 92 | +The exported model will contain: |
| 93 | + |
| 94 | +- Model weights with sparse attention applied |
| 95 | +- `config.json` with `sparse_attention_config` section |
| 96 | +- Tokenizer files |
| 97 | + |
| 98 | +### Exported Config Format |
| 99 | + |
| 100 | +The `config.json` includes a `sparse_attention_config` section using the `config_groups` pattern (similar to `quantization_config`): |
| 101 | + |
| 102 | +**For calibrated models:** |
| 103 | + |
| 104 | +```json |
| 105 | +{ |
| 106 | + "sparse_attention_config": { |
| 107 | + "config_groups": { |
| 108 | + "group_0": { |
| 109 | + "sparse_algo": "softmax_skip", |
| 110 | + "targets": ["LlamaAttention"] |
| 111 | + } |
| 112 | + }, |
| 113 | + "threshold_scale_factor": 437.7, |
| 114 | + "target_sparsity": 0.3, |
| 115 | + "producer": { |
| 116 | + "name": "modelopt", |
| 117 | + "version": "0.37.0" |
| 118 | + } |
| 119 | + } |
| 120 | +} |
| 121 | +``` |
| 122 | + |
| 123 | +**For non-calibrated models:** |
| 124 | + |
| 125 | +```json |
| 126 | +{ |
| 127 | + "sparse_attention_config": { |
| 128 | + "config_groups": { |
| 129 | + "group_0": { |
| 130 | + "sparse_algo": "softmax_skip", |
| 131 | + "threshold": 0.0001, |
| 132 | + "targets": ["LlamaAttention"] |
| 133 | + } |
| 134 | + }, |
| 135 | + "producer": { |
| 136 | + "name": "modelopt", |
| 137 | + "version": "0.37.0" |
| 138 | + } |
| 139 | + } |
| 140 | +} |
| 141 | +``` |
| 142 | + |
| 143 | +This format enables inference engines to reconstruct the sparse attention configuration from the checkpoint. |
| 144 | + |
| 145 | +## Configuration Options |
| 146 | + |
| 147 | +### Pre-defined Configuration |
| 148 | + |
| 149 | +ModelOpt provides a unified configuration that supports both simple and phase-aware thresholds: |
| 150 | + |
| 151 | +```python |
| 152 | +import modelopt.torch.sparsity as mts |
| 153 | + |
| 154 | +# The default config supports phase-aware thresholds |
| 155 | +SOFTMAX_SKIP_CFG = { |
| 156 | + "method": "softmax_skip", |
| 157 | + "sparse_cfg": { |
| 158 | + "*attn*": { |
| 159 | + "threshold": { |
| 160 | + "prefill": 1e-3, # More aggressive during prefill |
| 161 | + "decode": 1e-5, # Conservative during decode |
| 162 | + }, |
| 163 | + "enable": True, |
| 164 | + }, |
| 165 | + "default": {"enable": False}, |
| 166 | + }, |
| 167 | +} |
| 168 | + |
| 169 | +# Use the config |
| 170 | +model = mts.attention_sparsity.sparsify(model, config=SOFTMAX_SKIP_CFG) |
| 171 | +``` |
| 172 | + |
| 173 | +### Custom Configuration |
| 174 | + |
| 175 | +You can create custom configurations with simple or phase-aware thresholds: |
| 176 | + |
| 177 | +```python |
| 178 | +# Simple threshold (same for all phases) |
| 179 | +simple_config = { |
| 180 | + "method": "softmax_skip", |
| 181 | + "sparse_cfg": { |
| 182 | + "*attn*": { |
| 183 | + "threshold": 1e-4, # Single threshold for all phases |
| 184 | + "enable": True, |
| 185 | + }, |
| 186 | + "default": {"enable": False}, |
| 187 | + } |
| 188 | +} |
| 189 | + |
| 190 | +# Phase-aware threshold |
| 191 | +phase_aware_config = { |
| 192 | + "method": "softmax_skip", |
| 193 | + "sparse_cfg": { |
| 194 | + "*attn*": { |
| 195 | + "threshold": { |
| 196 | + "prefill": 1e-3, # Prefill phase |
| 197 | + "decode": 1e-5, # Decode phase |
| 198 | + }, |
| 199 | + "enable": True, |
| 200 | + }, |
| 201 | + "default": {"enable": False}, |
| 202 | + } |
| 203 | +} |
| 204 | +``` |
| 205 | + |
| 206 | +### Adding Custom Methods |
| 207 | + |
| 208 | +The architecture is designed to easily support new sparse attention methods. Refer to [`FlashSoftmaxSkipMethod`](../../modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py) source code for implementing custom methods. |
| 209 | + |
| 210 | +### Pattern-Based Configuration |
| 211 | + |
| 212 | +Apply different configurations to different layers: |
| 213 | + |
| 214 | +```python |
| 215 | +config = { |
| 216 | + "sparse_cfg": { |
| 217 | + "*layers.[0-12].*attention*": {"enable": True, "threshold": 1e-3}, # More aggressive for early layers |
| 218 | + "*layers.[13-24].*attention*": {"enable": True, "threshold": 1e-4}, # Conservative for later layers |
| 219 | + } |
| 220 | +} |
| 221 | +``` |
| 222 | + |
| 223 | +## Performance Considerations |
| 224 | + |
| 225 | +1. **Threshold Tuning**: |
| 226 | + - Lower thresholds (e.g., 1e-5) preserve more accuracy but less sparsity |
| 227 | + - Higher thresholds (e.g., 1e-3) provide more sparsity but may impact accuracy |
| 228 | + - Use calibration to find optimal values |
| 229 | + |
| 230 | +2. **Memory Usage**: |
| 231 | + - Sparse attention reduces peak memory usage during inference |
| 232 | + - Especially beneficial for long sequences (>1024 tokens) |
| 233 | + |
| 234 | +3. **Model Compatibility**: |
| 235 | + - Works best with models using eager attention implementation |
| 236 | + - Compatible with all HuggingFace transformer models |
| 237 | + |
| 238 | +## References |
| 239 | + |
| 240 | +- [TensorRT-Model-Optimizer Documentation](https://nvidia.github.io/TensorRT-Model-Optimizer) |
| 241 | +- [Sparse Attention Papers Collection](https://github.com/topics/sparse-attention) |
0 commit comments