diff --git a/examples/llm_sparse_attention/README.md b/examples/llm_sparse_attention/README.md new file mode 100644 index 000000000..907b7a456 --- /dev/null +++ b/examples/llm_sparse_attention/README.md @@ -0,0 +1,398 @@ +# Sparse Attention for Large Language Models + +This example demonstrates how to apply sparse attention optimization to Large Language Models (LLMs) using TensorRT-Model-Optimizer's attention sparsity module. + +
+ +| **Section** | **Description** | **Link** | **Docs** | +| :------------: | :------------: | :------------: | :------------: | +| Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | | +| Getting Started | Learn how to apply sparse attention to optimize inference efficiency | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/)\] | +| Support Matrix | View the support matrix to see sparse attention compatibility across different models | \[[Link](#support-matrix)\] | | +| Framework Scripts | Example scripts demonstrating sparse attention techniques for optimizing models | \[[Link](#framework-scripts)\] | | +| Evaluate Accuracy | Evaluate your model's accuracy with sparse attention! | \[[Link](#evaluate-accuracy)\] | | +| Exporting Checkpoints | Export to Hugging Face Unified Checkpoint and deploy on TRT-LLM/vLLM/SGLang | \[[Link](#exporting-checkpoints)\] | \[[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/deployment/3_unified_hf.html)\] | +| Pre-Sparsified Checkpoints | Ready to deploy Hugging Face pre-sparsified checkpoints | \[[Link](#pre-sparsified-checkpoints)\] | | +| Resources | Extra links to relevant resources | \[[Link](#resources)\] | | + +
+ +## Overview + +Sparse attention reduces the computational complexity of attention mechanisms by selectively computing only the most important attention scores. This can significantly speed up inference and reduce memory usage, especially for long sequences. + +## Features + +- **Sparse Attention Method**: + - Softmax Skip: Threshold-based masking for efficient attention computation + - Extensible architecture: Easy to add new sparse attention methods in the future +- **Calibration Support**: Automatically find optimal sparsity parameters +- **HuggingFace Integration**: Works with any HuggingFace transformer model +- **Composable**: Can be combined with quantization and other optimizations + +## Pre-Requisites + +### Docker + +For Hugging Face models, please use the TensorRT-LLM docker image (e.g., `nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2.post2`). +For NeMo models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:25.07`). +Visit our [installation docs](https://nvidia.github.io/TensorRT-Model-Optimizer/getting_started/2_installation.html) for more information. + +Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install example-specific dependencies. + +### Local Installation + +For Hugging Face models, install Model Optimizer with `hf` dependencies using `pip` from [PyPI](https://pypi.org/project/nvidia-modelopt/) and install the requirements for the example: + +```bash +pip install -U nvidia-modelopt[hf] +pip install -r requirements.txt +``` + +For TensorRT-LLM deployment, please use the TensorRT-LLM docker image or follow their [installation docs](https://nvidia.github.io/TensorRT-LLM/installation/index.html). +Similarly, for vLLM or SGLang deployment, please use their installation docs. + +> *When loading models from HuggingFace, `trust_remote_code=False` is used by default for security. If your model requires custom code, you may need to modify the script to set `trust_remote_code=True` in the `AutoModelForCausalLM.from_pretrained()` call.* + +> *If model loading fails on a multi-GPU system due to mismatched tensor placement, try setting CUDA_VISIBLE_DEVICES to limit the number of visible GPUs.* + +> *For large models with limited GPU memory, adjust `--seq_len` or `--num_samples` parameters. You can also modify the script to use HuggingFace's `device_map="auto"` feature in model loading to automatically distribute across GPUs.* + +## Getting Started + +```python +import modelopt.torch.sparsity as mts # Similar to mtq for quantization +from transformers import AutoModelForCausalLM + +# Load model +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-8B") + +# Define sparse attention config +sparse_config = { + "method": "softmax_skip", + "sparse_cfg": { + "*attn*": {"threshold": 1e-4, "enable": True}, + "default": {"enable": False} + } +} + +# Apply sparse attention +sparse_model = mts.attention_sparsity.sparsify(model, config=sparse_config) + +# Use the model as usual +output = sparse_model.generate(input_ids, max_new_tokens=100) +``` + +### Command Line Usage + +The `hf_spar_attn.py` script applies sparse attention to HuggingFace models: + +```bash +# Basic usage: Apply sparse attention and test generation +python hf_spar_attn.py --pyt_ckpt_path Qwen/Qwen3-8B + +# With output verification: Compare baseline vs sparse attention outputs +python hf_spar_attn.py --pyt_ckpt_path Qwen/Qwen3-8B --verify_output + +# Export to unified HuggingFace checkpoint +python hf_spar_attn.py --pyt_ckpt_path Qwen/Qwen3-8B --export_dir ./sparse_model +``` + +## Examples + +### Basic Usage + +Apply sparse attention to a model and test generation quality: + +```bash +python hf_spar_attn.py --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax \ + --verify_output +``` + +Available Options: + +- `--pyt_ckpt_path`: Model checkpoint path or HuggingFace model card (required) +- `--sparse_attn`: Sparse attention method (default: skip_softmax) +- `--verify_output`: Compare baseline vs sparse attention outputs +- `--export_dir`: Export model to specified directory +- `--backend`: Backend for computation - pytorch or triton (default: pytorch) +- `--seq_len`: Maximum sequence length (default: 2048) +- `--num_samples`: Number of test samples from NarrativeQA (default: 3) +- `--max_new_tokens`: Maximum new tokens to generate (default: 50) +- `--do_sample`: Use sampling for generation +- `--temperature`: Temperature for sampling (default: 0.7) + +Note: Sparsity statistics are automatically displayed after applying sparse attention. + +## Exporting Checkpoints + +### Export Model + +Export the sparse attention model to unified HuggingFace checkpoint format for deployment: + +```bash +python hf_spar_attn.py --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax \ + --export_dir ./sparse_model +``` + +The exported model will contain: + +- Model weights with sparse attention applied +- `config.json` with `sparse_attention_config` section +- Tokenizer files + +### Exported Config Format + +The `config.json` includes a `sparse_attention_config` section using the `config_groups` pattern (similar to `quantization_config`): + +**For calibrated models:** + +```json +{ + "sparse_attention_config": { + "config_groups": { + "group_0": { + "sparse_algo": "softmax_skip", + "targets": ["LlamaAttention"] + } + }, + "threshold_scale_factor": 437.7, + "target_sparsity": 0.3, + "producer": { + "name": "modelopt", + "version": "0.37.0" + } + } +} +``` + +**For non-calibrated models:** + +```json +{ + "sparse_attention_config": { + "config_groups": { + "group_0": { + "sparse_algo": "softmax_skip", + "threshold": 0.0001, + "targets": ["LlamaAttention"] + } + }, + "producer": { + "name": "modelopt", + "version": "0.37.0" + } + } +} +``` + +This format enables inference engines to reconstruct the sparse attention configuration from the checkpoint. + +### Deployment + +Deployment examples for TensorRT-LLM, vLLM, and SGLang will be added soon. + +### Unified HF Checkpoint Deployment Model Support Matrix + +Support matrix showing which models and sparse attention methods work with each deployment framework will be added. + +## Evaluate Accuracy + +### Accuracy Validation + +Evaluating the impact of sparse attention on model accuracy is crucial. The `hf_spar_attn.py` script provides built-in support for validation through the `--verify_output` flag, which compares outputs between baseline and sparse attention models. + +Sparsity statistics are automatically displayed to help monitor sparsity levels across different attention layers: + +```bash +python hf_spar_attn.py --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax \ + --verify_output +``` + +For comprehensive accuracy evaluation, additional benchmarks are available in the [llm_eval](../llm_eval/README.md) directory, including: + +- MMLU (Massive Multitask Language Understanding) +- lm_evaluation_harness for various language modeling tasks + +Please refer to the [llm_eval README](../llm_eval/README.md) for detailed instructions on running these evaluation benchmarks. + +> *Sparsity statistics are automatically displayed for each attention layer. Monitor these carefully to ensure the threshold settings are achieving the desired balance between performance and accuracy.* + +> *Different models may have varying sensitivity to sparse attention. It's recommended to evaluate on task-specific benchmarks relevant to your use case.* + +## Support Matrix + +### Hugging Face Supported Models + +Support matrix will be added as testing is completed for various models and sparse attention methods. + +> *This section is under active development. The sparse attention feature is currently being validated across different model architectures.* + +## Framework Scripts + +### Hugging Face Example + +For LLM models like [Llama](https://huggingface.co/meta-llama) or [Qwen](https://huggingface.co/Qwen): + +```bash +# Apply sparse attention and test generation +python hf_spar_attn.py --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax \ + --verify_output + +# Export to unified HuggingFace checkpoint +python hf_spar_attn.py --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax \ + --export_dir ./sparse_model +``` + +**Key Command-Line Flags:** + +- `--pyt_ckpt_path`: Model checkpoint path or HuggingFace model card (required) +- `--sparse_attn`: Sparse attention method - currently supports `skip_softmax` (default) and `skip_softmax_calib` +- `--verify_output`: Compare baseline vs sparse attention outputs for validation +- `--export_dir`: Directory to save the exported sparse model +- `--backend`: Backend for computation - `pytorch` or `triton` (default: pytorch) +- `--seq_len`: Maximum sequence length for input prompts (default: 2048) +- `--num_samples`: Number of test samples from NarrativeQA dataset (default: 3) +- `--max_new_tokens`: Maximum new tokens to generate (default: 50) + +Note: Sparsity statistics are automatically displayed after applying sparse attention - no flag needed. + +> *When loading models from HuggingFace, `trust_remote_code=False` is used by default for security. If the model requires custom code, you'll need to manually set `trust_remote_code=True` in the model loading code.* + +> *If GPU out-of-memory error is reported, try reducing `--seq_len` or `--num_samples`. For very large models, consider using HuggingFace's `device_map="auto"` feature in the model loading code to distribute across GPUs.* + +> *Sparse attention works best with models using eager attention implementation. Models with fused attention kernels may require modifications.* + +### NeMo Example Script + +NeMo framework sparse attention examples will be added in future releases. + +### Megatron-LM Example Script + +Megatron-LM framework sparse attention examples will be added in future releases. + +## Configuration Options + +### Pre-defined Configuration + +ModelOpt provides a unified configuration that supports both simple and phase-aware thresholds: + +```python +import modelopt.torch.sparsity as mts + +# The default config supports phase-aware thresholds +SOFTMAX_SKIP_CFG = { + "method": "softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": { + "prefill": 1e-3, # More aggressive during prefill + "decode": 1e-5, # Conservative during decode + }, + "enable": True, + }, + "default": {"enable": False}, + }, +} + +# Use the config +model = mts.attention_sparsity.sparsify(model, config=SOFTMAX_SKIP_CFG) +``` + +### Custom Configuration + +You can create custom configurations with simple or phase-aware thresholds: + +```python +# Simple threshold (same for all phases) +simple_config = { + "method": "softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-4, # Single threshold for all phases + "enable": True, + }, + "default": {"enable": False}, + } +} + +# Phase-aware threshold +phase_aware_config = { + "method": "softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": { + "prefill": 1e-3, # Prefill phase + "decode": 1e-5, # Decode phase + }, + "enable": True, + }, + "default": {"enable": False}, + } +} +``` + +### Adding Custom Methods + +The architecture is designed to easily support new sparse attention methods. Refer to [`FlashSoftmaxSkipMethod`](../../modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py) source code for implementing custom methods. + +### Pattern-Based Configuration + +Apply different configurations to different layers: + +```python +config = { + "sparse_cfg": { + "*layers.[0-12].*attention*": {"enable": True, "threshold": 1e-3}, # More aggressive for early layers + "*layers.[13-24].*attention*": {"enable": True, "threshold": 1e-4}, # Conservative for later layers + } +} +``` + +## Performance Considerations + +1. **Threshold Tuning**: + - Lower thresholds (e.g., 1e-5) preserve more accuracy but less sparsity + - Higher thresholds (e.g., 1e-3) provide more sparsity but may impact accuracy + - Use calibration to find optimal values + +2. **Memory Usage**: + - Sparse attention reduces peak memory usage during inference + - Especially beneficial for long sequences (>1024 tokens) + +3. **Model Compatibility**: + - Works best with models using eager attention implementation + - Compatible with all HuggingFace transformer models + +## Resources + +- 📅 [Roadmap](https://github.com/NVIDIA/TensorRT-Model-Optimizer/issues/146) +- 📖 [Documentation](https://nvidia.github.io/TensorRT-Model-Optimizer) +- 🎯 [Benchmarks](../benchmark.md) +- 💡 [Release Notes](https://nvidia.github.io/TensorRT-Model-Optimizer/reference/0_changelog.html) +- 🐛 [File a bug](https://github.com/NVIDIA/TensorRT-Model-Optimizer/issues/new?template=1_bug_report.md) +- ✨ [File a Feature Request](https://github.com/NVIDIA/TensorRT-Model-Optimizer/issues/new?template=2_feature_request.md) + +### Technical Resources + +Sparse attention reduces the computational cost of attention mechanisms by selectively computing attention scores. The primary method currently supported is: + +1. **Softmax Skip**: A threshold-based approach that skips computation of attention scores below a certain threshold. This method is particularly effective for long sequences where many attention scores are near zero. The implementation is available in [`FlashSoftmaxSkipMethod`](../../modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py). + +**Further Reading:** + +- [Sparse Attention Papers Collection](https://github.com/topics/sparse-attention) +- [TensorRT-Model-Optimizer Sparse Attention Documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/) + +## Pre-Sparsified Checkpoints + +Pre-sparsified model checkpoints will be made available on Hugging Face in the future. + +- Ready-to-deploy checkpoints will be published to the [🤗 Hugging Face - Nvidia TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4) +- Deployable on [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang) +- More models coming soon! diff --git a/examples/llm_sparse_attention/hf_spar_attn.py b/examples/llm_sparse_attention/hf_spar_attn.py new file mode 100644 index 000000000..461af581e --- /dev/null +++ b/examples/llm_sparse_attention/hf_spar_attn.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example script for applying sparse attention to HuggingFace models.""" + +import argparse +import random +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_CALIB, + SKIP_SOFTMAX_DEFAULT, +) +from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule +from modelopt.torch.utils.memory_monitor import launch_memory_monitor + +RAND_SEED = 1234 + +# You can define custom configurations or use the default +SPARSE_ATTN_CFG_CHOICES = { + "skip_softmax": SKIP_SOFTMAX_DEFAULT, + "skip_softmax_calib": SKIP_SOFTMAX_CALIB, +} + + +def print_sparsity_stats(model: nn.Module): + """Print sparsity statistics if available.""" + module_stats = [] + for name, module in model.named_modules(): + if hasattr(module, "get_stats"): + stats = module.get_stats() + if stats and "average_sparsity" in stats: + module_stats.append((name, stats["average_sparsity"])) + + if not module_stats: + print("No sparsity statistics available") + return + + # Check if all modules have the same sparsity + sparsities = [s for _, s in module_stats] + if len(set(sparsities)) == 1: + # All identical - show summary + print(f"Average sparsity across all {len(module_stats)} modules: {sparsities[0]:.2%}") + else: + # Different sparsities - show individual values + avg_sparsity = sum(sparsities) / len(sparsities) + print(f"Average sparsity: {avg_sparsity:.2%}") + print("Per-module breakdown:") + for name, sparsity in module_stats: + print(f" {name}: {sparsity:.2%} sparse") + + +def get_narrativeqa_samples(num_samples=3): + """Load samples from NarrativeQA dataset for testing. + + Args: + num_samples: Number of samples to generate + """ + # Load NarrativeQA dataset + dataset = load_dataset("narrativeqa", split="test", streaming=True) + + samples = [] + for i, item in enumerate(dataset): + if i >= num_samples: + break + + # Combine document context and question + context = item.get("document", {}).get("text", "") + question = item.get("question", {}).get("text", "") + + if context and question: + # Use the full context as-is + prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" + samples.append(prompt) + + if not samples: + raise ValueError("Could not load NarrativeQA samples") + + print(f"Loaded {len(samples)} NarrativeQA samples") + return samples + + +def truncate_text(text: str, tokenizer, max_length: int): + """Truncate text from the middle to preserve beginning and end. + + Args: + text: Input text to truncate + tokenizer: Tokenizer to use for encoding + max_length: Maximum number of tokens + + Returns: + Truncated text that fits within max_length tokens + """ + # First tokenize to see if truncation is needed + tokens = tokenizer.encode(text, add_special_tokens=True) + + if len(tokens) <= max_length: + return text + + # Need to truncate - preserve beginning and end + # Reserve some tokens for special tokens + available_tokens = max_length - 2 # Account for special tokens + + # Split tokens roughly in half for beginning and end + begin_tokens = available_tokens // 2 + end_tokens = available_tokens - begin_tokens + + # Decode beginning and end parts + begin_text = tokenizer.decode(tokens[:begin_tokens], skip_special_tokens=True) + end_text = tokenizer.decode(tokens[-end_tokens:], skip_special_tokens=True) + + # Combine with ellipsis marker + return begin_text + " [...] " + end_text + + +def verify_outputs(model, tokenizer, args): + """Compare outputs between baseline and sparse attention models.""" + # Update seq_len to match calibration max_seqlen if calibration was used + base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {}) + if "calibration" in base_config and "max_seqlen" in base_config["calibration"]: + calib_max_seqlen = base_config["calibration"]["max_seqlen"] + if args.seq_len != calib_max_seqlen: + print( + f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} " + f"to match calibration config" + ) + args.seq_len = calib_max_seqlen + + # Load and prepare a single test prompt + print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)") + prompts = get_narrativeqa_samples(num_samples=1) + prompt = prompts[0] + + # Prepare inputs + truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len) + display_prompt = ( + truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt + ) + + inputs = tokenizer( + truncated_prompt, + return_tensors="pt", + max_length=args.seq_len, + truncation=True, + padding=False, + ) + if torch.cuda.is_available(): + inputs = {k: v.cuda() for k, v in inputs.items()} + + print("\n" + "=" * 60) + print("BASELINE vs SPARSE ATTENTION COMPARISON") + print("=" * 60) + print(f"\nTest prompt: {display_prompt}") + print(f"Input tokens: {inputs['input_ids'].shape[1]} (max: {args.seq_len})") + if "[...]" in truncated_prompt: + print("Note: Text was middle-truncated to fit token limit") + + # Helper function to generate text + def generate_text(model, inputs, args, tokenizer): + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + temperature=args.temperature if args.do_sample else 1.0, + pad_token_id=tokenizer.pad_token_id, + ) + input_length = inputs["input_ids"].shape[1] + generated_ids = outputs[0][input_length:] + return tokenizer.decode(generated_ids, skip_special_tokens=True) + + # Find all sparse attention modules + sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] + + # Generate baseline by temporarily disabling sparse attention + print("\n" + "-" * 60) + print("Generating baseline (sparse attention disabled)...") + for module in sparse_modules: + module.disable() + baseline_text = generate_text(model, inputs, args, tokenizer) + + # Generate with sparse attention enabled + print("\nGenerating with sparse attention (calibrated thresholds)...") + for module in sparse_modules: + module.enable() + sparse_text = generate_text(model, inputs, args, tokenizer) + + # Display comparison + print("\n" + "-" * 60) + print("RESULTS:") + baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text + sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text + + print(f"\nBaseline: {baseline_display}") + print(f"With Sparse: {sparse_display}") + + if baseline_text == sparse_text: + print("\nOutputs are identical") + else: + print("\nOutputs differ") + + +def sparsify_model(model, args): + """Apply sparse attention to the model with optional calibration.""" + print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}") + base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] + + # Create modified config with selected backend + modified_sparse_cfg = {} + for pattern, cfg in base_config["sparse_cfg"].items(): + modified_cfg = cfg.copy() + modified_cfg["backend"] = args.backend + modified_sparse_cfg[pattern] = modified_cfg + + # Create new config with modified settings + sparse_config = SparseAttentionConfig( + method=base_config["method"], + sparse_cfg=modified_sparse_cfg, + collect_stats=True, # Enable stats collection for monitoring + ) + + # Sparsify with optional calibration - framework handles calibration automatically + model = mtsa.sparsify(model, config=sparse_config) + + print("Sparse attention applied successfully!") + + # Show sparsity statistics + print("\n" + "=" * 60) + print("Sparsity Statistics") + print("=" * 60) + print_sparsity_stats(model) + + return model + + +def main(args): + """Main function to run the selected mode.""" + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + launch_memory_monitor() + + print(f"Loading model: {args.pyt_ckpt_path}") + + # Load model and tokenizer + # Note: attn_implementation="eager" is required for calibration to work properly + # (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection) + model = AutoModelForCausalLM.from_pretrained( + args.pyt_ckpt_path, + attn_implementation="eager", + torch_dtype=torch.bfloat16, + ) + tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) + + # Set pad token if not set + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Move model to GPU if available + if torch.cuda.is_available(): + model = model.cuda() + print("Model moved to CUDA") + + # Apply sparse attention to the model (with calibration if configured) + model = sparsify_model(model, args) + + # Verify outputs if requested (compares baseline vs calibrated sparse model) + if args.verify_output: + verify_outputs(model, tokenizer, args) + + # Export if requested + if args.export_dir: + print(f"\nExporting model to: {args.export_dir}") + export_dir = Path(args.export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + with torch.inference_mode(): + export_hf_checkpoint(model, export_dir=export_dir) + + tokenizer.save_pretrained(export_dir) + print(f"Model exported successfully to: {export_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + + # Model arguments + parser.add_argument( + "--pyt_ckpt_path", + type=str, + required=True, + help="Specify where the PyTorch checkpoint path is", + ) + parser.add_argument( + "--sparse_attn", + type=str, + default="skip_softmax", + choices=list(SPARSE_ATTN_CFG_CHOICES.keys()), + help="Sparse attention configuration to apply.", + ) + parser.add_argument( + "--backend", + type=str, + default="pytorch", + choices=["pytorch", "triton"], + help="Backend to use for sparse attention computation (default: pytorch)", + ) + + # Sequence length arguments + parser.add_argument( + "--seq_len", + type=int, + default=2048, + help="Maximum sequence length for input prompts (will be truncated if longer)", + ) + parser.add_argument( + "--num_samples", + type=int, + default=3, + help="Number of samples to use from NarrativeQA dataset", + ) + + # Generation arguments + parser.add_argument( + "--max_new_tokens", type=int, default=50, help="Maximum new tokens to generate" + ) + parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation") + parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling") + + # Operation arguments + parser.add_argument( + "--verify_output", + action="store_true", + help="Verify that sparse attention outputs match baseline", + ) + parser.add_argument( + "--export_dir", + type=str, + default=None, + help="Directory to export the model with sparse attention applied", + ) + + args = parser.parse_args() + main(args) diff --git a/examples/llm_sparse_attention/requirements.txt b/examples/llm_sparse_attention/requirements.txt new file mode 100644 index 000000000..186c0f9f1 --- /dev/null +++ b/examples/llm_sparse_attention/requirements.txt @@ -0,0 +1,4 @@ +accelerate +datasets +transformers + diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f966ffac6..1ab640608 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -338,6 +338,129 @@ def _export_quantized_weight( sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale) +def _get_sparse_attention_config(model: nn.Module) -> dict[str, Any]: + """Extract sparse attention configuration from model for export. + + Args: + model: Model with sparse attention modules + + Returns: + Dictionary with sparse attention config in format: + { + "config_groups": { + "group_0": { + "sparse_algo": "softmax_skip", + "threshold": 1e-4, # only if not calibrated + "targets": ["LlamaAttention"] + } + }, + "threshold_scale_factor": 0.001234, # global, if calibrated + "target_sparsity": 0.5, # global, if calibrated + "producer": {"name": "modelopt", "version": "..."} + } + """ + from modelopt import __version__ + from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule + + # Collect all enabled sparse attention modules + sparse_modules = [] + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule) and module.is_enabled: + sparse_modules.append((name, module)) + + if not sparse_modules: + return {} + + sparse_config = { + "config_groups": {}, + "producer": { + "name": "modelopt", + "version": __version__, + }, + } + + # Check first module for global calibration parameters + # (all modules share the same calibration parameters) + first_module = sparse_modules[0][1] + method_instance = first_module._sparse_method_instance + threshold_scale_factor = getattr(method_instance, "threshold_scale_factor", None) + + if threshold_scale_factor is not None: + # Model was calibrated: add global calibration parameters + sparse_config["threshold_scale_factor"] = float(threshold_scale_factor) + + target_sparsity = getattr(method_instance, "target_sparsity", None) + if target_sparsity is not None: + sparse_config["target_sparsity"] = float(target_sparsity) + + # Group modules by configuration + # Key: (sparse_algo, threshold_repr), Value: list of module class names + config_to_targets = {} + + for name, module in sparse_modules: + method_instance = module._sparse_method_instance + + # Extract sparse algorithm name from method name + # e.g., "flash_softmax_skip" -> "softmax_skip" + method_name = method_instance.name + if method_name.startswith("flash_"): + sparse_algo = method_name[6:] # Remove "flash_" prefix + else: + sparse_algo = method_name + + # Get module's original class name for targets + # Get the class name before SparseAttentionModule wrapping + original_cls = module.get_original_cls_by_level(level=0) + target_class_name = original_cls.__name__ + + # Build config key for grouping + if threshold_scale_factor is None: + # Not calibrated: include threshold in grouping + threshold_config = getattr(method_instance, "threshold_config", None) + if isinstance(threshold_config, dict): + # Convert dict to tuple for hashable key + threshold_repr = tuple(sorted(threshold_config.items())) + else: + threshold_repr = threshold_config + else: + # Calibrated: no threshold in per-layer config + threshold_repr = None + + config_key = (sparse_algo, threshold_repr) + + if config_key not in config_to_targets: + config_to_targets[config_key] = { + "sparse_algo": sparse_algo, + "threshold_config": threshold_config if threshold_scale_factor is None else None, + "targets": set(), + } + + config_to_targets[config_key]["targets"].add(target_class_name) + + # Convert grouped configs to config_groups format + for group_idx, ((sparse_algo, threshold_repr), group_data) in enumerate( + config_to_targets.items() + ): + group_name = f"group_{group_idx}" + group_config = { + "sparse_algo": group_data["sparse_algo"], + "targets": sorted(group_data["targets"]), + } + + # Add threshold only if not calibrated + if group_data["threshold_config"] is not None: + threshold_config = group_data["threshold_config"] + if isinstance(threshold_config, dict): + # Convert to JSON-serializable format + group_config["threshold"] = {k: float(v) for k, v in threshold_config.items()} + else: + group_config["threshold"] = float(threshold_config) + + sparse_config["config_groups"][group_name] = group_config + + return sparse_config + + def _export_hf_checkpoint( model: nn.Module, dtype: torch.dtype | None = None ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -543,6 +666,11 @@ def export_hf_checkpoint( config_data["quantization_config"] = hf_quant_config + # Add sparse attention config if model has sparse attention + sparse_attention_config = _get_sparse_attention_config(model) + if sparse_attention_config: + config_data["sparse_attention_config"] = sparse_attention_config + with open(original_config, "w") as file: json.dump(config_data, file, indent=4) diff --git a/modelopt/torch/sparsity/attention_sparsity/__init__.py b/modelopt/torch/sparsity/attention_sparsity/__init__.py new file mode 100644 index 000000000..150f93a3a --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible sparse attention optimization for transformer models.""" + +# Initialize mode +from . import mode + +# Add methods to namespace +from .config import * +from .conversion import * +from .model_sparsify import * diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py new file mode 100644 index 000000000..3b616e8e3 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +from .calibrate import calibrate_sparse_attention +from .calibrator import DynamicThresholdCalibrator +from .dataset import RulerDatasetBuilder + +__all__ = [ + "DynamicThresholdCalibrator", + "RulerDatasetBuilder", + "calibrate_sparse_attention", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py new file mode 100644 index 000000000..d7ec0387b --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration functions for sparse attention.""" + +import warnings +from collections.abc import Callable +from typing import Any + +import torch +import torch.nn as nn +from transformers import AutoTokenizer + +from ..config import CalibrationConfig +from ..nn.sparse_attention import SparseAttentionModule +from .calibrator import DynamicThresholdCalibrator +from .dataset import RulerDatasetBuilder + + +def _extract_tokenizer_from_model(model: nn.Module) -> str: + """Extract tokenizer name/path from model config. + + Args: + model: Model to extract tokenizer from + + Returns: + Tokenizer name or path + + Raises: + ValueError: If tokenizer path cannot be determined from model + """ + # Extract tokenizer path from model config + tokenizer_path = getattr(getattr(model, "config", None), "_name_or_path", None) + + if not tokenizer_path: + raise ValueError("Could not load tokenizer from model.") + + return tokenizer_path + + +def _extract_calibration_config(config: dict[str, Any]) -> CalibrationConfig | None: + """Extract and validate calibration config from sparse_cfg patterns. + + Args: + config: Sparse attention configuration dict + + Returns: + Validated CalibrationConfig or None if not found + """ + # Extract sparse_cfg and search for calibration + sparse_cfg = config.get("sparse_cfg", {}) + + calib_dict = next( + ( + cfg["calibration"] + for cfg in sparse_cfg.values() + if isinstance(cfg, dict) and "calibration" in cfg + ), + None, + ) + + # Create and calidate the calibration config + return CalibrationConfig(**calib_dict) if calib_dict else None + + +def create_calibration_forward_loop( + calibration_data: list[dict[str, Any]], + tokenizer_name_or_path: str, + batch_size: int = 1, +) -> Callable: + """Create forward loop for calibration. + + Args: + calibration_data: List of samples with 'input' and 'length' fields + tokenizer_name_or_path: HuggingFace tokenizer path + batch_size: Batch size (currently unused, always 1) + + Returns: + Forward loop function that takes model as argument + """ + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + def forward_loop(model: nn.Module) -> None: + device = next(model.parameters()).device + + for sample in calibration_data: + inputs = tokenizer( + sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"] + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.no_grad(): + model(**inputs) + + return forward_loop + + +def calibrate_sparse_attention( + model: nn.Module, + config: dict[str, Any], + forward_loop: Callable | None = None, +) -> dict[str, Any]: + """Calibrate sparse attention parameters for optimal sparsity. + + Args: + model: Model with sparse attention modules + config: Sparse attention configuration dict + forward_loop: Callable that forwards calibration data through model. + If None, auto-generates RULER dataset. + + Returns: + Dictionary with calibration results + """ + # Extract and validate calibration config + calib_config = _extract_calibration_config(config) + if not calib_config: + return {} + + # Generate forward_loop if not provided + if not forward_loop: + tokenizer = _extract_tokenizer_from_model(model) + builder = RulerDatasetBuilder( + samples=calib_config.samples, + max_seqlen=calib_config.max_seqlen, + tokenizer_name_or_path=tokenizer, + num_length_bins=calib_config.num_length_bins, + max_length_filter=int(calib_config.max_seqlen * 1.2), + ) + calibration_data = builder.build_calibration_dataset() + print(f"Generated {len(calibration_data)} calibration samples") + forward_loop = create_calibration_forward_loop(calibration_data, tokenizer) + + # Get sparse attention modules + sparse_modules = [ + (name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule) + ] + + if not sparse_modules: + print("No sparse attention modules found for calibration") + return {} + + print(f"Calibrating {len(sparse_modules)} sparse attention modules together...") + + # Run calibration + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=calib_config.target_sparse_ratio, + threshold_trials=calib_config.threshold_trials, + ) + calibration_result = calibrator.calibrate(model, forward_loop) + + if "scale_factor" not in calibration_result: + warnings.warn("Calibration did not produce valid results") + return {} + + # Apply calibrated scale factor to all modules + scale_factor = calibration_result["scale_factor"] + print(f"\nApplying calibrated scale factor={scale_factor:.6f} to {len(sparse_modules)} modules") + + for module_name, module in sparse_modules: + module._sparse_method_instance.threshold_scale_factor = scale_factor + module._sparse_method_instance.target_sparsity = calib_config.target_sparse_ratio + + return {"calibration_results": {name: calibration_result for name, _ in sparse_modules}} diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py new file mode 100644 index 000000000..43440ac13 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -0,0 +1,308 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +import warnings +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm + +from ..nn.sparse_attention import SparseAttentionModule +from ..nn.stats_manager import SparseAttentionStatsManager + + +class DynamicThresholdCalibrator: + """Dynamic threshold calibrator using length-based linear relationship. + + Implements calibration algorithm: + 1. Find hyperparameter 'a' where threshold λ = a / context_length + 2. Use dataset with different lengths and test multiple thresholds + 3. For each sample, find optimal threshold closest to target sparsity + 4. Use linear regression to fit: threshold = a * (1/length) + """ + + def __init__( + self, + target_sparse_ratio: float = 0.5, + threshold_trials: list[float] | None = None, + ): + """Initialize dynamic threshold calibrator. + + Args: + target_sparse_ratio: Target sparsity ratio (0.0 to 1.0) + threshold_trials: List of thresholds to try during calibration + + Note: + Calibration only supports prefill phase (seq_len > 1). + Decode phase uses the same calibrated threshold. + """ + self.target_sparse_ratio = target_sparse_ratio + + # Default threshold trials if not provided + self.threshold_trials = threshold_trials or [ + 1e-6, + 5e-6, + 1e-5, + 5e-5, + 1e-4, + 5e-4, + 1e-3, + 5e-3, + 1e-2, + 5e-2, + 1e-1, + 5e-1, + ] + + # Statistics tracking + self.sparsity_results = [] + + def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]: + """Find optimal 'a' parameter for length-based threshold. + + Algorithm: + 1. Test all threshold trials by running forward_loop multiple times + 2. For each sample, find optimal threshold closest to target sparsity + 3. Use regression to find 'a' in: threshold = a / length + + Args: + model: The model with sparse attention modules + forward_loop: Callable that takes model and forwards calibration data + """ + # Extract attention modules + attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] + + if not attention_modules: + warnings.warn("No sparse attention modules found for calibration") + return {} + + print("Starting dynamic threshold calibration") + print(f"Target sparsity: {self.target_sparse_ratio}") + print(f"Threshold trials: {len(self.threshold_trials)}") + + # Stage 1: Collect sparsity for all sample-threshold pairs + print("\nStage 1: Collecting sparsity data...") + self.sparsity_results = [] + + # For each threshold, run forward_loop and collect per-sample statistics + for threshold_idx, threshold in enumerate( + tqdm(self.threshold_trials, desc="Testing thresholds") + ): + # Set threshold and enable calibration mode + self._set_threshold(attention_modules, threshold) + self._enable_calibration_mode(attention_modules) + + # Run forward loop and collect stats + with torch.no_grad(): + forward_loop(model) + per_sample_stats = self._extract_calibration_stats(attention_modules) + self._disable_calibration_mode(attention_modules) + + # Store results + for sample_idx, sample_stat in enumerate(per_sample_stats): + if threshold_idx == 0: + # Initialize sample entry on first threshold + sample_length = sample_stat.get("sample_length", 0) + if sample_length > 0: + self.sparsity_results.append( + { + "sample_index": sample_idx, + "length": sample_length, + "threshold_sparsities": {}, + } + ) + + # Add sparsity for this threshold + if sample_idx < len(self.sparsity_results): + sparsity = sample_stat.get("sparsity", 0.0) + self.sparsity_results[sample_idx]["threshold_sparsities"][threshold] = sparsity + + if not self.sparsity_results: + warnings.warn("No valid sparsity measurements collected during calibration") + return {} + + print(f"Collected statistics for {len(self.sparsity_results)} samples") + + # Stage 2: Find optimal threshold for each sample and compute scale factor + print( + f"\nStage 2: Finding threshold scale factor for target sparsity {self.target_sparse_ratio:.2f}" + ) + + # Find optimal threshold for each sample + optimal_pairs = [] + for sample_result in self.sparsity_results: + # Find threshold closest to target sparsity + best_threshold, achieved_sparsity = min( + sample_result["threshold_sparsities"].items(), + key=lambda item: abs(item[1] - self.target_sparse_ratio), + ) + + optimal_pairs.append( + { + "length": sample_result["length"], + "optimal_threshold": best_threshold, + "achieved_sparsity": achieved_sparsity, + "target_sparsity": self.target_sparse_ratio, + } + ) + + if not optimal_pairs: + warnings.warn( + f"No optimal threshold pairs found for target sparsity {self.target_sparse_ratio}. " + f"Collected {len(self.sparsity_results)} samples but none achieved target sparsity." + ) + return {} + + # Linear regression: threshold = a * (1/length) + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + # X = 1/length, Y = threshold + x = 1.0 / lengths + y = thresholds + + # Least squares: scale_factor = sum(x*y) / sum(x^2) + scale_factor = np.sum(x * y) / np.sum(x**2) + + # Calculate statistics + scale_factors_per_sample = y * lengths + scale_factor_std = np.std(scale_factors_per_sample) + + # Calculate R-squared for quality metric + y_pred = scale_factor * x + ss_res = np.sum((y - y_pred) ** 2) + ss_tot = np.sum((y - np.mean(y)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + # Calculate average achieved sparsity + avg_achieved_sparsity = np.mean([p["achieved_sparsity"] for p in optimal_pairs]) + + print("\nCalibration Results:") + print(f" Threshold scale factor: {scale_factor:.6f} (std: {scale_factor_std:.6f})") + print(f" R-squared: {r_squared:.4f}") + print( + f" Average achieved sparsity: {avg_achieved_sparsity:.2%} (target: {self.target_sparse_ratio:.2%})" + ) + print(f"\nExample thresholds with λ = {scale_factor:.6f} / length:") + for length in [1024, 2048, 4096, 8192, 16384]: + print(f" Length {length:5d}: threshold = {scale_factor / length:.2e}") + + # Apply the calibrated scale factor to modules + self._apply_length_based_calibration(attention_modules, scale_factor) + + return { + "scale_factor": scale_factor, + "scale_factor_std": scale_factor_std, + "r_squared": r_squared, + "num_samples": len(optimal_pairs), + "target_sparsity": self.target_sparse_ratio, + "avg_achieved_sparsity": avg_achieved_sparsity, + "optimal_pairs": optimal_pairs, + "calibration_type": "length_based_dynamic", + } + + def _apply_length_based_calibration(self, modules: list[nn.Module], scale_factor: float): + """Apply calibrated threshold scale factor to modules. + + Args: + modules: List of attention modules + scale_factor: Calibrated scale factor for λ = scale_factor / length + """ + for module in modules: + module._sparse_method_instance.threshold_scale_factor = scale_factor + + def _enable_calibration_mode(self, modules: list[nn.Module]): + """Enable calibration mode on sparse attention modules.""" + for idx, module in enumerate(modules): + # Create stats manager if needed + if not module._stats_manager: + module._stats_manager = SparseAttentionStatsManager( + module_name=f"sparse_attn_{idx}", enabled=True + ) + else: + # Re-enable if disabled + module._stats_manager.enabled = True + + # Enable calibration mode with fresh stats + module._stats_manager.set_calibration_mode(enabled=True, reset_history=True) + module._sparse_method_instance.set_calibration_mode(True) + + def _disable_calibration_mode(self, modules: list[nn.Module]): + """Disable calibration mode (but keep stats enabled if collect_stats=True).""" + for module in modules: + if module._stats_manager: + module._stats_manager.set_calibration_mode(enabled=False) + + module._sparse_method_instance.set_calibration_mode(False) + + def _extract_calibration_stats(self, modules: list[nn.Module]) -> list[dict]: + """Extract per-sample calibration statistics from modules. + + Args: + modules: List of attention modules + + Returns: + List of per-sample statistics across all modules + """ + # Collect from all stats managers + all_per_sample_stats = [] + + for module in modules: + # Skip modules without stats manager + if not hasattr(module, "_stats_manager") or module._stats_manager is None: + continue + + manager_stats = module._stats_manager.get_calibration_stats() + if manager_stats: + all_per_sample_stats.append(manager_stats) + + if not all_per_sample_stats: + return [] + + # Aggregate across modules by sample index + num_samples = len(all_per_sample_stats[0]) + aggregated_stats = [] + + for sample_idx in range(num_samples): + sparsities = [] + sample_length = 0 + + for module_stats in all_per_sample_stats: + if sample_idx < len(module_stats): + sample_stat = module_stats[sample_idx] + sparsities.append(sample_stat.get("sparsity", 0.0)) + if not sample_length and "sample_length" in sample_stat: + sample_length = sample_stat["sample_length"] + + avg_sparsity = np.mean(sparsities) if sparsities else 0.0 + + aggregated_stats.append( + { + "sparsity": avg_sparsity, + "sample_length": sample_length, + } + ) + + return aggregated_stats + + def _set_threshold(self, modules: list[nn.Module], threshold: float): + """Set threshold on sparse attention modules.""" + for module in modules: + module._sparse_method_instance.threshold = threshold diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py new file mode 100644 index 000000000..727fb189a --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -0,0 +1,602 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RULER dataset builder for sparse attention calibration.""" + +import random +import string +import uuid +from dataclasses import dataclass +from typing import Any + +from tqdm import tqdm +from transformers import AutoTokenizer + + +def _generate_target_lengths( + max_seqlen: int, num_length_bins: int = 4, min_seqlen: int = 1024 +) -> list[int]: + """Generate target lengths as descending powers of 2. + + Args: + max_seqlen: Maximum sequence length + num_length_bins: Maximum number of length bins to generate + min_seqlen: Minimum sequence length threshold + + Returns: + List of target lengths in descending order + + Examples: + >>> _generate_target_lengths(32768, 4) + [32768, 16384, 8192, 4096] + >>> _generate_target_lengths(2048, 4) + [2048, 1024] + """ + target_lengths = [] + current = max_seqlen + + for _ in range(num_length_bins): + if current < min_seqlen: + break + target_lengths.append(current) + current = current // 2 + + return target_lengths + + +@dataclass +class RulerTask: + """Configuration for a RULER task.""" + + name: str + task_type: str # niah, variable_tracking, freq_words_extraction, qa + tokens_to_generate: int + template: str + answer_prefix: str + args: dict[str, Any] + + +# Task configurations based on RULER benchmark +RULER_TASKS = { + "niah_multikey_2": RulerTask( + name="niah_multikey_2", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "words", + "type_needle_v": "numbers", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "niah_multikey_3": RulerTask( + name="niah_multikey_3", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "uuids", + "type_needle_v": "uuids", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "vt": RulerTask( + name="vt", + task_type="variable_tracking", + tokens_to_generate=30, + template=( + "Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n" + "{context}\n" + "Question: Find all variables that are assigned the value {query} in the text above." + ), + answer_prefix=( + " Answer: According to the chain(s) of variable assignment in the text above, " + "{num_v} variables are assgined the value {query}, they are: " + ), + args={"num_chains": 1, "num_hops": 4}, + ), + "fwe": RulerTask( + name="fwe", + task_type="freq_words_extraction", + tokens_to_generate=50, + template=( + "Read the following coded text and track the frequency of each coded word. " + "Find the three most frequently appeared coded words. {context}\n" + "Question: Do not provide any explanation. Please ignore the dots '....'. " + "What are the three most frequently appeared words in the above coded text?" + ), + answer_prefix=( + " Answer: According to the coded text above, " + "the three most frequently appeared words are:" + ), + args={}, + ), + "qa_1": RulerTask( + name="qa_1", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "squad", "demo_selection": "topk_lt_ctx"}, + ), + "qa_2": RulerTask( + name="qa_2", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "hotpotqa", "demo_selection": "topk_lt_ctx"}, + ), +} + + +class RulerDatasetBuilder: + """Builder for RULER calibration datasets.""" + + def __init__( + self, + samples: int, + max_seqlen: int, + tokenizer_name_or_path: str, + seed: int = 42, + num_length_bins: int = 4, + max_length_filter: int = 65536, + ): + """Initialize RULER dataset builder. + + Args: + samples: Total number of samples to generate (distributed evenly across length bins) + max_seqlen: Maximum sequence length (length bins auto-generated as powers of 2) + tokenizer_name_or_path: HuggingFace tokenizer path + seed: Random seed for reproducibility + num_length_bins: Number of length bins to generate (default: 4) + max_length_filter: Maximum sequence length to keep (default: 65536) + + Note: + Length bins are auto-generated as descending powers of 2: + [max_seqlen, max_seqlen/2, max_seqlen/4, ...] + Generation stops when num_length_bins is reached or length < 1024. + Subtasks are set to all the difficult tasks defined in RULER_TASKS. + """ + # Validate inputs + if samples <= 0: + raise ValueError(f"samples must be positive, got {samples}") + if max_seqlen < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {max_seqlen}") + + # Store parameters + self.total_samples = samples + self.max_seqlen = max_seqlen + self.num_length_bins = num_length_bins + self.subtasks = list(RULER_TASKS.keys()) + self.tokenizer_name_or_path = tokenizer_name_or_path + self.seed = seed + self.max_length_filter = max_length_filter + + # Generate target lengths and validate + self.target_lengths = _generate_target_lengths(max_seqlen, num_length_bins, min_seqlen=1024) + if not self.target_lengths: + raise ValueError(f"No valid target lengths generated from max_seqlen={max_seqlen}") + + # Distribute samples evenly across lengths + self.samples_per_length = [samples // len(self.target_lengths)] * len(self.target_lengths) + + # Initialize tokenizer and seed + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + random.seed(seed) + + def build_calibration_dataset(self) -> list[dict[str, Any]]: + """Build the complete calibration dataset. + + Returns: + List of calibration samples with 'input' and 'length' fields + """ + all_samples = [] + + # Generate calibration samples + for num_samples, target_length in tqdm( + zip(self.samples_per_length, self.target_lengths), + desc="Generating RULER calibration samples", + total=len(self.target_lengths), + ): + samples_per_task = num_samples // len(self.subtasks) + if samples_per_task <= 0: + continue + + # Generate equal samples for each task + for task_name in self.subtasks: + for sample_idx in range(samples_per_task): + sample = self._generate_sample(task_name, target_length, sample_idx) + if sample and sample["length"] <= self.max_length_filter: + all_samples.append(sample) + + random.shuffle(all_samples) + return all_samples + + def _generate_sample( + self, task_name: str, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a single RULER sample. + + Args: + task_name: Name of the RULER task + target_length: Target sequence length in tokens + sample_idx: Index of the sample (for uniqueness) + + Returns: + Dict with 'input', 'length', and metadata fields + """ + task = RULER_TASKS[task_name] + + if task.task_type == "niah": + return self._generate_niah_sample(task, target_length, sample_idx) + elif task.task_type == "variable_tracking": + return self._generate_vt_sample(task, target_length, sample_idx) + elif task.task_type == "freq_words_extraction": + return self._generate_fwe_sample(task, target_length, sample_idx) + elif task.task_type == "qa": + return self._generate_qa_sample(task, target_length, sample_idx) + else: + raise ValueError(f"Unknown task type: {task.task_type}") + + def _generate_niah_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a needle-in-haystack sample.""" + args = task.args + + # Generate needles based on type + if args["type_needle_k"] == "words": + needle_keys = [self._generate_random_word() for _ in range(args["num_needle_k"])] + else: # uuids + needle_keys = [str(uuid.uuid4()) for _ in range(args["num_needle_k"])] + + if args["type_needle_v"] == "numbers": + needle_values = [ + str(random.randint(100000, 999999)) for _ in range(args["num_needle_v"]) + ] + else: # uuids + needle_values = [str(uuid.uuid4()) for _ in range(args["num_needle_v"])] + + # Select query needles + query_keys = random.sample(needle_keys, min(args["num_needle_q"], len(needle_keys))) + + # Generate context with needles + context = self._generate_niah_context( + target_length, needle_keys, needle_values, args["type_haystack"] + ) + + # Format template + template = task.template.format( + type_needle_v=args["type_needle_v"], + context=context, + query=query_keys[0] if query_keys else needle_keys[0], + ) + + # Add answer prefix + full_input = template + task.answer_prefix.format( + type_needle_v=args["type_needle_v"], + query=query_keys[0] if query_keys else needle_keys[0], + ) + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_vt_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a variable tracking sample.""" + args = task.args + num_chains = args["num_chains"] + num_hops = args["num_hops"] + + # Generate variable chains + variables = [] + chains = [] + for _ in range(num_chains): + chain = [self._generate_random_variable() for _ in range(num_hops + 1)] + variables.extend(chain) + chains.append(chain) + + # Generate assignments + assignments = [ + f"VAR {chain[i]} = {chain[i + 1]}" for chain in chains for i in range(len(chain) - 1) + ] + + # Create context with padding + context = self._pad_context_with_text( + "\n".join(assignments), target_length, "variable tracking context" + ) + + # Select a query value + query_value = random.choice([chain[-1] for chain in chains]) + + # Format template + template = task.template.format(context=context, query=query_value) + + # Count variables with the query value + num_v = sum(1 for chain in chains if chain[-1] == query_value) + + # Add answer prefix + full_input = template + task.answer_prefix.format(num_v=num_v, query=query_value) + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_fwe_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a frequency word extraction sample.""" + # Generate coded words with frequencies + num_unique_words = 50 + coded_words = [self._generate_coded_word() for _ in range(num_unique_words)] + + # Assign frequencies (make top 3 clearly more frequent) + frequencies = {} + for i, word in enumerate(coded_words): + if i < 3: + frequencies[word] = random.randint(20, 30) # High frequency + else: + frequencies[word] = random.randint(1, 10) # Low frequency + + # Generate the coded text + word_list = [] + for word, freq in frequencies.items(): + word_list.extend([word] * freq) + random.shuffle(word_list) + + # Add dots for separation + coded_text = " .... ".join(word_list) + + # Pad to target length + context = self._pad_context_with_text(coded_text, target_length, "coded text padding") + + # Format template + template = task.template.format(context=context) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_qa_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a QA sample.""" + # Generate synthetic documents + num_docs = 5 + documents = [] + + # Create a simple QA pair + answer = self._generate_random_phrase() + question = f"What is the special code mentioned in document {random.randint(1, num_docs)}?" + + for i in range(num_docs): + doc_text = self._generate_document_text(200) # Base document + if i == 2: # Insert answer in one document + doc_text += f" The special code is {answer}. " + documents.append(f"Document {i + 1}:\n{doc_text}\n") + + # Combine documents + context_base = "\n".join(documents) + + # Pad to target length + context = self._pad_context_with_text( + context_base, target_length, "additional document text" + ) + + # Format template + template = task.template.format(context=context, query=question) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_niah_context( + self, + target_length: int, + needle_keys: list[str], + needle_values: list[str], + haystack_type: str, + ) -> str: + """Generate context for needle-in-haystack tasks.""" + # Create needle sentences + needles = [] + for key, value in zip(needle_keys, needle_values): + needles.append(f"The magic number for {key} is {value}.") + + # Generate haystack based on type + if haystack_type == "repeat": + base_text = "The grass is green. The sky is blue. The sun is yellow. " * 100 + elif haystack_type == "essay": + base_text = self._generate_essay_text(500) + else: # needle type - more needles as distractors + base_text = self._generate_needle_haystack(500) + + # Insert needles at random positions + words = base_text.split() + for needle in needles: + insert_pos = random.randint(0, len(words)) + words.insert(insert_pos, needle) + + context = " ".join(words) + + # Adjust to target length + tokens = self.tokenizer.encode(context, add_special_tokens=False) + while len(tokens) < target_length * 0.7: # Leave room for template + context += " " + self._generate_essay_text(100) + tokens = self.tokenizer.encode(context, add_special_tokens=False) + + if len(tokens) > target_length * 0.9: + # Truncate if too long + context = self.tokenizer.decode(tokens[: int(target_length * 0.8)]) + + return context + + def _pad_context_with_text( + self, base_context: str, target_length: int, padding_type: str + ) -> str: + """Pad context to approach target length.""" + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + while len(tokens) < target_length * 0.7: # Leave room for template + if padding_type == "variable tracking context": + padding = ( + f" VAR {self._generate_random_variable()} = {self._generate_random_variable()}." + ) + elif padding_type == "coded text padding": + padding = f" .... {self._generate_coded_word()} .... " + else: + padding = " " + self._generate_essay_text(50) + + base_context += padding + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + if len(tokens) > target_length * 0.9: + # Truncate if too long + base_context = self.tokenizer.decode(tokens[: int(target_length * 0.8)]) + + return base_context + + def _generate_random_word(self) -> str: + """Generate a random word.""" + return "".join(random.choices(string.ascii_lowercase, k=random.randint(5, 10))) + + def _generate_random_variable(self) -> str: + """Generate a random variable name.""" + return "".join(random.choices(string.ascii_uppercase, k=1)) + "".join( + random.choices(string.digits, k=3) + ) + + def _generate_coded_word(self) -> str: + """Generate a coded word.""" + return "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) + + def _generate_random_phrase(self) -> str: + """Generate a random phrase.""" + words = [self._generate_random_word() for _ in range(random.randint(2, 4))] + return " ".join(words) + + def _generate_essay_text(self, num_words: int) -> str: + """Generate essay-like text.""" + topics = [ + "technology", + "science", + "nature", + "history", + "culture", + "education", + "health", + "economics", + "politics", + "philosophy", + "art", + "literature", + ] + + sentences = [] + words_generated = 0 + + while words_generated < num_words: + topic = random.choice(topics) + word1 = self._generate_random_word() + word2 = self._generate_random_word() + word3 = self._generate_random_word() + sentence = f"The {topic} of {word1} is {word2} and {word3}. " + sentences.append(sentence) + words_generated += len(sentence.split()) + + return " ".join(sentences) + + def _generate_document_text(self, num_words: int) -> str: + """Generate document-like text.""" + return self._generate_essay_text(num_words) + + def _generate_needle_haystack(self, num_items: int) -> str: + """Generate haystack with many needle-like distractors.""" + items = [] + for _ in range(num_items): + key = self._generate_random_word() + value = str(random.randint(100000, 999999)) + items.append(f"The value for {key} is {value}.") + + return " ".join(items) diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py new file mode 100644 index 000000000..b0f342d29 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -0,0 +1,358 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration classes for sparse attention optimization.""" + +from collections.abc import Callable +from typing import Any + +from pydantic import Field, field_validator + +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField + +# Type definitions for sparse configuration +SparseAttributeConfig = dict[str, Any] # Configuration for a specific pattern + +SparseAttentionCfgType = dict[ + str | Callable, # Pattern or callable for matching modules + SparseAttributeConfig, # Configuration dict with threshold, enable, etc. +] + + +class SparseAttentionAttributeConfig(ModeloptBaseConfig): + """Sparse attention attribute configuration for pattern-based module config.""" + + enable: bool = ModeloptField( + default=True, + title="Enable sparse attention.", + description="If True, enables sparse attention. If False, bypasses sparsity.", + ) + + method: str = ModeloptField( + default="flash_softmax_skip", + title="Sparse attention method.", + description="The sparse attention method to use (e.g., 'flash_softmax_skip').", + ) + + threshold: float | dict[str, float] = ModeloptField( + default=1e-3, + title="Sparsity threshold.", + description=( + "Threshold for determining which attention values to skip. " + "Can be a float or dict with phase-specific values." + ), + ) + + br: int = ModeloptField( + default=128, + title="Block row size.", + description="Block row size for block-wise sparsity in Flash Attention.", + ) + + bc: int = ModeloptField( + default=128, + title="Block column size.", + description="Block column size for block-wise sparsity in Flash Attention.", + ) + + collect_stats: bool = ModeloptField( + default=False, + title="Collect statistics.", + description="Whether to collect sparsity statistics during forward pass.", + ) + + backend: str = ModeloptField( + default="pytorch", + title="Backend implementation.", + description=( + "Backend to use for sparse attention computation. " + "Only 'pytorch' is supported, which uses softmax patching with F.softmax. " + "Requires model to be loaded with attn_implementation='eager'." + ), + ) + + is_causal: bool = ModeloptField( + default=True, + title="Causal attention flag.", + description=( + "Whether the model uses causal (autoregressive) attention. " + "If True, sparsity statistics are calculated over the lower triangle only. " + "Defaults to True for decoder-only models like GPT, LLaMA, etc." + ), + ) + + calibration: dict | None = ModeloptField( + default=None, + title="Calibration configuration", + description=( + "Calibration settings for this pattern. " + "If provided, enables automatic threshold calibration. " + "Only one pattern should have calibration enabled." + ), + ) + + @field_validator("method") + @classmethod + def validate_method(cls, v): + """Validate method is a string.""" + if not isinstance(v, str): + raise ValueError("method must be a string") + return v + + @field_validator("backend") + @classmethod + def validate_backend(cls, v): + """Validate backend is pytorch.""" + if v != "pytorch": + raise ValueError( + f"Invalid backend: {v}. Only 'pytorch' backend is supported. " + f"Model must be loaded with attn_implementation='eager'." + ) + return v + + @field_validator("br", "bc") + @classmethod + def validate_block_size(cls, v): + """Validate block sizes are positive integers.""" + if v <= 0: + raise ValueError(f"Block size must be positive, got {v}") + return v + + @field_validator("threshold") + @classmethod + def validate_threshold(cls, v): + """Validate threshold is in valid range (0, 1) or dict with valid phases.""" + if isinstance(v, dict): + # Validate phase keys + valid_phases = {"prefill", "decode", "default"} + invalid_keys = set(v.keys()) - valid_phases + if invalid_keys: + raise ValueError( + f"Invalid threshold phases: {invalid_keys}. Valid phases: {valid_phases}" + ) + # Validate all values are in range (0, 1) + for phase, threshold in v.items(): + if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1: + raise ValueError( + f"Threshold for phase '{phase}' must be in range (0, 1), got {threshold}" + ) + elif isinstance(v, (int, float)): + if v <= 0 or v >= 1: + raise ValueError(f"Threshold must be in range (0, 1), got {v}") + else: + raise ValueError(f"Threshold must be a number in range (0, 1) or dict, got {type(v)}") + return v + + +class CalibrationConfig(ModeloptBaseConfig): + """Configuration for automatic threshold calibration using RULER dataset. + + Calibration learns a dynamic threshold λ = scale_factor / sequence_length that + achieves target sparsity. Only supports prefill phase (seq_len > 1). + """ + + target_sparse_ratio: float = ModeloptField( + default=0.5, + title="Target sparsity ratio", + description="Target ratio of sparse attention blocks (0.0 to 1.0).", + ) + + samples: int = ModeloptField( + default=24, + title="Calibration samples", + description="Total number of RULER samples for calibration (distributed across length bins).", + ) + + max_seqlen: int = ModeloptField( + default=32768, + title="Maximum sequence length", + description="Maximum sequence length for calibration (length bins auto-generated as powers of 2).", + ) + + num_length_bins: int = ModeloptField( + default=4, + title="Number of length bins", + description="Number of length bins to generate (hidden parameter, default: 4).", + ) + + threshold_trials: list[float] | None = ModeloptField( + default=None, + title="Threshold trials", + description=( + "List of threshold values to test during calibration. " + "If None, uses default: [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]" + ), + ) + + @field_validator("threshold_trials") + @classmethod + def validate_threshold_trials(cls, v): + """Validate threshold_trials are in valid range.""" + if v is not None: + if not isinstance(v, list): + raise ValueError(f"threshold_trials must be a list, got {type(v)}") + if len(v) == 0: + raise ValueError("threshold_trials must not be empty") + for threshold in v: + if not isinstance(threshold, (int, float)): + raise ValueError(f"All threshold_trials must be numbers, got {type(threshold)}") + if threshold <= 0 or threshold >= 1: + raise ValueError( + f"All threshold_trials must be in range (0, 1), got {threshold}" + ) + return v + + @field_validator("target_sparse_ratio") + @classmethod + def validate_target_sparse_ratio(cls, v): + """Validate target sparsity ratio is between 0 and 1.""" + if not 0.0 <= v <= 1.0: + raise ValueError(f"target_sparse_ratio must be between 0.0 and 1.0, got {v}") + return v + + @field_validator("samples") + @classmethod + def validate_samples(cls, v): + """Validate samples is positive.""" + if v <= 0: + raise ValueError(f"samples must be positive, got {v}") + return v + + @field_validator("max_seqlen") + @classmethod + def validate_max_seqlen(cls, v): + """Validate max_seqlen is at least 1024.""" + if v < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {v}") + return v + + @field_validator("num_length_bins") + @classmethod + def validate_num_length_bins(cls, v): + """Validate num_length_bins is positive.""" + if v <= 0: + raise ValueError(f"num_length_bins must be positive, got {v}") + return v + + +# Pre-defined Sparse Attention Configuration +# Default configuration with block-wise sparsity optimized for Flash Attention +SKIP_SOFTMAX_DEFAULT = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": { + "prefill": 1e-3, # More aggressive during prefill + "decode": 1e-4, # Conservative during decode + }, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +# Configuration with RULER calibration +# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length +# The calibrated threshold adapts to sequence length for optimal sparsity +SKIP_SOFTMAX_CALIB = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "br": 128, + "bc": 128, + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + "calibration": { + "target_sparse_ratio": 0.3, + "samples": 12, + "max_seqlen": 1024, + }, + }, + "default": {"enable": False}, + }, +} + + +class SparseAttentionConfig(ModeloptBaseConfig): + """Base configuration for sparse attention optimization. + + This base configuration provides the common structure for all sparse + attention methods and supports pattern-based layer configuration. + """ + + # Method selection + method: str = Field("flash_softmax_skip", description="Sparse attention method to use") + + # Statistics collection + collect_stats: bool = Field( + False, description="Whether to collect sparsity statistics during forward pass" + ) + + # Pattern-based sparse configuration (similar to quant_cfg in quantization) + sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={"*attention*": {"enable": True}, "default": {"enable": False}}, + title="Sparse attention configuration", + description="Pattern-based configuration for sparse attention. Keys are patterns to match module names, " + "values are configuration dicts with parameters like 'threshold', 'enable', and 'calibration'.", + validate_default=True, + ) + + # Export configuration + export_format: str | None = Field( + None, description="Export format for sparse attention (e.g., 'onnx', 'tensorrt')" + ) + + +class FlashSoftmaxSkipConfig(SparseAttentionConfig): + """Configuration for Flash Attention-aware softmax skip sparse attention.""" + + # Override method to default to flash_softmax_skip + method: str = Field( + "flash_softmax_skip", description="Sparse attention method (fixed to flash_softmax_skip)" + ) + + # Override sparse_cfg with flash_softmax_skip specific defaults + sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={ + "*attention*": { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + "default": {"enable": False}, + }, + title="Flash softmax skip sparse configuration", + description="Pattern-based configuration with flash_softmax_skip specific defaults. " + "Includes FA block sizes (br, bc) and correction factor settings.", + validate_default=True, + ) + + +__all__ = [ + "SKIP_SOFTMAX_CALIB", + "SKIP_SOFTMAX_DEFAULT", + "CalibrationConfig", + "FlashSoftmaxSkipConfig", + "SparseAttentionAttributeConfig", + "SparseAttentionCfgType", + "SparseAttentionConfig", + "SparseAttributeConfig", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py new file mode 100644 index 000000000..04af2dc98 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -0,0 +1,406 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversion and restoration utilities for sparse attention.""" + +import fnmatch +from collections.abc import Callable +from typing import Any + +import torch.nn as nn + +from modelopt.torch.opt.conversion import ModelLikeModule, ModeloptStateManager +from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict +from modelopt.torch.utils import get_unwrapped_name + +from .config import SparseAttentionConfig +from .nn.sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from .plugins.huggingface import register_sparse_attention_on_the_fly + + +def is_attn_sparsified(model: nn.Module) -> bool: + """Check if a model has sparse attention applied. + + Similar to quantization's is_quantized for API consistency. + + Args: + model: Model to check + + Returns: + True if model contains any SparseAttentionModule instances + """ + return any(isinstance(module, SparseAttentionModule) for module in model.modules()) + + +def convert_to_sparse_attention_model( + model: ModelLikeModule, config: SparseAttentionConfig +) -> ConvertReturnType: + """Convert model to use sparse attention. + + Args: + model: Model to convert + config: Sparse attention configuration + + Returns: + Tuple of (converted_model, metadata) + """ + # Initialize the true module if necessary + model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + + # Register sparse attention modules dynamically + register_sparse_attention_on_the_fly(model) + + # Replace attention modules with sparse versions + replace_sparse_attention_modules(model, version=ModeloptStateManager(model).state_version) + + # Apply configuration to sparse attention modules + sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {} + set_sparse_attention_by_cfg(model, sparse_cfg, config) + + # Create metadata + metadata = {} + update_sparse_attention_metadata(model, config, metadata) + + return model, metadata + + +def replace_sparse_attention_modules(model: nn.Module, version=None): + """Replace regular attention modules with sparse attention modules. + + Recursively replace all attention modules in the model with their sparse attention counterparts. + + Args: + model: Model to process + version: State version for tracking (optional) + """ + # Recursively replace modules + _replace_sparse_attention_modules(model, version=version) + + # Count and report replaced modules + replaced_count = sum(isinstance(m, SparseAttentionModule) for _, m in model.named_modules()) + if replaced_count > 0: + print(f"Inserted {replaced_count} sparse attention modules") + + +def _replace_sparse_attention_modules(model: nn.Module, version=None): + """Helper function for replace_sparse_attention_modules.""" + for name, child in model.named_children(): + if type(child) in SparseAttentionRegistry: + # REPLACE on the parent (model), not on child + sparse_module = SparseAttentionRegistry.convert(child) + setattr(model, name, sparse_module) + + # Now recurse into whichever module is now at `model.name` + _replace_sparse_attention_modules(getattr(model, name), version=version) + + +def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict, config: SparseAttentionConfig): + """Apply sparse attention configuration to model. + + Similar to quantization's set_quantizer_by_cfg. + + Args: + model: Model with sparse attention modules + sparse_cfg: Sparse configuration dictionary + config: Global sparse attention configuration + """ + sparse_cfg = sparse_cfg.copy() + + # Apply default first if exists + if "default" in sparse_cfg: + set_sparse_attention_attribute(model, "*", sparse_cfg["default"], config) + sparse_cfg.pop("default") + + # Apply pattern-specific configs + for pattern, cfg in sparse_cfg.items(): + set_sparse_attention_attribute(model, pattern, cfg, config) + + +def set_sparse_attention_attribute( + model: nn.Module, + wildcard_or_filter: str | Callable, + attribute_cfg: dict[str, Any], + global_config: SparseAttentionConfig, +): + """Set sparse attention attributes for modules matching pattern. + + Similar to quantization's set_quantizer_attribute. + + Args: + model: Model to configure + wildcard_or_filter: Pattern to match module names + attribute_cfg: Attributes to apply + global_config: Global sparse attention configuration + """ + # Merge global config fields with pattern config + # Filter out model-level configs that shouldn't be passed to modules + module_cfg = {k: v for k, v in attribute_cfg.items() if k != "calibration"} + + full_cfg = { + "method": global_config.method, + "collect_stats": global_config.collect_stats, + **module_cfg, + } + + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + # Check pattern match + matched = False + if isinstance(wildcard_or_filter, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter) + elif callable(wildcard_or_filter): + matched = wildcard_or_filter(name) + else: + continue + + if matched: + # Apply config using the same method as TensorQuantizer + module.set_from_attribute_config(full_cfg) + + +def restore_sparse_attention_model( + model: ModelLikeModule, config: SparseAttentionConfig, metadata: MetadataDict +) -> nn.Module: + """Restore sparse attention model from saved state. + + Args: + model: Model to restore + config: Sparse attention configuration + metadata: Saved metadata + + Returns: + Restored model + """ + # Convert to sparse attention model + model, _ = convert_to_sparse_attention_model(model, config) + + # Restore sparse attention state from metadata + if "sparse_attention_state" in metadata: + restore_sparse_attention_state(model, metadata["sparse_attention_state"]) + + return model + + +def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]): + """Restore sparse attention state from state dict. + + Args: + model: Model with sparse attention modules + state_dict: Saved state dictionary + """ + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + module_name = get_unwrapped_name(name, model) + if module_name in state_dict: + module_state = state_dict[module_name] + + # Restore method and config + if "method" in module_state: + module._method = module_state["method"] + if "method_config" in module_state: + # Restore config attributes + # Separate method instance attributes from module attributes + method_instance_attrs = {"threshold_scale_factor", "target_sparsity"} + + for key, val in module_state["method_config"].items(): + if key not in method_instance_attrs: + # Set on module + setattr(module, f"_{key}", val) + + # Re-setup with restored config + module._setup() + + # Restore method instance attributes after _setup + if "method_config" in module_state: + for key, val in module_state["method_config"].items(): + if key in {"threshold_scale_factor", "target_sparsity"}: + setattr(module._sparse_method_instance, key, val) + + +def update_sparse_attention_metadata( + model: nn.Module, config: SparseAttentionConfig, metadata: MetadataDict +) -> None: + """Update metadata with sparse attention state. + + Args: + model: Model with sparse attention + config: Configuration used + metadata: Metadata dict to update + """ + sparse_state = {} + + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + module_name = get_unwrapped_name(name, model) + + # Collect method config from module attributes + method_config = { + k[1:]: v + for k, v in module.__dict__.items() + if k.startswith("_") and k not in ("_method", "_enabled", "_sparse_method_instance") + } + + # Also collect calibration-related attributes from method instance + method_instance = module._sparse_method_instance + for attr in ["threshold_scale_factor", "target_sparsity"]: + if hasattr(method_instance, attr): + val = getattr(method_instance, attr) + if val is not None: + method_config[attr] = val + + module_state = { + "method": method_instance.name, + "method_config": method_config, + } + + sparse_state[module_name] = module_state + + metadata["sparse_attention_state"] = sparse_state + metadata["sparse_attention_config"] = ( + config.model_dump() if hasattr(config, "model_dump") else vars(config) + ) + + +def disable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): + """Disable sparse attention for matching modules. + + Similar to mtq.disable_quantizer for API consistency. + + Args: + model: Model with sparse attention applied + wildcard_or_filter_func: Wildcard string or filter function to match module names. + For example: "*lm_head*", "*layer_0*", etc. + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> # Disable sparse attention for lm_head + >>> sparse_attn.disable_sparse_attention(model, "*lm_head*") + """ + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + matched = False + if isinstance(wildcard_or_filter_func, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter_func) + elif callable(wildcard_or_filter_func): + matched = wildcard_or_filter_func(name) + + if matched: + module.disable() + + +def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): + """Enable sparse attention for matching modules. + + Similar to mtq.enable_quantizer for API consistency. + + Args: + model: Model with sparse attention applied + wildcard_or_filter_func: Wildcard string or filter function to match module names. + For example: "*attention*", "*attn*", etc. + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> # Re-enable sparse attention for all attention modules + >>> sparse_attn.enable_sparse_attention(model, "*attention*") + """ + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + matched = False + if isinstance(wildcard_or_filter_func, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter_func) + elif callable(wildcard_or_filter_func): + matched = wildcard_or_filter_func(name) + + if matched: + module.enable() + + +def print_sparse_attention_summary(model: nn.Module): + """Print summary of sparse attention modules in the model. + + Similar to mtq.print_quant_summary for API consistency. + + Args: + model: Model with sparse attention applied + + Prints: + - Total sparse attention modules + - Enabled vs disabled count + - Method distribution + - Configuration summary by module + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> sparse_attn.print_sparse_attention_summary(model) + """ + sparse_modules = [] + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + sparse_modules.append((name, module)) + + if not sparse_modules: + print("No sparse attention modules found in model") + return + + enabled_count = sum(1 for _, m in sparse_modules if m.is_enabled) + disabled_count = len(sparse_modules) - enabled_count + + # Count methods + method_counts = {} + for _, module in sparse_modules: + method = getattr(module, "_method", "unknown") + method_counts[method] = method_counts.get(method, 0) + 1 + + print(f"\n{'=' * 70}") + print(f"{'Sparse Attention Summary':^70}") + print(f"{'=' * 70}") + print(f"Total sparse attention modules: {len(sparse_modules)}") + print(f" Enabled: {enabled_count}") + print(f" Disabled: {disabled_count}") + + if method_counts: + print("\nMethods:") + for method, count in sorted(method_counts.items()): + print(f" {method}: {count}") + + print(f"\n{'Module Details':^70}") + print(f"{'-' * 70}") + + for name, module in sparse_modules: + status = "✓" if module.is_enabled else "✗" + method = getattr(module, "_method", "unknown") + threshold = getattr(module, "_threshold", "N/A") + + # Format threshold nicely + if isinstance(threshold, dict): + threshold_str = str(threshold) + elif isinstance(threshold, float): + threshold_str = f"{threshold:.2e}" + else: + threshold_str = str(threshold) + + print(f"{status} {name}") + print(f" Method: {method}, Threshold: {threshold_str}") + + print(f"{'=' * 70}\n") diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py new file mode 100644 index 000000000..5120bd755 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse attention methods package.""" + +from .registry import SparseAttentionMethod, get_sparse_method, register_sparse_method + +__all__ = [ + "SparseAttentionMethod", + "get_sparse_method", + "register_sparse_method", +] + +# Import method implementations to trigger registration +from . import flash_softmax_skip diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py new file mode 100644 index 000000000..04b696d11 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py @@ -0,0 +1,289 @@ +"""Flash Attention-aware softmax skip method for sparse attention. + +This module implements block-wise sparsity that aligns with Flash Attention's +processing pattern for optimal performance. +""" + +import math + +import numpy as np +import torch + +from . import SparseAttentionMethod, register_sparse_method + + +@register_sparse_method("flash_softmax_skip") +class FlashSoftmaxSkipMethod(SparseAttentionMethod): + """Flash Attention-aware softmax skip sparse attention method. + + Implements row-level block-wise sparsity aligned with Flash Attention's + processing pattern for optimal performance and accuracy. + """ + + def __init__(self, method_config: dict | None = None): + """Initialize Flash softmax skip method. + + Args: + method_config: Configuration dict with threshold, br, bc, is_causal, etc. + """ + config = method_config or {} + + # Extract configuration + self.threshold_config = config.get("threshold", 1e-4) + self.br = config.get("br", 128) + self.bc = config.get("bc", 128) + self.enable_correction_factor = config.get("enable_correction_factor", True) + self.collect_stats = config.get("collect_stats", True) + self.phase = config.get("phase", None) + self.backend = config.get("backend", "pytorch") + self.is_causal = config.get("is_causal", True) + # Calibration mode: when True, prevent threshold updates to preserve calibrator's test threshold + self._calibration_mode = False + + # Initialize threshold + if isinstance(self.threshold_config, dict): + self.threshold = self.threshold_config.get( + "default", self.threshold_config.get("prefill", 1e-4) + ) + else: + self.threshold = self.threshold_config + + def _update_threshold(self, phase: str): + """Update threshold based on phase.""" + if isinstance(self.threshold_config, dict): + self.threshold = self.threshold_config.get( + phase, self.threshold_config.get("default", self.threshold) + ) + + def set_calibration_mode(self, enabled: bool): + """Set calibration mode to prevent _update_threshold from modifying the threshold.""" + self._calibration_mode = enabled + + def _infer_phase(self, attention_scores: torch.Tensor) -> str: + """Infer phase from attention scores shape.""" + return "decode" if attention_scores.shape[2] == 1 else "prefill" + + def _reshape_to_blocks( + self, tensor: torch.Tensor, br: int, bc: int + ) -> tuple[torch.Tensor, ...]: + """Reshape tensor into blocks for Flash Attention processing. + + Args: + tensor: Input tensor of shape [batch, heads, seq_q, seq_k] + br: Block row size + bc: Block column size + + Returns: + Tuple of (blocked_tensor, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k) + """ + batch_size, num_heads, seq_q, seq_k = tensor.shape + + # Calculate padding needed + padded_seq_q = math.ceil(seq_q / br) * br + padded_seq_k = math.ceil(seq_k / bc) * bc + + # Pad tensor if necessary + if padded_seq_q != seq_q or padded_seq_k != seq_k: + pad_q = padded_seq_q - seq_q + pad_k = padded_seq_k - seq_k + # Use dtype min instead of -inf for numerical stability + pad_value = torch.finfo(tensor.dtype).min + tensor = torch.nn.functional.pad(tensor, (0, pad_k, 0, pad_q), value=pad_value) + + # Reshape to blocks + num_block_rows = padded_seq_q // br + num_block_cols = padded_seq_k // bc + + # Keep natural order for row-level processing: [batch, heads, block_rows, br, block_cols, bc] + blocked = tensor.view(batch_size, num_heads, num_block_rows, br, num_block_cols, bc) + + return blocked, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k + + def calc_correction_factor_and_p( + self, attn_weights: torch.Tensor, phase: str + ) -> tuple[torch.Tensor, dict]: + """Calculate sparse mask and statistics for Flash Attention. + + Implements block-wise sparsity compatible with Flash Attention's online softmax: + 1. Reshape attention scores into 128x128 blocks + 2. Track block-wise maximum values (simulating Flash Attention's row processing) + 3. Compute cumulative maximum across blocks (for online normalization) + 4. Apply threshold: mask blocks where p = score - cummax < log(threshold) + 5. Calculate correction factor and sparsity statistics + + Args: + attn_weights: Pre-softmax attention scores [batch, heads, seq_q, seq_k] + phase: "prefill" (seq_q > 1) or "decode" (seq_q = 1) + + Returns: + element_mask: Boolean mask [batch, heads, seq_q, seq_k] + stats: Dict with sparsity, correction_factor, total_blocks, etc. + """ + batch_size, num_heads, seq_q, seq_k = attn_weights.shape + + # Calculate threshold + threshold_scale_factor = getattr(self, "threshold_scale_factor", None) + if threshold_scale_factor: + # Use calibrated dynamic threshold: λ = scale_factor / length + log_threshold = np.log(threshold_scale_factor / seq_k) + else: + # Use static threshold from config + log_threshold = np.log(self.threshold) + + if phase == "prefill": + blocked_attn, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k = ( + self._reshape_to_blocks(attn_weights, self.br, self.bc) + ) + + # Step 1: Compute maximum value in each block + # For each 128x128 block, find max across the 128 columns + # blocked_attn: [batch, heads, block_rows, br=128, block_cols, bc=128] + # block_max: [batch, heads, block_rows, br=128, block_cols] + block_max = blocked_attn.max(dim=-1)[0] + + # Step 2: Track cumulative maximum across blocks (left to right) + # This simulates Flash Attention's online softmax normalization + # block_max_cummax: [batch, heads, block_rows, br=128, block_cols] + block_max_cummax = block_max.cummax(dim=-1)[0] + + # Step 3: Calculate correction factor (how often max changes) + # Used by Flash Attention to adjust running sum when max increases + block_max_larger = torch.ones_like(block_max) + block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] + correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + + # Step 4: Normalize attention scores by cumulative max + # p represents log-space difference: log(score) - log(cummax) + p = blocked_attn - block_max_cummax[..., None] + + # Step 5: Apply threshold and create block-level mask + # Keep blocks where at least one element exceeds log(threshold) + p_larger_than_thresh = p > log_threshold + # Reduce over bc (128 cols), then br (128 rows) to get block-level decision + # Result: [batch, heads, block_rows, block_cols] + block_mask = p_larger_than_thresh.any(dim=-1).any(dim=-2) + + # Step 6: Expand block mask back to element level + # All 128x128 elements in a block share the same mask value + # [batch, heads, block_rows, block_cols] -> [batch, heads, block_rows, br=128, block_cols, bc=128] + element_mask = block_mask.unsqueeze(-2).unsqueeze(-1).expand_as(blocked_attn) + + # Step 7: Reshape to original attention shape and remove padding + element_mask = element_mask.reshape(batch_size, num_heads, padded_seq_q, padded_seq_k) + element_mask = element_mask[:, :, :seq_q, :seq_k] + + # Step 8: Calculate sparsity statistics + # Count kept blocks (averaged across batch and heads) + kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + + # Total valid blocks (lower triangle only for causal attention) + # Note: Causal mask pre-applied by attention module, so block_mask naturally + # has zeros in upper triangle. We only count lower triangle for denominator. + total_blocks = ( + num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2 + if self.is_causal + else num_block_rows * num_block_cols # Non-causal: N*N + ) + sparsity = 1 - (kept_blocks / total_blocks) + else: # decode + blocked_attn, _, num_block_cols, _, padded_seq_k = self._reshape_to_blocks( + attn_weights, 1, self.bc + ) + + # Decode: Single query row attends to all past key blocks + # blocked_attn: [batch, heads, 1, 1, num_block_cols, bc=128] + + # Step 1: Find maximum in each key block + # block_max: [batch, heads, 1, 1, num_block_cols] + block_max = blocked_attn.max(dim=-1)[0] + + # Step 2: Track cumulative maximum across key blocks (left to right) + # Simulates Flash Attention's online softmax normalization + block_max_cummax = block_max.cummax(dim=-1)[0] + + # Step 3: Calculate correction factor + # Tracks how often the maximum increases (needed for Flash Attention rescaling) + block_max_larger = torch.ones_like(block_max) + block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] + correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + + # Step 4: Normalize scores by cumulative max + # p = log(score) - log(cummax) in log-space + p = blocked_attn - block_max_cummax[..., None] + + # Step 5: Apply threshold and create block mask + # Keep blocks where at least one element exceeds threshold + p_larger_than_thresh = p > log_threshold + block_mask = p_larger_than_thresh.any(dim=-1, keepdim=False) + + # Step 6: Expand to element level and remove padding + element_mask = block_mask[..., None].expand_as(blocked_attn) + element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) + element_mask = element_mask[:, :, :seq_q, :seq_k] + + # Step 7: Calculate statistics + kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + total_blocks = num_block_cols + sparsity = 1 - (kept_blocks / total_blocks) + + # Create stats dictionary + stats = { + "correction_factor": correction_factor if self.enable_correction_factor else 1.0, + "sparsity": sparsity, + "phase": phase, + "total_blocks": total_blocks, + "sparse_blocks": int(sparsity * total_blocks), + "sample_length": seq_k, + } + + return element_mask, stats + + def apply_sparsity( + self, + query: torch.Tensor | None = None, + key: torch.Tensor | None = None, + value: torch.Tensor | None = None, + attention_scores: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Apply Flash Attention-aware block-wise sparsity. + + Args: + query: Query tensor (unused, for API compatibility) + key: Key tensor (unused, for API compatibility) + value: Value tensor (unused, for API compatibility) + attention_scores: Attention scores tensor with shape [batch, heads, seq_q, seq_k] + + Returns: + Tuple with potentially modified attention_scores + """ + # Attention scores must be provided for sparse attention + assert attention_scores is not None, "attention_scores must be provided for apply_sparsity" + + # Attention scores are always 4D: [batch, heads, seq_q, seq_k] + assert len(attention_scores.shape) == 4, ( + f"Expected 4D attention scores, got shape {attention_scores.shape}" + ) + + # Infer phase from tensor shape + phase = self._infer_phase(attention_scores) + + # Update threshold for the detected phase (skip during calibration) + if not self._calibration_mode: + self._update_threshold(phase) + + # Apply block-wise sparsity + sparse_mask, stats = self.calc_correction_factor_and_p(attention_scores, phase) + + # Store stats for module to collect (doesn't persist across calls) + self._last_stats = stats + + # Apply mask to create sparse scores + mask_value = torch.finfo(attention_scores.dtype).min + sparse_scores = attention_scores.masked_fill(~sparse_mask, mask_value) + + return query, key, value, sparse_scores + + @property + def name(self) -> str: + """Method identifier.""" + return "flash_softmax_skip" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py new file mode 100644 index 000000000..081ad9e27 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registry and base class for sparse attention methods.""" + +from abc import ABC, abstractmethod + +import torch + + +class SparseAttentionMethod(ABC): + """Base class for sparse attention methods.""" + + @abstractmethod + def apply_sparsity( + self, + query: torch.Tensor | None = None, + key: torch.Tensor | None = None, + value: torch.Tensor | None = None, + attention_scores: torch.Tensor | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + """Apply sparsity to attention computation. + + Args: + query: Query tensor + key: Key tensor + value: Value tensor + attention_scores: Pre-computed attention scores + + Returns: + Tuple of (query, key, value, attention_scores) with sparsity applied + """ + + @property + @abstractmethod + def name(self) -> str: + """Method name identifier.""" + + +# Method Registry with versioning support +_SPARSE_ATTENTION_METHODS: dict[str, dict[str, type[SparseAttentionMethod]]] = {} + + +def register_sparse_method(name: str, version: str = "v1"): + """Decorator to register sparse attention methods with version support. + + Args: + name: Method name to register + version: Version string (default: "v1") + + Example: + @register_sparse_method("my_method", version="v3") + class MyMethodV3(SparseAttentionMethod): + ... + """ + + def decorator(cls: type[SparseAttentionMethod]): + if name not in _SPARSE_ATTENTION_METHODS: + _SPARSE_ATTENTION_METHODS[name] = {} + + if version in _SPARSE_ATTENTION_METHODS[name]: + import warnings + + warnings.warn( + f"Overriding existing sparse attention method: {name}@{version}", + RuntimeWarning, + stacklevel=2, + ) + + _SPARSE_ATTENTION_METHODS[name][version] = cls + return cls + + return decorator + + +def get_sparse_method(name: str, version: str | None = None) -> type[SparseAttentionMethod]: + """Get sparse attention method by name and optional version. + + Args: + name: Method name to retrieve + version: Optional version string. If None, uses latest version. + + Returns: + Method class + + Raises: + ValueError: If method name or version is not registered + + Example: + >>> get_sparse_method("flash_softmax_skip") # Latest version + >>> get_sparse_method("flash_softmax_skip", "v1") # Specific version + """ + if name not in _SPARSE_ATTENTION_METHODS: + available = list(_SPARSE_ATTENTION_METHODS.keys()) + raise ValueError(f"Unknown sparse attention method: {name}. Available: {available}") + + method_versions = _SPARSE_ATTENTION_METHODS[name] + + if not version: + version = sorted(method_versions.keys())[-1] + + if version not in method_versions: + available_versions = list(method_versions.keys()) + raise ValueError( + f"Unknown version {version} for method {name}. Available: {available_versions}" + ) + + return method_versions[version] diff --git a/modelopt/torch/sparsity/attention_sparsity/mode.py b/modelopt/torch/sparsity/attention_sparsity/mode.py new file mode 100644 index 000000000..f389509a5 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/mode.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse attention mode descriptor for ModelOpt.""" + +from modelopt.torch.opt.config import ModeloptBaseConfig +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ModeDescriptor, + RestoreEntrypoint, + UpdateEntrypoint, + _ModeRegistryCls, +) + +from .config import SparseAttentionConfig +from .conversion import ( + convert_to_sparse_attention_model, + restore_sparse_attention_model, + update_sparse_attention_metadata, +) + +# Create registry for sparse attention modes +SparseAttentionModeRegistry = _ModeRegistryCls("sparse_attention") + + +@SparseAttentionModeRegistry.register_mode +class SparseAttentionModeDescriptor(ModeDescriptor): + """Mode descriptor for sparse attention optimization. + + This mode enables various sparse attention methods to reduce + computational complexity and memory usage in transformer models. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "sparse_attention" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return SparseAttentionConfig + + @property + def next_prohibited_modes(self) -> set[str] | None: + """Modes that should not be applied after this mode.""" + # Can work with quantization but not with weight sparsity + return {"sparsity"} + + @property + def export_mode(self) -> str | None: + """The mode that corresponds to the export mode of this mode.""" + return "export_sparse_attention" + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return convert_to_sparse_attention_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_sparse_attention_model + + @property + def update_for_save(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the model's state before saving.""" + return update_sparse_attention_metadata + + @property + def update_for_new_mode(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the model's state before new mode.""" + return update_sparse_attention_metadata diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py new file mode 100644 index 000000000..908f3ad89 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -0,0 +1,197 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main API functions for sparse attention optimization.""" + +from typing import Any + +import torch + +from modelopt.torch.opt.conversion import ModeloptStateManager, apply_mode +from modelopt.torch.opt.searcher import ForwardLoop + +from .calibration import calibrate_sparse_attention +from .config import SparseAttentionConfig +from .mode import SparseAttentionModeRegistry + +__all__ = [ + "calibrate", + "sparsify", +] + + +def sparsify( + model: torch.nn.Module, + config: dict[str, Any] | SparseAttentionConfig, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Applies sparse attention optimization to the model in-place. + + This method performs replacement of attention modules with their sparse counterparts and + optionally performs calibration as specified by ``config``. + ``forward_loop`` is used to forward data through the model and gather statistics for calibration. + + Args: + model: A pytorch model + config: A dictionary or an instance of + :class:`SparseAttentionConfig ` + specifying the values for keys ``"sparse_cfg"``, ``"method"``, and optionally ``"calibration"``. + + The ``"sparse_cfg"`` key specifies the sparse attention configurations. + The ``"method"`` key specifies the sparse attention method (e.g., "softmax_skip"). + The ``"calibration"`` key specifies calibration settings if automatic threshold tuning is desired. + + Sparse attention configurations is a dictionary mapping wildcards or filter functions + to its sparse attention attributes. The wildcards or filter functions are matched + against the module names. The sparse attention attributes include ``"threshold"``, + ``"enable"``, and method-specific parameters. + + An example ``config`` dictionary is given below: + + .. code-block::python + + config = { + "method": "softmax_skip", + "sparse_cfg": { + # Phase-aware thresholds with backend selection and calibration + "*attention*": { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + "calibration": { # Optional: enables automatic threshold calibration + "target_sparse_ratio": 0.5, + "samples": 48, + "max_seqlen": 8192, + }, + }, + # Disable for specific layers + "*layer_0*": {"enable": False}, + # Default settings + "default": {"enable": False}, + }, + } + + The ``"backend"`` parameter must be set to ``"pytorch"``: + + - ``"pytorch"``: Softmax patching approach (only supported backend) + + This requires the model to be loaded with ``attn_implementation="eager"``. + + forward_loop: A callable that forwards all calibration data through the model. This is used + to gather statistics for calibration. It should take model as the argument. It does not need + to return anything. + + This argument is only required when calibration is enabled in the config. + + Here are a few examples for correct ``forward_loop`` definitions: + + Example 1: + + .. code-block:: + + def forward_loop(model) -> None: + # iterate over the data loader and forward data through the model + for batch in data_loader: + model(batch) + + Example 2: + + .. code-block:: + + def forward_loop(model) -> float: + # evaluate the model on the task + return evaluate(model, task, ....) + + .. note:: + + Calibration does not require forwarding the entire dataset through the model. + Please subsample the dataset or reduce the number of batches if needed. + + .. important:: + + The model must always be loaded with ``attn_implementation="eager"`` + for sparse attention to work correctly: + + .. code-block:: python + + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained( + model_path, + attn_implementation="eager", # Required for sparse attention + torch_dtype=torch.bfloat16, + ) + + This is because sparse attention works by patching torch.nn.functional.softmax, + which is only called in the eager attention implementation. + + Returns: + A pytorch model which has sparse attention applied and optionally calibrated. + """ + model = apply_mode( + model, mode=[("sparse_attention", config)], registry=SparseAttentionModeRegistry + ) + + # Calibrate the sparsity ratio of the attention modules + return calibrate(model, forward_loop=forward_loop) + + +def calibrate( + model: torch.nn.Module, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Calibrates sparse attention thresholds based on target sparsity. + + This function performs calibration to find optimal thresholds that achieve + the target sparsity ratio specified in the sparse attention configuration. + + Args: + model: A pytorch model with sparse attention already applied + forward_loop: Optional callable that forwards calibration data through the model. + It should take model as the argument and can optionally return metrics. + If None, will auto-generate RULER dataset for calibration. + + Returns: + The calibrated model with optimized sparse attention thresholds. + If no calibration is configured, returns the model unchanged. + """ + # Get the sparse attention config from the model's state + if not ModeloptStateManager.is_converted(model): + return model + + manager = ModeloptStateManager(model) + + sparse_attn_config = next( + (state["config"] for name, state in manager._state if name == "sparse_attention"), None + ) + + if sparse_attn_config is None: + return model + + # Check if calibration is configured in any sparse_cfg pattern + # Note: sparse_attn_config is always a dict (stored via config.model_dump()) + sparse_cfg = sparse_attn_config.get("sparse_cfg", {}) + + has_calibration = any( + isinstance(cfg, dict) and "calibration" in cfg for cfg in sparse_cfg.values() + ) + + if not has_calibration: + return model + + # Run calibration (handles stats collection internally) + calibrate_sparse_attention(model, sparse_attn_config, forward_loop=forward_loop) + + return model diff --git a/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py b/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py new file mode 100644 index 000000000..8a0d5ee71 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Neural network modules for sparse attention.""" + +from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from .stats_manager import SparseAttentionStatsManager + +__all__ = [ + "SparseAttentionModule", + "SparseAttentionRegistry", + "SparseAttentionStatsManager", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py new file mode 100644 index 000000000..a45931224 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible sparse attention module.""" + +import torch +import torch.nn.functional as F + +from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls +from modelopt.torch.quantization.utils import replace_function + +from ..config import SparseAttentionAttributeConfig +from ..methods import get_sparse_method +from .stats_manager import SparseAttentionStatsManager + + +class SparseAttentionModule(DynamicModule): + """Generic sparse attention module wrapper for applying sparsity to attention layers. + + This module wraps existing attention implementations to add sparse attention + capabilities by patching torch.nn.functional.softmax. + + Forward Flow: + ------------- + 1. Check if sparse attention is enabled (pass-through if disabled) + 2. Create softmax patch context with sparse_softmax function + 3. Apply sparse attention by patching F.softmax: + - Patches torch.nn.functional.softmax with sparse_softmax + - sparse_softmax applies method's sparsity logic before softmax + 4. Forward through original attention with sparsity applied + + Requirements: + ------------- + - Model must be loaded with attn_implementation="eager" for proper softmax interception + - Only PyTorch backend is supported (patches F.softmax) + + Attributes: + ----------- + _enabled: bool + Whether sparse attention is enabled + _method: str + The sparse attention method to use (e.g., "flash_softmax_skip") + _method_config: dict + Configuration dictionary for the sparse method (threshold, br, bc, etc.) + _sparse_method_instance: SparseAttentionMethod + Instance of the configured sparse attention method + """ + + def set_from_attribute_config( + self, attribute_cfg: SparseAttentionAttributeConfig | dict | None = None + ): + """Set sparse attention attributes from configuration. + + Similar to TensorQuantizer.set_from_attribute_config. + + Args: + attribute_cfg: Sparse attention attribute configuration. + If None, uses default SparseAttentionAttributeConfig. + """ + # Use default config if not provided + attribute_cfg = ( + attribute_cfg if attribute_cfg is not None else SparseAttentionAttributeConfig() + ) + + # Store raw config for method initialization + self._method_config = {} + + # Define which attributes are method-specific vs module-specific + # Module-specific attributes control the SparseAttentionModule behavior + _module_attributes = {"enable", "method"} + + # Custom setters for special module attributes + _custom_setters = { + "enable": ("_enabled", lambda val: bool(val)), + "method": ("_method", lambda val: str(val)), + } + + # Process each attribute from config + for attribute, val in attribute_cfg.items(): + # Validate attribute if using config class + if hasattr(SparseAttentionAttributeConfig, "model_fields"): + assert attribute in SparseAttentionAttributeConfig.model_fields, ( + f"{attribute} is not a valid SparseAttentionModule attribute" + ) + + if attribute in _module_attributes: + # Module-level attribute: store with underscore prefix + attr_name, setter = _custom_setters.get(attribute, (f"_{attribute}", lambda v: v)) + setattr(self, attr_name, setter(val)) + else: + # Method-specific attribute: store in config dict + self._method_config[attribute] = val + + # Initialize sparse method instance + self._init_sparse_method() + + def _init_sparse_method(self): + """Initialize the sparse method instance.""" + method_class = get_sparse_method(self._method) + + # Initialize the sparse method instance + # _method_config is always initialized in set_from_attribute_config + self._sparse_method_instance = method_class(method_config=self._method_config) # type: ignore[call-arg] + + def enable(self): + """Enable sparse attention for this module.""" + self._enabled = True + + def disable(self): + """Disable sparse attention for this module.""" + self._enabled = False + + @property + def is_enabled(self) -> bool: + """Check if sparse attention is enabled.""" + return getattr(self, "_enabled", True) + + def get_stats(self) -> dict: + """Get sparsity statistics from the stats manager. + + Returns: + Dictionary with sparsity statistics including 'average_sparsity' if available. + Returns empty dict if stats manager is not enabled. + """ + if self._stats_manager is not None and self._stats_manager.enabled: + return self._stats_manager.get_summary() + return {} + + def _setup(self): + """Setup called by DynamicModule.""" + # Apply default configuration if not yet configured + if not hasattr(self, "_method"): + self.set_from_attribute_config(None) + + # Create stats manager if stats collection is enabled + if self._method_config.get("collect_stats", False): + self._stats_manager = SparseAttentionStatsManager( + module_name="sparse_attention", enabled=True + ) + else: + self._stats_manager = None + + def forward(self, *args, **kwargs): + """Forward with selected sparse attention method. + + This method dispatches to the appropriate sparse attention implementation + based on the configured method and backend. + """ + # Pass through if sparse attention is disabled + if not self.is_enabled: + return super().forward(*args, **kwargs) + + # Get the appropriate context manager for this configuration + context = self._get_sparse_context() + + # Apply sparse attention through the context + with context: + result = super().forward(*args, **kwargs) + + # Collect stats if manager is available + if self._stats_manager is not None and hasattr(self._sparse_method_instance, "_last_stats"): + self._stats_manager.collect(self._sparse_method_instance._last_stats) + + return result + + def _get_sparse_context(self): + """Get the softmax patch context for applying sparse attention.""" + return self._create_softmax_patch_context() + + def _create_softmax_patch_context(self): + """Create context manager for patching softmax function.""" + return replace_function(torch.nn.functional, "softmax", self._create_sparse_softmax()) + + def _create_sparse_softmax(self): + """Create sparse softmax function for current method.""" + original_softmax = F.softmax + + def sparse_softmax(input, dim=-1, *args, **kwargs): + # Let the method handle the sparsification + _, _, _, sparse_input = self._sparse_method_instance.apply_sparsity( + None, None, None, input + ) + + # Use sparse input if modified, otherwise use original + if sparse_input is not None: + return original_softmax(sparse_input, dim, *args, **kwargs) + return original_softmax(input, dim, *args, **kwargs) + + return sparse_softmax + + +# Create registry for sparse attention modules +SparseAttentionRegistry = _DMRegistryCls("SparseAttention", SparseAttentionModule) diff --git a/modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py new file mode 100644 index 000000000..9862e6de4 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Statistics manager for sparse attention modules.""" + + +class SparseAttentionStatsManager: + """Centralized statistics manager for sparse attention. + + This class is the single source of truth for all statistics collection + in sparse attention modules. It handles both runtime aggregation and + per-sample calibration statistics. + + Design principles: + - Single responsibility: only stats management + - No computation: receives pre-computed stats from methods + - Optional: can be None if stats collection disabled + - Zero overhead when disabled + """ + + def __init__(self, module_name: str, enabled: bool = True): + """Initialize stats manager. + + Args: + module_name: Name of the module this manager is attached to + enabled: Whether stats collection is enabled + """ + self.module_name = module_name + self.enabled = enabled + self.calibration_mode = False + + # Aggregated stats (running totals across all forward passes) + self.aggregated_stats: dict = { + "total_calls": 0, + "total_blocks": 0, + "sparse_blocks": 0, + "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, + } + + # Per-sample stats (only populated during calibration) + self.per_sample_stats: list[dict] = [] + + def collect(self, stats: dict): + """Collect statistics from a single forward pass. + + Args: + stats: Dictionary containing statistics from method computation. + Expected keys: sparsity, phase, total_blocks, sparse_blocks, + sample_length (optional) + """ + if not self.enabled: + return + + # Update aggregated stats + self.aggregated_stats["total_calls"] += 1 + self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0) + self.aggregated_stats["sparse_blocks"] += stats.get("sparse_blocks", 0) + + phase = stats.get("phase", "unknown") + if phase in self.aggregated_stats["phase_counts"]: + self.aggregated_stats["phase_counts"][phase] += 1 + + # In calibration mode, store per-sample stats + if self.calibration_mode: + self.per_sample_stats.append( + { + "module": self.module_name, + "sparsity": stats.get("sparsity", 0.0), + "sample_length": stats.get("sample_length", 0), + "phase": phase, + } + ) + + def get_summary(self) -> dict: + """Get aggregated statistics summary. + + Returns: + Dictionary with module name, total calls, average sparsity, + and phase distribution. + """ + total_blocks = self.aggregated_stats["total_blocks"] + if total_blocks > 0: + avg_sparsity = self.aggregated_stats["sparse_blocks"] / total_blocks + else: + avg_sparsity = 0.0 + + return { + "module": self.module_name, + "total_calls": self.aggregated_stats["total_calls"], + "average_sparsity": avg_sparsity, + "phase_distribution": self.aggregated_stats["phase_counts"].copy(), + } + + def set_calibration_mode(self, enabled: bool, reset_history: bool = True): + """Enable or disable calibration mode. + + In calibration mode, per-sample statistics are stored for detailed + analysis. Otherwise, only aggregated stats are maintained. + + Args: + enabled: Whether to enable calibration mode + reset_history: Whether to clear per_sample_stats when enabling + """ + self.calibration_mode = enabled + if enabled and reset_history: + self.per_sample_stats = [] + + def reset(self): + """Reset all statistics to initial state.""" + self.aggregated_stats = { + "total_calls": 0, + "total_blocks": 0, + "sparse_blocks": 0, + "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, + } + self.per_sample_stats = [] + + def get_calibration_stats(self) -> list[dict]: + """Get per-sample calibration statistics. + + Returns: + List of per-sample statistics dictionaries. + Empty list if not in calibration mode. + """ + return self.per_sample_stats diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py new file mode 100644 index 000000000..ba8c8b821 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plugins for sparse attention integration with various frameworks.""" + +from .huggingface import register_sparse_attention_on_the_fly + +__all__ = [ + "register_sparse_attention_on_the_fly", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py new file mode 100644 index 000000000..0012257b6 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic sparse attention registration for HuggingFace models.""" + +import torch.nn as nn +import transformers + +from modelopt.torch.opt.dynamic import DynamicModule + +from ..nn.sparse_attention import SparseAttentionModule, SparseAttentionRegistry + + +class _GenericSparseAttention(SparseAttentionModule): + """Generic sparse attention that works with any HF attention module. + + This class provides a universal sparse attention wrapper that can + work with various transformer attention implementations. + """ + + def _setup(self): + """Setup sparse attention for any attention type. + + The base SparseAttentionModule handles detection and initialization. + """ + super()._setup() + + def get_attn_type(self, attn_module) -> type: + """Get the original attention type. + + Args: + attn_module: Attention module (possibly wrapped) + + Returns: + Original class type + """ + # If this is a DynamicModule, get the original class + if isinstance(attn_module, DynamicModule): + return attn_module.get_original_cls_by_level(level=0) + return type(attn_module) + + +def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: + """Dynamically register sparse attention for any model. + + This function automatically detects attention modules in the model + and registers them for sparse attention optimization. + + Args: + model: Model to process + + Returns: + True if any modules were registered + """ + if not _is_supported_model(model): + return False + + registered_count = 0 + attention_types = set() + + for name, module in model.named_modules(): + # Skip if already a sparse attention module + if isinstance(module, SparseAttentionModule): + continue + + # Check if this is an attention module by name + module_type = type(module) + type_name = module_type.__name__ + + # Common attention module patterns + is_attention = ( + "attention" in type_name.lower() + or type_name.endswith("Attention") + or type_name.endswith("SelfAttention") + ) + + if is_attention and module_type not in SparseAttentionRegistry: + # Register attention type + if module_type not in attention_types: + SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention) + attention_types.add(module_type) + registered_count += 1 + print(f"Registered {type_name} for sparse attention optimization") + + if registered_count > 0: + print(f"Dynamically registered {registered_count} attention module types for sparsity") + + return registered_count > 0 + + +def _is_supported_model(model: nn.Module) -> bool: + """Check if model is supported for sparse attention. + + Supports HuggingFace PreTrainedModel and any PyTorch model with attention modules. + + Args: + model: Model to check + + Returns: + True if model is supported + """ + # Check for HuggingFace PreTrainedModel + try: + if isinstance(model, transformers.PreTrainedModel): + return True + except ImportError: + pass + + # Support any PyTorch model with attention modules + return isinstance(model, nn.Module) diff --git a/modelopt/torch/sparsity/weight_sparsity/__init__.py b/modelopt/torch/sparsity/weight_sparsity/__init__.py new file mode 100644 index 000000000..3a3e4377f --- /dev/null +++ b/modelopt/torch/sparsity/weight_sparsity/__init__.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""API for weight sparsification algorithms.""" + +from . import mode, module, plugins + +# Explicitly expose commonly used items +from .mode import SparsityModeRegistry +from .module import SparseModule, SpDMRegistry +from .sparsification import * diff --git a/modelopt/torch/sparsity/weight_sparsity/config.py b/modelopt/torch/sparsity/weight_sparsity/config.py new file mode 100644 index 000000000..b88533885 --- /dev/null +++ b/modelopt/torch/sparsity/weight_sparsity/config.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default configurations for sparsity modes.""" + +from pydantic import create_model + +from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules + +from .module import SpDMRegistry + +SparseMagnitudeConfig = create_model( + "SparseMagnitudeConfig", + **get_kwargs_for_create_model_with_rules( + registry=SpDMRegistry, + default_rules={ + "nn.Linear": {"*": {}, "*lm_head*": None}, + "nn.Conv2d": {"*": {}, "*lm_head*": None}, + }, + doc='Configuration for the ``"sparse_magnitude"`` mode.', + ), +) + + +SparseGPTConfig = create_model( + "SparseGPTConfig", + **get_kwargs_for_create_model_with_rules( + registry=SpDMRegistry, + default_rules={ + "nn.Linear": {"*": {}, "*lm_head*": None}, + "nn.Conv2d": {"*": {}, "*lm_head*": None}, + }, + doc='Configuration for the ``"sparse_gpt"`` mode.', + ), +) + + +class ExportSparseConfig(ModeloptBaseConfig): + """Configuration (empty!) for the ``"export_sparse"`` mode.""" diff --git a/modelopt/torch/sparsity/weight_sparsity/magnitude.py b/modelopt/torch/sparsity/weight_sparsity/magnitude.py new file mode 100644 index 000000000..2896e7dd3 --- /dev/null +++ b/modelopt/torch/sparsity/weight_sparsity/magnitude.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Magnitude-base sparsity inspired by NVIDIA ASP (Automatic SParsity).""" + +import re +import warnings +from itertools import permutations + +import torch +import torch.nn as nn + +from .module import SparseModule +from .searcher import BaseSparseSearcher + + +def get_nmprune_info(pattern: str) -> tuple[bool, int, int]: + """Gets the n:m sparsity pattern information from a given string.""" + nm_prune = re.search(r"(\d+):(\d+) sparsity", pattern) + if nm_prune is not None: + n, m = map(int, nm_prune.groups()) + return nm_prune is not None, n, m + return False, 0, 0 + + +def fill(x): + """Calculates the ratio of non-zero elements in a tensor.""" + return float(x.nonzero().size(0)) / torch.numel(x) + + +def reshape_1d(matrix, m): + """Reshapes a given matrix into m-dimensional vectors: (h,w) -> (hw/m, m).""" + if matrix.shape[1] % m > 0: + new_cols = matrix.shape[1] + (m - matrix.shape[1] % m) + mat = matrix.new_empty(matrix.shape[0], new_cols).fill_(0) + mat[:, : matrix.shape[1]] = matrix + + return mat.view(-1, m), mat.shape + else: + return matrix.view(-1, m), matrix.shape + + +def compute_valid_1d_patterns(m, n): + """Computes all possible m:n patterns in a 1D vector. + + The function generates a tensor of size m with n ones and (m-n) zeros. + It then generates all permutations of this tensor, removes duplicates, + and returns the unique patterns as a tensor. + """ + patterns = torch.zeros(m) + patterns[:n] = 1 + valid_patterns = torch.tensor(list(set(permutations(patterns.tolist())))) + return valid_patterns + + +def mn_1d_best(matrix, m, n): + """Finds the best m:n pattern in a given matrix. + + The function computes all possible m:n patterns and selects the one + that maximizes the sum of non-masked weights in the matrix. The selected + pattern is then used to create a mask for the matrix. + """ + patterns = compute_valid_1d_patterns(m, n).to(matrix.device) + + # Find the best m:n pattern (sum of non-masked weights). + mask = torch.IntTensor(matrix.shape).fill_(1).view(-1, m) + mat, _ = reshape_1d(matrix, m) + pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1) + mask[:] = patterns[pmax[:]] + mask = mask.view(matrix.shape) + return mask + + +def m4n2_1d(mat): + """Finds the best 2:4 pattern in a given matrix.""" + return mn_1d_best(mat, 4, 2) + + +def create_asp_mask(tensor: nn.Parameter, pattern: str) -> torch.BoolTensor: + """Creates a mask for a given tensor based on a specified sparse pattern. + + The function reshapes the tensor and applies the specified pattern to create a sparse mask. + The default pattern is m4n2_1d, which finds the best 2:4 sparsity pattern in the tensor. + """ + pattern_method_lut = {BaseSparseSearcher._pattern_2_4: m4n2_1d} + if pattern not in pattern_method_lut: + raise NotImplementedError(f"Unsupported pattern {pattern} for ASP sparsity") + func = pattern_method_lut[pattern] + + shape = tensor.shape + tensor.type() + t = tensor.float().contiguous() + + # 1d-tensor + if len(shape) == 1: + t = t.view(1, shape[0]) + mask = func(t) + # 2d-tensor (K, C) + elif len(shape) == 2: + # linear + t = t.view(shape[0], shape[1]) + mask = func(t) + # 3d-tensor (K, C, R) + elif len(shape) == 3: + # 1d convs + t = t.permute(0, 2, 1).contiguous().view(shape[0] * shape[2], shape[1]) + mask = func(t) + mask = mask.view(shape[0], shape[2], shape[1]).permute(0, 2, 1).contiguous() + # 4d-tensor (K, C, R, S) + elif len(shape) == 4: + # 2d convs + t = t.permute(2, 3, 0, 1).contiguous().view(shape[2] * shape[3] * shape[0], shape[1]) + mask = func(t) + mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2, 3, 0, 1).contiguous() + + return mask.view(shape).to(dtype=torch.bool) + + +class MagnitudeSearcher(BaseSparseSearcher): + """Searcher for magnitude-based sparsity.""" + + def _check_weight_size(self, weight: torch.nn.Parameter, mod_name: str) -> bool: + """Check if the weight size is supported.""" + # rules from ASP + if weight.size(0) % 8 != 0 or weight.size(1) % 16 != 0: + warnings.warn( + f"Skipping sparsifying {mod_name} of size={weight.size()!s} and" + f" type={weight.dtype!s} for sparsity" + ) + return False + + return True + + def _compute_mask(self, module: SparseModule) -> torch.BoolTensor: + """Compute the mask (and weight update) for the given module.""" + return create_asp_mask(module.weight, self.config["pattern"]) diff --git a/modelopt/torch/sparsity/weight_sparsity/mode.py b/modelopt/torch/sparsity/weight_sparsity/mode.py new file mode 100644 index 000000000..db7e1f332 --- /dev/null +++ b/modelopt/torch/sparsity/weight_sparsity/mode.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparsity mode descriptor.""" + +from torch import nn + +from modelopt.torch.opt.config import ModeloptBaseConfig +from modelopt.torch.opt.conversion import ApplyModeError +from modelopt.torch.opt.dynamic import DynamicSpace +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ConvertReturnType, + MetadataDict, + ModeDescriptor, + RestoreEntrypoint, + UpdateEntrypoint, + _ModeRegistryCls, +) +from modelopt.torch.opt.searcher import BaseSearcher +from modelopt.torch.utils import compare_dict, unwrap_model + +from .config import ExportSparseConfig, SparseGPTConfig, SparseMagnitudeConfig +from .magnitude import MagnitudeSearcher +from .module import SpDMRegistry +from .sparsegpt import SparseGPTSearcher + +SparsityModeRegistry = _ModeRegistryCls("sparsity") + + +def convert_sparse_model(model: nn.Module, config: ModeloptBaseConfig) -> ConvertReturnType: + """Function for converting a model to a sparsity meta-model.""" + # we use the search space utility here with a custom registry to convert the model + dynamic_space = DynamicSpace(model) + dynamic_space.convert_to_dynamic(config.model_dump(), SpDMRegistry) + + return dynamic_space.model, {"subnet_config": DynamicSpace(model).config()} + + +def restore_sparse_model( + model: nn.Module, config: ModeloptBaseConfig, metadata: MetadataDict +) -> nn.Module: + """Function for restoring a previously convert model to a sparsity meta-model.""" + model, _ = convert_sparse_model(model, config) + + if "subnet_config" in metadata: + DynamicSpace(model).select(metadata["subnet_config"]) + + return model + + +def update_sparse_metadata( + model: nn.Module, config: ModeloptBaseConfig, metadata: MetadataDict +) -> None: + """Update subnet config to current subnet config of model.""" + metadata["subnet_config"] = DynamicSpace(model).config() + + +def export_sparse(model: nn.Module, config: ExportSparseConfig) -> ConvertReturnType: + """Export a sparse model to a regular model.""" + # sanity check to avoid DP/DDP here in the entrypoint + model = unwrap_model(model, raise_error=True) + + # store config from model if we can find it for a future convert/restore process + metadata = {"subnet_config": DynamicSpace(model).config()} + + # export model in-place + model = DynamicSpace(model).export(SpDMRegistry) + + return model, metadata + + +def restore_export_sparse( + model: nn.Module, config: ExportSparseConfig, metadata: MetadataDict +) -> nn.Module: + """Restore & export a sparse model to a regular model.""" + # select activated/deactivated sparse modules + DynamicSpace(model).select(metadata["subnet_config"]) + + # run export + model, metadata_new = export_sparse(model, config) + + # double check metadata + unmatched_keys = compare_dict(metadata, metadata_new) + if unmatched_keys: + raise ApplyModeError(f"Unmatched metadata={unmatched_keys}!") + + return model + + +@SparsityModeRegistry.register_mode +class SparseMagnitudeModeDescriptor(ModeDescriptor): + """Class to define and describe magnitude-based sparsification.""" + + @property + def name(self) -> str: + """Returns the name of the mode.""" + return "sparse_magnitude" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return SparseMagnitudeConfig + + @property + def next_modes(self) -> set[str] | None: + """Specifies the next modes for the mode.""" + return {"export_sparse", "kd_loss", "quantize"} + + @property + def export_mode(self) -> str | None: + """The mode that corresponds to the export mode of this mode.""" + return "export_sparse" + + @property + def search_algorithm(self) -> type[BaseSearcher]: + """Specifies the search algorithm for the mode.""" + return MagnitudeSearcher + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return convert_sparse_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_sparse_model + + @property + def update_for_save(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the models metadata.""" + return update_sparse_metadata + + @property + def update_for_new_mode(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the models metadata.""" + return update_sparse_metadata + + +@SparsityModeRegistry.register_mode +class SparseGPTModeDescriptor(SparseMagnitudeModeDescriptor): + """Class to define and describe sparsification based on SparseGPT.""" + + @property + def name(self) -> str: + """Returns the name of the mode.""" + return "sparsegpt" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return SparseGPTConfig + + @property + def search_algorithm(self) -> type[BaseSearcher]: + """Specifies the search algorithm for the mode.""" + return SparseGPTSearcher + + +@SparsityModeRegistry.register_mode +class ExportSparseModeDescriptor(ModeDescriptor): + """Class to describe the ``"export_sparse"`` mode. + + The properties of this mode can be inspected via the source code. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "export_sparse" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return ExportSparseConfig + + @property + def is_export_mode(self) -> bool: + """Specifies if this mode is an export mode.""" + return True + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return export_sparse + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_export_sparse diff --git a/modelopt/torch/sparsity/weight_sparsity/module.py b/modelopt/torch/sparsity/weight_sparsity/module.py new file mode 100644 index 000000000..65e04a184 --- /dev/null +++ b/modelopt/torch/sparsity/weight_sparsity/module.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic class for all sparse modules.""" + +import torch +from torch import nn + +from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls +from modelopt.torch.opt.hparam import Hparam + +__all__ = ["SpDMRegistry", "SparseModule"] + + +SpDMRegistry = _DMRegistryCls(prefix="Sparse") # global instance for the sparsity registry + + +@SpDMRegistry.register({nn.Linear: "nn.Linear", nn.Conv2d: "nn.Conv2d"}) +class SparseModule(DynamicModule): + """Base dynamic class for all sparse modules.""" + + @staticmethod + def _get_weight(mod: "SparseModule", weight: torch.Tensor) -> torch.Tensor: + if mod.is_sparse and mod._weight_mask is not None: + masked_weight = weight * mod._weight_mask + # Quick workaround for the custom attribute for Megatron. + # TODO: maybe we need a more general way for customized attributes + if hasattr(weight, "main_grad"): + masked_weight.main_grad = weight.main_grad + return masked_weight + return weight + + def _setup(self): + # define hparam to check if sparsity is activated + hp = Hparam([0, -1], original=0) + hp.active = 0 + self._register_hparam("is_sparse", hp) + + # define the sparse mask here (don't pre-allocate memory to maximize memory savings) + self._register_temp_attribute("_weight_mask", None, lambda m, n, v: m.register_buffer(n, v)) + + # register dynamic attributes of the class + self._register_dynamic_attribute("weight", self._get_weight) + + def modify(self, *args, **kwargs): + """Initialize the sparsity mask when this is called. + + Note that for any module that is not frozen via ``None`` in the rules, this function will be + called. Hence, we use this function to initialize the sparsity mask only when necessary. + """ + hp = self.get_hparam("is_sparse") + if -1 in hp.choices and self._weight_mask is None: + hp.active = -1 + self._weight_mask = torch.ones_like(self.weight, dtype=torch.bool) + + def set_mask(self, value: torch.BoolTensor | None): + """Set the active sparse mask of the module weights.""" + if value is None: + self._weight_mask = None + return + + # sanity checks on the mask + w_shape = self.weight.shape + assert value is not None, "Mask cannot be None." + assert value.shape == w_shape, f"Mask must have shape {w_shape}, got {value.shape} instead." + assert value.dtype == torch.bool, f"Mask must be of type torch.bool, but got {value.dtype}." + + # assign mask + with torch.no_grad(): + if torch.all(value): + self._weight_mask = None + elif self._weight_mask is None: + self._weight_mask = value.detach().clone().to(self.weight.device) + else: + self._weight_mask.copy_(value.to(self._weight_mask.device)) diff --git a/modelopt/torch/sparsity/weight_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/weight_sparsity/plugins/__init__.py new file mode 100644 index 000000000..94cc09e59 --- /dev/null +++ b/modelopt/torch/sparsity/weight_sparsity/plugins/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Handles sparsity plugins for third-party modules. + +Currently, we support plugins for + +- :meth:`megatron` + +""" + +from modelopt.torch.utils import import_plugin + +with import_plugin("megatron"): + from .megatron import * diff --git a/modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py b/modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py new file mode 100644 index 000000000..c566db5c4 --- /dev/null +++ b/modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Support sparsify and save/resore for Megatron.""" + +import megatron.core.transformer.mlp as megatron_mlp +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint + +from modelopt.torch.opt.plugins.megatron import _MegatronMLP + +from ..config import SparseGPTConfig, SparseMagnitudeConfig +from ..module import SparseModule, SpDMRegistry + + +class _MegatronParallelLinear(SparseModule): + def _get_shard_axis_dict(self, state_dict): + raise NotImplementedError + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets) + + sparse_state_dict = { + k: v + for k, v in self.state_dict(prefix="", keep_vars=True).items() + if k == "_weight_mask" + } + + sharded_axis_dict = self._get_shard_axis_dict(sparse_state_dict) + + if sparse_state_dict: + sharded_state_dict.update( + **make_sharded_tensors_for_checkpoint( + sparse_state_dict, prefix, sharded_axis_dict, sharded_offsets + ) + ) + + return sharded_state_dict + + +@SpDMRegistry.register( + {ColumnParallelLinear: "megatron.core.tensor_parallel.layers.ColumnParallelLinear"} +) +class _MegatronColumnParallelLinear(_MegatronParallelLinear): + def _get_shard_axis_dict(self, state_dict): + return {"_weight_mask": 0} + + +@SpDMRegistry.register( + {RowParallelLinear: "megatron.core.tensor_parallel.layers.RowParallelLinear"} +) +class _MegatronRowParallelLinear(_MegatronParallelLinear): + def _get_shard_axis_dict(self, state_dict): + return {"_weight_mask": 1} + + +@SpDMRegistry.register({megatron_mlp.MLP: "megatron.core.transformer.mlp.MLP"}) +class _SparseMegatronMLP(_MegatronMLP): + """Module to support special handling of `linear_fc1` in `sharded_state_dict()` of MCore `MLP`.""" + + _modelopt_state_keys = [r"\._weight_mask$"] + + +def _get_extra_rules(): + """Get the extra rules for megatron.""" + return { + "megatron.core.tensor_parallel.layers.ColumnParallelLinear": { + "*": {}, + "*output_layer*": None, + }, + "megatron.core.tensor_parallel.layers.RowParallelLinear": { + "*": {}, + "*output_layer*": None, + }, + "megatron.core.transformer.mlp.MLP": {}, + } + + +# Update the default rules +SparseMagnitudeConfig.register_default(_get_extra_rules()) +SparseGPTConfig.register_default(_get_extra_rules()) diff --git a/modelopt/torch/sparsity/weight_sparsity/searcher.py b/modelopt/torch/sparsity/weight_sparsity/searcher.py new file mode 100644 index 000000000..15281cd03 --- /dev/null +++ b/modelopt/torch/sparsity/weight_sparsity/searcher.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Searcher interface for sparsity algorithms.""" + +from abc import abstractmethod +from collections.abc import Iterator + +import torch +import torch.nn as nn + +from modelopt.torch.opt.searcher import BaseSearcher, SearchConfig, SearchStateDict +from modelopt.torch.utils import print_rank_0 + +from . import magnitude as asp +from .module import SparseModule + + +class BaseSparseSearcher(BaseSearcher): + """A generic sparse mask searching algorithm.""" + + _pattern_2_4 = "2:4 sparsity" + + @property + def default_search_config(self) -> SearchConfig: + """Get the default config for the searcher.""" + return {**super().default_search_config, "pattern": self._pattern_2_4} + + @property + def default_state_dict(self) -> SearchStateDict: + """Return default state dict.""" + return {} + + def sanitize_search_config(self, config: SearchConfig | None) -> SearchStateDict: + """Sanitize the search config dict.""" + config_sanitized = super().sanitize_search_config(config) + + # sanity check of sparsity format + is_nm_prune, n, m = asp.get_nmprune_info(config_sanitized["pattern"]) + assert is_nm_prune and n == 2 and m == 4, ( + f"Unsupported pattern {self.config['pattern']} for sparsity" + ) + + return config_sanitized + + @abstractmethod + def _check_weight_size(self, weight, mod_name) -> bool: + """Check if the weight size is supported by the algorithm.""" + raise NotImplementedError + + @abstractmethod + def _compute_mask(self, module: SparseModule) -> torch.BoolTensor: + """Compute the mask and update weight for a given module.""" + raise NotImplementedError + + def _named_sparsifiable_modules(self) -> Iterator[tuple[str, nn.Module]]: + """Get the named sparsifiable modules.""" + for name, module in self.model.named_modules(): + if ( + isinstance(module, SparseModule) + and module.is_sparse + and self._check_weight_size(module.weight, name) + ): + yield name, module + + def run_search(self): + """Search for sparse mask.""" + for name, module in self._named_sparsifiable_modules(): + # compute the mask (and potentially weight update inside compute_mask) + print_rank_0(f"Searching for sparse mask and weight update for module {name}.") + with torch.no_grad(): + module.set_mask(self._compute_mask(module)) diff --git a/modelopt/torch/sparsity/weight_sparsity/sparsegpt.py b/modelopt/torch/sparsity/weight_sparsity/sparsegpt.py new file mode 100644 index 000000000..3b8a25ca6 --- /dev/null +++ b/modelopt/torch/sparsity/weight_sparsity/sparsegpt.py @@ -0,0 +1,276 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions of SparseGPT.""" + +import math +import warnings + +import torch +import torch.nn as nn + +from modelopt.torch.opt.searcher import SearchConfig +from modelopt.torch.utils import print_rank_0 + +from .magnitude import get_nmprune_info +from .module import SparseModule +from .searcher import BaseSparseSearcher + + +def invert(hessian: torch.Tensor) -> torch.Tensor: + """Invert a Hessian matrix.""" + try: + hessian_inv = torch.linalg.cholesky(hessian) + hessian_inv = torch.cholesky_inverse(hessian_inv) + hessian_inv = torch.linalg.cholesky(hessian_inv, upper=True) + except RuntimeError: + cols = hessian.size(0) + eps = 1e-6 * torch.eye(cols).to(hessian.device) + hessian_inv = torch.cholesky_inverse(torch.linalg.cholesky(hessian + eps)) + + return hessian_inv + + +def prepare( + tensor: torch.Tensor, hessian: torch.Tensor, hessian_damp: float +) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare the inverse Hessian matrix.""" + weight = tensor.detach().clone() + # move the hessian matrix from CPU to GPU for acceleration + hessian = hessian.to(weight.device) + if len(weight.size()) == 4: + weight = weight.flatten(1) + + zero = torch.diag(hessian) == 0 + hessian[zero, zero] = 1 + weight[:, zero] = 0 + + damp = hessian_damp * torch.mean(torch.diag(hessian)) + cols = weight.size(1) + diag = torch.arange(cols) + hessian[diag, diag] += damp + + hessian_inv = invert(hessian) + + # remove the Hessian matrix to save GPU memory + del hessian + torch.cuda.empty_cache() + + return weight, hessian_inv + + +@torch.no_grad() +def create_sgpt_mask( + tensor: torch.Tensor, hessian: torch.Tensor, config: SearchConfig +) -> torch.Tensor: + """Create a sparse mask for the given tensor.""" + shape = tensor.size() + weight, hessian_inv = prepare(tensor, hessian, config["hessian_damp"]) + rows, cols = weight.size() + hessian_inv_diag = torch.diagonal(hessian_inv, dim1=0, dim2=1) + + is_nm_prune, n, m = get_nmprune_info(config["pattern"]) + col_bs = config["col_block_size"] + row_bs = config["row_block_size"] + # if row_bs is not specified, prune the whole weight block + if row_bs == -1: + row_bs = rows + + for r1 in range(0, rows, row_bs): + r2 = min(r1 + row_bs, rows) + # the mask of the weights not to be pruned + w_rows = weight[r1:r2].float() + + # pruning the weight block W[row:row+row_bs, i1:i1+col_bs] + for i1 in range(0, cols, col_bs): + i2 = min(i1 + col_bs, cols) + w_blk = w_rows[:, i1:i2].clone() + q_blk = torch.zeros_like(w_blk) + # the error of the weights to be pruned + delta_blk = torch.zeros_like(w_blk) + hinv_blk = hessian_inv[i1:i2, i1:i2] + hinv_diag_blk = hessian_inv_diag[i1:i2] + + errors_blk = (w_blk**2) / (hinv_diag_blk**2 + 1e-9) + if torch.isnan(errors_blk).any(): + print("nan in errors_blk.") + + mask_blk = torch.zeros_like(w_blk, dtype=torch.bool) + + for j in range(i2 - i1): + # compute the error of the weights to be pruned + w = w_blk[:, j] + d = hinv_diag_blk[j] + if is_nm_prune and j % m == 0: + errors_blk = (w_blk[:, j : j + m] ** 2) / (hinv_diag_blk[j : j + m] ** 2 + 1e-9) + mask_blk.scatter_( + 1, j + torch.topk(errors_blk, n, dim=1, largest=False)[1], True + ) + + q = w.clone() + q[mask_blk[:, j]] = 0 + q_blk[:, j] = q + + # update the remaining weights in the col_bs block to compensate the error caused by pruning W[:, j] + err = (w - q) / d + w_blk[:, j:] -= err.unsqueeze(1).matmul(hinv_blk[j, j:].unsqueeze(0)) + delta_blk[:, j] = err + + # compensate the error caused by pruning W[:, i: i + col_bs] with the weights update in W[:, i + col_bs:] + w_rows[:, i1:i2] = q_blk + w_rows[:, i2:] -= delta_blk.matmul(hessian_inv[i1:i2, i2:]) + if torch.isnan(w_rows[:, i2:]).any(): + print("nan") + + weight[r1:r2] = w_rows + + mask = weight != 0 + + return mask.view(shape) + + +class SparseGPTSearcher(BaseSparseSearcher): + """SparseGPT-based sparse mask searching algorithm.""" + + @property + def default_search_config(self) -> SearchConfig: + """Get the default config for the searcher.""" + return { + **super().default_search_config, + "col_block_size": 128, # column block size in sparsegpt + "row_block_size": -1, # row block size in sparsegpt + "hessian_damp": 0.1, # hessian damp in sparsegpt + "calib_size": 256, # calibration size for hessian matrix calculation + "device": "cuda", # device of hessian matrix + } + + def _check_weight_size(self, weight, mod_name) -> bool: + """Check if the weight size is supported by SparseGPT.""" + _, _, m = get_nmprune_info(self.config["pattern"]) + + # the column size must be divisible by m + if weight.size(0) % m != 0 or weight.size(1) % m != 0: + warnings.warn( + f"Skipping pruning {mod_name} of size={weight.size()!s} and" + f" type={weight.dtype!s} for SparseGPT" + ) + return False + + return True + + def _compute_mask(self, module: SparseModule) -> torch.BoolTensor: + """Compute the mask (and weight update) for the given module.""" + return create_sgpt_mask(module.weight, module.hessian, self.config) + + @torch.no_grad() + def before_search(self): + """Register the forward hook to collect the hessian matrix.""" + super().before_search() + + handles = [] + for _, module in self._named_sparsifiable_modules(): + # setup and register the forward hook + self._setup_forward_hook(module) + handles.append(module.register_forward_hook(self._hook_compute_hessian)) + + print_rank_0(f"Collecting Hessian statistics for {len(handles)} modules.") + + # run a forward loop to collect the hessian matrix + assert self.forward_loop is not None, "Please provide `data_loader` or `forward_loop`!" + self.forward_loop(self.model) + + # remove the forward hooks + for handle in handles: + handle.remove() + + def after_search(self): + """Remove Hessian artifacts from network.""" + super().after_search() + for _, module in self._named_sparsifiable_modules(): + del module.hessian + del module.samples + + @staticmethod + def _is_memory_sufficient(device_id, threshold): + """Check if the memory usage on the CUDA device is below the threshold.""" + total_memory = torch.cuda.get_device_properties(device_id).total_memory + allocated_memory = torch.cuda.memory_allocated(device_id) + free_memory = total_memory - allocated_memory + return free_memory / total_memory > (1 - threshold) + + @classmethod + def _setup_forward_hook(cls, mod: SparseModule) -> None: + """Setup the attributes we need for our forward hook during the SparseGPT search.""" + # initialize the hessian matrix + if isinstance(mod, nn.Conv2d): + # For conv2d layers, the hessian matrix is calculated as X * X^T, where X is the + # flattened weight matrix. + cols = mod.weight.size(1) * mod.weight.size(2) * mod.weight.size(3) + else: + # For linear layers, the hessian matrix is calculated as X * X^T, where X is the + # weight matrix. + cols = mod.weight.size(1) + + target_device = mod.weight.device + # Hessian matrix is stored in the GPU memory by default + if target_device.type == "cuda" and cls._is_memory_sufficient(target_device.index, 0.8): + hessian = torch.zeros((cols, cols), dtype=torch.float32).to(target_device) + else: + hessian = torch.zeros((cols, cols), dtype=torch.float32).to("cpu") + + # store the hessian matrix and the number of samples + # TODO: this should probably be improved eventually!! + mod.hessian = hessian + mod.samples = 0 + + @classmethod + def _hook_compute_hessian(cls, mod: nn.Module, inp: torch.Tensor, out: torch.Tensor): + with torch.inference_mode(): + # TODO: move the hessian matrix to GPU memory if possible + if isinstance(inp, tuple): + inp = inp[0] + # use torch.float32 to avoid overflow in Hessian + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + # nn.Linear and *ParallelLinear in mcore + if "Linear" in type(mod).__name__: + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp.t_() + if isinstance(mod, nn.Conv2d): + unfold = nn.Unfold( + mod.kernel_size, + dilation=mod.dilation, + stride=mod.stride, + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + mod.hessian *= mod.samples / (mod.samples + tmp) + mod.samples += tmp + inp = math.sqrt(2 / mod.samples) * inp.float() + + # the hessian matrix is calculated as X * X^T + target_device = mod.hessian.device + if mod.hessian.device.type == "cuda": + if cls._is_memory_sufficient(mod.hessian.device.index, 0.8): + mod.hessian += inp.matmul(inp.t()).to(mod.hessian.device) + else: + target_device = "cpu" + mod.hessian = mod.hessian.to("cpu") + mod.hessian += inp.matmul(inp.t()).to(target_device) + + assert not torch.isinf(mod.hessian).any(), "Hessian contains inf" diff --git a/modelopt/torch/sparsity/weight_sparsity/sparsification.py b/modelopt/torch/sparsity/weight_sparsity/sparsification.py new file mode 100644 index 000000000..6af07db84 --- /dev/null +++ b/modelopt/torch/sparsity/weight_sparsity/sparsification.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""High-level API to automatically sparsify your model with various algorithms.""" + +from typing import Any + +from torch import nn + +from modelopt.torch.opt.conversion import apply_mode, get_mode +from modelopt.torch.opt.mode import ModeLike +from modelopt.torch.opt.searcher import BaseSearcher, SearchConfig +from modelopt.torch.utils import unwrap_model + +from .mode import SparsityModeRegistry + +__all__ = ["export", "sparsify"] + + +def sparsify( + model: nn.Module, mode: ModeLike, config: SearchConfig | None = None +) -> tuple[nn.Module, dict[str, Any]]: + """Sparsify a given model and search for they optimal sparsified weights. + + Args: + model: A standard model that contains standard building blocks to be sparsified in-place. + mode: A (list of) string(s) or Mode(s) or a list of tuples containing the mode and its + config indicating the desired mode(s) (and configurations) for the convert + process. Modes set up the model for different algorithms for model optimization. The + following modes are available: + + * :class:`"sparse_magnitude"`: + The ``model`` will be sparsified according to the magnitude of weights in each + layer. The mode's config is described in + :class:`SparseMagnitudeConfig`. + * :class:`"sparsegpt"`: + The ``model`` will be sparsified and weights are updated optimally using an Hessian + approximation of the loss function (see SparseGPT paper for details). The mode's + config is described in + :class:`SparseGPTConfig`. + + If the mode argument is specified as a dictionary, the keys should indicate the mode and + the values specify the per-mode configuration. If not provided, then default + configuration would be used. + + config: Additional optional arguments to configure the search. Currently, we support: + + * ``verbose``: Whether to print detailed search stats during search. + * ``forward_loop``: A ``Callable`` that takes a model as input and runs a forward loop + on it. It is recommended to choose the data loader used inside the forward loop + carefully to reduce the runtime. Cannot be provided at the same time as + ``data_loader`` and ``collect_func``. + * ``data_loader``: An iterator yielding batches of data for calibrating the + normalization layers in the model or compute gradient scores. It is recommended to use + the same data loader as for training but with significantly fewer iterations. Cannot + be provided at the same time as ``forward_loop``. + * ``collect_func``: A ``Callable`` that takes a batch of data from the data loader as + input and returns the input to ``model.forward()`` as described in + :meth:`run_forward_loop `. Cannot + be provided at the same time as ``forward_loop``. + + .. note:: + + Additional configuration options may be added by individual algorithms. Please + refer to the documentation of the individual algorithms for more information. + + Returns: A sparsified model + + .. note:: + + The given model is sparsified in-place. The returned model is thus a reference to the same + model instance as the input model. + """ + # apply sparsity to the model + model = apply_mode(model, mode, registry=SparsityModeRegistry) + + # retrieve searcher class + searcher_cls: type[BaseSearcher] = getattr(get_mode(model), "search_algorithm") + + # run search+sparsification algorithm + searcher = searcher_cls() + searcher.search(model, {}, (), config) + + # return the sparsified model + return model + + +def export(model: nn.Module) -> nn.Module: + """Export a sparse dynamic model to a regular model. + + This should be done after the model is fine-tuned and the weights are fixed. + + .. warning:: + + After the call to ``export()``, the sparsity mask will no longer be enforced. This means any + future weight updates would destroy the sparsity pattern. If you want to continue training, + call ``export()`` after training is finished. + """ + # unwrap a DP/DDP model + model = unwrap_model( + model, + warn=True, + msg=( + f"Unwrapping a {type(model).__name__} model for export! Note that the export is" + " in-place and the model wrapper should be re-created after export since the wrapper" + " might not support changing parameters after initialization." + ), + ) + + # apply export mode and return model + return apply_mode(model, "export_sparse", registry=SparsityModeRegistry) diff --git a/tests/_test_utils/torch_sparsity/sparse_attention_common.py b/tests/_test_utils/torch_sparsity/sparse_attention_common.py new file mode 100644 index 000000000..6a848fe7d --- /dev/null +++ b/tests/_test_utils/torch_sparsity/sparse_attention_common.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utilities for sparse attention testing.""" + +import torch +import torch.nn as nn + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule + + +# Test models for sparse attention +class SimpleAttentionModel(nn.Module): + """Simple attention model for testing.""" + + def __init__(self, hidden_size=256, num_heads=8): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.attention = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True + ) + self.fc = nn.Linear(hidden_size, hidden_size) + + def forward(self, x): + attn_output, _ = self.attention(x, x, x, need_weights=False) + return self.fc(attn_output) + + @classmethod + def get_input(cls, hidden_size=256, seq_len=10, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, hidden_size) + + +class SimpleTransformerEncoderLayer(nn.Module): + """Simple TransformerEncoderLayer wrapper for testing.""" + + def __init__(self, d_model=128, nhead=4, dim_feedforward=256): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True, + ) + + def forward(self, x): + return self.layer(x) + + @classmethod + def get_input(cls, d_model=128, seq_len=20, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, d_model) + + +class SimpleTransformerEncoder(nn.Module): + """Simple TransformerEncoder wrapper for testing.""" + + def __init__(self, d_model=128, nhead=4, num_layers=2): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.encoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True), + num_layers=num_layers, + ) + + def forward(self, x): + return self.encoder(x) + + @classmethod + def get_input(cls, d_model=128, seq_len=10, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, d_model) + + +# Test configurations +FLASH_SOFTMAX_SKIP_DEFAULT_CFG = { + "method": "flash_softmax_skip", + "sparse_cfg": {"*attn*": {"threshold": 1e-4, "br": 128, "bc": 128, "enable": True}}, +} + +FLASH_SOFTMAX_SKIP_PHASE_AWARE_CFG = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, + "bc": 128, + "enable": True, + } + }, +} + +FLASH_SOFTMAX_SKIP_STATS_CFG = { + "method": "flash_softmax_skip", + "collect_stats": True, + "sparse_cfg": { + "*attn*": { + "threshold": 1e-4, + "br": 128, + "bc": 128, + "collect_stats": True, + "enable": True, + } + }, +} + +FLASH_SOFTMAX_SKIP_CALIBRATION_CFG = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "br": 128, + "bc": 128, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 6, + "max_seqlen": 1024, + }, + } + }, +} + + +def get_test_configs(): + """Get test configurations for parameterized tests. + + Note: Calibration config excluded (requires GPU and real tokenizers). + """ + return [FLASH_SOFTMAX_SKIP_DEFAULT_CFG, FLASH_SOFTMAX_SKIP_PHASE_AWARE_CFG] + + +def sparsify_model_and_forward(model, config, calib_data): + """Apply sparse attention and run forward passes. + + Args: + model: Model to sparsify + config: Sparse attention configuration + calib_data: List of calibration data tensors + + Returns: + Sparsified model + """ + + def forward_loop(model): + for batch in calib_data: + model(batch) + + # Apply sparse attention + model = sparse_attn.sparsify(model, config, forward_loop=forward_loop) + + # Verify sparse attention modules were inserted + assert any(isinstance(m, SparseAttentionModule) for m in model.modules()), ( + "No sparse attention modules found" + ) + + # Test forward passes + model.eval() + with torch.no_grad(): + for batch in calib_data: + output = model(batch) + assert not torch.isnan(output).any(), "NaN in output" + assert output is not None, "Output is None" + + return model + + +def save_restore_test(model_cls, device, sparse_config): + """Test save and restore of sparse attention state. + + Args: + model_cls: Model class to test + device: Device to run on ('cpu' or 'cuda') + sparse_config: Sparse attention configuration + """ + # Create and sparsify reference model + model_sparse = model_cls().to(device) + calib_data = [model_sparse.get_input().to(device) for _ in range(2)] + + sparsify_model_and_forward(model_sparse, sparse_config, calib_data) + + # Save state + state_dict = mto.modelopt_state(model_sparse) + + # Restore to new model + model_restored = model_cls().to(device) + mto.restore_from_modelopt_state(model_restored, state_dict) + model_restored.load_state_dict(model_sparse.state_dict()) + + # Verify outputs match + test_input = calib_data[0] + model_sparse.eval() + model_restored.eval() + + with torch.no_grad(): + output_sparse = model_sparse(test_input) + output_restored = model_restored(test_input) + + assert torch.allclose(output_sparse, output_restored, atol=1e-6), ( + "Restored model output doesn't match original" + ) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_basic_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_basic_gpu.py new file mode 100644 index 000000000..68651c306 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_basic_gpu.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for attention sparsity module.""" + +import pytest +import torch +from _test_utils.torch_sparsity.sparse_attention_common import ( + FLASH_SOFTMAX_SKIP_DEFAULT_CFG, + SimpleAttentionModel, + SimpleTransformerEncoder, + SimpleTransformerEncoderLayer, + get_test_configs, + save_restore_test, + sparsify_model_and_forward, +) + +import modelopt.torch.sparsity.attention_sparsity as sparse_attn + +# Skip all tests if GPU is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + + +class TestAttentionSparsityGPU: + """GPU tests for attention sparsity.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup for each test.""" + self.device = torch.device("cuda") + torch.cuda.empty_cache() + + @pytest.mark.parametrize( + "model_cls", + [SimpleAttentionModel, SimpleTransformerEncoderLayer, SimpleTransformerEncoder], + ) + @pytest.mark.parametrize("config", get_test_configs()) + def test_gpu_forward(self, model_cls, config): + """Test sparse attention forward pass on GPU.""" + model = model_cls().to(self.device) + calib_data = [model.get_input().to(self.device) for _ in range(2)] + + sparsify_model_and_forward(model, config, calib_data) + + # Additional GPU-specific checks + for batch in calib_data: + with torch.no_grad(): + output = model(batch) + assert output.device.type == "cuda" + + @pytest.mark.parametrize( + "model_cls", + [SimpleAttentionModel, SimpleTransformerEncoderLayer, SimpleTransformerEncoder], + ) + def test_save_restore(self, model_cls): + """Test save and restore on GPU.""" + save_restore_test(model_cls, "cuda", FLASH_SOFTMAX_SKIP_DEFAULT_CFG) + + def test_memory_efficiency(self): + """Test that sparse attention can handle larger sequences.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).to(self.device) + calib_data = [model.get_input(d_model=256, seq_len=32).to(self.device) for _ in range(2)] + + sparse_model = sparsify_model_and_forward(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG, calib_data) + + # Test with longer sequence + x = model.get_input(d_model=256, seq_len=512).to(self.device) + + with torch.no_grad(): + output = sparse_model(x) + + assert not torch.isnan(output).any() + + def test_mixed_precision(self): + """Test sparse attention with mixed precision.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).to(self.device) + calib_data = [model.get_input(d_model=256).to(self.device) for _ in range(2)] + + sparse_model = sparsify_model_and_forward(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG, calib_data) + + # Test with fp16 + sparse_model = sparse_model.half() + x = model.get_input(d_model=256, seq_len=64).to(self.device).half() + + with torch.no_grad(): + output = sparse_model(x) + + assert output.dtype == torch.float16 + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_different_dtypes(self, dtype): + """Test sparse attention with different dtypes.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).to(self.device).to(dtype) + calib_data = [model.get_input(d_model=256).to(self.device).to(dtype) for _ in range(2)] + + sparse_model = sparsify_model_and_forward(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG, calib_data) + + # Test forward + x = model.get_input(d_model=256).to(self.device).to(dtype) + with torch.no_grad(): + output = sparse_model(x) + + assert output.dtype == dtype + assert not torch.isnan(output).any() + if dtype != torch.bfloat16: # bfloat16 can have inf + assert not torch.isinf(output).any() + + def test_backward_pass(self): + """Test that gradients flow correctly through sparse attention.""" + model = SimpleAttentionModel(hidden_size=128, num_heads=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG) + + # Enable training mode + model.train() + + x = model.get_input(hidden_size=128, seq_len=32).to(self.device) + x.requires_grad = True + + # Forward + output = model(x) + loss = output.sum() + + # Backward + loss.backward() + + # Check gradients exist + assert x.grad is not None + assert not torch.isnan(x.grad).any() + + # Check model gradients + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" + + @pytest.mark.parametrize("seq_len", [1, 64, 256, 1024, 2048]) + def test_various_sequence_lengths(self, seq_len): + """Test sparse attention with various sequence lengths.""" + model = SimpleAttentionModel(hidden_size=128, num_heads=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG) + + x = model.get_input(hidden_size=128, seq_len=seq_len, batch_size=1).to(self.device) + + model.eval() + with torch.no_grad(): + output = model(x) + + assert output.shape == (1, seq_len, 128) + assert not torch.isnan(output).any() + + @pytest.mark.parametrize("batch_size", [1, 2, 8, 16]) + def test_various_batch_sizes(self, batch_size): + """Test sparse attention with various batch sizes.""" + model = SimpleTransformerEncoderLayer(d_model=128, nhead=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG) + + x = model.get_input(d_model=128, seq_len=64, batch_size=batch_size).to(self.device) + + model.eval() + with torch.no_grad(): + output = model(x) + + assert output.shape == (batch_size, 64, 128) + assert not torch.isnan(output).any() + + @pytest.mark.skip(reason="Performance benchmarking not needed for basic CI") + def test_performance_comparison(self): + """Compare performance of sparse vs dense attention.""" + # Skipped for basic CI - can be run manually for performance analysis + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py new file mode 100644 index 000000000..f32888ead --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for sparse attention calibration.""" + +import pytest +import torch +from _test_utils.torch_sparsity.sparse_attention_common import SimpleTransformerEncoderLayer + +import modelopt.torch.opt as mto +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.calibration import RulerDatasetBuilder +from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule + +# Skip all tests if no GPU available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required") + + +class TestRulerDatasetBuilderGPU: + """Test RULER dataset generation with real tokenizers on GPU.""" + + def test_ruler_generation_with_real_tokenizer(self): + """Test RULER generation with GPT2 tokenizer.""" + builder = RulerDatasetBuilder( + samples=6, # Need at least 6 samples (1 per task) + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should generate 6 samples (1 per task) + assert len(dataset) == 6 + + # All samples should have valid structure + for sample in dataset: + assert "input" in sample + assert "length" in sample + assert sample["length"] > 0 + + def test_generated_length_accuracy(self): + """Test that generated token counts are accurate.""" + builder = RulerDatasetBuilder( + samples=3, + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Check that lengths are within reasonable range of target + for sample in dataset: + # RULER aims for 70-90% of target for context + assert 700 < sample["length"] < 1400 + + def test_multiple_subtasks(self): + """Test generation with multiple RULER subtasks.""" + builder = RulerDatasetBuilder( + samples=12, # Need at least 6 for 1 per task, use 12 for 2 per task + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Check task distribution (should have multiple tasks from RULER_TASKS) + tasks_found = {s["task"] for s in dataset} + assert len(tasks_found) >= 2 # At least 2 different tasks + + def test_large_context_lengths(self): + """Test with larger context lengths.""" + builder = RulerDatasetBuilder( + samples=24, # 4 lengths * 6 tasks = need 24 for 1 per (length, task) + max_seqlen=8192, # Generates: [8192, 4096, 2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + assert len(dataset) == 24 + + # Verify we have different lengths + lengths = [s["length"] for s in dataset] + # Should have variety of lengths across the bins + assert len(set(lengths)) > 1 # At least 2 different target lengths used + + +class TestCalibrationGPU: + """Test calibration with real models on GPU.""" + + @pytest.fixture + def simple_model(self): + """Create simple attention model for testing.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + return model + + def test_calibration_simple_model(self, simple_model): + """Test calibration with simple attention model.""" + model = simple_model + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "br": 64, + "bc": 64, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + # Simple forward loop for calibration + pass + + # Apply sparse attention with calibration + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Verify sparse attention modules exist + sparse_modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + assert len(sparse_modules) > 0 + + # Verify calibration was applied + for module in sparse_modules: + method = module._sparse_method_instance + # Check if calibrated threshold scale factor is set + if hasattr(method, "threshold_scale_factor") and method.threshold_scale_factor: + assert method.threshold_scale_factor > 0 + + def test_calibration_pytorch_backend(self, simple_model): + """Test calibration with pytorch backend.""" + model = simple_model + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Check backend is set correctly + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + method = module._sparse_method_instance + assert hasattr(method, "backend") + assert method.backend == "pytorch" + + def test_simplified_calibration(self, simple_model): + """Test simplified calibration (prefill phase only).""" + model = simple_model + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Should complete without errors + assert sparse_model is not None + + def test_calibration_persistence(self, simple_model): + """Test save and restore of calibrated model.""" + model = simple_model + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate model + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Save modelopt state + modelopt_state = mto.modelopt_state(sparse_model) + + # Create new model and restore + model_restored = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + + restored = mto.restore_from_modelopt_state(model_restored, modelopt_state) + + # Check that sparse attention is restored + has_sparse = any(isinstance(m, SparseAttentionModule) for m in restored.modules()) + assert has_sparse + + +class TestCalibrationEndToEnd: + """Integration tests with inference.""" + + @pytest.fixture + def simple_model_setup(self): + """Setup simple model.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + return model + + def test_calibrated_model_inference(self, simple_model_setup): + """Test inference with calibrated model.""" + model = simple_model_setup + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate model + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Test inference + test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=10).cuda() + + sparse_model.eval() + with torch.no_grad(): + output = sparse_model(test_input) + + # Check output is valid + assert output is not None + assert not torch.isnan(output).any() + + def test_calibrated_vs_fixed_threshold(self, simple_model_setup): + """Compare calibrated vs fixed threshold models.""" + # Config with calibration + config_calibrated = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + # Config with fixed threshold (no calibration) + config_fixed = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "enable": True, + } + }, + } + + def forward_loop(model): + pass + + # Test both can be created + model_calibrated = sparsify( + SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda(), + config_calibrated, + forward_loop=forward_loop, + ) + + model_fixed = sparsify( + SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda(), + config_fixed, + ) + + # Both should work for inference + test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=10).cuda() + + with torch.no_grad(): + output_calibrated = model_calibrated(test_input) + output_fixed = model_fixed(test_input) + + assert output_calibrated is not None + assert output_fixed is not None + + def test_memory_usage(self, simple_model_setup): + """Test that calibration doesn't cause memory issues.""" + model = simple_model_setup + + # Clear cache before test + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate + sparsify(model, config, forward_loop=forward_loop) + + # Check memory didn't explode + final_memory = torch.cuda.memory_allocated() + memory_increase = final_memory - initial_memory + + # Memory should be reasonable (not more than 2GB increase) + assert memory_increase < 2 * 1024**3 # 2GB + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py new file mode 100644 index 000000000..155524770 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py @@ -0,0 +1,335 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for sparse attention HF checkpoint export. + +Note: These tests use HuggingFace models created with create_tiny_llama_dir() rather than +the simple test models from sparse_attention_common.py because export_hf_checkpoint() +requires HF-specific features (model.save_pretrained(), model.config, etc.). +""" + +import json + +import pytest +import torch +from _test_utils.torch_model.transformers_models import create_tiny_llama_dir +from _test_utils.torch_sparsity.sparse_attention_common import FLASH_SOFTMAX_SKIP_CALIBRATION_CFG +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.sparsity.attention_sparsity.config import SparseAttentionConfig + +# Skip all tests if GPU is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Create tiny Llama model directory for testing. + + Uses create_tiny_llama_dir() to create a minimal local HF model + without downloading, which is faster and doesn't require network access. + """ + tmp_path = tmp_path_factory.mktemp("models") + return create_tiny_llama_dir(tmp_path, with_tokenizer=True, num_hidden_layers=2) + + +@pytest.fixture(scope="module") +def tinyllama_model(tiny_llama_dir): + """Load tiny Llama model for testing.""" + model = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="eager", + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + return model + + +@pytest.fixture(scope="module") +def tinyllama_tokenizer(tiny_llama_dir): + """Load tiny Llama tokenizer for testing.""" + return AutoTokenizer.from_pretrained(tiny_llama_dir) + + +class TestSparseAttentionExport: + """Test sparse attention model export to HF unified checkpoint.""" + + def test_export_non_calibrated_model(self, tinyllama_model, tmp_path): + """Test export of non-calibrated sparse attention model.""" + model = tinyllama_model + + # Apply sparse attention without calibration + config = SparseAttentionConfig( + method="flash_softmax_skip", + sparse_cfg={ + "*attn*": { + "threshold": 1e-4, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Export to temporary directory + export_dir = tmp_path / "non_calibrated_export" + export_hf_checkpoint(sparse_model, export_dir=export_dir) + + # Verify config.json was created + config_path = export_dir / "config.json" + assert config_path.exists(), "config.json not found" + + # Load and verify sparse_attention_config + with open(config_path) as f: + exported_config = json.load(f) + + assert "sparse_attention_config" in exported_config, "sparse_attention_config not found" + sparse_config = exported_config["sparse_attention_config"] + + # Verify structure for non-calibrated model + assert "config_groups" in sparse_config + assert "producer" in sparse_config + assert sparse_config["producer"]["name"] == "modelopt" + + # Should NOT have global calibration parameters + assert "threshold_scale_factor" not in sparse_config + assert "target_sparsity" not in sparse_config + + # Verify config_groups has threshold + groups = sparse_config["config_groups"] + assert len(groups) > 0, "No groups found in sparse_attention_config" + + # Check first group + group_0 = groups["group_0"] + + assert "sparse_algo" in group_0 + assert group_0["sparse_algo"] == "softmax_skip" + assert "threshold" in group_0 + assert isinstance(group_0["threshold"], float) + assert group_0["threshold"] == 1e-4 + assert "targets" in group_0 + assert isinstance(group_0["targets"], list) + assert len(group_0["targets"]) > 0 + + def test_export_non_calibrated_phase_aware_model(self, tinyllama_model, tmp_path): + """Test export of non-calibrated model with phase-aware thresholds.""" + model = tinyllama_model + + # Apply sparse attention with phase-aware thresholds + config = SparseAttentionConfig( + method="flash_softmax_skip", + sparse_cfg={ + "*attn*": { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Export to temporary directory + export_dir = tmp_path / "phase_aware_export" + export_hf_checkpoint(sparse_model, export_dir=export_dir) + + # Load and verify config + config_path = export_dir / "config.json" + with open(config_path) as f: + exported_config = json.load(f) + + sparse_config = exported_config["sparse_attention_config"] + groups = sparse_config["config_groups"] + + # Check first group has phase-aware threshold dict + group_0 = groups["group_0"] + + assert "threshold" in group_0 + threshold = group_0["threshold"] + assert isinstance(threshold, dict) + assert "prefill" in threshold + assert "decode" in threshold + assert threshold["prefill"] == 1e-3 + assert threshold["decode"] == 1e-5 + assert "targets" in group_0 + + def test_export_calibrated_model(self, tiny_llama_dir, tmp_path): + """Test export of calibrated sparse attention model.""" + # Load a fresh model instance for calibration + model = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="eager", + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + + # Apply sparse attention with calibration (use common config) + config = SparseAttentionConfig(**FLASH_SOFTMAX_SKIP_CALIBRATION_CFG) + + # Create a simple forward loop for calibration + def forward_loop(model): + """Simple forward loop for calibration.""" + device = next(model.parameters()).device + # Create a few samples of different lengths + for seq_len in [512, 768, 1024]: + input_ids = torch.randint(0, 100, (1, seq_len), device=device) + with torch.no_grad(): + model(input_ids) + + # Sparsify with calibration + sparse_model = sparse_attn.sparsify(model, config, forward_loop=forward_loop) + + # Export to temporary directory + export_dir = tmp_path / "calibrated_export" + export_hf_checkpoint(sparse_model, export_dir=export_dir) + + # Verify config.json + config_path = export_dir / "config.json" + assert config_path.exists() + + with open(config_path) as f: + exported_config = json.load(f) + + assert "sparse_attention_config" in exported_config + sparse_config = exported_config["sparse_attention_config"] + + # Verify structure for calibrated model + assert "config_groups" in sparse_config + assert "producer" in sparse_config + + # SHOULD have global calibration parameters + assert "threshold_scale_factor" in sparse_config + assert "target_sparsity" in sparse_config + + # Verify calibration values + assert isinstance(sparse_config["threshold_scale_factor"], float) + assert sparse_config["threshold_scale_factor"] > 0 + assert sparse_config["target_sparsity"] == 0.5 + + # Verify config_groups do NOT have threshold (calibrated) + groups = sparse_config["config_groups"] + assert len(groups) > 0 + + group_0 = groups["group_0"] + + assert "sparse_algo" in group_0 + assert group_0["sparse_algo"] == "softmax_skip" + # Should NOT have threshold field for calibrated model + assert "threshold" not in group_0 + assert "targets" in group_0 + + def test_export_model_without_sparse_attention(self, tiny_llama_dir, tmp_path): + """Test export of model without sparse attention.""" + # Load a fresh model instance (not the shared fixture) + model = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="eager", + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + + # Export without applying sparse attention + export_dir = tmp_path / "no_sparse_export" + export_hf_checkpoint(model, export_dir=export_dir) + + # Verify config.json + config_path = export_dir / "config.json" + assert config_path.exists() + + with open(config_path) as f: + exported_config = json.load(f) + + # Should NOT have sparse_attention_config + assert "sparse_attention_config" not in exported_config + + def test_export_disabled_sparse_attention(self, tinyllama_model, tmp_path): + """Test export of model with disabled sparse attention modules.""" + model = tinyllama_model + + # Apply sparse attention + config = SparseAttentionConfig( + method="flash_softmax_skip", + sparse_cfg={ + "*attn*": { + "threshold": 1e-4, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Disable all sparse attention modules + sparse_attn.disable_sparse_attention(sparse_model, "*") + + # Export to temporary directory + export_dir = tmp_path / "disabled_export" + export_hf_checkpoint(sparse_model, export_dir=export_dir) + + # Verify config.json + config_path = export_dir / "config.json" + with open(config_path) as f: + exported_config = json.load(f) + + # Should NOT have sparse_attention_config (all modules disabled) + assert "sparse_attention_config" not in exported_config + + def test_export_all_layers_have_same_config(self, tinyllama_model, tmp_path): + """Test that all layers in exported config have consistent sparse_algo.""" + model = tinyllama_model + + config = SparseAttentionConfig( + method="flash_softmax_skip", + sparse_cfg={ + "*attn*": { + "threshold": 1e-4, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Export + export_dir = tmp_path / "consistent_config_export" + export_hf_checkpoint(sparse_model, export_dir=export_dir) + + # Load config + config_path = export_dir / "config.json" + with open(config_path) as f: + exported_config = json.load(f) + + groups = exported_config["sparse_attention_config"]["config_groups"] + + # Verify there's only one config group (all layers have same config) + assert len(groups) == 1, f"Expected 1 group (all layers same config), got: {len(groups)}" + + # Verify the single group has correct algo + group_0 = groups["group_0"] + assert group_0["sparse_algo"] == "softmax_skip" + assert "targets" in group_0 diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py new file mode 100644 index 000000000..d2cae4675 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Real-world testing with small Llama3/TinyLlama model.""" + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig +from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule + +# Skip all tests if GPU is not available +# Note: These tests are slower due to model loading +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + + +@pytest.fixture(scope="module") +def tinyllama_model(): + """Load TinyLlama model for testing.""" + try: + model = AutoModelForCausalLM.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + torch_dtype=torch.float16, + device_map="cuda", + ) + return model + except Exception as e: + pytest.skip(f"Could not load TinyLlama model: {e}") + + +@pytest.fixture(scope="module") +def tinyllama_tokenizer(): + """Load TinyLlama tokenizer for testing.""" + try: + tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + return tokenizer + except Exception as e: + pytest.skip(f"Could not load TinyLlama tokenizer: {e}") + + +class TestLlama3Basic: + """Basic Llama3/TinyLlama sparse attention tests.""" + + def test_load_and_sparsify(self, tinyllama_model): + """Load TinyLlama and apply sparse attention.""" + model = tinyllama_model + + config = SparseAttentionConfig( + method="flash_softmax_skip", + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Verify sparse attention modules were added + sparse_count = sum( + 1 for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) + ) + assert sparse_count > 0, "No sparse attention modules found" + + # TinyLlama has 22 layers, so should have 22 attention modules + assert sparse_count > 10, f"Expected >10 sparse modules, got {sparse_count}" + + def test_forward_prefill(self, tinyllama_model, tinyllama_tokenizer): + """Forward pass with seq_len=64 (prefill).""" + model = tinyllama_model + tokenizer = tinyllama_tokenizer + + config = SparseAttentionConfig( + method="flash_softmax_skip", + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Create prefill input (seq_len > 1) + test_text = "Once upon a time in a land far away" + inputs = tokenizer(test_text, return_tensors="pt").to("cuda") + + # Forward pass + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(**inputs) + + # Verify output + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + assert outputs.logits.shape[1] == inputs.input_ids.shape[1] # seq_len preserved + + def test_forward_decode(self, tinyllama_model): + """Forward pass with seq_len=1 (decode).""" + model = tinyllama_model + + config = SparseAttentionConfig( + method="flash_softmax_skip", + sparse_cfg={ + "*attn*": { + "threshold": 1e-5, # More conservative for decode + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Create decode input (seq_len = 1) + input_ids = torch.randint(0, 32000, (1, 1), device="cuda") + + # Forward pass + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(input_ids) + + # Verify output + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + assert outputs.logits.shape == (1, 1, 32000) # batch=1, seq=1, vocab_size + + def test_gqa_attention(self, tinyllama_model): + """Verify GQA support (num_kv_heads < num_heads).""" + model = tinyllama_model + + # Check if model uses GQA + config = model.config + has_gqa = hasattr(config, "num_key_value_heads") and ( + config.num_key_value_heads < config.num_attention_heads + ) + + if not has_gqa: + pytest.skip("Model does not use GQA") + + # Apply sparse attention + sparse_config = SparseAttentionConfig( + method="flash_softmax_skip", + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, sparse_config) + + # Test forward pass with GQA + input_ids = torch.randint(0, 32000, (1, 32), device="cuda") + + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(input_ids) + + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + + +class TestLlama3Calibration: + """Llama3/TinyLlama calibration tests.""" + + def test_calibration_with_ruler(self, tinyllama_model): + """Full calibration with RULER dataset.""" + model = tinyllama_model + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + """Simple forward loop for calibration.""" + test_input = torch.randint(0, 32000, (1, 64), device="cuda") + with torch.no_grad(): + model(test_input) + + # Apply sparsification with calibration + sparse_model = sparse_attn.sparsify(model, config, forward_loop=forward_loop) + + # Verify sparse modules exist + sparse_count = sum( + 1 for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) + ) + assert sparse_count > 0 + + def test_phase_aware_thresholds(self, tinyllama_model, tinyllama_tokenizer): + """Test prefill vs decode threshold differences.""" + model = tinyllama_model + tokenizer = tinyllama_tokenizer + + config = SparseAttentionConfig( + method="flash_softmax_skip", + sparse_cfg={ + "*attn*": { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Test prefill phase + prefill_text = "Once upon a time" + prefill_inputs = tokenizer(prefill_text, return_tensors="pt").to("cuda") + + sparse_model.eval() + with torch.no_grad(): + prefill_output = sparse_model(**prefill_inputs) + + assert not torch.isnan(prefill_output.logits).any() + + # Test decode phase + decode_input = torch.randint(0, 32000, (1, 1), device="cuda") + + with torch.no_grad(): + decode_output = sparse_model(decode_input) + + assert not torch.isnan(decode_output.logits).any() + + def test_calibration_persistence(self, tinyllama_model): + """Save and restore calibrated model.""" + model = tinyllama_model + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + } + }, + } + + # Sparsify model + sparse_model = sparse_attn.sparsify(model, config) + + # Save modelopt state + modelopt_state = mto.modelopt_state(sparse_model) + + # Verify state is not empty + assert modelopt_state is not None + assert isinstance(modelopt_state, dict) + + +class TestLlama3Inference: + """Llama3/TinyLlama inference tests.""" + + def test_text_generation(self, tinyllama_model, tinyllama_tokenizer): + """Generate text with sparse attention.""" + model = tinyllama_model + tokenizer = tinyllama_tokenizer + + config = SparseAttentionConfig( + method="flash_softmax_skip", + sparse_cfg={ + "*attn*": { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Generate text + prompt = "Once upon a time" + inputs = tokenizer(prompt, return_tensors="pt").to("cuda") + + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model.generate( + **inputs, + max_new_tokens=20, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + + # Verify generation worked + assert outputs is not None + assert outputs.shape[1] > inputs.input_ids.shape[1] # Generated new tokens + + # Decode to verify it's valid text + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + assert len(generated_text) > len(prompt) + + def test_both_backends(self, tinyllama_model): + """Test PyTorch and Triton backends.""" + model_pytorch = tinyllama_model + + # Test PyTorch backend + config_pytorch = SparseAttentionConfig( + method="flash_softmax_skip", + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_pytorch = sparse_attn.sparsify(model_pytorch, config_pytorch) + + test_input = torch.randint(0, 32000, (1, 32), device="cuda") + + sparse_pytorch.eval() + with torch.no_grad(): + output_pytorch = sparse_pytorch(test_input) + + assert not torch.isnan(output_pytorch.logits).any() + + def test_sparsity_statistics(self, tinyllama_model): + """Collect and verify sparsity stats.""" + model = tinyllama_model + + config = SparseAttentionConfig( + method="flash_softmax_skip", + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "collect_stats": True, + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Run forward pass + test_input = torch.randint(0, 32000, (1, 64), device="cuda") + + sparse_model.eval() + with torch.no_grad(): + sparse_model(test_input) + + # Check if stats were collected + stats_collected = False + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + if hasattr(module, "_sparse_method_instance"): + method = module._sparse_method_instance + if hasattr(method, "stats") and method.stats: + stats_collected = True + # Verify stats have expected keys + assert "sparsity" in method.stats or "total_blocks" in method.stats + break + + # Stats collection may not always be enabled + if not stats_collected: + pytest.skip("Statistics collection not available") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gpu/torch/sparsity/plugins/test_megatron_sparsity.py b/tests/gpu/torch/sparsity/weight_sparsity/plugins/test_megatron_sparsity.py similarity index 100% rename from tests/gpu/torch/sparsity/plugins/test_megatron_sparsity.py rename to tests/gpu/torch/sparsity/weight_sparsity/plugins/test_megatron_sparsity.py diff --git a/tests/gpu/torch/sparsity/test_sparse_fsdp.py b/tests/gpu/torch/sparsity/weight_sparsity/test_sparse_fsdp.py similarity index 100% rename from tests/gpu/torch/sparsity/test_sparse_fsdp.py rename to tests/gpu/torch/sparsity/weight_sparsity/test_sparse_fsdp.py diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_export_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_export_config.py new file mode 100644 index 000000000..4bc2953e6 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_export_config.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sparse attention export configuration extraction.""" + +from _test_utils.torch_sparsity.sparse_attention_common import ( + FLASH_SOFTMAX_SKIP_DEFAULT_CFG, + FLASH_SOFTMAX_SKIP_PHASE_AWARE_CFG, + SimpleTransformerEncoderLayer, + sparsify_model_and_forward, +) + +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.export.unified_export_hf import _get_sparse_attention_config + + +class TestSparseAttentionConfigExtraction: + """Test sparse attention config extraction for export.""" + + def test_extract_non_calibrated_config(self): + """Test extraction of non-calibrated sparse attention config.""" + model = SimpleTransformerEncoderLayer() + calib_data = [model.get_input() for _ in range(2)] + + # Apply sparse attention without calibration + sparse_model = sparsify_model_and_forward(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG, calib_data) + + # Extract config + extracted_config = _get_sparse_attention_config(sparse_model) + + # Verify structure + assert "config_groups" in extracted_config + assert "producer" in extracted_config + assert extracted_config["producer"]["name"] == "modelopt" + + # Should NOT have calibration parameters + assert "threshold_scale_factor" not in extracted_config + assert "target_sparsity" not in extracted_config + + # Check config_groups + groups = extracted_config["config_groups"] + assert len(groups) > 0 + + # Verify first group has sparse_algo and threshold + group_0 = groups["group_0"] + assert "sparse_algo" in group_0 + assert group_0["sparse_algo"] == "softmax_skip" + assert "threshold" in group_0 + assert group_0["threshold"] == 1e-4 + assert "targets" in group_0 + assert isinstance(group_0["targets"], list) + assert len(group_0["targets"]) > 0 + + def test_extract_phase_aware_config(self): + """Test extraction of phase-aware threshold config.""" + model = SimpleTransformerEncoderLayer() + calib_data = [model.get_input() for _ in range(2)] + + # Apply sparse attention with phase-aware thresholds + sparse_model = sparsify_model_and_forward( + model, FLASH_SOFTMAX_SKIP_PHASE_AWARE_CFG, calib_data + ) + + # Extract config + extracted_config = _get_sparse_attention_config(sparse_model) + + # Check config_groups have phase-aware threshold + groups = extracted_config["config_groups"] + group_0 = groups["group_0"] + assert "threshold" in group_0 + threshold = group_0["threshold"] + assert isinstance(threshold, dict) + assert "prefill" in threshold + assert "decode" in threshold + assert threshold["prefill"] == 1e-3 + assert threshold["decode"] == 1e-5 + + def test_extract_empty_config_no_sparse_modules(self): + """Test extraction returns empty dict when no sparse modules.""" + model = SimpleTransformerEncoderLayer() + + # Don't apply sparse attention + extracted_config = _get_sparse_attention_config(model) + + # Should return empty dict + assert extracted_config == {} + + def test_extract_empty_config_all_disabled(self): + """Test extraction returns empty dict when all modules disabled.""" + model = SimpleTransformerEncoderLayer() + calib_data = [model.get_input() for _ in range(2)] + + # Apply sparse attention + sparse_model = sparsify_model_and_forward(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG, calib_data) + + # Disable all modules + sparse_attn.disable_sparse_attention(sparse_model, "*") + + # Extract config + extracted_config = _get_sparse_attention_config(sparse_model) + + # Should return empty dict + assert extracted_config == {} + + def test_sparse_algo_name_extraction(self): + """Test that sparse_algo is correctly extracted from method name.""" + model = SimpleTransformerEncoderLayer() + calib_data = [model.get_input() for _ in range(2)] + + # Apply sparse attention + sparse_model = sparsify_model_and_forward(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG, calib_data) + + # Extract config + extracted_config = _get_sparse_attention_config(sparse_model) + + # Verify sparse_algo is "softmax_skip" (stripped "flash_" prefix) + groups = extracted_config["config_groups"] + for group_config in groups.values(): + assert group_config["sparse_algo"] == "softmax_skip" + + def test_mock_calibrated_config(self): + """Test extraction with mock calibrated parameters.""" + model = SimpleTransformerEncoderLayer() + calib_data = [model.get_input() for _ in range(2)] + + # Apply sparse attention + sparse_model = sparsify_model_and_forward(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG, calib_data) + + # Manually set calibration parameters on method instances + # (simulating what calibration does) + for module in sparse_model.modules(): + if hasattr(module, "_sparse_method_instance"): + module._sparse_method_instance.threshold_scale_factor = 0.00123456 + module._sparse_method_instance.target_sparsity = 0.5 + + # Extract config + extracted_config = _get_sparse_attention_config(sparse_model) + + # Verify global calibration parameters + assert "threshold_scale_factor" in extracted_config + assert "target_sparsity" in extracted_config + assert extracted_config["threshold_scale_factor"] == 0.00123456 + assert extracted_config["target_sparsity"] == 0.5 + + # Verify config_groups do NOT have threshold field (calibrated) + groups = extracted_config["config_groups"] + for group_config in groups.values(): + assert "sparse_algo" in group_config + assert "threshold" not in group_config + assert "targets" in group_config + + def test_producer_metadata(self): + """Test that producer metadata is correctly added.""" + model = SimpleTransformerEncoderLayer() + calib_data = [model.get_input() for _ in range(2)] + + # Apply sparse attention + sparse_model = sparsify_model_and_forward(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG, calib_data) + + # Extract config + extracted_config = _get_sparse_attention_config(sparse_model) + + # Verify producer metadata + assert "producer" in extracted_config + producer = extracted_config["producer"] + assert "name" in producer + assert "version" in producer + assert producer["name"] == "modelopt" + assert isinstance(producer["version"], str) + assert len(producer["version"]) > 0 + + def test_targets_in_config(self): + """Test that targets field contains module class names.""" + model = SimpleTransformerEncoderLayer() + calib_data = [model.get_input() for _ in range(2)] + + # Apply sparse attention + sparse_model = sparsify_model_and_forward(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG, calib_data) + + # Extract config + extracted_config = _get_sparse_attention_config(sparse_model) + + # Verify targets contain class names + groups = extracted_config["config_groups"] + for group_config in groups.values(): + assert "targets" in group_config + targets = group_config["targets"] + assert isinstance(targets, list) + assert len(targets) > 0 + # Should contain attention class names + assert all(isinstance(t, str) for t in targets) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py new file mode 100644 index 000000000..051ea20db --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -0,0 +1,450 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sparse attention calibration.""" + +import numpy as np +import pytest +import torch.nn as nn +from _test_utils.torch_sparsity.sparse_attention_common import SimpleAttentionModel + +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.calibration import ( + DynamicThresholdCalibrator, + RulerDatasetBuilder, +) +from modelopt.torch.sparsity.attention_sparsity.calibration.dataset import _generate_target_lengths +from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule + + +class TestLengthGeneration: + """Test automatic target length generation.""" + + def test_generate_target_lengths_default(self): + """Test default 4 bins generation.""" + lengths = _generate_target_lengths(32768, num_length_bins=4) + assert lengths == [32768, 16384, 8192, 4096] + + def test_generate_target_lengths_stops_at_minimum(self): + """Test generation stops at minimum threshold.""" + lengths = _generate_target_lengths(2048, num_length_bins=4) + assert lengths == [2048, 1024] # Stops at 1024 + + def test_generate_target_lengths_fewer_bins(self): + """Test with fewer bins.""" + lengths = _generate_target_lengths(16384, num_length_bins=2) + assert lengths == [16384, 8192] + + def test_generate_target_lengths_more_bins(self): + """Test with more bins.""" + lengths = _generate_target_lengths(65536, num_length_bins=6) + assert lengths == [65536, 32768, 16384, 8192, 4096, 2048] + + def test_generate_target_lengths_exactly_minimum(self): + """Test when max_seqlen equals minimum.""" + lengths = _generate_target_lengths(1024, num_length_bins=4) + assert lengths == [1024] + + +class TestRulerDatasetBuilder: + """Test RULER dataset generation without requiring real tokenizers.""" + + def test_builder_initialization(self): + """Test that builder initializes correctly.""" + builder = RulerDatasetBuilder( + samples=12, + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + assert builder.total_samples == 12 + assert builder.max_seqlen == 2048 + assert builder.target_lengths == [2048, 1024] + assert builder.samples_per_length == [6, 6] # Evenly distributed + assert len(builder.subtasks) == 6 # All RULER_TASKS + assert builder.seed == 42 + + def test_builder_initialization_invalid_config(self): + """Test that builder raises error for invalid inputs.""" + # Test invalid samples + with pytest.raises(ValueError, match="samples must be positive"): + RulerDatasetBuilder( + samples=0, + max_seqlen=2048, + tokenizer_name_or_path="gpt2", + ) + + # Test max_seqlen below minimum + with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): + RulerDatasetBuilder( + samples=4, + max_seqlen=512, # Below minimum + tokenizer_name_or_path="gpt2", + ) + + def test_dataset_generation_minimal(self): + """Test generating small dataset.""" + builder = RulerDatasetBuilder( + samples=12, # 6 tasks x 2 lengths = need 12 for 1 per task per length + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should generate 12 samples (6 tasks x 1 sample per task x 2 lengths) + assert len(dataset) == 12 + assert all(isinstance(sample, dict) for sample in dataset) + + def test_dataset_structure(self): + """Test that dataset has correct structure.""" + builder = RulerDatasetBuilder( + samples=6, # Need at least 6 (1 per task) + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + sample = dataset[0] + + # Check required fields + assert "input" in sample + assert "length" in sample + assert "task" in sample + assert "target_length" in sample + + # Check field types + assert isinstance(sample["input"], str) + assert isinstance(sample["length"], int) + assert isinstance(sample["task"], str) + assert sample["length"] > 0 + + def test_sample_distribution(self): + """Test that samples are distributed across lengths and subtasks.""" + builder = RulerDatasetBuilder( + samples=24, # 6 tasks x 2 lengths x 2 samples = 24 + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should have 24 samples (12 per length, 2 per task) + assert len(dataset) == 24 + + # Check task distribution (should have variety from all RULER_TASKS) + tasks = [s["task"] for s in dataset] + # Verify we have all 6 tasks represented + assert len(set(tasks)) == 6 + + def test_length_targeting(self): + """Test that generated lengths are close to targets.""" + builder = RulerDatasetBuilder( + samples=6, # 1 per task + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Lengths should be within reasonable range of target + # RULER aims for 70-90% of target length for context + for sample in dataset: + assert 700 < sample["length"] < 1400 # Reasonable range around 1024 + + def test_uneven_sample_distribution(self): + """Test that samples are distributed evenly (remainder dropped).""" + builder = RulerDatasetBuilder( + samples=50, # 50 samples across 4 lengths + max_seqlen=8192, # Generates: [8192, 4096, 2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + # Even distribution: 50//4 = 12 per length + assert builder.total_samples == 50 + assert builder.target_lengths == [8192, 4096, 2048, 1024] + assert builder.samples_per_length == [12, 12, 12, 12] + assert sum(builder.samples_per_length) == 48 # 2 samples dropped (remainder) + + # Actual generated samples: 12//6=2 per task, 4 lengths, 6 tasks + # Total: 2 x 6 x 4 = 48 + dataset = builder.build_calibration_dataset() + assert len(dataset) == 48 + + +class TestDynamicThresholdCalibrator: + """Test calibration algorithm correctness.""" + + def test_calibrator_initialization(self): + """Test that calibrator initializes correctly.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + threshold_trials=[1e-4, 1e-3, 1e-2], + ) + + assert calibrator.target_sparse_ratio == 0.5 + assert len(calibrator.threshold_trials) == 3 + + def test_calibrator_default_threshold_trials(self): + """Test that calibrator has default threshold trials.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + ) + + # Should have default threshold trials + assert calibrator.threshold_trials is not None + assert len(calibrator.threshold_trials) == 12 + # Check they are positive and in valid range + trials = calibrator.threshold_trials + assert all(0 < t < 1 for t in trials) + + def test_regression_calculation_synthetic(self): + """Test 'a' parameter calculation with synthetic data.""" + # Create synthetic optimal pairs + # If threshold = a / length, then with perfect data: + # length=1000, threshold=10 => a=10000 + # length=2000, threshold=5 => a=10000 + optimal_pairs = [ + {"length": 1000, "optimal_threshold": 10.0, "achieved_sparsity": 0.5}, + {"length": 2000, "optimal_threshold": 5.0, "achieved_sparsity": 0.5}, + {"length": 4000, "optimal_threshold": 2.5, "achieved_sparsity": 0.5}, + ] + + # Manual regression calculation + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + # Calculate 'a' using least squares + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Should be close to 10000 + assert 9500 < a_parameter < 10500 + + # Test individual 'a' values + a_per_sample = y * lengths + assert np.allclose(a_per_sample, 10000, rtol=0.05) + + def test_multiple_samples_different_lengths(self): + """Test regression with varied lengths.""" + # More realistic scenario with some variance + optimal_pairs = [ + {"length": 500, "optimal_threshold": 20.0, "achieved_sparsity": 0.5}, + {"length": 1000, "optimal_threshold": 10.5, "achieved_sparsity": 0.51}, + {"length": 2000, "optimal_threshold": 5.2, "achieved_sparsity": 0.49}, + {"length": 4000, "optimal_threshold": 2.4, "achieved_sparsity": 0.50}, + ] + + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Should still be around 10000 with some tolerance for variance + assert 9000 < a_parameter < 11000 + + def test_r_squared_calculation(self): + """Test R-squared calculation for regression quality.""" + # Perfect fit data + optimal_pairs = [ + {"length": 1000, "optimal_threshold": 10.0}, + {"length": 2000, "optimal_threshold": 5.0}, + {"length": 4000, "optimal_threshold": 2.5}, + ] + + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Calculate R-squared + y_pred = a_parameter * x + ss_res = np.sum((y - y_pred) ** 2) + ss_tot = np.sum((y - np.mean(y)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + # Perfect fit should have R^2 close to 1 + assert r_squared > 0.99 + + +class TestCalibrationIntegration: + """Test end-to-end calibration without GPU.""" + + def test_calibration_disabled(self): + """Test that no calibration occurs when disabled.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attention*": { + "threshold": 1e-3, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + # No forward_loop needed when calibration disabled + sparse_model = sparsify(model, config) + + # Check that sparse attention is applied but not calibrated + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + # Check that no calibration is set + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + method = module._sparse_method_instance + assert not getattr(method, "threshold_scale_factor", None) + + def test_sparsify_with_calibration_requires_forward_loop(self): + """Test that calibration requires forward_loop or proper model config.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attention*": { + "threshold": 1e-3, + "br": 64, + "bc": 64, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + # Without forward_loop and without model.config._name_or_path, should raise ValueError + with pytest.raises(ValueError, match="Could not load tokenizer"): + sparsify(model, config, forward_loop=None) + + def test_multiple_sparse_modules(self): + """Test that calibration handles multiple attention layers.""" + + # Create model with multiple attention modules + class MultiAttentionModel(nn.Module): + def __init__(self): + super().__init__() + self.attn1 = nn.MultiheadAttention(64, 4, batch_first=True) + self.attn2 = nn.MultiheadAttention(64, 4, batch_first=True) + + def forward(self, x): + x, _ = self.attn1(x, x, x) + x, _ = self.attn2(x, x, x) + return x + + model = MultiAttentionModel() + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": {"*attn*": {"threshold": 1e-3, "br": 64, "bc": 64, "enable": True}}, + } + + sparse_model = sparsify(model, config) + + # Count sparse attention modules + sparse_count = sum( + 1 for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) + ) + + # Should have 2 sparse attention modules + assert sparse_count == 2 + + def test_calibration_config_validation(self): + """Test CalibrationConfig validation.""" + from modelopt.torch.sparsity.attention_sparsity.config import CalibrationConfig + + # Valid config + config = CalibrationConfig( + target_sparse_ratio=0.5, + samples=48, + max_seqlen=32768, + ) + assert config.target_sparse_ratio == 0.5 + assert config.samples == 48 + assert config.max_seqlen == 32768 + + # Invalid target_sparse_ratio (> 1.0) + with pytest.raises(ValueError, match="target_sparse_ratio must be between"): + CalibrationConfig(target_sparse_ratio=1.5, samples=48, max_seqlen=32768) + + # Invalid target_sparse_ratio (< 0.0) + with pytest.raises(ValueError, match="target_sparse_ratio must be between"): + CalibrationConfig(target_sparse_ratio=-0.1, samples=48, max_seqlen=32768) + + # Invalid samples + with pytest.raises(ValueError, match="samples must be positive"): + CalibrationConfig(target_sparse_ratio=0.5, samples=0, max_seqlen=32768) + + # Invalid max_seqlen + with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): + CalibrationConfig(target_sparse_ratio=0.5, samples=48, max_seqlen=512) + + def test_threshold_trials_validation(self): + """Test threshold_trials validation.""" + from modelopt.torch.sparsity.attention_sparsity.config import CalibrationConfig + + # Valid custom threshold_trials + config = CalibrationConfig( + target_sparse_ratio=0.5, + threshold_trials=[1e-5, 1e-4, 1e-3, 1e-2], + ) + assert config.threshold_trials == [1e-5, 1e-4, 1e-3, 1e-2] + + # None (use defaults) + config_default = CalibrationConfig(target_sparse_ratio=0.5) + assert config_default.threshold_trials is None + + # Invalid: empty list + with pytest.raises(ValueError, match="threshold_trials must not be empty"): + CalibrationConfig(threshold_trials=[]) + + # Invalid: threshold out of range (>= 1.0) + with pytest.raises(ValueError, match="must be in range"): + CalibrationConfig(threshold_trials=[1e-4, 1.0]) + + # Invalid: threshold out of range (<= 0) + with pytest.raises(ValueError, match="must be in range"): + CalibrationConfig(threshold_trials=[1e-4, 0]) + + # Invalid: not a list (Pydantic raises ValidationError, not ValueError) + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="Input should be a valid list"): + CalibrationConfig(threshold_trials=1e-4) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py new file mode 100644 index 000000000..c135d7a48 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test sparse attention configuration validation.""" + +import pytest +from pydantic import ValidationError + +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_DEFAULT, + FlashSoftmaxSkipConfig, + SparseAttentionAttributeConfig, + SparseAttentionConfig, +) + + +class TestSparseAttentionAttributeConfig: + """Test SparseAttentionAttributeConfig validators.""" + + def test_valid_config(self): + """Test creating valid config.""" + config = SparseAttentionAttributeConfig( + method="flash_softmax_skip", + threshold=1e-4, + br=128, + bc=128, + enable=True, + ) + assert config.method == "flash_softmax_skip" + assert config.threshold == 1e-4 + assert config.br == 128 + assert config.bc == 128 + + def test_method_validation(self): + """Test method must be string.""" + with pytest.raises(ValidationError, match="Input should be a valid string"): + SparseAttentionAttributeConfig(method=123) + + def test_block_size_validation_negative(self): + """Test block sizes must be positive.""" + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(br=-1) + + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(bc=0) + + def test_block_size_validation_large(self): + """Test that large block sizes are accepted.""" + # Large block sizes are allowed (warning removed for simplicity) + config = SparseAttentionAttributeConfig(br=2048) + assert config.br == 2048 + + def test_threshold_validation_range(self): + """Test threshold must be in range (0, 1).""" + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=-0.1) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.5) + + def test_threshold_validation_dict(self): + """Test threshold dict validation.""" + # Valid phase-aware threshold + config = SparseAttentionAttributeConfig(threshold={"prefill": 1e-3, "decode": 1e-5}) + assert config.threshold == {"prefill": 1e-3, "decode": 1e-5} + + # Invalid phase key + with pytest.raises(ValidationError, match="Invalid threshold phases"): + SparseAttentionAttributeConfig(threshold={"invalid_phase": 1e-3}) + + # Invalid threshold value in dict (negative) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": -1e-3}) + + # Invalid threshold value in dict (>= 1.0) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.0}) + + def test_threshold_validation_type(self): + """Test threshold type validation.""" + with pytest.raises(ValidationError, match="Input should be a valid"): + SparseAttentionAttributeConfig(threshold="invalid") + + +class TestSparseAttentionConfig: + """Test SparseAttentionConfig.""" + + def test_default_config(self): + """Test default configuration.""" + config = SparseAttentionConfig() + assert config.method == "flash_softmax_skip" + assert config.collect_stats is False + + def test_predefined_config(self): + """Test pre-defined configuration.""" + assert "method" in SKIP_SOFTMAX_DEFAULT + assert "sparse_cfg" in SKIP_SOFTMAX_DEFAULT + assert "*attn*" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"] + + +class TestFlashSoftmaxSkipConfig: + """Test FlashSoftmaxSkipConfig.""" + + def test_default_values(self): + """Test default values for flash_softmax_skip config.""" + config = FlashSoftmaxSkipConfig() + assert config.method == "flash_softmax_skip" + assert "*attention*" in config.sparse_cfg diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py new file mode 100644 index 000000000..f80761887 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sparse attention conversion and replacement.""" + +import torch.nn as nn +from _test_utils.torch_sparsity.sparse_attention_common import ( + FLASH_SOFTMAX_SKIP_DEFAULT_CFG, + SimpleAttentionModel, + SimpleTransformerEncoderLayer, +) + +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity.conversion import ( + disable_sparse_attention, + enable_sparse_attention, +) +from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule + + +class TestSparseAttentionReplacement: + """Test module replacement logic.""" + + def test_basic_replacement(self): + """Test that attention modules are replaced with sparse versions.""" + model = SimpleAttentionModel() + + # Count original attention modules + original_attention_count = sum( + isinstance(m, nn.MultiheadAttention) for m in model.modules() + ) + assert original_attention_count > 0 + + # Apply sparse attention + sparse_model = sparse_attn.sparsify(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG) + + # Count sparse attention modules + sparse_attention_count = sum( + isinstance(m, SparseAttentionModule) for m in sparse_model.modules() + ) + + # Verify replacement occurred + assert sparse_attention_count > 0 + + def test_enable_disable_toggle(self): + """Test enabling and disabling sparse attention.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SOFTMAX_SKIP_DEFAULT_CFG) + + # Check initially enabled + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + # Disable all sparse attention modules + disable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert not module.is_enabled + + # Re-enable all sparse attention modules + enable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + def test_pattern_based_replacement(self): + """Test pattern-based selective replacement.""" + model = SimpleTransformerEncoderLayer() + + # Apply with pattern + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*self_attn*": {"threshold": 1e-4, "br": 128, "bc": 128, "enable": True}, + "default": {"enable": False}, + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Verify sparse modules exist + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py new file mode 100644 index 000000000..7644bff4a --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sparse attention mode registry.""" + +from modelopt.torch.opt.mode import _ModeRegistryCls +from modelopt.torch.sparsity.attention_sparsity.mode import SparseAttentionModeRegistry + + +def test_sparse_attention_mode_exists(): + """Test that sparse_attention mode is registered.""" + assert "sparse_attention" in SparseAttentionModeRegistry + + +def test_sparse_attention_mode_descriptor(): + """Test sparse attention mode descriptor properties.""" + mode_descriptor = _ModeRegistryCls.get_from_any("sparse_attention") + + assert mode_descriptor is not None + assert hasattr(mode_descriptor, "config_class") + assert hasattr(mode_descriptor, "convert") + + +def test_mode_registry_get(): + """Test getting mode from registry.""" + mode = SparseAttentionModeRegistry["sparse_attention"] + assert mode is not None diff --git a/tests/unit/torch/sparsity/test_sparsify.py b/tests/unit/torch/sparsity/weight_sparsity/test_sparsify.py similarity index 100% rename from tests/unit/torch/sparsity/test_sparsify.py rename to tests/unit/torch/sparsity/weight_sparsity/test_sparsify.py