Skip to content

Commit 2ab2a02

Browse files
committed
add hf unified ckpt export for sparse attention
Signed-off-by: Kai Xu <[email protected]>
1 parent 4fc1bd3 commit 2ab2a02

File tree

9 files changed

+938
-11
lines changed

9 files changed

+938
-11
lines changed
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Sparse Attention for Large Language Models
2+
3+
This example demonstrates how to apply sparse attention optimization to Large Language Models (LLMs) using TensorRT-Model-Optimizer's attention sparsity module.
4+
5+
## Overview
6+
7+
Sparse attention reduces the computational complexity of attention mechanisms by selectively computing only the most important attention scores. This can significantly speed up inference and reduce memory usage, especially for long sequences.
8+
9+
## Features
10+
11+
- **Sparse Attention Method**:
12+
- Softmax Skip: Threshold-based masking for efficient attention computation
13+
- Extensible architecture: Easy to add new sparse attention methods in the future
14+
- **Calibration Support**: Automatically find optimal sparsity parameters
15+
- **HuggingFace Integration**: Works with any HuggingFace transformer model
16+
- **Composable**: Can be combined with quantization and other optimizations
17+
18+
## Installation
19+
20+
```bash
21+
pip install nvidia-modelopt transformers torch
22+
```
23+
24+
## Quick Start
25+
26+
```python
27+
import modelopt.torch.sparsity as mts # Similar to mtq for quantization
28+
from transformers import AutoModelForCausalLM
29+
30+
# Load model
31+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-8B")
32+
33+
# Define sparse attention config
34+
sparse_config = {
35+
"method": "softmax_skip",
36+
"sparse_cfg": {
37+
"*attn*": {"threshold": 1e-4, "enable": True},
38+
"default": {"enable": False}
39+
}
40+
}
41+
42+
# Apply sparse attention
43+
sparse_model = mts.attention_sparsity.sparsify(model, config=sparse_config)
44+
45+
# Use the model as usual
46+
output = sparse_model.generate(input_ids, max_new_tokens=100)
47+
```
48+
49+
### Command Line Usage
50+
51+
The `hf_spar_attn.py` script supports two modes:
52+
53+
```bash
54+
# Basic mode: Apply sparse attention and test generation
55+
python hf_spar_attn.py --mode basic --model_name Qwen/Qwen3-8B --show_stats
56+
57+
# Export mode: Export to unified HuggingFace checkpoint
58+
python hf_spar_attn.py --mode export --model_name Qwen/Qwen3-8B --export_dir ./sparse_model
59+
```
60+
61+
## Examples
62+
63+
### 1. Basic Usage (--mode basic)
64+
65+
Apply sparse attention to a model and test generation quality:
66+
67+
```bash
68+
python hf_spar_attn.py --mode basic \
69+
--model_name Qwen/Qwen3-8B \
70+
--sparse_attn skip_softmax \
71+
--show_stats \
72+
--benchmark
73+
```
74+
75+
Options for basic mode:
76+
77+
- `--show_stats`: Display sparsity statistics for each attention layer
78+
- `--benchmark`: Compare performance before and after sparse attention
79+
- `--show_memory`: Display GPU memory usage
80+
81+
### 2. Export Model (--mode export)
82+
83+
Export the sparse attention model to unified HuggingFace checkpoint format for deployment:
84+
85+
```bash
86+
python hf_spar_attn.py --mode export \
87+
--model_name Qwen/Qwen3-8B \
88+
--sparse_attn skip_softmax \
89+
--export_dir ./sparse_model
90+
```
91+
92+
The exported model will contain:
93+
94+
- Model weights with sparse attention applied
95+
- `config.json` with `sparse_attention_config` section
96+
- Tokenizer files
97+
98+
### Exported Config Format
99+
100+
The `config.json` includes a `sparse_attention_config` section using the `config_groups` pattern (similar to `quantization_config`):
101+
102+
**For calibrated models:**
103+
104+
```json
105+
{
106+
"sparse_attention_config": {
107+
"config_groups": {
108+
"group_0": {
109+
"sparse_algo": "softmax_skip",
110+
"targets": ["LlamaAttention"]
111+
}
112+
},
113+
"threshold_scale_factor": 437.7,
114+
"target_sparsity": 0.3,
115+
"producer": {
116+
"name": "modelopt",
117+
"version": "0.37.0"
118+
}
119+
}
120+
}
121+
```
122+
123+
**For non-calibrated models:**
124+
125+
```json
126+
{
127+
"sparse_attention_config": {
128+
"config_groups": {
129+
"group_0": {
130+
"sparse_algo": "softmax_skip",
131+
"threshold": 0.0001,
132+
"targets": ["LlamaAttention"]
133+
}
134+
},
135+
"producer": {
136+
"name": "modelopt",
137+
"version": "0.37.0"
138+
}
139+
}
140+
}
141+
```
142+
143+
This format enables inference engines to reconstruct the sparse attention configuration from the checkpoint.
144+
145+
## Configuration Options
146+
147+
### Pre-defined Configuration
148+
149+
ModelOpt provides a unified configuration that supports both simple and phase-aware thresholds:
150+
151+
```python
152+
import modelopt.torch.sparsity as mts
153+
154+
# The default config supports phase-aware thresholds
155+
SOFTMAX_SKIP_CFG = {
156+
"method": "softmax_skip",
157+
"sparse_cfg": {
158+
"*attn*": {
159+
"threshold": {
160+
"prefill": 1e-3, # More aggressive during prefill
161+
"decode": 1e-5, # Conservative during decode
162+
},
163+
"enable": True,
164+
},
165+
"default": {"enable": False},
166+
},
167+
}
168+
169+
# Use the config
170+
model = mts.attention_sparsity.sparsify(model, config=SOFTMAX_SKIP_CFG)
171+
```
172+
173+
### Custom Configuration
174+
175+
You can create custom configurations with simple or phase-aware thresholds:
176+
177+
```python
178+
# Simple threshold (same for all phases)
179+
simple_config = {
180+
"method": "softmax_skip",
181+
"sparse_cfg": {
182+
"*attn*": {
183+
"threshold": 1e-4, # Single threshold for all phases
184+
"enable": True,
185+
},
186+
"default": {"enable": False},
187+
}
188+
}
189+
190+
# Phase-aware threshold
191+
phase_aware_config = {
192+
"method": "softmax_skip",
193+
"sparse_cfg": {
194+
"*attn*": {
195+
"threshold": {
196+
"prefill": 1e-3, # Prefill phase
197+
"decode": 1e-5, # Decode phase
198+
},
199+
"enable": True,
200+
},
201+
"default": {"enable": False},
202+
}
203+
}
204+
```
205+
206+
### Adding Custom Methods
207+
208+
The architecture is designed to easily support new sparse attention methods. Refer to [`FlashSoftmaxSkipMethod`](../../modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py) source code for implementing custom methods.
209+
210+
### Pattern-Based Configuration
211+
212+
Apply different configurations to different layers:
213+
214+
```python
215+
config = {
216+
"sparse_cfg": {
217+
"*layers.[0-12].*attention*": {"enable": True, "threshold": 1e-3}, # More aggressive for early layers
218+
"*layers.[13-24].*attention*": {"enable": True, "threshold": 1e-4}, # Conservative for later layers
219+
}
220+
}
221+
```
222+
223+
## Performance Considerations
224+
225+
1. **Threshold Tuning**:
226+
- Lower thresholds (e.g., 1e-5) preserve more accuracy but less sparsity
227+
- Higher thresholds (e.g., 1e-3) provide more sparsity but may impact accuracy
228+
- Use calibration to find optimal values
229+
230+
2. **Memory Usage**:
231+
- Sparse attention reduces peak memory usage during inference
232+
- Especially beneficial for long sequences (>1024 tokens)
233+
234+
3. **Model Compatibility**:
235+
- Works best with models using eager attention implementation
236+
- Compatible with all HuggingFace transformer models
237+
238+
## References
239+
240+
- [TensorRT-Model-Optimizer Documentation](https://nvidia.github.io/TensorRT-Model-Optimizer)
241+
- [Sparse Attention Papers Collection](https://github.com/topics/sparse-attention)

modelopt/torch/export/unified_export_hf.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,129 @@ def _export_quantized_weight(
337337
sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale)
338338

339339

340+
def _get_sparse_attention_config(model: nn.Module) -> dict[str, Any]:
341+
"""Extract sparse attention configuration from model for export.
342+
343+
Args:
344+
model: Model with sparse attention modules
345+
346+
Returns:
347+
Dictionary with sparse attention config in format:
348+
{
349+
"config_groups": {
350+
"group_0": {
351+
"sparse_algo": "softmax_skip",
352+
"threshold": 1e-4, # only if not calibrated
353+
"targets": ["LlamaAttention"]
354+
}
355+
},
356+
"threshold_scale_factor": 0.001234, # global, if calibrated
357+
"target_sparsity": 0.5, # global, if calibrated
358+
"producer": {"name": "modelopt", "version": "..."}
359+
}
360+
"""
361+
from modelopt import __version__
362+
from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule
363+
364+
# Collect all enabled sparse attention modules
365+
sparse_modules = []
366+
for name, module in model.named_modules():
367+
if isinstance(module, SparseAttentionModule) and module.is_enabled:
368+
sparse_modules.append((name, module))
369+
370+
if not sparse_modules:
371+
return {}
372+
373+
sparse_config = {
374+
"config_groups": {},
375+
"producer": {
376+
"name": "modelopt",
377+
"version": __version__,
378+
},
379+
}
380+
381+
# Check first module for global calibration parameters
382+
# (all modules share the same calibration parameters)
383+
first_module = sparse_modules[0][1]
384+
method_instance = first_module._sparse_method_instance
385+
threshold_scale_factor = getattr(method_instance, "threshold_scale_factor", None)
386+
387+
if threshold_scale_factor is not None:
388+
# Model was calibrated: add global calibration parameters
389+
sparse_config["threshold_scale_factor"] = float(threshold_scale_factor)
390+
391+
target_sparsity = getattr(method_instance, "target_sparsity", None)
392+
if target_sparsity is not None:
393+
sparse_config["target_sparsity"] = float(target_sparsity)
394+
395+
# Group modules by configuration
396+
# Key: (sparse_algo, threshold_repr), Value: list of module class names
397+
config_to_targets = {}
398+
399+
for name, module in sparse_modules:
400+
method_instance = module._sparse_method_instance
401+
402+
# Extract sparse algorithm name from method name
403+
# e.g., "flash_softmax_skip" -> "softmax_skip"
404+
method_name = method_instance.name
405+
if method_name.startswith("flash_"):
406+
sparse_algo = method_name[6:] # Remove "flash_" prefix
407+
else:
408+
sparse_algo = method_name
409+
410+
# Get module's original class name for targets
411+
# Get the class name before SparseAttentionModule wrapping
412+
original_cls = module.get_original_cls_by_level(level=0)
413+
target_class_name = original_cls.__name__
414+
415+
# Build config key for grouping
416+
if threshold_scale_factor is None:
417+
# Not calibrated: include threshold in grouping
418+
threshold_config = getattr(method_instance, "threshold_config", None)
419+
if isinstance(threshold_config, dict):
420+
# Convert dict to tuple for hashable key
421+
threshold_repr = tuple(sorted(threshold_config.items()))
422+
else:
423+
threshold_repr = threshold_config
424+
else:
425+
# Calibrated: no threshold in per-layer config
426+
threshold_repr = None
427+
428+
config_key = (sparse_algo, threshold_repr)
429+
430+
if config_key not in config_to_targets:
431+
config_to_targets[config_key] = {
432+
"sparse_algo": sparse_algo,
433+
"threshold_config": threshold_config if threshold_scale_factor is None else None,
434+
"targets": set(),
435+
}
436+
437+
config_to_targets[config_key]["targets"].add(target_class_name)
438+
439+
# Convert grouped configs to config_groups format
440+
for group_idx, ((sparse_algo, threshold_repr), group_data) in enumerate(
441+
config_to_targets.items()
442+
):
443+
group_name = f"group_{group_idx}"
444+
group_config = {
445+
"sparse_algo": group_data["sparse_algo"],
446+
"targets": sorted(group_data["targets"]),
447+
}
448+
449+
# Add threshold only if not calibrated
450+
if group_data["threshold_config"] is not None:
451+
threshold_config = group_data["threshold_config"]
452+
if isinstance(threshold_config, dict):
453+
# Convert to JSON-serializable format
454+
group_config["threshold"] = {k: float(v) for k, v in threshold_config.items()}
455+
else:
456+
group_config["threshold"] = float(threshold_config)
457+
458+
sparse_config["config_groups"][group_name] = group_config
459+
460+
return sparse_config
461+
462+
340463
def _export_hf_checkpoint(
341464
model: nn.Module, dtype: torch.dtype | None = None
342465
) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -538,6 +661,11 @@ def export_hf_checkpoint(
538661

539662
config_data = set_config_if_spec_decoding(model, config_data)
540663

664+
# Add sparse attention config if model has sparse attention
665+
sparse_attention_config = _get_sparse_attention_config(model)
666+
if sparse_attention_config:
667+
config_data["sparse_attention_config"] = sparse_attention_config
668+
541669
with open(original_config, "w") as file:
542670
json.dump(config_data, file, indent=4)
543671

modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,5 +172,6 @@ def calibrate_sparse_attention(
172172

173173
for module_name, module in sparse_modules:
174174
module._sparse_method_instance.threshold_scale_factor = scale_factor
175+
module._sparse_method_instance.target_sparsity = calib_config.target_sparse_ratio
175176

176177
return {"calibration_results": {name: calibration_result for name, _ in sparse_modules}}

0 commit comments

Comments
 (0)