Skip to content

Commit 9fa8991

Browse files
committed
add initial support for sparse attention
Signed-off-by: Kai Xu <[email protected]>
1 parent 657482e commit 9fa8991

File tree

17 files changed

+3500
-0
lines changed

17 files changed

+3500
-0
lines changed
Lines changed: 380 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,380 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Example script for applying sparse attention to HuggingFace models."""
18+
19+
import argparse
20+
import random
21+
from pathlib import Path
22+
23+
import numpy as np
24+
import torch
25+
import torch.nn as nn
26+
from datasets import load_dataset
27+
from transformers import AutoModelForCausalLM, AutoTokenizer
28+
29+
import modelopt.torch.sparsity.attention_sparsity as mtsa
30+
from modelopt.torch.export import export_hf_checkpoint
31+
from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig
32+
from modelopt.torch.sparsity.attention_sparsity.config import (
33+
SKIP_SOFTMAX_CALIB,
34+
SKIP_SOFTMAX_DEFAULT,
35+
)
36+
from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule
37+
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
38+
39+
RAND_SEED = 1234
40+
41+
# You can define custom configurations or use the default
42+
SPARSE_ATTN_CFG_CHOICES = {
43+
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
44+
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
45+
}
46+
47+
48+
def print_sparsity_stats(model: nn.Module):
49+
"""Print sparsity statistics if available."""
50+
module_stats = []
51+
for name, module in model.named_modules():
52+
if hasattr(module, "get_stats"):
53+
stats = module.get_stats()
54+
if stats and "average_sparsity" in stats:
55+
module_stats.append((name, stats["average_sparsity"]))
56+
57+
if not module_stats:
58+
print("No sparsity statistics available")
59+
return
60+
61+
# Check if all modules have the same sparsity
62+
sparsities = [s for _, s in module_stats]
63+
if len(set(sparsities)) == 1:
64+
# All identical - show summary
65+
print(f"Average sparsity across all {len(module_stats)} modules: {sparsities[0]:.2%}")
66+
else:
67+
# Different sparsities - show individual values
68+
avg_sparsity = sum(sparsities) / len(sparsities)
69+
print(f"Average sparsity: {avg_sparsity:.2%}")
70+
print("Per-module breakdown:")
71+
for name, sparsity in module_stats:
72+
print(f" {name}: {sparsity:.2%} sparse")
73+
74+
75+
def get_narrativeqa_samples(num_samples=3):
76+
"""Load samples from NarrativeQA dataset for testing.
77+
78+
Args:
79+
num_samples: Number of samples to generate
80+
"""
81+
# Load NarrativeQA dataset
82+
dataset = load_dataset("narrativeqa", split="test", streaming=True)
83+
84+
samples = []
85+
for i, item in enumerate(dataset):
86+
if i >= num_samples:
87+
break
88+
89+
# Combine document context and question
90+
context = item.get("document", {}).get("text", "")
91+
question = item.get("question", {}).get("text", "")
92+
93+
if context and question:
94+
# Use the full context as-is
95+
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
96+
samples.append(prompt)
97+
98+
if not samples:
99+
raise ValueError("Could not load NarrativeQA samples")
100+
101+
print(f"Loaded {len(samples)} NarrativeQA samples")
102+
return samples
103+
104+
105+
def truncate_text(text: str, tokenizer, max_length: int):
106+
"""Truncate text from the middle to preserve beginning and end.
107+
108+
Args:
109+
text: Input text to truncate
110+
tokenizer: Tokenizer to use for encoding
111+
max_length: Maximum number of tokens
112+
113+
Returns:
114+
Truncated text that fits within max_length tokens
115+
"""
116+
# First tokenize to see if truncation is needed
117+
tokens = tokenizer.encode(text, add_special_tokens=True)
118+
119+
if len(tokens) <= max_length:
120+
return text
121+
122+
# Need to truncate - preserve beginning and end
123+
# Reserve some tokens for special tokens
124+
available_tokens = max_length - 2 # Account for special tokens
125+
126+
# Split tokens roughly in half for beginning and end
127+
begin_tokens = available_tokens // 2
128+
end_tokens = available_tokens - begin_tokens
129+
130+
# Decode beginning and end parts
131+
begin_text = tokenizer.decode(tokens[:begin_tokens], skip_special_tokens=True)
132+
end_text = tokenizer.decode(tokens[-end_tokens:], skip_special_tokens=True)
133+
134+
# Combine with ellipsis marker
135+
return begin_text + " [...] " + end_text
136+
137+
138+
def verify_outputs(model, tokenizer, args):
139+
"""Compare outputs between baseline and sparse attention models."""
140+
# Load and prepare a single test prompt
141+
print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)")
142+
prompts = get_narrativeqa_samples(num_samples=1)
143+
prompt = prompts[0]
144+
145+
# Prepare inputs
146+
truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len)
147+
display_prompt = (
148+
truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt
149+
)
150+
151+
inputs = tokenizer(
152+
truncated_prompt,
153+
return_tensors="pt",
154+
max_length=args.seq_len,
155+
truncation=True,
156+
padding=False,
157+
)
158+
if torch.cuda.is_available():
159+
inputs = {k: v.cuda() for k, v in inputs.items()}
160+
161+
print("\n" + "=" * 60)
162+
print("BASELINE vs SPARSE ATTENTION COMPARISON")
163+
print("=" * 60)
164+
print(f"\nTest prompt: {display_prompt}")
165+
print(f"Input tokens: {inputs['input_ids'].shape[1]} (max: {args.seq_len})")
166+
if "[...]" in truncated_prompt:
167+
print("Note: Text was middle-truncated to fit token limit")
168+
169+
# Helper function to generate text
170+
def generate_text(model, inputs, args, tokenizer):
171+
with torch.no_grad():
172+
outputs = model.generate(
173+
**inputs,
174+
max_new_tokens=args.max_new_tokens,
175+
do_sample=args.do_sample,
176+
temperature=args.temperature if args.do_sample else 1.0,
177+
pad_token_id=tokenizer.pad_token_id,
178+
)
179+
input_length = inputs["input_ids"].shape[1]
180+
generated_ids = outputs[0][input_length:]
181+
return tokenizer.decode(generated_ids, skip_special_tokens=True)
182+
183+
# Find all sparse attention modules
184+
sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)]
185+
186+
# Generate baseline by temporarily disabling sparse attention
187+
print("\n" + "-" * 60)
188+
print("Generating baseline (sparse attention disabled)...")
189+
for module in sparse_modules:
190+
module.disable()
191+
baseline_text = generate_text(model, inputs, args, tokenizer)
192+
193+
# Generate with sparse attention enabled
194+
print("\nGenerating with sparse attention (calibrated thresholds)...")
195+
for module in sparse_modules:
196+
module.enable()
197+
sparse_text = generate_text(model, inputs, args, tokenizer)
198+
199+
# Display comparison
200+
print("\n" + "-" * 60)
201+
print("RESULTS:")
202+
baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text
203+
sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text
204+
205+
print(f"\nBaseline: {baseline_display}")
206+
print(f"With Sparse: {sparse_display}")
207+
208+
if baseline_text == sparse_text:
209+
print("\nOutputs are identical")
210+
else:
211+
print("\nOutputs differ")
212+
213+
214+
def sparsify_model(model, args):
215+
"""Apply sparse attention to the model with optional calibration."""
216+
print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}")
217+
base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]
218+
219+
# Create modified config with selected backend
220+
modified_sparse_cfg = {}
221+
for pattern, cfg in base_config["sparse_cfg"].items():
222+
modified_cfg = cfg.copy()
223+
modified_cfg["backend"] = args.backend
224+
modified_sparse_cfg[pattern] = modified_cfg
225+
226+
# Create new config with modified settings
227+
sparse_config = SparseAttentionConfig(
228+
method=base_config["method"], sparse_cfg=modified_sparse_cfg
229+
)
230+
231+
# Check if calibration is present in config
232+
has_calibration = any(
233+
"calibration" in cfg for cfg in modified_sparse_cfg.values() if isinstance(cfg, dict)
234+
)
235+
236+
if has_calibration:
237+
print("\n" + "=" * 60)
238+
print("CALIBRATION")
239+
print("=" * 60)
240+
print("Config includes calibration - running automatic threshold calibration...")
241+
242+
# Display calibration settings
243+
for cfg in modified_sparse_cfg.values():
244+
if isinstance(cfg, dict) and "calibration" in cfg:
245+
calib = cfg["calibration"]
246+
print(f" Target sparsity: {calib.get('target_sparse_ratio', 0.5)}")
247+
print(f" Samples: {calib.get('samples', 48)}")
248+
print(f" Max sequence length: {calib.get('max_seqlen', 32768)}")
249+
print(" Tokenizer: Auto-extracted from model")
250+
print(" Dataset: RULER (6 default tasks)")
251+
break
252+
253+
# Sparsify with calibration - framework will auto-generate RULER dataset
254+
model = mtsa.sparsify(model, config=sparse_config)
255+
print("\nCalibration complete! Model now uses dynamic threshold: λ = a / context_length")
256+
else:
257+
model = mtsa.sparsify(model, config=sparse_config)
258+
259+
print("Sparse attention applied successfully!")
260+
261+
# Show sparsity statistics
262+
print("\n" + "=" * 60)
263+
print("Sparsity Statistics")
264+
print("=" * 60)
265+
print_sparsity_stats(model)
266+
267+
return model
268+
269+
270+
def main(args):
271+
"""Main function to run the selected mode."""
272+
if not torch.cuda.is_available():
273+
raise OSError("GPU is required for inference.")
274+
275+
random.seed(RAND_SEED)
276+
np.random.seed(RAND_SEED)
277+
launch_memory_monitor()
278+
279+
print(f"Loading model: {args.pyt_ckpt_path}")
280+
281+
# Load model and tokenizer
282+
# Note: attn_implementation="eager" is required for calibration to work properly
283+
# (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection)
284+
model = AutoModelForCausalLM.from_pretrained(
285+
args.pyt_ckpt_path,
286+
attn_implementation="eager",
287+
torch_dtype=torch.bfloat16,
288+
)
289+
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
290+
291+
# Set pad token if not set
292+
if tokenizer.pad_token is None:
293+
tokenizer.pad_token = tokenizer.eos_token
294+
295+
# Move model to GPU if available
296+
if torch.cuda.is_available():
297+
model = model.cuda()
298+
print("Model moved to CUDA")
299+
300+
# Apply sparse attention to the model (with calibration if configured)
301+
model = sparsify_model(model, args)
302+
303+
# Verify outputs if requested (compares baseline vs calibrated sparse model)
304+
if args.verify_output:
305+
verify_outputs(model, tokenizer, args)
306+
307+
# Export if requested
308+
if args.export_dir:
309+
print(f"\nExporting model to: {args.export_dir}")
310+
export_dir = Path(args.export_dir)
311+
export_dir.mkdir(parents=True, exist_ok=True)
312+
313+
with torch.inference_mode():
314+
export_hf_checkpoint(model, export_dir=export_dir)
315+
316+
tokenizer.save_pretrained(export_dir)
317+
print(f"Model exported successfully to: {export_dir}")
318+
319+
320+
if __name__ == "__main__":
321+
parser = argparse.ArgumentParser(description=__doc__)
322+
323+
# Model arguments
324+
parser.add_argument(
325+
"--pyt_ckpt_path",
326+
type=str,
327+
required=True,
328+
help="Specify where the PyTorch checkpoint path is",
329+
)
330+
parser.add_argument(
331+
"--sparse_attn",
332+
type=str,
333+
default="skip_softmax",
334+
choices=list(SPARSE_ATTN_CFG_CHOICES.keys()),
335+
help="Sparse attention configuration to apply.",
336+
)
337+
parser.add_argument(
338+
"--backend",
339+
type=str,
340+
default="pytorch",
341+
choices=["pytorch", "triton"],
342+
help="Backend to use for sparse attention computation (default: pytorch)",
343+
)
344+
345+
# Sequence length arguments
346+
parser.add_argument(
347+
"--seq_len",
348+
type=int,
349+
default=2048,
350+
help="Maximum sequence length for input prompts (will be truncated if longer)",
351+
)
352+
parser.add_argument(
353+
"--num_samples",
354+
type=int,
355+
default=3,
356+
help="Number of samples to use from NarrativeQA dataset",
357+
)
358+
359+
# Generation arguments
360+
parser.add_argument(
361+
"--max_new_tokens", type=int, default=50, help="Maximum new tokens to generate"
362+
)
363+
parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation")
364+
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling")
365+
366+
# Operation arguments
367+
parser.add_argument(
368+
"--verify_output",
369+
action="store_true",
370+
help="Verify that sparse attention outputs match baseline",
371+
)
372+
parser.add_argument(
373+
"--export_dir",
374+
type=str,
375+
default=None,
376+
help="Directory to export the model with sparse attention applied",
377+
)
378+
379+
args = parser.parse_args()
380+
main(args)

0 commit comments

Comments
 (0)