|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +"""Example script for applying sparse attention to HuggingFace models.""" |
| 18 | + |
| 19 | +import argparse |
| 20 | +import random |
| 21 | +from pathlib import Path |
| 22 | + |
| 23 | +import numpy as np |
| 24 | +import torch |
| 25 | +import torch.nn as nn |
| 26 | +from datasets import load_dataset |
| 27 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 28 | + |
| 29 | +import modelopt.torch.sparsity.attention_sparsity as mtsa |
| 30 | +from modelopt.torch.export import export_hf_checkpoint |
| 31 | +from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig |
| 32 | +from modelopt.torch.sparsity.attention_sparsity.config import ( |
| 33 | + SKIP_SOFTMAX_CALIB, |
| 34 | + SKIP_SOFTMAX_DEFAULT, |
| 35 | +) |
| 36 | +from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule |
| 37 | +from modelopt.torch.utils.memory_monitor import launch_memory_monitor |
| 38 | + |
| 39 | +RAND_SEED = 1234 |
| 40 | + |
| 41 | +# You can define custom configurations or use the default |
| 42 | +SPARSE_ATTN_CFG_CHOICES = { |
| 43 | + "skip_softmax": SKIP_SOFTMAX_DEFAULT, |
| 44 | + "skip_softmax_calib": SKIP_SOFTMAX_CALIB, |
| 45 | +} |
| 46 | + |
| 47 | + |
| 48 | +def print_sparsity_stats(model: nn.Module): |
| 49 | + """Print sparsity statistics if available.""" |
| 50 | + module_stats = [] |
| 51 | + for name, module in model.named_modules(): |
| 52 | + if hasattr(module, "get_stats"): |
| 53 | + stats = module.get_stats() |
| 54 | + if stats and "average_sparsity" in stats: |
| 55 | + module_stats.append((name, stats["average_sparsity"])) |
| 56 | + |
| 57 | + if not module_stats: |
| 58 | + print("No sparsity statistics available") |
| 59 | + return |
| 60 | + |
| 61 | + # Check if all modules have the same sparsity |
| 62 | + sparsities = [s for _, s in module_stats] |
| 63 | + if len(set(sparsities)) == 1: |
| 64 | + # All identical - show summary |
| 65 | + print(f"Average sparsity across all {len(module_stats)} modules: {sparsities[0]:.2%}") |
| 66 | + else: |
| 67 | + # Different sparsities - show individual values |
| 68 | + avg_sparsity = sum(sparsities) / len(sparsities) |
| 69 | + print(f"Average sparsity: {avg_sparsity:.2%}") |
| 70 | + print("Per-module breakdown:") |
| 71 | + for name, sparsity in module_stats: |
| 72 | + print(f" {name}: {sparsity:.2%} sparse") |
| 73 | + |
| 74 | + |
| 75 | +def get_narrativeqa_samples(num_samples=3): |
| 76 | + """Load samples from NarrativeQA dataset for testing. |
| 77 | +
|
| 78 | + Args: |
| 79 | + num_samples: Number of samples to generate |
| 80 | + """ |
| 81 | + # Load NarrativeQA dataset |
| 82 | + dataset = load_dataset("narrativeqa", split="test", streaming=True) |
| 83 | + |
| 84 | + samples = [] |
| 85 | + for i, item in enumerate(dataset): |
| 86 | + if i >= num_samples: |
| 87 | + break |
| 88 | + |
| 89 | + # Combine document context and question |
| 90 | + context = item.get("document", {}).get("text", "") |
| 91 | + question = item.get("question", {}).get("text", "") |
| 92 | + |
| 93 | + if context and question: |
| 94 | + # Use the full context as-is |
| 95 | + prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" |
| 96 | + samples.append(prompt) |
| 97 | + |
| 98 | + if not samples: |
| 99 | + raise ValueError("Could not load NarrativeQA samples") |
| 100 | + |
| 101 | + print(f"Loaded {len(samples)} NarrativeQA samples") |
| 102 | + return samples |
| 103 | + |
| 104 | + |
| 105 | +def truncate_text(text: str, tokenizer, max_length: int): |
| 106 | + """Truncate text from the middle to preserve beginning and end. |
| 107 | +
|
| 108 | + Args: |
| 109 | + text: Input text to truncate |
| 110 | + tokenizer: Tokenizer to use for encoding |
| 111 | + max_length: Maximum number of tokens |
| 112 | +
|
| 113 | + Returns: |
| 114 | + Truncated text that fits within max_length tokens |
| 115 | + """ |
| 116 | + # First tokenize to see if truncation is needed |
| 117 | + tokens = tokenizer.encode(text, add_special_tokens=True) |
| 118 | + |
| 119 | + if len(tokens) <= max_length: |
| 120 | + return text |
| 121 | + |
| 122 | + # Need to truncate - preserve beginning and end |
| 123 | + # Reserve some tokens for special tokens |
| 124 | + available_tokens = max_length - 2 # Account for special tokens |
| 125 | + |
| 126 | + # Split tokens roughly in half for beginning and end |
| 127 | + begin_tokens = available_tokens // 2 |
| 128 | + end_tokens = available_tokens - begin_tokens |
| 129 | + |
| 130 | + # Decode beginning and end parts |
| 131 | + begin_text = tokenizer.decode(tokens[:begin_tokens], skip_special_tokens=True) |
| 132 | + end_text = tokenizer.decode(tokens[-end_tokens:], skip_special_tokens=True) |
| 133 | + |
| 134 | + # Combine with ellipsis marker |
| 135 | + return begin_text + " [...] " + end_text |
| 136 | + |
| 137 | + |
| 138 | +def verify_outputs(model, tokenizer, args): |
| 139 | + """Compare outputs between baseline and sparse attention models.""" |
| 140 | + # Load and prepare a single test prompt |
| 141 | + print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)") |
| 142 | + prompts = get_narrativeqa_samples(num_samples=1) |
| 143 | + prompt = prompts[0] |
| 144 | + |
| 145 | + # Prepare inputs |
| 146 | + truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len) |
| 147 | + display_prompt = ( |
| 148 | + truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt |
| 149 | + ) |
| 150 | + |
| 151 | + inputs = tokenizer( |
| 152 | + truncated_prompt, |
| 153 | + return_tensors="pt", |
| 154 | + max_length=args.seq_len, |
| 155 | + truncation=True, |
| 156 | + padding=False, |
| 157 | + ) |
| 158 | + if torch.cuda.is_available(): |
| 159 | + inputs = {k: v.cuda() for k, v in inputs.items()} |
| 160 | + |
| 161 | + print("\n" + "=" * 60) |
| 162 | + print("BASELINE vs SPARSE ATTENTION COMPARISON") |
| 163 | + print("=" * 60) |
| 164 | + print(f"\nTest prompt: {display_prompt}") |
| 165 | + print(f"Input tokens: {inputs['input_ids'].shape[1]} (max: {args.seq_len})") |
| 166 | + if "[...]" in truncated_prompt: |
| 167 | + print("Note: Text was middle-truncated to fit token limit") |
| 168 | + |
| 169 | + # Helper function to generate text |
| 170 | + def generate_text(model, inputs, args, tokenizer): |
| 171 | + with torch.no_grad(): |
| 172 | + outputs = model.generate( |
| 173 | + **inputs, |
| 174 | + max_new_tokens=args.max_new_tokens, |
| 175 | + do_sample=args.do_sample, |
| 176 | + temperature=args.temperature if args.do_sample else 1.0, |
| 177 | + pad_token_id=tokenizer.pad_token_id, |
| 178 | + ) |
| 179 | + input_length = inputs["input_ids"].shape[1] |
| 180 | + generated_ids = outputs[0][input_length:] |
| 181 | + return tokenizer.decode(generated_ids, skip_special_tokens=True) |
| 182 | + |
| 183 | + # Find all sparse attention modules |
| 184 | + sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] |
| 185 | + |
| 186 | + # Generate baseline by temporarily disabling sparse attention |
| 187 | + print("\n" + "-" * 60) |
| 188 | + print("Generating baseline (sparse attention disabled)...") |
| 189 | + for module in sparse_modules: |
| 190 | + module.disable() |
| 191 | + baseline_text = generate_text(model, inputs, args, tokenizer) |
| 192 | + |
| 193 | + # Generate with sparse attention enabled |
| 194 | + print("\nGenerating with sparse attention (calibrated thresholds)...") |
| 195 | + for module in sparse_modules: |
| 196 | + module.enable() |
| 197 | + sparse_text = generate_text(model, inputs, args, tokenizer) |
| 198 | + |
| 199 | + # Display comparison |
| 200 | + print("\n" + "-" * 60) |
| 201 | + print("RESULTS:") |
| 202 | + baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text |
| 203 | + sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text |
| 204 | + |
| 205 | + print(f"\nBaseline: {baseline_display}") |
| 206 | + print(f"With Sparse: {sparse_display}") |
| 207 | + |
| 208 | + if baseline_text == sparse_text: |
| 209 | + print("\nOutputs are identical") |
| 210 | + else: |
| 211 | + print("\nOutputs differ") |
| 212 | + |
| 213 | + |
| 214 | +def sparsify_model(model, args): |
| 215 | + """Apply sparse attention to the model with optional calibration.""" |
| 216 | + print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}") |
| 217 | + base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] |
| 218 | + |
| 219 | + # Create modified config with selected backend |
| 220 | + modified_sparse_cfg = {} |
| 221 | + for pattern, cfg in base_config["sparse_cfg"].items(): |
| 222 | + modified_cfg = cfg.copy() |
| 223 | + modified_cfg["backend"] = args.backend |
| 224 | + modified_sparse_cfg[pattern] = modified_cfg |
| 225 | + |
| 226 | + # Create new config with modified settings |
| 227 | + sparse_config = SparseAttentionConfig( |
| 228 | + method=base_config["method"], sparse_cfg=modified_sparse_cfg |
| 229 | + ) |
| 230 | + |
| 231 | + # Check if calibration is present in config |
| 232 | + has_calibration = any( |
| 233 | + "calibration" in cfg for cfg in modified_sparse_cfg.values() if isinstance(cfg, dict) |
| 234 | + ) |
| 235 | + |
| 236 | + if has_calibration: |
| 237 | + print("\n" + "=" * 60) |
| 238 | + print("CALIBRATION") |
| 239 | + print("=" * 60) |
| 240 | + print("Config includes calibration - running automatic threshold calibration...") |
| 241 | + |
| 242 | + # Display calibration settings |
| 243 | + for cfg in modified_sparse_cfg.values(): |
| 244 | + if isinstance(cfg, dict) and "calibration" in cfg: |
| 245 | + calib = cfg["calibration"] |
| 246 | + print(f" Target sparsity: {calib.get('target_sparse_ratio', 0.5)}") |
| 247 | + print(f" Samples: {calib.get('samples', 48)}") |
| 248 | + print(f" Max sequence length: {calib.get('max_seqlen', 32768)}") |
| 249 | + print(" Tokenizer: Auto-extracted from model") |
| 250 | + print(" Dataset: RULER (6 default tasks)") |
| 251 | + break |
| 252 | + |
| 253 | + # Sparsify with calibration - framework will auto-generate RULER dataset |
| 254 | + model = mtsa.sparsify(model, config=sparse_config) |
| 255 | + print("\nCalibration complete! Model now uses dynamic threshold: λ = a / context_length") |
| 256 | + else: |
| 257 | + model = mtsa.sparsify(model, config=sparse_config) |
| 258 | + |
| 259 | + print("Sparse attention applied successfully!") |
| 260 | + |
| 261 | + # Show sparsity statistics |
| 262 | + print("\n" + "=" * 60) |
| 263 | + print("Sparsity Statistics") |
| 264 | + print("=" * 60) |
| 265 | + print_sparsity_stats(model) |
| 266 | + |
| 267 | + return model |
| 268 | + |
| 269 | + |
| 270 | +def main(args): |
| 271 | + """Main function to run the selected mode.""" |
| 272 | + if not torch.cuda.is_available(): |
| 273 | + raise OSError("GPU is required for inference.") |
| 274 | + |
| 275 | + random.seed(RAND_SEED) |
| 276 | + np.random.seed(RAND_SEED) |
| 277 | + launch_memory_monitor() |
| 278 | + |
| 279 | + print(f"Loading model: {args.pyt_ckpt_path}") |
| 280 | + |
| 281 | + # Load model and tokenizer |
| 282 | + # Note: attn_implementation="eager" is required for calibration to work properly |
| 283 | + # (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection) |
| 284 | + model = AutoModelForCausalLM.from_pretrained( |
| 285 | + args.pyt_ckpt_path, |
| 286 | + attn_implementation="eager", |
| 287 | + torch_dtype=torch.bfloat16, |
| 288 | + ) |
| 289 | + tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) |
| 290 | + |
| 291 | + # Set pad token if not set |
| 292 | + if tokenizer.pad_token is None: |
| 293 | + tokenizer.pad_token = tokenizer.eos_token |
| 294 | + |
| 295 | + # Move model to GPU if available |
| 296 | + if torch.cuda.is_available(): |
| 297 | + model = model.cuda() |
| 298 | + print("Model moved to CUDA") |
| 299 | + |
| 300 | + # Apply sparse attention to the model (with calibration if configured) |
| 301 | + model = sparsify_model(model, args) |
| 302 | + |
| 303 | + # Verify outputs if requested (compares baseline vs calibrated sparse model) |
| 304 | + if args.verify_output: |
| 305 | + verify_outputs(model, tokenizer, args) |
| 306 | + |
| 307 | + # Export if requested |
| 308 | + if args.export_dir: |
| 309 | + print(f"\nExporting model to: {args.export_dir}") |
| 310 | + export_dir = Path(args.export_dir) |
| 311 | + export_dir.mkdir(parents=True, exist_ok=True) |
| 312 | + |
| 313 | + with torch.inference_mode(): |
| 314 | + export_hf_checkpoint(model, export_dir=export_dir) |
| 315 | + |
| 316 | + tokenizer.save_pretrained(export_dir) |
| 317 | + print(f"Model exported successfully to: {export_dir}") |
| 318 | + |
| 319 | + |
| 320 | +if __name__ == "__main__": |
| 321 | + parser = argparse.ArgumentParser(description=__doc__) |
| 322 | + |
| 323 | + # Model arguments |
| 324 | + parser.add_argument( |
| 325 | + "--pyt_ckpt_path", |
| 326 | + type=str, |
| 327 | + required=True, |
| 328 | + help="Specify where the PyTorch checkpoint path is", |
| 329 | + ) |
| 330 | + parser.add_argument( |
| 331 | + "--sparse_attn", |
| 332 | + type=str, |
| 333 | + default="skip_softmax", |
| 334 | + choices=list(SPARSE_ATTN_CFG_CHOICES.keys()), |
| 335 | + help="Sparse attention configuration to apply.", |
| 336 | + ) |
| 337 | + parser.add_argument( |
| 338 | + "--backend", |
| 339 | + type=str, |
| 340 | + default="pytorch", |
| 341 | + choices=["pytorch", "triton"], |
| 342 | + help="Backend to use for sparse attention computation (default: pytorch)", |
| 343 | + ) |
| 344 | + |
| 345 | + # Sequence length arguments |
| 346 | + parser.add_argument( |
| 347 | + "--seq_len", |
| 348 | + type=int, |
| 349 | + default=2048, |
| 350 | + help="Maximum sequence length for input prompts (will be truncated if longer)", |
| 351 | + ) |
| 352 | + parser.add_argument( |
| 353 | + "--num_samples", |
| 354 | + type=int, |
| 355 | + default=3, |
| 356 | + help="Number of samples to use from NarrativeQA dataset", |
| 357 | + ) |
| 358 | + |
| 359 | + # Generation arguments |
| 360 | + parser.add_argument( |
| 361 | + "--max_new_tokens", type=int, default=50, help="Maximum new tokens to generate" |
| 362 | + ) |
| 363 | + parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation") |
| 364 | + parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling") |
| 365 | + |
| 366 | + # Operation arguments |
| 367 | + parser.add_argument( |
| 368 | + "--verify_output", |
| 369 | + action="store_true", |
| 370 | + help="Verify that sparse attention outputs match baseline", |
| 371 | + ) |
| 372 | + parser.add_argument( |
| 373 | + "--export_dir", |
| 374 | + type=str, |
| 375 | + default=None, |
| 376 | + help="Directory to export the model with sparse attention applied", |
| 377 | + ) |
| 378 | + |
| 379 | + args = parser.parse_args() |
| 380 | + main(args) |
0 commit comments