diff --git a/.vscode/settings.json b/.vscode/settings.json index 0a3a2353e..50391ea82 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -45,4 +45,7 @@ ], "git.alwaysSignOff": true, "git.enableCommitSigning": true, + "cursorpyright.analysis.extraPaths": [ + "./tests/" + ], } diff --git a/examples/llm_eval/lm_eval_hf.py b/examples/llm_eval/lm_eval_hf.py index 31103ff86..24dcb28f6 100755 --- a/examples/llm_eval/lm_eval_hf.py +++ b/examples/llm_eval/lm_eval_hf.py @@ -43,9 +43,11 @@ from lm_eval.api.model import T from lm_eval.models.huggingface import HFLM from quantization_utils import quantize_model +from sparse_attention_utils import sparsify_model import modelopt.torch.opt as mto from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T: @@ -60,9 +62,20 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | calib_size = arg_dict.pop("calib_size", 512) compress = arg_dict.pop("compress", False) + # Sparse attention arguments + sparse_cfg = arg_dict.pop("sparse_cfg", None) + additional_config = {} if additional_config is None else additional_config additional_config = {k: v for k, v in additional_config.items() if v is not None} + # Force eager attention if sparse attention is requested + if sparse_cfg: + additional_config["attn_implementation"] = "eager" + warnings.warn( + "Sparse attention requires attn_implementation='eager'. " + "Forcing eager attention implementation." + ) + # Enable automatic save/load of modelopt state huggingface checkpointing mto.enable_huggingface_checkpointing() @@ -91,6 +104,15 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | auto_quantize_checkpoint=auto_quantize_checkpoint, ) + if sparse_cfg: + if is_attn_sparsified(model_obj.model): + warnings.warn("Skipping sparse attention: model already has sparse attention applied.") + else: + sparsify_model( + model=model_obj, + sparse_cfg=sparse_cfg, + ) + return model_obj @@ -152,6 +174,11 @@ def setup_parser_with_modelopt_args(): action="store_true", help="Compress the model after quantization", ) + parser.add_argument( + "--sparse_cfg", + type=str, + help="Sparse attention configuration (e.g., SKIP_SOFTMAX_DEFAULT, SKIP_SOFTMAX_CALIB)", + ) return parser @@ -177,6 +204,7 @@ def setup_parser_with_modelopt_args(): "calib_batch_size": args.calib_batch_size, "calib_size": args.calib_size, "compress": args.compress, + "sparse_cfg": args.sparse_cfg, } ) diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index ca244052b..0bf47fcd3 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -48,6 +48,7 @@ from fire import Fire from modeling import EvalModel, select_model from quantization_utils import MAX_SEQ_LEN, get_tokenizer, quantize_model +from sparse_attention_utils import sparsify_model from tqdm import tqdm try: @@ -56,6 +57,7 @@ LLM = None # type: ignore[misc] import modelopt.torch.opt as mto from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -230,6 +232,7 @@ def main( auto_quantize_method: str = "gradient", auto_quantize_score_size: int = 128, auto_quantize_checkpoint: str | None = None, + sparse_cfg: str | None = None, **kwargs, ): random.seed(RAND_SEED) @@ -266,6 +269,14 @@ def main( max_batch_size=1, ) else: + # Force eager attention if sparse attention is requested + if sparse_cfg: + kwargs["attn_implementation"] = "eager" + warnings.warn( + "Sparse attention requires attn_implementation='eager'. " + "Forcing eager attention implementation." + ) + model = select_model( max_input_length=MAX_SEQ_LEN, max_output_length=2, dtype=dtype, **kwargs ) @@ -289,6 +300,20 @@ def main( auto_quantize_checkpoint=auto_quantize_checkpoint, ) + # Apply sparse attention if requested + if sparse_cfg: + model.load() + + if is_attn_sparsified(model.model): + warnings.warn( + "Skipping sparse attention: model already has sparse attention applied." + ) + else: + sparsify_model( + model=model, + sparse_cfg=sparse_cfg, + ) + for subject in tqdm(subjects): dev_df = pd.read_csv(os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None)[ :ntrain diff --git a/examples/llm_eval/modeling.py b/examples/llm_eval/modeling.py index 747b95d5b..d06d05560 100644 --- a/examples/llm_eval/modeling.py +++ b/examples/llm_eval/modeling.py @@ -179,6 +179,7 @@ class SeqToSeqModel(EvalModel): lora_path: str = "" device: str = "cuda" load_8bit: bool = False + attn_implementation: str | None = None def load(self): if self.model is None: @@ -188,6 +189,8 @@ def load(self): if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto") + if self.attn_implementation: + args["attn_implementation"] = self.attn_implementation self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path, **args) print_gpu_utilization() if self.lora_path: @@ -241,6 +244,8 @@ def load(self): if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto") + if self.attn_implementation: + args["attn_implementation"] = self.attn_implementation self.model = AutoModelForCausalLM.from_pretrained( self.model_path, trust_remote_code=True, **args ) diff --git a/examples/llm_eval/sparse_attention_utils.py b/examples/llm_eval/sparse_attention_utils.py new file mode 100644 index 000000000..dc7a1b14e --- /dev/null +++ b/examples/llm_eval/sparse_attention_utils.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for sparse attention integration with llm_eval.""" + +import modelopt.torch.sparsity.attention_sparsity as mtsa + +# Custom sparse attention configurations +CUSTOM_SPARSE_CONFIG = { + "SPARSE_CONSERVATIVE": { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 5e-4, "decode": 1e-5}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + }, + "default": {"enable": False}, + }, + }, + "SPARSE_AGGRESSIVE": { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 5e-3, "decode": 5e-4}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + }, + "default": {"enable": False}, + }, + }, +} + + +def _extract_model(model_obj): + """Extract actual model from wrapper (HFLM or EvalModel).""" + if hasattr(model_obj, "gpt2"): + return model_obj.gpt2 + elif hasattr(model_obj, "model"): + return model_obj.model + else: + return model_obj + + +def sparsify_model( + model, + sparse_cfg: str, + backend=None, +): + """Apply sparse attention to model with optional RULER calibration. + + Args: + model: Model wrapper (HFLM or EvalModel) or raw model + sparse_cfg: Sparse attention config name or dict + backend: Backend to use (optional, overrides config backend) + + Returns: + The model with sparse attention applied + + Note: + Calibration is automatically triggered if the config contains a 'calibration' field. + The calibration will auto-generate RULER dataset from the model's tokenizer. + """ + # Extract actual model + net = _extract_model(model) + + # Resolve config + if isinstance(sparse_cfg, str): + # Try custom configs first + mtsa_cfg = CUSTOM_SPARSE_CONFIG.get(sparse_cfg) + if mtsa_cfg is None: + # Try predefined configs + mtsa_cfg = getattr(mtsa, sparse_cfg, None) + if mtsa_cfg is None: + raise ValueError(f"Unknown sparse_cfg: {sparse_cfg}") + else: + mtsa_cfg = sparse_cfg + + # Override backend if specified + if backend: + if isinstance(mtsa_cfg, dict) and "sparse_cfg" in mtsa_cfg: + modified_sparse_cfg = {} + for pattern, cfg in mtsa_cfg["sparse_cfg"].items(): + modified_cfg = cfg.copy() if isinstance(cfg, dict) else cfg + if isinstance(modified_cfg, dict): + modified_cfg["backend"] = backend + modified_sparse_cfg[pattern] = modified_cfg + mtsa_cfg = {"sparse_cfg": modified_sparse_cfg} + + # Apply sparsification + print(f"\nApplying sparse attention with config: {sparse_cfg}") + mtsa.sparsify(net, mtsa_cfg) + print("Sparse attention applied successfully!") + + return model diff --git a/examples/llm_sparsity/attention_sparsity/README.md b/examples/llm_sparsity/attention_sparsity/README.md new file mode 100644 index 000000000..708947683 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/README.md @@ -0,0 +1,161 @@ +# Attention Sparsity for HuggingFace Models + +In this tutorial, we demonstrate how to use NVIDIA TensorRT Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. + +## Getting Started + +### Quick Example + +```python +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT + +# Load your model +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-8B", + attn_implementation="eager", # Required for sparse attention + torch_dtype=torch.bfloat16, +) + +# Apply sparse attention +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) +``` + +> [!Note] +> `attn_implementation="eager"` is required for sparse attention to work properly. Flash Attention 2 or SDPA would bypass the softmax patching needed for stats collection. + +## Configuration Options + +Two pre-defined configurations are available: + +### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT) + +Uses a fixed threshold value. Simple but may not be optimal for all sequence lengths. + +```python +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT + +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) +``` + +### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB) + +Uses RULER-based calibration to determine an optimal dynamic threshold that adapts to sequence length. Recommended for production use. + +```python +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_CALIB + +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB) +``` + +## Prerequisites + +### Install Requirements + +```bash +pip install -r requirements.txt +``` + +### Download RULER Calibration Data (Required for Calibration) + +If using `SKIP_SOFTMAX_CALIB`, you need to download the RULER calibration dataset first: + +```bash +bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh +``` + +This downloads the Paul Graham essays dataset used for generating calibration samples. + +## Run Sparse Attention on HuggingFace Models + +### Basic Usage (Without Calibration) + +Apply sparse attention with a fixed threshold: + +```bash +python examples/llm_sparsity/attention_sparsity/hf_sa.py \ + --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax +``` + +### With RULER Calibration + +Apply sparse attention with calibrated thresholds for optimal sparsity: + +```bash +python examples/llm_sparsity/attention_sparsity/hf_sa.py \ + --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax_calib +``` + +The calibration process: + +1. Generates RULER calibration samples +2. Collects attention statistics during forward passes +3. Determines optimal threshold scale factor for target sparsity ratio + +### Command Line Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--pyt_ckpt_path` | Required | HuggingFace model path or name | +| `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax` or `skip_softmax_calib` | +| `--backend` | `pytorch` | Backend: `pytorch` or `triton` | +| `--seq_len` | `2048` | Maximum sequence length for input prompts | +| `--export_dir` | `None` | Directory to export the sparsified model | + +## Output Comparison + +The script automatically compares outputs before and after applying sparse attention: + +1. Loads a test sample from the NarrativeQA dataset +2. Generates text before sparse attention is applied +3. Applies sparse attention (with optional calibration) +4. Generates text after sparse attention is applied +5. Compares and displays both outputs + +## Export Model + +Export the sparsified model to a HuggingFace checkpoint: + +```bash +python examples/llm_sparsity/attention_sparsity/hf_sa.py \ + --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax_calib \ + --export_dir ./exported_sparse_model +``` + +The exported model can be loaded and used with standard HuggingFace APIs. + +## Custom Configuration + +You can create custom sparse attention configurations: + +```python +custom_config = { + "sparse_cfg": { + "calibration": { # Optional: omit for fixed threshold + "target_sparse_ratio": 0.5, # Target 50% sparsity + "samples": 128, # Number of calibration samples + "max_seqlen": 8192, # Maximum sequence length + }, + "*attn*": { # Pattern to match attention modules + "method": "flash_skip_softmax", + "threshold": 1e-4, # Fixed threshold (ignored if calibration is used) + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + +model = mtsa.sparsify(model, config=custom_config) +``` + +## References + +- [TensorRT Model Optimizer Documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/) +- [RULER: What's the Real Context Size of Your Long-Context Language Models?](https://github.com/NVIDIA/RULER) diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 11564a4ec..29a2b53aa 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -28,9 +28,10 @@ import modelopt.torch.opt as mto import modelopt.torch.sparsity.attention_sparsity as mtsa from modelopt.torch.export import export_hf_checkpoint -from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig -from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT -from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_CALIB, + SKIP_SOFTMAX_DEFAULT, +) from modelopt.torch.utils.memory_monitor import launch_memory_monitor RAND_SEED = 1234 @@ -38,9 +39,10 @@ # Enable HuggingFace checkpointing support mto.enable_huggingface_checkpointing() -# You can define custom configurations or use the default +# Sparse attention configuration choices SPARSE_ATTN_CFG_CHOICES = { "skip_softmax": SKIP_SOFTMAX_DEFAULT, + "skip_softmax_calib": SKIP_SOFTMAX_CALIB, } @@ -116,30 +118,23 @@ def truncate_text(text: str, tokenizer, max_length: int): return begin_text + " [...] " + end_text -def verify_outputs(model, tokenizer, args): - """Compare outputs between baseline and sparse attention models.""" - # Update seq_len to match calibration max_seqlen if calibration was used - base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {}) - if "calibration" in base_config and "max_seqlen" in base_config["calibration"]: - calib_max_seqlen = base_config["calibration"]["max_seqlen"] - if args.seq_len != calib_max_seqlen: - print( - f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} " - f"to match calibration config" - ) - args.seq_len = calib_max_seqlen +def generate_sample_output(model, tokenizer, args): + """Generate sample output for comparison. + + Args: + model: The model to generate with + tokenizer: Tokenizer for encoding/decoding + args: Command line arguments - # Load and prepare a single test prompt - print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)") + Returns: + Tuple of (generated_text, input_prompt, input_ids) + """ + # Load test sample prompts = get_narrativeqa_samples(num_samples=1) prompt = prompts[0] # Prepare inputs truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len) - display_prompt = ( - truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt - ) - inputs = tokenizer( truncated_prompt, return_tensors="pt", @@ -150,14 +145,7 @@ def verify_outputs(model, tokenizer, args): if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} - print("\n" + "=" * 60) - print("BASELINE vs SPARSE ATTENTION COMPARISON") - print("=" * 60) - print(f"\nTest prompt: {display_prompt}") - print(f"Input tokens: {inputs['input_ids'].shape[1]}") - - # Helper function to generate text - def generate_text(model, inputs, args, tokenizer): + # Generate with torch.no_grad(): outputs = model.generate( **inputs, @@ -168,60 +156,9 @@ def generate_text(model, inputs, args, tokenizer): ) input_length = inputs["input_ids"].shape[1] generated_ids = outputs[0][input_length:] - return tokenizer.decode(generated_ids, skip_special_tokens=True) - - # Find all sparse attention modules - sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] - - # Generate baseline by temporarily disabling sparse attention - print("\n" + "-" * 60) - print("Generating baseline (sparse attention disabled)...") - for module in sparse_modules: - module.disable() - baseline_text = generate_text(model, inputs, args, tokenizer) + generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) - # Generate with sparse attention enabled - print("\nGenerating with sparse attention (calibrated thresholds)...") - for module in sparse_modules: - module.enable() - sparse_text = generate_text(model, inputs, args, tokenizer) - - # Display comparison - print("\n" + "-" * 60) - print("RESULTS:") - baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text - sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text - - print(f"\nBaseline: {baseline_display}") - print(f"With Sparse: {sparse_display}") - - if baseline_text == sparse_text: - print("\nOutputs are identical") - else: - print("\nOutputs differ") - - -def sparsify_model(model, args): - """Apply sparse attention to the model with optional calibration.""" - print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}") - base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] - - # Create modified config with selected backend - modified_sparse_cfg = {} - for pattern, cfg in base_config["sparse_cfg"].items(): - modified_cfg = cfg.copy() - modified_cfg["backend"] = args.backend - modified_sparse_cfg[pattern] = modified_cfg - - # Create new config with modified settings - sparse_config = SparseAttentionConfig(sparse_cfg=modified_sparse_cfg) - - # Sparsify the model - model = mtsa.sparsify(model, config=sparse_config) - - print("Sparse attention applied successfully!") - - return model + return generated_text, truncated_prompt, inputs["input_ids"] def main(args): @@ -254,12 +191,40 @@ def main(args): model = model.cuda() print("Model moved to CUDA") - # Apply sparse attention to the model (with calibration if configured) - model = sparsify_model(model, args) + # Generate sample output BEFORE sparse attention + print("\nGenerating sample output before sparse attention...") + output_before, test_prompt, input_ids = generate_sample_output(model, tokenizer, args) + + # Apply sparse attention with optional calibration + print(f"\nApplying sparse attention: {args.sparse_attn}") + sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] + model = mtsa.sparsify(model, config=sparse_config) + print("Sparse attention applied successfully!") + + # Generate sample output AFTER sparse attention + print("\nGenerating sample output after sparse attention...") + output_after, _, _ = generate_sample_output(model, tokenizer, args) + + # Display comparison + print("\n" + "=" * 60) + print("OUTPUT COMPARISON (Before vs After Sparse Attention)") + print("=" * 60) + display_prompt = test_prompt[:150] + "..." if len(test_prompt) > 150 else test_prompt + print(f"\nTest prompt: {display_prompt}") + print(f"Input tokens: {input_ids.shape[1]}") + + output_before_display = ( + output_before[:300] + "..." if len(output_before) > 300 else output_before + ) + output_after_display = output_after[:300] + "..." if len(output_after) > 300 else output_after + + print(f"\nBefore sparse attention: {output_before_display}") + print(f"After sparse attention: {output_after_display}") - # Verify outputs if requested (compares baseline vs calibrated sparse model) - if args.verify_output: - verify_outputs(model, tokenizer, args) + if output_before == output_after: + print("\nOutputs are identical") + else: + print("\nOutputs differ") # Export if requested if args.export_dir: @@ -306,12 +271,6 @@ def main(args): default=2048, help="Maximum sequence length for input prompts (will be truncated if longer)", ) - parser.add_argument( - "--num_samples", - type=int, - default=3, - help="Number of samples to use from NarrativeQA dataset", - ) # Generation arguments parser.add_argument( @@ -321,11 +280,6 @@ def main(args): parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling") # Operation arguments - parser.add_argument( - "--verify_output", - action="store_true", - help="Verify that sparse attention outputs match baseline", - ) parser.add_argument( "--export_dir", type=str, diff --git a/examples/llm_sparsity/attention_sparsity/requirements.txt b/examples/llm_sparsity/attention_sparsity/requirements.txt new file mode 100644 index 000000000..a3e0dfa17 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/requirements.txt @@ -0,0 +1,2 @@ +nltk +wonderwords diff --git a/examples/llm_sparsity/weight_sparsity/finetune.py b/examples/llm_sparsity/weight_sparsity/finetune.py index d13b43fde..8e92a5cac 100644 --- a/examples/llm_sparsity/weight_sparsity/finetune.py +++ b/examples/llm_sparsity/weight_sparsity/finetune.py @@ -1,5 +1,6 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py + +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py - -# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py new file mode 100644 index 000000000..3b616e8e3 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +from .calibrate import calibrate_sparse_attention +from .calibrator import DynamicThresholdCalibrator +from .dataset import RulerDatasetBuilder + +__all__ = [ + "DynamicThresholdCalibrator", + "RulerDatasetBuilder", + "calibrate_sparse_attention", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py new file mode 100644 index 000000000..1b8f0e71b --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration functions for sparse attention.""" + +import warnings +from collections.abc import Callable +from typing import Any + +import torch +import torch.nn as nn +from transformers import AutoTokenizer + +from ..config import CalibrationConfig +from ..conversion import print_sparse_attention_summary +from ..sparse_attention import SparseAttentionModule +from .calibrator import DynamicThresholdCalibrator +from .dataset import RulerDatasetBuilder + + +def _extract_tokenizer_from_model(model: nn.Module) -> str: + """Extract tokenizer name/path from model config. + + Args: + model: Model to extract tokenizer from + + Returns: + Tokenizer name or path + + Raises: + ValueError: If tokenizer path cannot be determined from model + """ + # Extract tokenizer path from model config + tokenizer_path = getattr(getattr(model, "config", None), "_name_or_path", None) + + if not tokenizer_path: + raise ValueError("Could not load tokenizer from model.") + + return tokenizer_path + + +def _extract_calibration_config(config: dict[str, Any]) -> CalibrationConfig | None: + """Extract and validate calibration config from sparse_cfg. + + Args: + config: Sparse attention configuration dict + + Returns: + Validated CalibrationConfig instance, or None if calibration is not configured + + Raises: + ValueError: If calibration config has invalid type or contains invalid values + """ + sparse_cfg = config.get("sparse_cfg", {}) + + # Calibration is optional + if "calibration" not in sparse_cfg: + return None + + calib_dict = sparse_cfg["calibration"] + + # Validate calibration is a dict + if not isinstance(calib_dict, dict): + raise ValueError(f"Calibration config must be a dict, got {type(calib_dict).__name__}. ") + + # Create and validate CalibrationConfig + return CalibrationConfig(**calib_dict) + + +def create_calibration_forward_loop( + calibration_data: list[dict[str, Any]], + tokenizer_name_or_path: str, + batch_size: int = 1, + chunk_size: int = 2048, +) -> Callable: + """Create forward loop for calibration. + + Args: + calibration_data: List of samples with 'input' and 'length' fields + tokenizer_name_or_path: HuggingFace tokenizer path + batch_size: Batch size (currently unused, always 1) + chunk_size: Chunk size for chunked prefill to avoid OOM. Set to -1 to disable. + + Returns: + Forward loop function that takes model as argument + """ + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + def forward_loop(model: nn.Module) -> None: + device = next(model.parameters()).device + + for sample in calibration_data: + inputs = tokenizer( + sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"] + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + input_ids = inputs["input_ids"].to(device) + seq_len = input_ids.shape[1] + + with torch.no_grad(): + if chunk_size > 0 and seq_len > chunk_size: + # Chunked prefill to avoid OOM with long sequences + past_key_values = None + for start_idx in range(0, seq_len, chunk_size): + end_idx = min(start_idx + chunk_size, seq_len) + chunk_input_ids = input_ids[:, start_idx:end_idx] + + outputs = model( + chunk_input_ids, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # Clean up KV cache + del past_key_values + torch.cuda.empty_cache() + else: + # Full prefill without chunking + model(input_ids, use_cache=False) + + return forward_loop + + +def calibrate_sparse_attention( + model: nn.Module, + config: dict[str, Any], + forward_loop: Callable | None = None, +) -> dict[str, Any]: + """Calibrate sparse attention parameters for optimal sparsity. + + Args: + model: Model with sparse attention modules + config: Sparse attention configuration dict + forward_loop: Callable that forwards calibration data through model. + If None, auto-generates RULER dataset. + + Returns: + Dictionary with calibration results + """ + # Extract and validate calibration config + calib_config = _extract_calibration_config(config) + + # Skip calibration if not configured + if calib_config is None: + return {} + + # Generate forward_loop if not provided + if not forward_loop: + tokenizer = _extract_tokenizer_from_model(model) + builder = RulerDatasetBuilder( + samples=calib_config.samples, + max_seqlen=calib_config.max_seqlen, + tokenizer_name_or_path=tokenizer, + num_length_bins=calib_config.num_length_bins, + max_length_filter=int(calib_config.max_seqlen * 1.5), + ) + calibration_data = builder.build_calibration_dataset() + print(f"Generated {len(calibration_data)} calibration samples") + forward_loop = create_calibration_forward_loop( + calibration_data, tokenizer, chunk_size=calib_config.chunk_size + ) + + # Get sparse attention modules + sparse_modules = [ + (name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule) + ] + + if not sparse_modules: + print("No sparse attention modules found for calibration") + return {} + + print(f"Calibrating {len(sparse_modules)} sparse attention modules together...") + + # Run calibration + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=calib_config.target_sparse_ratio, + threshold_trials=calib_config.threshold_trials, + ) + calibration_result = calibrator.calibrate(model, forward_loop) + + # Print calibration statistics (regardless of success/failure for debugging) + print("\nCalibration complete!") + print_sparse_attention_summary(model) + + if "scale_factor" not in calibration_result: + warnings.warn("Calibration did not produce valid results") + return {} + + # Apply calibrated scale factor to all modules + scale_factor = calibration_result["scale_factor"] + print(f"\nApplying calibrated scale factor={scale_factor:.6f} to {len(sparse_modules)} modules") + + for module_name, module in sparse_modules: + module._sparse_method_instance.threshold_scale_factor = scale_factor + + return {"calibration_results": {name: calibration_result for name, _ in sparse_modules}} diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py new file mode 100644 index 000000000..39abdbca8 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -0,0 +1,312 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +import warnings +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm + +from ..sparse_attention import SparseAttentionModule +from ..stats_manager import SparseAttentionStatsManager + + +class DynamicThresholdCalibrator: + """Dynamic threshold calibrator using length-based linear relationship. + + Implements calibration algorithm: + 1. Find hyperparameter 'a' where threshold λ = a / context_length + 2. Use dataset with different lengths and test multiple thresholds + 3. For each sample, find optimal threshold closest to target sparsity + 4. Use linear regression to fit: threshold = a * (1/length) + """ + + @dataclass + class SampleSparsity: + """Sparsity results for a single calibration sample.""" + + length: int + threshold_sparsities: dict[float, float] + + def __init__( + self, + target_sparse_ratio: float = 0.5, + threshold_trials: list[float] | None = None, + ): + """Initialize dynamic threshold calibrator. + + Args: + target_sparse_ratio: Target sparsity ratio (0.0 to 1.0) + threshold_trials: List of thresholds to try during calibration + + Note: + Calibration only supports prefill phase (seq_len > 1). + Decode phase uses the same calibrated threshold. + """ + self.target_sparse_ratio = target_sparse_ratio + + # Default threshold trials if not provided + self.threshold_trials = threshold_trials or [ + 1e-6, + 5e-6, + 1e-5, + 5e-5, + 1e-4, + 5e-4, + 1e-3, + 5e-3, + 1e-2, + 5e-2, + 1e-1, + 5e-1, + ] + + # Statistics tracking + self.sparsity_results = [] + + def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]: + """Find optimal 'a' parameter for length-based threshold. + + Algorithm: + 1. Test all threshold trials by running forward_loop multiple times + 2. For each sample, find optimal threshold closest to target sparsity + 3. Use regression to find 'a' in: threshold = a / length + + Args: + model: The model with sparse attention modules + forward_loop: Callable that takes model and forwards calibration data + """ + # Extract attention modules + attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] + + if not attention_modules: + raise ValueError("No sparse attention modules found for calibration") + + print("Starting dynamic threshold calibration") + print(f"Target sparsity: {self.target_sparse_ratio}") + print(f"Threshold trials: {len(self.threshold_trials)}") + + # Stage 1: Collect sparsity for all sample-threshold pairs + print("\nStage 1: Collecting sparsity data...") + + # Run first threshold to discover samples and initialize results + self._set_threshold(attention_modules, self.threshold_trials[0]) + self._enable_calibration_mode(attention_modules) + with torch.no_grad(): + forward_loop(model) + per_sample_stats = self._extract_calibration_stats(attention_modules) + self._disable_calibration_mode(attention_modules) + + # Initialize sparsity_results with sample info + self.sparsity_results = [ + self.SampleSparsity( + length=stat["sample_length"], + threshold_sparsities={self.threshold_trials[0]: stat["sparsity"]}, + ) + for stat in per_sample_stats + ] + + # Collect remaining thresholds + for threshold in tqdm(self.threshold_trials[1:], desc="Testing thresholds"): + self._set_threshold(attention_modules, threshold) + self._enable_calibration_mode(attention_modules) + with torch.no_grad(): + forward_loop(model) + per_sample_stats = self._extract_calibration_stats(attention_modules) + self._disable_calibration_mode(attention_modules) + + for sample_idx, sample_stat in enumerate(per_sample_stats): + self.sparsity_results[sample_idx].threshold_sparsities[threshold] = sample_stat[ + "sparsity" + ] + + if not self.sparsity_results: + warnings.warn("No valid sparsity measurements collected during calibration") + return {} + + print(f"Collected statistics for {len(self.sparsity_results)} samples") + + # Stage 2: Find optimal threshold for each sample and compute 'a' + print( + f"\nStage 2: Finding 'a' parameter for target sparsity {self.target_sparse_ratio:.2f}" + ) + + # Find optimal threshold for each sample + optimal_pairs = [] + for sample_result in self.sparsity_results: + # Find threshold closest to target sparsity + best_threshold, achieved_sparsity = min( + sample_result.threshold_sparsities.items(), + key=lambda item: abs(item[1] - self.target_sparse_ratio), + ) + + optimal_pairs.append( + { + "length": sample_result.length, + "optimal_threshold": best_threshold, + "achieved_sparsity": achieved_sparsity, + "target_sparsity": self.target_sparse_ratio, + } + ) + + if not optimal_pairs: + warnings.warn( + f"No optimal threshold pairs found for target sparsity {self.target_sparse_ratio}. " + f"Collected {len(self.sparsity_results)} samples but none achieved target sparsity." + ) + return {} + + # Linear regression: threshold = a * (1/length) + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + # X = 1/length, Y = threshold + x = 1.0 / lengths + y = thresholds + + # Least squares: scale_factor = sum(x*y) / sum(x^2) + scale_factor = np.sum(x * y) / np.sum(x**2) + + # Calculate statistics + scale_factors_per_sample = y * lengths + scale_factor_std = np.std(scale_factors_per_sample) + + # Calculate R-squared for quality metric + y_pred = scale_factor * x + ss_res = np.sum((y - y_pred) ** 2) + ss_tot = np.sum((y - np.mean(y)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + # Calculate average achieved sparsity + avg_achieved_sparsity = np.mean([p["achieved_sparsity"] for p in optimal_pairs]) + + print("\nCalibration Results:") + print(f" Threshold scale factor: {scale_factor:.6f} (std: {scale_factor_std:.6f})") + print(f" R-squared: {r_squared:.4f}") + print( + f" Average achieved sparsity: {avg_achieved_sparsity:.2%} (target: {self.target_sparse_ratio:.2%})" + ) + print(f"\nExample thresholds with λ = {scale_factor:.6f} / length:") + for length in [1024, 2048, 4096, 8192, 16384]: + print(f" Length {length:5d}: threshold = {scale_factor / length:.2e}") + + # Apply the calibrated scale factor to modules + self._apply_length_based_calibration(attention_modules, scale_factor) + + return { + "scale_factor": scale_factor, + "scale_factor_std": scale_factor_std, + "r_squared": r_squared, + "num_samples": len(optimal_pairs), + "target_sparsity": self.target_sparse_ratio, + "avg_achieved_sparsity": avg_achieved_sparsity, + "optimal_pairs": optimal_pairs, + "calibration_type": "length_based_dynamic", + } + + def _apply_length_based_calibration(self, modules: list[nn.Module], scale_factor: float): + """Apply calibrated threshold scale factor to modules. + + Args: + modules: List of attention modules + scale_factor: Calibrated scale factor for λ = scale_factor / length + """ + for module in modules: + module._sparse_method_instance.threshold_scale_factor = scale_factor + + def _enable_calibration_mode(self, modules: list[nn.Module]): + """Enable calibration mode on sparse attention modules.""" + for idx, module in enumerate(modules): + # Create stats manager if needed + if not module._stats_manager: + module._stats_manager = SparseAttentionStatsManager( + module_name=f"sparse_attn_{idx}", enabled=True + ) + else: + # Re-enable if disabled + module._stats_manager.enabled = True + + # Enable calibration mode with fresh stats + module._stats_manager.set_calibration_mode(enabled=True, reset_history=True) + module._sparse_method_instance.set_calibration_mode(True) + + def _disable_calibration_mode(self, modules: list[nn.Module]): + """Disable calibration mode (but keep stats enabled if collect_stats=True).""" + for module in modules: + if module._stats_manager: + module._stats_manager.set_calibration_mode(enabled=False) + + module._sparse_method_instance.set_calibration_mode(False) + + def _extract_calibration_stats(self, modules: list[nn.Module]) -> list[dict]: + """Extract per-sample calibration statistics from modules. + + Args: + modules: List of attention modules + + Returns: + List of per-sample statistics across all modules + """ + # Collect from all stats managers + all_per_sample_stats = [] + + for module in modules: + # Skip modules without stats manager + if not hasattr(module, "_stats_manager") or module._stats_manager is None: + continue + + manager_stats = module._stats_manager.get_calibration_stats() + if manager_stats: + all_per_sample_stats.append(manager_stats) + + if not all_per_sample_stats: + return [] + + # Aggregate across modules by sample index + num_samples = len(all_per_sample_stats[0]) + aggregated_stats = [] + + for sample_idx in range(num_samples): + sparsities = [] + sample_length = 0 + + for module_stats in all_per_sample_stats: + if sample_idx < len(module_stats): + sample_stat = module_stats[sample_idx] + sparsities.append(sample_stat.get("sparsity", 0.0)) + if not sample_length and "sample_length" in sample_stat: + sample_length = sample_stat["sample_length"] + + avg_sparsity = float(np.mean(sparsities)) if sparsities else 0.0 + + aggregated_stats.append( + { + "sparsity": avg_sparsity, + "sample_length": sample_length, + } + ) + + return aggregated_stats + + def _set_threshold(self, modules: list[nn.Module], threshold: float): + """Set threshold on sparse attention modules.""" + for module in modules: + module._sparse_method_instance.threshold = threshold diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py new file mode 100644 index 000000000..7603b4e1d --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -0,0 +1,546 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RULER dataset builder for sparse attention calibration.""" + +import random +import string +from dataclasses import dataclass +from typing import Any + +from tqdm import tqdm +from transformers import AutoTokenizer + +from . import ruler_utils + + +def _generate_target_lengths( + max_seqlen: int, num_length_bins: int = 4, min_seqlen: int = 1024 +) -> list[int]: + """Generate target lengths as descending powers of 2. + + Args: + max_seqlen: Maximum sequence length + num_length_bins: Maximum number of length bins to generate + min_seqlen: Minimum sequence length threshold + + Returns: + List of target lengths in descending order + + Examples: + >>> _generate_target_lengths(32768, 4) + [32768, 16384, 8192, 4096] + >>> _generate_target_lengths(2048, 4) + [2048, 1024] + """ + target_lengths = [] + current = max_seqlen + + for _ in range(num_length_bins): + if current < min_seqlen: + break + target_lengths.append(current) + current = current // 2 + + return target_lengths + + +@dataclass +class RulerTask: + """Configuration for a RULER task.""" + + name: str + task_type: str # niah, variable_tracking, freq_words_extraction, qa + tokens_to_generate: int + template: str + answer_prefix: str + args: dict[str, Any] + + +# Task configurations based on RULER benchmark +RULER_TASKS = { + "niah_multikey_2": RulerTask( + name="niah_multikey_2", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "words", + "type_needle_v": "numbers", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "niah_multikey_3": RulerTask( + name="niah_multikey_3", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "uuids", + "type_needle_v": "uuids", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "vt": RulerTask( + name="vt", + task_type="variable_tracking", + tokens_to_generate=30, + template=( + "Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n" + "{context}\n" + "Question: Find all variables that are assigned the value {query} in the text above." + ), + answer_prefix=( + " Answer: According to the chain(s) of variable assignment in the text above, " + "{num_v} variables are assgined the value {query}, they are: " + ), + args={"num_chains": 1, "num_hops": 4}, + ), + "fwe": RulerTask( + name="fwe", + task_type="freq_words_extraction", + tokens_to_generate=50, + template=( + "Read the following coded text and track the frequency of each coded word. " + "Find the three most frequently appeared coded words. {context}\n" + "Question: Do not provide any explanation. Please ignore the dots '....'. " + "What are the three most frequently appeared words in the above coded text?" + ), + answer_prefix=( + " Answer: According to the coded text above, " + "the three most frequently appeared words are:" + ), + args={"alpha": 2.0}, + ), + "qa_1": RulerTask( + name="qa_1", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "squad"}, + ), + "qa_2": RulerTask( + name="qa_2", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "hotpotqa"}, + ), +} + + +class RulerDatasetBuilder: + """Builder for RULER calibration datasets.""" + + def __init__( + self, + samples: int, + max_seqlen: int, + tokenizer_name_or_path: str | object, + num_length_bins: int = 4, + max_length_filter: int = 65536, + seed: int = 42, + ): + """Initialize RULER dataset builder. + + Args: + samples: Total number of samples to generate (distributed evenly across length bins) + max_seqlen: Maximum sequence length (length bins auto-generated as powers of 2) + tokenizer_name_or_path: HuggingFace tokenizer path or tokenizer object + seed: Random seed for reproducibility + num_length_bins: Number of length bins to generate (default: 4) + max_length_filter: Maximum sequence length to keep (default: 65536) + + Note: + Length bins are auto-generated as descending powers of 2: + [max_seqlen, max_seqlen/2, max_seqlen/4, ...] + Generation stops when num_length_bins is reached or length < 1024. + Subtasks are set to all the difficult tasks defined in RULER_TASKS. + """ + # Validate inputs + if samples <= 0: + raise ValueError(f"samples must be positive, got {samples}") + if max_seqlen < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {max_seqlen}") + + # Store parameters + self.total_samples = samples + self.max_seqlen = max_seqlen + self.num_length_bins = num_length_bins + self.subtasks = list(RULER_TASKS.keys()) + self.tokenizer_name_or_path = tokenizer_name_or_path + self.seed = seed + self.max_length_filter = max_length_filter + + # Generate target lengths and validate + self.target_lengths = _generate_target_lengths(max_seqlen, num_length_bins, min_seqlen=1024) + if not self.target_lengths: + raise ValueError(f"No valid target lengths generated from max_seqlen={max_seqlen}") + + # Distribute samples evenly across lengths + self.samples_per_length = [samples // len(self.target_lengths)] * len(self.target_lengths) + + # Initialize tokenizer + if isinstance(tokenizer_name_or_path, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + else: + self.tokenizer = tokenizer_name_or_path + random.seed(seed) + + def build_calibration_dataset(self) -> list[dict[str, Any]]: + """Build the complete calibration dataset. + + Returns: + List of calibration samples with 'input' and 'length' fields + """ + all_samples = [] + + # Generate calibration samples + for num_samples, target_length in tqdm( + zip(self.samples_per_length, self.target_lengths), + desc="Generating RULER calibration samples", + total=len(self.target_lengths), + ): + samples_per_task = max(num_samples // len(self.subtasks), 1) + + # Generate equal samples for each task + for task_name in self.subtasks: + for sample_idx in range(samples_per_task): + sample = self._generate_sample(task_name, target_length, sample_idx) + if sample and sample["length"] <= self.max_length_filter: + all_samples.append(sample) + + random.shuffle(all_samples) + return all_samples + + def _generate_sample( + self, task_name: str, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a single RULER sample. + + Args: + task_name: Name of the RULER task + target_length: Target sequence length in tokens + sample_idx: Index of the sample (for uniqueness) + + Returns: + Dict with 'input', 'length', and metadata fields + """ + task = RULER_TASKS[task_name] + + if task.task_type == "niah": + return self._generate_niah_sample(task, target_length, sample_idx) + elif task.task_type == "variable_tracking": + return self._generate_vt_sample(task, target_length, sample_idx) + elif task.task_type == "freq_words_extraction": + return self._generate_fwe_sample(task, target_length, sample_idx) + elif task.task_type == "qa": + return self._generate_qa_sample(task, target_length, sample_idx) + else: + raise ValueError(f"Unknown task type: {task.task_type}") + + def _generate_niah_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a needle-in-haystack sample.""" + args = task.args + + # Find optimal haystack size for target length + optimal_haystack = ruler_utils.find_optimal_haystack_size( + tokenizer=self.tokenizer, + max_seq_length=target_length, + template=task.template, + answer_prefix=task.answer_prefix, + tokens_to_generate=task.tokens_to_generate, + type_haystack=args.get("type_haystack", "essay"), + type_needle_k=args.get("type_needle_k", "words"), + type_needle_v=args.get("type_needle_v", "numbers"), + num_needle_k=args.get("num_needle_k", 1), + num_needle_v=args.get("num_needle_v", 1), + num_needle_q=args.get("num_needle_q", 1), + ) + + # Generate sample using official RULER implementation + sample = ruler_utils.generate_niah_sample( + num_haystack=optimal_haystack, + tokenizer=self.tokenizer, + template=task.template, + answer_prefix=task.answer_prefix, + tokens_to_generate=task.tokens_to_generate, + type_haystack=args.get("type_haystack", "essay"), + type_needle_k=args.get("type_needle_k", "words"), + type_needle_v=args.get("type_needle_v", "numbers"), + num_needle_k=args.get("num_needle_k", 1), + num_needle_v=args.get("num_needle_v", 1), + num_needle_q=args.get("num_needle_q", 1), + random_seed=self.seed + sample_idx, + ) + + # Add task metadata + sample["task"] = task.name + sample["target_length"] = target_length + sample["sample_idx"] = sample_idx + + return sample + + def _generate_vt_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a variable tracking sample.""" + args = task.args + num_chains = args["num_chains"] + num_hops = args["num_hops"] + + # Generate variable chains + variables = [] + chains = [] + for _ in range(num_chains): + chain = [self._generate_random_variable() for _ in range(num_hops + 1)] + variables.extend(chain) + chains.append(chain) + + # Generate assignments + assignments = [ + f"VAR {chain[i]} = {chain[i + 1]}" for chain in chains for i in range(len(chain) - 1) + ] + + # Create context with padding + context = self._pad_context_with_text( + "\n".join(assignments), target_length, "variable tracking context" + ) + + # Select a query value + query_value = random.choice([chain[-1] for chain in chains]) + + # Format template + template = task.template.format(context=context, query=query_value) + + # Count variables with the query value + num_v = sum(1 for chain in chains if chain[-1] == query_value) + + # Add answer prefix + full_input = template + task.answer_prefix.format(num_v=num_v, query=query_value) + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_fwe_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a frequency word extraction sample.""" + # Generate coded words with frequencies + num_unique_words = 50 + coded_words = [self._generate_coded_word() for _ in range(num_unique_words)] + + # Assign frequencies (make top 3 clearly more frequent) + frequencies = {} + for i, word in enumerate(coded_words): + if i < 3: + frequencies[word] = random.randint(20, 30) # High frequency + else: + frequencies[word] = random.randint(1, 10) # Low frequency + + # Generate the coded text + word_list = [] + for word, freq in frequencies.items(): + word_list.extend([word] * freq) + random.shuffle(word_list) + + # Add dots for separation + coded_text = " .... ".join(word_list) + + # Pad to target length + context = self._pad_context_with_text(coded_text, target_length, "coded text padding") + + # Format template + template = task.template.format(context=context) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_qa_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a QA sample.""" + # Generate synthetic documents + num_docs = 5 + documents = [] + + # Create a simple QA pair + answer = self._generate_random_phrase() + question = f"What is the special code mentioned in document {random.randint(1, num_docs)}?" + + for i in range(num_docs): + doc_text = self._generate_document_text(200) # Base document + if i == 2: # Insert answer in one document + doc_text += f" The special code is {answer}. " + documents.append(f"Document {i + 1}:\n{doc_text}\n") + + # Combine documents + context_base = "\n".join(documents) + + # Pad to target length + context = self._pad_context_with_text( + context_base, target_length, "additional document text" + ) + + # Format template + template = task.template.format(context=context, query=question) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _pad_context_with_text( + self, base_context: str, target_length: int, padding_type: str + ) -> str: + """Pad context to approach target length.""" + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + while len(tokens) < target_length * 0.7: # Leave room for template + if padding_type == "variable tracking context": + padding = ( + f" VAR {self._generate_random_variable()} = {self._generate_random_variable()}." + ) + elif padding_type == "coded text padding": + padding = f" .... {self._generate_coded_word()} .... " + else: + padding = " " + self._generate_essay_text(50) + + base_context += padding + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + if len(tokens) > target_length * 0.9: + # Truncate if too long + base_context = self.tokenizer.decode(tokens[: int(target_length * 0.8)]) + + return base_context + + def _generate_random_word(self) -> str: + """Generate a random word.""" + return "".join(random.choices(string.ascii_lowercase, k=random.randint(5, 10))) + + def _generate_random_variable(self) -> str: + """Generate a random variable name.""" + return "".join(random.choices(string.ascii_uppercase, k=1)) + "".join( + random.choices(string.digits, k=3) + ) + + def _generate_coded_word(self) -> str: + """Generate a coded word.""" + return "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) + + def _generate_random_phrase(self) -> str: + """Generate a random phrase.""" + words = [self._generate_random_word() for _ in range(random.randint(2, 4))] + return " ".join(words) + + def _generate_essay_text(self, num_words: int) -> str: + """Generate essay-like text.""" + topics = [ + "technology", + "science", + "nature", + "history", + "culture", + "education", + "health", + "economics", + "politics", + "philosophy", + "art", + "literature", + ] + + sentences = [] + words_generated = 0 + + while words_generated < num_words: + topic = random.choice(topics) + word1 = self._generate_random_word() + word2 = self._generate_random_word() + word3 = self._generate_random_word() + sentence = f"The {topic} of {word1} is {word2} and {word3}. " + sentences.append(sentence) + words_generated += len(sentence.split()) + + return " ".join(sentences) + + def _generate_document_text(self, num_words: int) -> str: + """Generate document-like text.""" + return self._generate_essay_text(num_words) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh b/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh new file mode 100755 index 000000000..54797f2a5 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Download RULER calibration data for attention sparsity. + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DATA_DIR="${SCRIPT_DIR}/data" +ESSAYS_DIR="${DATA_DIR}/essays" +URLS_FILE="${DATA_DIR}/PaulGrahamEssays_URLs.txt" +URLS_URL="https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt" + +mkdir -p "${ESSAYS_DIR}" + +# Download URL list if not exists +if [ ! -f "${URLS_FILE}" ]; then + echo "Downloading URL list..." + curl -fsSL "${URLS_URL}" -o "${URLS_FILE}" +fi + +# Download essays from GitHub URLs +echo -n "Downloading essays" +count=0 +while IFS= read -r url || [ -n "$url" ]; do + if [[ "${url}" == https://github.com*.txt ]]; then + filename=$(basename "${url}") + filepath="${ESSAYS_DIR}/${filename}" + if [ ! -f "${filepath}" ]; then + raw_url="${url/github.com/raw.githubusercontent.com}" + raw_url="${raw_url/\/raw\//\/}" + curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null && echo -n "." + count=$((count + 1)) + fi + fi +done < "${URLS_FILE}" +echo " done" + +echo "Downloaded ${count} essays to ${ESSAYS_DIR}" diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py new file mode 100644 index 000000000..70d4da81b --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py @@ -0,0 +1,487 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied and Adapted from https://github.com/NVIDIA/RULER +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""Official RULER dataset generation utilities adapted for Model Optimizer. + +This module contains core logic from the RULER benchmark (https://github.com/NVIDIA/RULER) +adapted to work as a library for calibration purposes. The generation logic closely follows +the official RULER implementation to ensure dataset consistency. + +Key adaptations from official RULER: +- Converted from CLI scripts to library functions +- Works with HuggingFace tokenizers directly +- Removed file I/O, returns data structures +- Simplified for calibration use case (primarily NIAH tasks) +""" + +import logging +import random +import re +import uuid +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +# Needle/Haystack template from official RULER +NEEDLE_TEMPLATE = "One of the special magic {type_needle_v} for {key} is: {value}." + +# Depth positions for needle insertion (from official RULER) +DEPTHS = [ + 0, + 2, + 5, + 7, + 10, + 12, + 15, + 18, + 20, + 23, + 25, + 28, + 30, + 33, + 35, + 38, + 40, + 43, + 45, + 48, + 50, + 53, + 55, + 58, + 60, + 62, + 65, + 67, + 70, + 72, + 75, + 77, + 80, + 82, + 85, + 87, + 90, + 92, + 95, + 97, + 100, +] + +# Data directory for RULER calibration files (downloaded via download_ruler_data.sh) +DATA_DIR = Path(__file__).parent / "data" +RULER_URLS_FILE = DATA_DIR / "PaulGrahamEssays_URLs.txt" +ESSAYS_DIR = DATA_DIR / "essays" + + +def _get_data_dir() -> Path: + """Get data directory for RULER data. + + Returns: + Path to data directory under calibration/ (created if doesn't exist) + """ + data_dir = Path(__file__).parent / "data" + data_dir.mkdir(parents=True, exist_ok=True) + return data_dir + + +def _load_paul_graham_essays_from_files() -> str: + """Load Paul Graham essays from local files. + + Reads essay .txt files from the data/essays directory. + Files must be downloaded first using download_ruler_data.sh. + + Returns: + Combined essay text + + Raises: + RuntimeError: If essays directory doesn't exist or is empty + """ + if not ESSAYS_DIR.exists(): + raise RuntimeError( + f"Essays directory not found at {ESSAYS_DIR}.\n" + "Please run the download script first:\n" + " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + ) + + essay_files = list(ESSAYS_DIR.glob("*.txt")) + if not essay_files: + raise RuntimeError( + f"No essay files found in {ESSAYS_DIR}.\n" + "Please run the download script first:\n" + " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + ) + + logger.info(f"Loading {len(essay_files)} Paul Graham essays from local files...") + + all_essays = [] + for filepath in essay_files: + text = filepath.read_text() + all_essays.append(text) + + combined_text = " ".join(all_essays) + logger.info(f"Loaded {len(all_essays)} essays successfully") + + return combined_text + + +def _load_paul_graham_essays() -> str: + """Load Paul Graham essays from local files. + + Essay files must be downloaded first using download_ruler_data.sh. + + Returns: + Essay text as string + """ + essay_text = _load_paul_graham_essays_from_files() + return re.sub(r"\s+", " ", essay_text) + + +def _load_word_lists(): + """Load word lists for random word generation. + + Returns: + List of words (adj-noun combinations) + """ + import wonderwords + + # Load wonderwords lists (same as official RULER) + nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt") + adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt") + words = [f"{adj}-{noun}" for adj in adjs for noun in nouns] + words = sorted(set(words)) + return words + + +# Global word list (loaded once) +_WORD_LIST = None + + +def generate_random_number(num_digits=7) -> str: + """Generate random number (from official RULER).""" + lower_bound = 10 ** (num_digits - 1) + upper_bound = 10**num_digits - 1 + return str(random.randint(lower_bound, upper_bound)) + + +def generate_random_word() -> str: + """Generate random word (from official RULER).""" + global _WORD_LIST + if _WORD_LIST is None: + _WORD_LIST = _load_word_lists() + return random.choice(_WORD_LIST) + + +def generate_random_uuid() -> str: + """Generate random UUID (from official RULER).""" + return str(uuid.UUID(int=random.getrandbits(128), version=4)) + + +def generate_random(type_needle: str) -> str: + """Generate random needle value based on type (from official RULER). + + Args: + type_needle: Type of needle ('numbers', 'words', 'uuids') + + Returns: + Random value as string + """ + if type_needle == "numbers": + return generate_random_number() + elif type_needle == "words": + return generate_random_word() + elif type_needle == "uuids": + return generate_random_uuid() + else: + raise ValueError(f"Unknown needle type: {type_needle}") + + +def generate_niah_sample( + num_haystack: int, + tokenizer, + template: str, + answer_prefix: str, + tokens_to_generate: int = 128, + type_haystack: str = "essay", + type_needle_k: str = "words", + type_needle_v: str = "numbers", + num_needle_k: int = 1, + num_needle_v: int = 1, + num_needle_q: int = 1, + random_seed: int = 42, +) -> dict[str, Any]: + """Generate a single NIAH (Needle in a Haystack) sample. + + This function implements the core generation logic from official RULER's niah.py, + adapted to work as a library function. + + Args: + num_haystack: Number of haystack items/words + tokenizer: HuggingFace tokenizer (AutoTokenizer instance) + template: NIAH question template + answer_prefix: Answer prefix template + tokens_to_generate: Expected number of generation tokens + type_haystack: Type of haystack ('essay', 'noise', 'needle') + type_needle_k: Type of needle keys ('numbers', 'words', 'uuids') + type_needle_v: Type of needle values ('numbers', 'words', 'uuids') + num_needle_k: Number of needle keys + num_needle_v: Number of needle values per key + num_needle_q: Number of needles to query + random_seed: Random seed for this sample + + Returns: + Dictionary with 'input', 'outputs', 'length' keys + """ + import nltk + from nltk.tokenize import sent_tokenize + + try: + nltk.data.find("tokenizers/punkt") + except LookupError: + nltk.download("punkt", quiet=True) + nltk.download("punkt_tab", quiet=True) + + if random_seed is not None: + random.seed(random_seed) + + # Ensure num_needle_k >= num_needle_q + num_needle_k = max(num_needle_k, num_needle_q) + + # Generate needles (keys and values) + keys, values, needles = [], [], [] + for _ in range(num_needle_k): + keys.append(generate_random(type_needle_k)) + value = [] + for _ in range(num_needle_v): + value.append(generate_random(type_needle_v)) + needles.append( + NEEDLE_TEMPLATE.format( + type_needle_v=type_needle_v, + key=keys[-1], + value=value[-1], + ) + ) + values.append(value) + + random.shuffle(needles) + + # Generate context based on haystack type + if type_haystack == "essay": + # Load essay corpus + essay_text = _load_paul_graham_essays() + haystack = essay_text.split(" ") + + # Create text from haystack + if num_haystack <= len(haystack): + text = " ".join(haystack[:num_haystack]) + else: + # Repeat haystack as needed + repeats = (num_haystack + len(haystack) - 1) // len(haystack) + text = " ".join((haystack * repeats)[:num_haystack]) + + # Insert needles at various depths + document_sents = sent_tokenize(text.strip()) + insertion_positions = [ + 0, + *sorted( + int(len(document_sents) * (depth / 100)) + for depth in random.sample(DEPTHS, len(needles)) + ), + len(document_sents), + ] + + document_sents_list = [] + for i in range(1, len(insertion_positions)): + last_pos = insertion_positions[i - 1] + next_pos = insertion_positions[i] + document_sents_list.append(" ".join(document_sents[last_pos:next_pos])) + if i - 1 < len(needles): + document_sents_list.append(needles[i - 1]) + + context = " ".join(document_sents_list) + + if type_haystack == "noise": + haystack_sent = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." + sentences = [haystack_sent] * num_haystack + indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) + for index, element in zip(indexes, needles): + sentences.insert(index, element) + context = "\n".join(sentences) + + elif type_haystack == "needle": + sentences = [ + NEEDLE_TEMPLATE.format( + type_needle_v=type_needle_v, + key=generate_random(type_needle_k), + value=generate_random(type_needle_v), + ) + for _ in range(num_haystack) + ] + + indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) + for index, element in zip(indexes, needles): + sentences.insert(index, element) + context = "\n".join(sentences) + + # Generate query and answer + indices = random.sample(range(num_needle_k), num_needle_q) + queries = [keys[i] for i in indices] + answers = [a for i in indices for a in values[i]] + query = ", ".join(queries[:-1]) + ", and " + queries[-1] if len(queries) > 1 else queries[0] + + # Format template (adjust for singular vs plural) + type_needle_v_display = type_needle_v + formatted_template = template + if num_needle_q * num_needle_v == 1: + formatted_template = formatted_template.replace("Some", "A") + formatted_template = formatted_template.replace("are all", "is") + formatted_template = formatted_template.replace("are", "is") + formatted_template = formatted_template.replace("answers", "answer") + type_needle_v_display = type_needle_v[:-1] # remove "s" + + input_text = formatted_template.format( + type_needle_v=type_needle_v_display, + context=context, + query=query, + ) + + # Add answer prefix + formatted_answer_prefix = answer_prefix.format( + type_needle_v=type_needle_v_display, + query=query, + ) + input_text = input_text + formatted_answer_prefix + + # Calculate actual length + if hasattr(tokenizer, "encode"): + # HuggingFace tokenizer + tokens = tokenizer.encode(input_text, add_special_tokens=False) + length = len(tokens) + tokens_to_generate + else: + # Fallback + length = len(input_text.split()) + tokens_to_generate + + return { + "input": input_text, + "outputs": answers, + "length": length, + } + + +def find_optimal_haystack_size( + tokenizer, + max_seq_length: int, + template: str, + answer_prefix: str, + tokens_to_generate: int = 128, + type_haystack: str = "essay", + **kwargs, +) -> int: + """Find optimal haystack size using binary search (from official RULER). + + Args: + tokenizer: HuggingFace tokenizer + max_seq_length: Maximum sequence length + tokens_to_generate: Expected generation tokens + type_haystack: Type of haystack + template: NIAH question template + answer_prefix: Answer prefix template + **kwargs: Additional arguments for generate_niah_sample + + Returns: + Optimal number of haystack items + """ + # Determine incremental step based on haystack type + if type_haystack == "essay": + incremental = 500 + elif type_haystack in ["noise", "needle"]: + incremental = 25 + else: + incremental = 100 + + if max_seq_length < 4096 and type_haystack != "essay": + incremental = 5 + + # Estimate tokens per haystack item + sample = generate_niah_sample( + incremental, + tokenizer, + template, + answer_prefix, + tokens_to_generate, + type_haystack=type_haystack, + **kwargs, + ) + + if hasattr(tokenizer, "encode"): + sample_tokens = len(tokenizer.encode(sample["input"], add_special_tokens=False)) + else: + sample_tokens = len(sample["input"].split()) + + tokens_per_haystack = sample_tokens / incremental + estimated_max = int((max_seq_length / tokens_per_haystack) * 3) + + # Binary search for optimal size + lower_bound = incremental + upper_bound = max(estimated_max, incremental * 2) + optimal_num_haystack = None + + logger.info(f"Estimated {tokens_per_haystack:.1f} tokens per haystack") + logger.info(f"Binary search bounds: {lower_bound} to {upper_bound}") + + while lower_bound <= upper_bound: + mid = (lower_bound + upper_bound) // 2 + sample = generate_niah_sample( + mid, + tokenizer, + template, + answer_prefix, + tokens_to_generate, + type_haystack=type_haystack, + **kwargs, + ) + total_tokens = sample["length"] + + logger.debug(f"Testing haystack size: {mid}, tokens: {total_tokens}/{max_seq_length}") + + if total_tokens <= max_seq_length: + optimal_num_haystack = mid + lower_bound = mid + 1 + else: + upper_bound = mid - 1 + + final_size = optimal_num_haystack if optimal_num_haystack is not None else incremental + logger.info(f"Optimal haystack size: {final_size}") + + return final_size diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index e72dacc94..8271dd4a2 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -77,6 +77,12 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) + collect_stats: bool = ModeloptField( + default=False, + title="Collect statistics.", + description="Whether to collect sparsity statistics during forward pass for monitoring.", + ) + is_causal: bool = ModeloptField( default=True, title="Causal attention flag.", @@ -87,16 +93,6 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) - calibration: dict | None = ModeloptField( - default=None, - title="Calibration configuration", - description=( - "Calibration settings for this pattern. " - "If provided, enables automatic threshold calibration. " - "Only one pattern should have calibration enabled." - ), - ) - @field_validator("method") @classmethod def validate_method(cls, v): @@ -150,24 +146,113 @@ def validate_threshold(cls, v): return v -# Pre-defined Sparse Attention Configuration -# Default configuration with block-wise sparsity optimized for Flash Attention -SKIP_SOFTMAX_DEFAULT = { - "sparse_cfg": { - "*attn*": { - "method": "flash_skip_softmax", - "threshold": { - "prefill": 1e-3, # More aggressive during prefill - "decode": 1e-4, # Conservative during decode - }, - "br": 128, # Flash Attention block rows - "bc": 128, # Flash Attention block columns - "backend": "pytorch", # Only pytorch backend supported - "enable": True, - }, - "default": {"enable": False}, - }, -} +class CalibrationConfig(ModeloptBaseConfig): + """Configuration for automatic threshold calibration using RULER dataset. + + Calibration learns a dynamic threshold λ = scale_factor / sequence_length that + achieves target sparsity. Only supports prefill phase (seq_len > 1). + """ + + target_sparse_ratio: float = ModeloptField( + default=0.5, + title="Target sparsity ratio", + description="Target ratio of sparse attention blocks (0.0 to 1.0).", + ) + + samples: int = ModeloptField( + default=24, + title="Calibration samples", + description="Total number of RULER samples for calibration (distributed across length bins).", + ) + + max_seqlen: int = ModeloptField( + default=32768, + title="Maximum sequence length", + description="Maximum sequence length for calibration (length bins auto-generated as powers of 2).", + ) + + num_length_bins: int = ModeloptField( + default=4, + title="Number of length bins", + description="Number of length bins to generate (hidden parameter, default: 4).", + ) + + chunk_size: int = ModeloptField( + default=2048, + title="Chunk size for prefill", + description=( + "Chunk size for chunked prefill to avoid OOM with long sequences. " + "When sequence length exceeds chunk_size, prefill is done in chunks using KV cache. " + "Set to -1 to disable chunking (full prefill)." + ), + ) + + threshold_trials: list[float] | None = ModeloptField( + default=None, + title="Threshold trials", + description=( + "List of threshold values to test during calibration. " + "If None, uses default: [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]" + ), + ) + + @field_validator("threshold_trials") + @classmethod + def validate_threshold_trials(cls, v): + """Validate threshold_trials are in valid range.""" + if v is not None: + if not isinstance(v, list): + raise ValueError(f"threshold_trials must be a list, got {type(v)}") + if len(v) == 0: + raise ValueError("threshold_trials must not be empty") + for threshold in v: + if not isinstance(threshold, (int, float)): + raise ValueError(f"All threshold_trials must be numbers, got {type(threshold)}") + if threshold <= 0 or threshold >= 1: + raise ValueError( + f"All threshold_trials must be in range (0, 1), got {threshold}" + ) + return v + + @field_validator("target_sparse_ratio") + @classmethod + def validate_target_sparse_ratio(cls, v): + """Validate target sparsity ratio is between 0 and 1.""" + if not 0.0 <= v <= 1.0: + raise ValueError(f"target_sparse_ratio must be between 0.0 and 1.0, got {v}") + return v + + @field_validator("samples") + @classmethod + def validate_samples(cls, v): + """Validate samples is positive.""" + if v <= 0: + raise ValueError(f"samples must be positive, got {v}") + return v + + @field_validator("max_seqlen") + @classmethod + def validate_max_seqlen(cls, v): + """Validate max_seqlen is at least 1024.""" + if v < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {v}") + return v + + @field_validator("num_length_bins") + @classmethod + def validate_num_length_bins(cls, v): + """Validate num_length_bins is positive.""" + if v <= 0: + raise ValueError(f"num_length_bins must be positive, got {v}") + return v + + @field_validator("chunk_size") + @classmethod + def validate_chunk_size(cls, v): + """Validate chunk_size is positive or -1 (disabled).""" + if v != -1 and v <= 0: + raise ValueError(f"chunk_size must be positive or -1 (disabled), got {v}") + return v class SparseAttentionConfig(ModeloptBaseConfig): @@ -184,8 +269,9 @@ class SparseAttentionConfig(ModeloptBaseConfig): "default": {"enable": False}, }, title="Sparse attention configuration", - description="Pattern-based configuration for sparse attention. Keys are patterns to match module names, " - "values are configuration dicts with parameters like 'threshold', 'enable', and 'calibration'.", + description="Pattern-based configuration for sparse attention. Keys are patterns to match module names " + "(or 'calibration' for global calibration settings), values are configuration dicts with parameters like " + "'threshold', 'enable', etc.", validate_default=True, ) @@ -198,15 +284,17 @@ class SparseAttentionConfig(ModeloptBaseConfig): class FlashSkipSoftmaxConfig(SparseAttentionConfig): """Configuration for Flash Attention-aware softmax skip sparse attention.""" + # Override sparse_cfg with flash_skip_softmax specific defaults # Override sparse_cfg with flash_skip_softmax specific defaults sparse_cfg: SparseAttentionCfgType = ModeloptField( default={ "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "threshold": {"prefill": 1e-3, "decode": 1e-5}, "br": 128, # Flash Attention block rows "bc": 128, # Flash Attention block columns "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, # Enable statistics collection "enable": True, }, "default": {"enable": False}, @@ -218,8 +306,55 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): ) +# Pre-defined Sparse Attention Configuration +# Default configuration with block-wise sparsity optimized for Flash Attention +SKIP_SOFTMAX_DEFAULT = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": { + "prefill": 1e-3, # More aggressive during prefill + "decode": 1e-4, # Conservative during decode + }, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +# Configuration with RULER calibration +# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length +# The calibrated threshold adapts to sequence length for optimal sparsity +SKIP_SOFTMAX_CALIB = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": 0.75, + "samples": 16, + "max_seqlen": 16384, + }, + "*attn*": { + "method": "flash_skip_softmax", + "br": 128, + "bc": 128, + "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + __all__ = [ + "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "CalibrationConfig", + "FlashSkipSoftmaxConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", "SparseAttentionCfgType", diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index ad137e9ee..aa3eb7c29 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -226,6 +226,8 @@ def update_sparse_attention_metadata( if isinstance(module, SparseAttentionModule): module_name = get_unwrapped_name(name, model) + # Save the method configuration that was used + # _method_config already contains the validated config dict # Save the method configuration that was used # _method_config already contains the validated config dict module_state = { @@ -299,3 +301,44 @@ def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Cal if matched: module.enable() + + +def _format_threshold(info: dict) -> str: + """Format threshold info for display.""" + t = info.get("type") + if t == "dynamic": + return f"λ={info.get('scale_factor', 0):.2f}" + if t == "static": + v = info.get("value") + if isinstance(v, dict): + return f"threshold={v}" + return f"threshold={v:.2e}" if isinstance(v, float) else f"threshold={v}" + return "threshold=N/A" + + +def print_sparse_attention_summary(model: nn.Module): + """Print summary of sparse attention modules in the model. + + Args: + model: Model with sparse attention applied + """ + sparse_modules = [ + (name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule) + ] + + if not sparse_modules: + print("No sparse attention modules found") + return + + enabled = sum(1 for _, m in sparse_modules if m.is_enabled) + print(f"Sparse attention: {enabled}/{len(sparse_modules)} modules enabled") + + # Group by (method, threshold) + groups: dict[tuple[str, str], int] = {} + for _, module in sparse_modules: + method = getattr(module, "_method", "unknown") + threshold = _format_threshold(module.get_threshold_info()) + groups[(method, threshold)] = groups.get((method, threshold), 0) + 1 + + for (method, threshold), count in sorted(groups.items()): + print(f" {method}: {count} layers, {threshold}") diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index 8801bafb0..458fbeb50 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -20,6 +20,7 @@ """ import math +from typing import Any import numpy as np import torch @@ -44,7 +45,7 @@ def __init__(self, method_config: dict | None = None): """ config = method_config or {} - # Extract configuration (defaults handled by Pydantic) + # Extract configuration self.threshold_config = config["threshold"] self.br = config["br"] self.bc = config["bc"] @@ -52,9 +53,11 @@ def __init__(self, method_config: dict | None = None): self.is_causal = config["is_causal"] # Optional parameters not in Pydantic config - self.enable_correction_factor = config.get("enable_correction_factor", True) self.phase = config.get("phase", None) + # Calibration mode: when True, prevent threshold updates to preserve calibrator's test threshold + self._calibration_mode = False + # Initialize threshold if isinstance(self.threshold_config, dict): self.threshold = self.threshold_config.get( @@ -63,6 +66,10 @@ def __init__(self, method_config: dict | None = None): else: self.threshold = self.threshold_config + def set_calibration_mode(self, enabled: bool): + """Set calibration mode to prevent _update_threshold from modifying the threshold.""" + self._calibration_mode = enabled + def _update_threshold(self, phase: str): """Update threshold based on phase.""" if isinstance(self.threshold_config, dict): @@ -184,18 +191,15 @@ def calc_correction_factor_and_p( element_mask = element_mask[:, :, :seq_q, :seq_k] # Step 8: Calculate sparsity statistics - # Count kept blocks (averaged across batch and heads) - kept_blocks = block_mask.sum().item() / (batch_size * num_heads) - - # Total valid blocks (lower triangle only for causal attention) - # Note: Causal mask pre-applied by attention module, so block_mask naturally - # has zeros in upper triangle. We only count lower triangle for denominator. - total_blocks = ( - num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2 - if self.is_causal - else num_block_rows * num_block_cols # Non-causal: N*N - ) - sparsity = 1 - (kept_blocks / total_blocks) + # density = sum(mask) / numel(mask) * N / (N+1) for causal + if self.is_causal: + density = float(block_mask.sum() / block_mask.numel()) * ( + num_block_rows / (num_block_rows + 1) + ) + else: + density = float(block_mask.sum() / block_mask.numel()) + sparsity = 1 - density + total_blocks = num_block_rows * num_block_cols else: # decode blocked_attn, _, num_block_cols, _, padded_seq_k = self._reshape_to_blocks( attn_weights, 1, self.bc @@ -232,14 +236,14 @@ def calc_correction_factor_and_p( element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) element_mask = element_mask[:, :, :seq_q, :seq_k] - # Step 7: Calculate statistics - kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + # Step 7: Calculate sparsity statistics + density = float(block_mask.sum() / block_mask.numel()) + sparsity = 1 - density total_blocks = num_block_cols - sparsity = 1 - (kept_blocks / total_blocks) # Create stats dictionary stats = { - "correction_factor": correction_factor if self.enable_correction_factor else 1.0, + "correction_factor": correction_factor, "sparsity": sparsity, "phase": phase, "total_blocks": total_blocks, @@ -278,8 +282,9 @@ def apply_sparsity( # Infer phase from tensor shape phase = self._infer_phase(attention_scores) - # Update threshold for the detected phase - self._update_threshold(phase) + # Update threshold for the detected phase (skip during calibration) + if not self._calibration_mode: + self._update_threshold(phase) # Apply block-wise sparsity sparse_mask, stats = self.calc_correction_factor_and_p(attention_scores, phase) @@ -293,6 +298,34 @@ def apply_sparsity( return query, key, value, sparse_scores + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information for this method. + + Returns: + Dictionary with threshold configuration and calibration info. + """ + threshold_scale_factor = getattr(self, "threshold_scale_factor", None) + + if threshold_scale_factor is not None: + # Calibrated dynamic threshold + return { + "type": "dynamic", + "scale_factor": threshold_scale_factor, + "formula": "λ / length", + "example_lengths": { + 1024: threshold_scale_factor / 1024, + 2048: threshold_scale_factor / 2048, + 4096: threshold_scale_factor / 4096, + 8192: threshold_scale_factor / 8192, + }, + } + else: + # Static threshold (single value or phase-specific dict) + return { + "type": "static", + "value": self.threshold_config, + } + @property def name(self) -> str: """Method identifier.""" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index df7b5853b..2d79d9a40 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -18,6 +18,7 @@ import re import warnings from abc import ABC, abstractmethod +from typing import Any import torch @@ -45,6 +46,18 @@ def apply_sparsity( Tuple of (query, key, value, attention_scores) with sparsity applied """ + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information for display/debugging. + + Returns: + Dictionary with threshold information. Should include: + - 'type': 'static', 'dynamic', or 'none' + - 'value': threshold value (for static) + - 'scale_factor': scale factor (for dynamic) + - Other method-specific info + """ + return {"type": "none", "value": None} + @property @abstractmethod def name(self) -> str: diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py index 88434e746..b6b1e809f 100644 --- a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -22,10 +22,12 @@ from modelopt.torch.opt.conversion import apply_mode from modelopt.torch.opt.searcher import ForwardLoop +from .calibration import calibrate_sparse_attention from .config import SparseAttentionConfig from .mode import SparseAttentionModeRegistry __all__ = [ + "calibrate", "sparsify", ] @@ -58,12 +60,36 @@ def sparsify( .. code-block::python config = { - "method": "flash_skip_softmax", "sparse_cfg": { + # Phase-aware thresholds with backend selection "*attention*": { + "method": "flash_skip_softmax", "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + # Disable for specific layers + "*layer.0*": {"enable": False}, + # Default settings + "default": {"enable": False}, + }, + } + + For automatic threshold calibration using RULER dataset: + + .. code-block::python + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", "backend": "pytorch", "enable": True, + "calibration": { # Enables automatic threshold calibration + "target_sparse_ratio": 0.5, + "samples": 48, + "max_seqlen": 8192, + }, }, "default": {"enable": False}, }, @@ -110,7 +136,7 @@ def forward_loop(model) -> float: from transformers import AutoModelForCausalLM - model = AutoModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained(b model_path, attn_implementation="eager", # Required for sparse attention torch_dtype=torch.bfloat16, @@ -126,4 +152,26 @@ def forward_loop(model) -> float: model, mode=[("sparse_attention", config)], registry=SparseAttentionModeRegistry ) + # Calibrate the sparsity ratio of the attention modules + return calibrate(model, config, forward_loop=forward_loop) + + +def calibrate( + model: torch.nn.Module, + config: dict[str, Any] | SparseAttentionConfig, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Calibrates sparse attention thresholds based on target sparsity. + + Args: + model: Model with sparse attention modules + config: Sparse attention configuration with calibration settings + forward_loop: Optional callable that forwards calibration data through the model. + If provided, uses this for calibration data. + If None, will auto-generate RULER dataset for calibration. + + Returns: + The calibrated model with optimized sparse attention thresholds. + """ + calibrate_sparse_attention(model, config, forward_loop=forward_loop) return model diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index 16b08bf19..3fbe9b79b 100644 --- a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -15,6 +15,8 @@ """Extensible sparse attention module.""" +from typing import Any + import torch import torch.nn.functional as F @@ -23,6 +25,7 @@ from .config import SparseAttentionAttributeConfig from .methods import get_sparse_method +from .stats_manager import SparseAttentionStatsManager class SparseAttentionModule(DynamicModule): @@ -103,6 +106,14 @@ def set_from_attribute_config( # Initialize sparse method instance self._init_sparse_method() + # Create stats manager based on config + if self._method_config.get("collect_stats", False): + self._stats_manager = SparseAttentionStatsManager( + module_name="sparse_attention", enabled=True + ) + else: + self._stats_manager = None + def _init_sparse_method(self): """Initialize the sparse method instance.""" method_class = get_sparse_method(self._method) @@ -129,11 +140,22 @@ def get_stats(self) -> dict: Returns: Dictionary with sparsity statistics including 'average_sparsity' if available. - Returns empty dict (statistics collection will be added in calibration PR). + Returns empty dict if stats manager is not enabled. """ - # TODO: Statistics collection will be added in calibration PR + if self._stats_manager is not None and self._stats_manager.enabled: + return self._stats_manager.get_summary() return {} + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information from the sparse method instance. + + Returns: + Dictionary with threshold information from the sparse method. + """ + if hasattr(self, "_sparse_method_instance") and self._sparse_method_instance is not None: + return self._sparse_method_instance.get_threshold_info() + return {"type": "none", "value": None} + def _setup(self): """Setup called by DynamicModule.""" # Apply default configuration if not yet configured @@ -157,6 +179,10 @@ def forward(self, *args, **kwargs): with context: result = super().forward(*args, **kwargs) + # Collect stats if manager is available + if self._stats_manager is not None and hasattr(self._sparse_method_instance, "_last_stats"): + self._stats_manager.collect(self._sparse_method_instance._last_stats) + return result def _get_sparse_context(self): @@ -178,8 +204,8 @@ def sparse_softmax(input, dim=-1, *args, **kwargs): ) # Use sparse input if modified, otherwise use original - if sparse_input is not None: - return original_softmax(sparse_input, dim, *args, **kwargs) + # if sparse_input is not None: + # return original_softmax(sparse_input, dim, *args, **kwargs) return original_softmax(input, dim, *args, **kwargs) return sparse_softmax diff --git a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py new file mode 100644 index 000000000..9fc57a0b1 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Statistics manager for sparse attention modules.""" + + +class SparseAttentionStatsManager: + """Centralized statistics manager for sparse attention. + + This class is the single source of truth for all statistics collection + in sparse attention modules. It handles both runtime aggregation and + per-sample calibration statistics. + + Design principles: + - Single responsibility: only stats management + - No computation: receives pre-computed stats from methods + - Optional: can be None if stats collection disabled + - Zero overhead when disabled + """ + + def __init__(self, module_name: str, enabled: bool = True): + """Initialize stats manager. + + Args: + module_name: Name of the module this manager is attached to + enabled: Whether stats collection is enabled + """ + self.module_name = module_name + self.enabled = enabled + self.calibration_mode = False + + # Aggregated stats (running totals across all forward passes) + self.aggregated_stats: dict = { + "total_calls": 0, + "total_blocks": 0, + "sparse_blocks": 0, + "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, + } + + # Per-sample stats (only populated during calibration) + self.per_sample_stats: list[dict] = [] + + def collect(self, stats: dict): + """Collect statistics from a single forward pass. + + Args: + stats: Dictionary containing statistics from method computation. + Expected keys: sparsity, phase, total_blocks, sparse_blocks, + sample_length (optional) + """ + if not self.enabled: + return + + # Update aggregated stats + self.aggregated_stats["total_calls"] += 1 + self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0) + self.aggregated_stats["sparse_blocks"] += stats.get("sparse_blocks", 0) + + phase = stats.get("phase", "unknown") + if phase in self.aggregated_stats["phase_counts"]: + self.aggregated_stats["phase_counts"][phase] += 1 + + # In calibration mode, store per-sample stats + if self.calibration_mode: + self.per_sample_stats.append( + { + "module": self.module_name, + "sparsity": stats.get("sparsity", 0.0), + "sample_length": stats.get("sample_length", 0), + "phase": phase, + } + ) + + def get_summary(self) -> dict: + """Get aggregated statistics summary. + + Returns: + Dictionary with module name, total calls, average sparsity, + and phase distribution. + """ + total_blocks = self.aggregated_stats["total_blocks"] + if total_blocks > 0: + avg_sparsity = self.aggregated_stats["sparse_blocks"] / total_blocks + else: + avg_sparsity = 0.0 + + return { + "module": self.module_name, + "total_calls": self.aggregated_stats["total_calls"], + "average_sparsity": avg_sparsity, + "phase_distribution": self.aggregated_stats["phase_counts"].copy(), + } + + def set_calibration_mode(self, enabled: bool, reset_history: bool = True): + """Enable or disable calibration mode. + + In calibration mode, per-sample statistics are stored for detailed + analysis. Otherwise, only aggregated stats are maintained. + + Args: + enabled: Whether to enable calibration mode + reset_history: Whether to clear per_sample_stats when enabling + """ + self.calibration_mode = enabled + if enabled and reset_history: + self.per_sample_stats = [] + + def reset(self): + """Reset all statistics to initial state.""" + self.aggregated_stats = { + "total_calls": 0, + "total_blocks": 0, + "sparse_blocks": 0, + "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, + } + self.per_sample_stats = [] + + def get_calibration_stats(self) -> list[dict]: + """Get per-sample calibration statistics. + + Returns: + List of per-sample statistics dictionaries. + Empty list if not in calibration mode. + """ + return self.per_sample_stats diff --git a/setup.py b/setup.py index dd124e10e..cbb3e5eca 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,8 @@ "torch-geometric", "tox>4.18", "tox-current-env>=0.0.12", + "nltk", + "wonderwords", ], # docs "dev-docs": [ diff --git a/tests/_test_utils/torch_sparsity/sparse_attention_common.py b/tests/_test_utils/torch_sparsity/sparse_attention_common.py index 7724908b0..5ed079966 100644 --- a/tests/_test_utils/torch_sparsity/sparse_attention_common.py +++ b/tests/_test_utils/torch_sparsity/sparse_attention_common.py @@ -153,13 +153,15 @@ def forward_loop(model): with torch.no_grad(): for batch in calib_data: output = model(batch) - assert not torch.isnan(output).any(), "NaN in output" - assert output is not None, "Output is None" + assert not torch.isnan(output).any(), ( + f"NaN detected in output for batch shape {batch.shape}" + ) + assert output is not None, f"Output is None for batch shape {batch.shape}" return model -def save_restore_test(model_cls, device, sparse_config): +def save_restore_test(model_cls, device, sparse_config, atol=1e-6): """Test save and restore of sparse attention state. Args: @@ -190,6 +192,6 @@ def save_restore_test(model_cls, device, sparse_config): output_sparse = model_sparse(test_input) output_restored = model_restored(test_input) - assert torch.allclose(output_sparse, output_restored, atol=1e-6), ( + assert torch.allclose(output_sparse, output_restored, atol), ( "Restored model output doesn't match original" ) diff --git a/tests/examples/llm_eval/test_llm_eval.py b/tests/examples/llm_eval/test_llm_eval.py index 0abf78b53..88d29dedc 100644 --- a/tests/examples/llm_eval/test_llm_eval.py +++ b/tests/examples/llm_eval/test_llm_eval.py @@ -36,3 +36,20 @@ def test_llama_eval_fp8(): finally: # Force kill llm-serve if it's still running subprocess.run(["pkill", "-f", "llm-serve"], check=False) + + +def test_llama_eval_sparse_attention(tiny_llama_path): + """Test sparse attention with llm_eval integration.""" + try: + # Test with default sparse attention config (no quantization) + run_llm_ptq_command( + model=tiny_llama_path, + quant="none", # No quantization, only sparse attention + tasks="lm_eval", + lm_eval_tasks="hellaswag", + lm_eval_limit=0.05, # Small limit for fast test + sparse_cfg="SKIP_SOFTMAX_DEFAULT", + batch=4, + ) + finally: + subprocess.run(["pkill", "-f", "llm-serve"], check=False) diff --git a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py index b70dfab35..9f1cb8125 100644 --- a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py +++ b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py @@ -34,7 +34,6 @@ def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", } ) kwargs.setdefault("seq_len", 128) - kwargs.setdefault("num_samples", 1) kwargs.setdefault("max_new_tokens", 16) cmd_parts = extend_cmd_parts(["python", "hf_sa.py"], **kwargs) @@ -43,8 +42,10 @@ def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", @pytest.mark.parametrize("method", ["skip_softmax"]) def test_attention_sparsity(tiny_llama_path, tmp_path, method): - """Test sparse attention with TinyLlama.""" + """Test sparse attention with TinyLlama (with and without calibration).""" run_attention_sparsity_command( model=tiny_llama_path, method=method, + seq_len=128, + max_new_tokens=10, ) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py new file mode 100644 index 000000000..913dc24a0 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for sparse attention calibration.""" + +import pytest +import torch +from _test_utils.torch_sparsity.sparse_attention_common import SimpleTransformerEncoderLayer + +import modelopt.torch.opt as mto +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.calibration import RulerDatasetBuilder +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +# Skip all tests if no GPU available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required") + + +class TestRulerDatasetBuilderGPU: + """Test RULER dataset generation with real tokenizers on GPU.""" + + def test_ruler_generation_with_real_tokenizer(self): + """Test RULER generation with GPT2 tokenizer.""" + builder = RulerDatasetBuilder( + samples=6, # Need at least 6 samples (1 per task) + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should generate 6 samples (1 per task) + assert len(dataset) == 6 + + # All samples should have valid structure + for sample in dataset: + assert "input" in sample + assert "length" in sample + assert sample["length"] > 0 + + def test_generated_length_accuracy(self): + """Test that generated token counts are accurate.""" + builder = RulerDatasetBuilder( + samples=3, + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Check that lengths are within reasonable range of target + for sample in dataset: + # RULER aims for 70-90% of target for context + assert 700 < sample["length"] < 1400 + + def test_multiple_subtasks(self): + """Test generation with multiple RULER subtasks.""" + builder = RulerDatasetBuilder( + samples=12, # Need at least 6 for 1 per task, use 12 for 2 per task + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Check task distribution (should have multiple tasks from RULER_TASKS) + tasks_found = {s["task"] for s in dataset} + assert len(tasks_found) >= 2 # At least 2 different tasks + + def test_large_context_lengths(self): + """Test with larger context lengths.""" + builder = RulerDatasetBuilder( + samples=24, # 4 lengths * 6 tasks = need 24 for 1 per (length, task) + max_seqlen=8192, # Generates: [8192, 4096, 2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + assert len(dataset) == 24 + + # Verify we have different lengths + lengths = [s["length"] for s in dataset] + # Should have variety of lengths across the bins + assert len(set(lengths)) > 1 # At least 2 different target lengths used + + +class TestCalibrationGPU: + """Test calibration with real models on GPU.""" + + @pytest.fixture + def simple_model(self): + """Create simple attention model for testing.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + return model + + def test_calibration_simple_model(self, simple_model): + """Test calibration with simple attention model.""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 64, + "bc": 64, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + # Simple forward loop for calibration + pass + + # Apply sparse attention with calibration + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Verify sparse attention modules exist + sparse_modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + assert len(sparse_modules) > 0 + + # Verify calibration was applied + for module in sparse_modules: + method = module._sparse_method_instance + # Check if calibrated threshold scale factor is set + if hasattr(method, "threshold_scale_factor") and method.threshold_scale_factor: + assert method.threshold_scale_factor > 0 + + def test_calibration_pytorch_backend(self, simple_model): + """Test calibration with pytorch backend.""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Check backend is set correctly + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + method = module._sparse_method_instance + assert hasattr(method, "backend") + assert method.backend == "pytorch" + + def test_simplified_calibration(self, simple_model): + """Test simplified calibration (prefill phase only).""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Should complete without errors + assert sparse_model is not None + + def test_calibration_persistence(self, simple_model): + """Test save and restore of calibrated model.""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate model + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Save modelopt state + modelopt_state = mto.modelopt_state(sparse_model) + + # Create new model and restore + model_restored = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + + restored = mto.restore_from_modelopt_state(model_restored, modelopt_state) + + # Check that sparse attention is restored + has_sparse = any(isinstance(m, SparseAttentionModule) for m in restored.modules()) + assert has_sparse + + +class TestCalibrationEndToEnd: + """Integration tests with inference.""" + + @pytest.fixture + def simple_model_setup(self): + """Setup simple model.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + return model + + def test_calibrated_model_inference(self, simple_model_setup): + """Test inference with calibrated model.""" + model = simple_model_setup + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate model + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Test inference + test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=10).cuda() + + sparse_model.eval() + with torch.no_grad(): + output = sparse_model(test_input) + + # Check output is valid + assert output is not None + assert not torch.isnan(output).any() + + def test_calibrated_vs_fixed_threshold(self, simple_model_setup): + """Compare calibrated vs fixed threshold models.""" + # Config with calibration + config_calibrated = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + # Config with fixed threshold (no calibration) + config_fixed = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + } + }, + } + + def forward_loop(model): + pass + + # Test both can be created + model_calibrated = sparsify( + SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda(), + config_calibrated, + forward_loop=forward_loop, + ) + + model_fixed = sparsify( + SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda(), + config_fixed, + ) + + # Both should work for inference + test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=10).cuda() + + with torch.no_grad(): + output_calibrated = model_calibrated(test_input) + output_fixed = model_fixed(test_input) + + assert output_calibrated is not None + assert output_fixed is not None + + def test_memory_usage(self, simple_model_setup): + """Test that calibration doesn't cause memory issues.""" + model = simple_model_setup + + # Clear cache before test + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate + sparsify(model, config, forward_loop=forward_loop) + + # Check memory didn't explode + final_memory = torch.cuda.memory_allocated() + memory_increase = final_memory - initial_memory + + # Memory should be reasonable (not more than 2GB increase) + assert memory_increase < 2 * 1024**3 # 2GB + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py new file mode 100644 index 000000000..4558ca22b --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -0,0 +1,623 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sparse attention calibration.""" + +import pytest + +pytest.importorskip("transformers") + +import numpy as np +from _test_utils.torch_sparsity.sparse_attention_common import ( + SimpleAttentionModel, + SimpleTransformerEncoder, +) +from pydantic import ValidationError + +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.calibration import ( + DynamicThresholdCalibrator, + RulerDatasetBuilder, +) +from modelopt.torch.sparsity.attention_sparsity.calibration.calibrate import ( + _extract_calibration_config, + calibrate_sparse_attention, + create_calibration_forward_loop, +) +from modelopt.torch.sparsity.attention_sparsity.calibration.dataset import _generate_target_lengths +from modelopt.torch.sparsity.attention_sparsity.config import CalibrationConfig +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestLengthGeneration: + """Test automatic target length generation.""" + + def test_generate_target_lengths_default(self): + """Test default 4 bins generation.""" + lengths = _generate_target_lengths(32768, num_length_bins=4) + assert lengths == [32768, 16384, 8192, 4096] + + def test_generate_target_lengths_stops_at_minimum(self): + """Test generation stops at minimum threshold.""" + lengths = _generate_target_lengths(2048, num_length_bins=4) + assert lengths == [2048, 1024] # Stops at 1024 + + def test_generate_target_lengths_fewer_bins(self): + """Test with fewer bins.""" + lengths = _generate_target_lengths(16384, num_length_bins=2) + assert lengths == [16384, 8192] + + def test_generate_target_lengths_more_bins(self): + """Test with more bins.""" + lengths = _generate_target_lengths(65536, num_length_bins=6) + assert lengths == [65536, 32768, 16384, 8192, 4096, 2048] + + def test_generate_target_lengths_exactly_minimum(self): + """Test when max_seqlen equals minimum.""" + lengths = _generate_target_lengths(1024, num_length_bins=4) + assert lengths == [1024] + + +class TestRulerDatasetBuilder: + """Test RULER dataset generation without requiring real tokenizers.""" + + def test_builder_initialization(self): + """Test that builder initializes correctly.""" + builder = RulerDatasetBuilder( + samples=12, + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + assert builder.total_samples == 12 + assert builder.max_seqlen == 2048 + assert builder.target_lengths == [2048, 1024] + assert builder.samples_per_length == [6, 6] # Evenly distributed + assert len(builder.subtasks) == 6 # All RULER_TASKS + assert builder.seed == 42 + + def test_builder_initialization_invalid_config(self): + """Test that builder raises error for invalid inputs.""" + # Test invalid samples + with pytest.raises(ValueError, match="samples must be positive"): + RulerDatasetBuilder( + samples=0, + max_seqlen=2048, + tokenizer_name_or_path="gpt2", + ) + + # Test max_seqlen below minimum + with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): + RulerDatasetBuilder( + samples=4, + max_seqlen=512, # Below minimum + tokenizer_name_or_path="gpt2", + ) + + def test_dataset_generation_minimal(self): + """Test generating small dataset.""" + builder = RulerDatasetBuilder( + samples=12, # 6 tasks x 2 lengths = need 12 for 1 per task per length + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should generate 12 samples (6 tasks x 1 sample per task x 2 lengths) + assert len(dataset) == 12 + assert all(isinstance(sample, dict) for sample in dataset) + + def test_dataset_structure(self): + """Test that dataset has correct structure.""" + builder = RulerDatasetBuilder( + samples=6, # Need at least 6 (1 per task) + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + sample = dataset[0] + + # Check required fields + assert "input" in sample + assert "length" in sample + assert "task" in sample + assert "target_length" in sample + + # Check field types + assert isinstance(sample["input"], str) + assert isinstance(sample["length"], int) + assert isinstance(sample["task"], str) + assert sample["length"] > 0 + + def test_sample_distribution(self): + """Test that samples are distributed across lengths and subtasks.""" + builder = RulerDatasetBuilder( + samples=24, # 6 tasks x 2 lengths x 2 samples = 24 + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should have 24 samples (12 per length, 2 per task) + assert len(dataset) == 24 + + # Check task distribution (should have variety from all RULER_TASKS) + tasks = [s["task"] for s in dataset] + # Verify we have all 6 tasks represented + assert len(set(tasks)) == 6 + + def test_length_targeting(self): + """Test that generated lengths are close to targets.""" + builder = RulerDatasetBuilder( + samples=6, # 1 per task + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Lengths should be within reasonable range of target + # RULER aims for 70-90% of target length for context + for sample in dataset: + assert 700 < sample["length"] < 1400 # Reasonable range around 1024 + + def test_uneven_sample_distribution(self): + """Test that samples are distributed evenly (remainder dropped).""" + builder = RulerDatasetBuilder( + samples=50, # 50 samples across 4 lengths + max_seqlen=8192, # Generates: [8192, 4096, 2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + # Even distribution: 50//4 = 12 per length + assert builder.total_samples == 50 + assert builder.target_lengths == [8192, 4096, 2048, 1024] + assert builder.samples_per_length == [12, 12, 12, 12] + assert sum(builder.samples_per_length) == 48 # 2 samples dropped (remainder) + + # Actual generated samples: 12//6=2 per task, 4 lengths, 6 tasks + # Total: 2 x 6 x 4 = 48 + dataset = builder.build_calibration_dataset() + assert len(dataset) == 48 + + +class TestDynamicThresholdCalibrator: + """Test calibration algorithm correctness.""" + + def test_calibrator_initialization(self): + """Test that calibrator initializes correctly.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + threshold_trials=[1e-4, 1e-3, 1e-2], + ) + + assert calibrator.target_sparse_ratio == 0.5 + assert len(calibrator.threshold_trials) == 3 + + def test_calibrator_default_threshold_trials(self): + """Test that calibrator has default threshold trials.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + ) + + # Should have default threshold trials + assert calibrator.threshold_trials is not None + assert len(calibrator.threshold_trials) == 12 + # Check they are positive and in valid range + trials = calibrator.threshold_trials + assert all(0 < t < 1 for t in trials) + + def test_regression_calculation_synthetic(self): + """Test 'a' parameter calculation with synthetic data.""" + # Create synthetic optimal pairs + # If threshold = a / length, then with perfect data: + # length=1000, threshold=10 => a=10000 + # length=2000, threshold=5 => a=10000 + optimal_pairs = [ + {"length": 1000, "optimal_threshold": 10.0, "achieved_sparsity": 0.5}, + {"length": 2000, "optimal_threshold": 5.0, "achieved_sparsity": 0.5}, + {"length": 4000, "optimal_threshold": 2.5, "achieved_sparsity": 0.5}, + ] + + # Manual regression calculation + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + # Calculate 'a' using least squares + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Should be close to 10000 + assert 9500 < a_parameter < 10500 + + # Test individual 'a' values + a_per_sample = y * lengths + assert np.allclose(a_per_sample, 10000, rtol=0.05) + + def test_multiple_samples_different_lengths(self): + """Test regression with varied lengths.""" + # More realistic scenario with some variance + optimal_pairs = [ + {"length": 500, "optimal_threshold": 20.0, "achieved_sparsity": 0.5}, + {"length": 1000, "optimal_threshold": 10.5, "achieved_sparsity": 0.51}, + {"length": 2000, "optimal_threshold": 5.2, "achieved_sparsity": 0.49}, + {"length": 4000, "optimal_threshold": 2.4, "achieved_sparsity": 0.50}, + ] + + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Should still be around 10000 with some tolerance for variance + assert 9000 < a_parameter < 11000 + + def test_r_squared_calculation(self): + """Test R-squared calculation for regression quality.""" + # Perfect fit data + optimal_pairs = [ + {"length": 1000, "optimal_threshold": 10.0}, + {"length": 2000, "optimal_threshold": 5.0}, + {"length": 4000, "optimal_threshold": 2.5}, + ] + + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Calculate R-squared + y_pred = a_parameter * x + ss_res = np.sum((y - y_pred) ** 2) + ss_tot = np.sum((y - np.mean(y)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + # Perfect fit should have R^2 close to 1 + assert r_squared > 0.99 + + +class TestCalibrationIntegration: + """Test end-to-end calibration without GPU.""" + + def test_calibration_disabled(self): + """Test that no calibration occurs when disabled.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + # No forward_loop needed when calibration disabled + sparse_model = sparsify(model, config) + + # Check that sparse attention is applied but not calibrated + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + # Check that no calibration is set + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + method = module._sparse_method_instance + assert not getattr(method, "threshold_scale_factor", None) + + def test_sparsify_with_calibration_requires_forward_loop(self): + """Test that calibration requires forward_loop or proper model config.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 64, + "bc": 64, + "enable": True, + }, + }, + } + + # Without forward_loop and without model.config._name_or_path, should raise ValueError + with pytest.raises(ValueError, match="Could not load tokenizer"): + sparsify(model, config, forward_loop=None) + + def test_multiple_sparse_modules(self): + """Test that calibration handles multiple attention layers.""" + model = SimpleTransformerEncoder() + + config = { + "sparse_cfg": {"*attn*": {"threshold": 1e-3, "br": 64, "bc": 64, "enable": True}}, + } + + sparse_model = sparsify(model, config) + + # Count sparse attention modules + sparse_count = sum( + 1 for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) + ) + + # Should have 2 sparse attention modules + assert sparse_count == 2 + + def test_calibration_config_validation(self): + """Test CalibrationConfig validation.""" + # Valid config + config = CalibrationConfig( + target_sparse_ratio=0.5, + samples=48, + max_seqlen=32768, + ) + assert config.target_sparse_ratio == 0.5 + assert config.samples == 48 + assert config.max_seqlen == 32768 + + # Invalid target_sparse_ratio (> 1.0) + with pytest.raises(ValueError, match="target_sparse_ratio must be between"): + CalibrationConfig(target_sparse_ratio=1.5, samples=48, max_seqlen=32768) + + # Invalid target_sparse_ratio (< 0.0) + with pytest.raises(ValueError, match="target_sparse_ratio must be between"): + CalibrationConfig(target_sparse_ratio=-0.1, samples=48, max_seqlen=32768) + + # Invalid samples + with pytest.raises(ValueError, match="samples must be positive"): + CalibrationConfig(target_sparse_ratio=0.5, samples=0, max_seqlen=32768) + + # Invalid max_seqlen + with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): + CalibrationConfig(target_sparse_ratio=0.5, samples=48, max_seqlen=512) + + def test_threshold_trials_validation(self): + """Test threshold_trials validation.""" + # Valid custom threshold_trials + config = CalibrationConfig( + target_sparse_ratio=0.5, + threshold_trials=[1e-5, 1e-4, 1e-3, 1e-2], + ) + assert config.threshold_trials == [1e-5, 1e-4, 1e-3, 1e-2] + + # None (use defaults) + config_default = CalibrationConfig(target_sparse_ratio=0.5) + assert config_default.threshold_trials is None + + # Invalid: empty list + with pytest.raises(ValueError, match="threshold_trials must not be empty"): + CalibrationConfig(threshold_trials=[]) + + # Invalid: threshold out of range (>= 1.0) + with pytest.raises(ValueError, match="must be in range"): + CalibrationConfig(threshold_trials=[1e-4, 1.0]) + + # Invalid: threshold out of range (<= 0) + with pytest.raises(ValueError, match="must be in range"): + CalibrationConfig(threshold_trials=[1e-4, 0]) + + # Invalid: not a list (Pydantic raises ValidationError, not ValueError) + with pytest.raises(ValidationError, match="Input should be a valid list"): + CalibrationConfig(threshold_trials=1e-4) + + +class TestDynamicThresholdCalibratorMethods: + """Test individual methods of DynamicThresholdCalibrator.""" + + def test_set_threshold(self): + """Test _set_threshold method.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + # Get sparse modules + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + assert len(modules) > 0 + + # Create calibrator and set threshold + calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + calibrator._set_threshold(modules, 0.05) + + # Verify threshold was set + for module in modules: + assert module._sparse_method_instance.threshold == 0.05 + + def test_enable_disable_calibration_mode(self): + """Test _enable_calibration_mode and _disable_calibration_mode.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + + calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + + # Enable calibration mode + calibrator._enable_calibration_mode(modules) + + for module in modules: + assert module._stats_manager is not None + assert module._stats_manager.enabled is True + assert module._stats_manager.calibration_mode is True + assert module._sparse_method_instance._calibration_mode is True + + # Disable calibration mode + calibrator._disable_calibration_mode(modules) + + for module in modules: + assert module._stats_manager.calibration_mode is False + assert module._sparse_method_instance._calibration_mode is False + + def test_extract_calibration_stats_no_stats(self): + """Test _extract_calibration_stats when no stats collected.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + + calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + + # Extract stats without running any forward passes + stats = calibrator._extract_calibration_stats(modules) + + # Should return empty list + assert stats == [] + + def test_calibrator_with_single_sample(self): + """Test calibrator edge case with only one sample.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + threshold_trials=[0.001, 0.01, 0.1], + ) + + # Even with one sample, regression should work + assert calibrator.target_sparse_ratio == 0.5 + assert len(calibrator.threshold_trials) == 3 + + +class TestCalibrateFunction: + """Test calibrate_sparse_attention function.""" + + def test_calibrate_no_config(self): + """Test calibration when config has no calibration section.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + # Config without calibration + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + # Should return empty dict when no calibration config + result = calibrate_sparse_attention(model, config) + + assert result == {} + + def test_extract_calibration_config(self): + """Test _extract_calibration_config function.""" + # Config with calibration + config = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": 0.3, + "samples": 12, + "max_seqlen": 2048, + }, + "*attn*": { + "method": "flash_skip_softmax", + }, + }, + } + + calib_config = _extract_calibration_config(config) + + assert calib_config is not None + assert calib_config.target_sparse_ratio == 0.3 + assert calib_config.samples == 12 + assert calib_config.max_seqlen == 2048 + + def test_extract_calibration_config_none(self): + """Test _extract_calibration_config when no calibration.""" + # Config without calibration + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + } + }, + } + + calib_config = _extract_calibration_config(config) + + assert calib_config is None + + def test_create_calibration_forward_loop(self): + """Test create_calibration_forward_loop function.""" + calibration_data = [ + {"input": "This is a test sample.", "length": 512}, + {"input": "Another test sample.", "length": 1024}, + ] + + forward_loop = create_calibration_forward_loop( + calibration_data=calibration_data, + tokenizer_name_or_path="gpt2", + ) + + # Should return a callable + assert callable(forward_loop) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py new file mode 100644 index 000000000..1824825f9 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test sparse attention configuration validation.""" + +import pytest +from pydantic import ValidationError + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_DEFAULT, + FlashSkipSoftmaxConfig, + SparseAttentionAttributeConfig, + SparseAttentionConfig, +) + + +class TestSparseAttentionAttributeConfig: + """Test SparseAttentionAttributeConfig validators.""" + + def test_valid_config(self): + """Test creating valid config.""" + config = SparseAttentionAttributeConfig( + method="flash_skip_softmax", + threshold=1e-4, + br=128, + bc=128, + enable=True, + ) + assert config.method == "flash_skip_softmax" + assert config.threshold == 1e-4 + assert config.br == 128 + assert config.bc == 128 + + def test_method_validation(self): + """Test method must be string.""" + with pytest.raises(ValidationError, match="Input should be a valid string"): + SparseAttentionAttributeConfig(method=123) + + def test_block_size_validation_negative(self): + """Test block sizes must be positive.""" + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(br=-1) + + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(bc=0) + + def test_block_size_validation_large(self): + """Test that large block sizes are accepted.""" + # Large block sizes are allowed (warning removed for simplicity) + config = SparseAttentionAttributeConfig(br=2048) + assert config.br == 2048 + + def test_threshold_validation_range(self): + """Test threshold must be in range (0, 1).""" + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=-0.1) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.5) + + def test_threshold_validation_dict(self): + """Test threshold dict validation.""" + # Valid phase-aware threshold + config = SparseAttentionAttributeConfig(threshold={"prefill": 1e-3, "decode": 1e-5}) + assert config.threshold == {"prefill": 1e-3, "decode": 1e-5} + + # Invalid phase key + with pytest.raises(ValidationError, match="Invalid threshold phases"): + SparseAttentionAttributeConfig(threshold={"invalid_phase": 1e-3}) + + # Invalid threshold value in dict (negative) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": -1e-3}) + + # Invalid threshold value in dict (>= 1.0) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.0}) + + def test_threshold_validation_type(self): + """Test threshold type validation.""" + with pytest.raises(ValidationError, match="Input should be a valid"): + SparseAttentionAttributeConfig(threshold="invalid") + + +class TestSparseAttentionConfig: + """Test SparseAttentionConfig.""" + + def test_default_config(self): + """Test default configuration.""" + config = SparseAttentionConfig() + assert "sparse_cfg" in config.model_dump() + # Check default pattern has method + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" + + def test_predefined_config(self): + """Test pre-defined configuration.""" + assert "sparse_cfg" in SKIP_SOFTMAX_DEFAULT + assert "method" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"]["*attn*"] + assert "*attn*" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"] + + +class TestFlashSkipSoftmaxConfig: + """Test FlashSkipSoftmaxConfig.""" + + def test_default_values(self): + """Test default values for flash_skip_softmax config.""" + config = FlashSkipSoftmaxConfig() + assert "*attention*" in config.sparse_cfg + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index d93e929dc..1ba86c143 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -31,6 +31,7 @@ from modelopt.torch.sparsity.attention_sparsity.conversion import ( disable_sparse_attention, enable_sparse_attention, + print_sparse_attention_summary, ) from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule @@ -170,6 +171,19 @@ def test_disable_enable_functions(self): if isinstance(module, SparseAttentionModule): assert module.is_enabled + def test_print_sparse_attention_summary(self, capsys): + """Test print_sparse_attention_summary function.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Print summary + print_sparse_attention_summary(model) + + # Capture output + captured = capsys.readouterr() + assert "Total sparse attention modules:" in captured.out + assert "Enabled:" in captured.out + def test_restore_sparse_attention_model(self): """Test save/restore via modelopt_state.""" # Create and sparsify original model @@ -192,3 +206,100 @@ def test_restore_sparse_attention_model(self): if isinstance(module, SparseAttentionModule): assert hasattr(module, "_method") assert module._method == "flash_skip_softmax" + + +class TestSparseAttentionModuleMethods: + """Test SparseAttentionModule methods.""" + + def test_get_stats_with_stats_manager(self): + """Test get_stats() when stats manager exists and is enabled.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "collect_stats": True, # Enable stats collection + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module + sparse_module = None + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + sparse_module = module + break + + assert sparse_module is not None + assert sparse_module._stats_manager is not None + + # Get stats (should return summary) + stats = sparse_module.get_stats() + + assert isinstance(stats, dict) + assert "module" in stats + assert "total_calls" in stats + assert "average_sparsity" in stats + + def test_get_stats_without_stats_manager(self): + """Test get_stats() when stats manager is None.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "collect_stats": False, # Disable stats collection + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + # Stats manager should be None + assert module._stats_manager is None + + # get_stats should return empty dict + stats = module.get_stats() + assert stats == {} + break + + def test_get_threshold_info(self): + """Test get_threshold_info() method.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.005, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module and test threshold info + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + info = module.get_threshold_info() + + assert isinstance(info, dict) + assert "type" in info + assert info["type"] == "static" + assert info["value"] == 0.005 + break diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py b/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py new file mode 100644 index 000000000..02188e97a --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py @@ -0,0 +1,334 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for SparseAttentionStatsManager.""" + +import pytest + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager + + +class TestStatsManagerInitialization: + """Test stats manager initialization.""" + + def test_initialization_defaults(self): + """Test default initialization.""" + manager = SparseAttentionStatsManager(module_name="test_module") + + assert manager.module_name == "test_module" + assert manager.enabled is True + assert manager.calibration_mode is False + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + assert manager.per_sample_stats == [] + + def test_initialization_disabled(self): + """Test initialization with disabled stats.""" + manager = SparseAttentionStatsManager(module_name="test_module", enabled=False) + + assert manager.enabled is False + assert manager.calibration_mode is False + + def test_initialization_custom_name(self): + """Test initialization with custom module name.""" + manager = SparseAttentionStatsManager(module_name="custom.attention.module") + + assert manager.module_name == "custom.attention.module" + + +class TestStatsCollection: + """Test statistics collection functionality.""" + + def test_collect_stats_enabled(self): + """Test collecting stats when enabled.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + + manager.collect(stats) + + assert manager.aggregated_stats["total_calls"] == 1 + assert manager.aggregated_stats["total_blocks"] == 100 + assert manager.aggregated_stats["sparse_blocks"] == 50 + assert manager.aggregated_stats["phase_counts"]["prefill"] == 1 + assert manager.aggregated_stats["phase_counts"]["decode"] == 0 + + def test_collect_stats_disabled(self): + """Test that collect() is no-op when disabled.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=False) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + } + + manager.collect(stats) + + # Should remain at initial values + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + + def test_collect_multiple_calls(self): + """Test accumulation over multiple collect calls.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect multiple times + for i in range(5): + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + } + manager.collect(stats) + + assert manager.aggregated_stats["total_calls"] == 5 + assert manager.aggregated_stats["total_blocks"] == 500 + assert manager.aggregated_stats["sparse_blocks"] == 250 + assert manager.aggregated_stats["phase_counts"]["prefill"] == 5 + + def test_collect_different_phases(self): + """Test phase counting.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect prefill stats + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + + # Collect decode stats + manager.collect({"phase": "decode", "total_blocks": 10, "sparse_blocks": 5}) + + assert manager.aggregated_stats["phase_counts"]["prefill"] == 2 + assert manager.aggregated_stats["phase_counts"]["decode"] == 1 + assert manager.aggregated_stats["phase_counts"]["unknown"] == 0 + + +class TestCalibrationMode: + """Test calibration mode functionality.""" + + def test_calibration_mode_per_sample_collection(self): + """Test that calibration mode stores per-sample stats.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Enable calibration mode + manager.set_calibration_mode(enabled=True) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + + manager.collect(stats) + + # Should store in per_sample_stats + assert len(manager.per_sample_stats) == 1 + assert manager.per_sample_stats[0]["module"] == "test" + assert manager.per_sample_stats[0]["sparsity"] == 0.5 + assert manager.per_sample_stats[0]["sample_length"] == 1024 + assert manager.per_sample_stats[0]["phase"] == "prefill" + + def test_calibration_mode_off(self): + """Test that per-sample stats are not collected when calibration mode is off.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + # Calibration mode is off by default + + stats = {"sparsity": 0.5, "phase": "prefill", "total_blocks": 100, "sparse_blocks": 50} + + manager.collect(stats) + + # Should NOT store in per_sample_stats + assert len(manager.per_sample_stats) == 0 + + # But should still aggregate + assert manager.aggregated_stats["total_calls"] == 1 + + def test_set_calibration_mode_with_reset(self): + """Test set_calibration_mode with reset_history=True.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect some stats in calibration mode + manager.set_calibration_mode(enabled=True) + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + assert len(manager.per_sample_stats) == 1 + + # Re-enable with reset + manager.set_calibration_mode(enabled=True, reset_history=True) + assert len(manager.per_sample_stats) == 0 # Should be cleared + + def test_set_calibration_mode_without_reset(self): + """Test set_calibration_mode with reset_history=False.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect some stats + manager.set_calibration_mode(enabled=True) + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + assert len(manager.per_sample_stats) == 1 + + # Disable without reset + manager.set_calibration_mode(enabled=False, reset_history=False) + assert len(manager.per_sample_stats) == 1 # Should be preserved + + +class TestGetSummary: + """Test get_summary() functionality.""" + + def test_get_summary_with_data(self): + """Test get_summary returns correct averages.""" + manager = SparseAttentionStatsManager(module_name="test_module", enabled=True) + + # Collect stats + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 30}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + + summary = manager.get_summary() + + assert summary["module"] == "test_module" + assert summary["total_calls"] == 2 + # Average sparsity: (30+50) / (100+100) = 80/200 = 0.4 + assert summary["average_sparsity"] == 0.4 + assert summary["phase_distribution"]["prefill"] == 2 + + def test_get_summary_no_data(self): + """Test get_summary with no collected data.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + summary = manager.get_summary() + + assert summary["module"] == "test" + assert summary["total_calls"] == 0 + assert summary["average_sparsity"] == 0.0 + assert summary["phase_distribution"]["prefill"] == 0 + + def test_get_summary_zero_blocks(self): + """Test get_summary when total_blocks is zero.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect stats with zero blocks + manager.collect({"phase": "prefill", "total_blocks": 0, "sparse_blocks": 0}) + + summary = manager.get_summary() + + assert summary["average_sparsity"] == 0.0 # Should handle division by zero + + +class TestGetCalibrationStats: + """Test get_calibration_stats() functionality.""" + + def test_get_calibration_stats(self): + """Test retrieving per-sample calibration stats.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + manager.set_calibration_mode(enabled=True) + + # Collect multiple samples + for i in range(3): + manager.collect( + { + "sparsity": 0.3 + i * 0.1, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 30, + "sample_length": 1024 + i * 512, + } + ) + + calib_stats = manager.get_calibration_stats() + + assert len(calib_stats) == 3 + assert calib_stats[0]["sparsity"] == 0.3 + assert calib_stats[1]["sparsity"] == 0.4 + assert calib_stats[2]["sparsity"] == 0.5 + + def test_get_calibration_stats_empty(self): + """Test get_calibration_stats when no calibration data.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + calib_stats = manager.get_calibration_stats() + + assert calib_stats == [] + + +class TestReset: + """Test reset functionality.""" + + def test_reset(self): + """Test reset() clears all statistics.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + manager.set_calibration_mode(enabled=True) + + # Collect some stats + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + manager.collect( + { + "sparsity": 0.3, + "phase": "decode", + "total_blocks": 10, + "sparse_blocks": 3, + "sample_length": 128, + } + ) + + # Verify stats exist + assert manager.aggregated_stats["total_calls"] == 2 + assert len(manager.per_sample_stats) == 2 + + # Reset + manager.reset() + + # All stats should be cleared + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + assert manager.per_sample_stats == [] + assert manager.aggregated_stats["phase_counts"]["prefill"] == 0 + assert manager.aggregated_stats["phase_counts"]["decode"] == 0 diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py new file mode 100644 index 000000000..ac9f46a54 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for threshold calibration functionality.""" + +import pytest + +pytest.importorskip("transformers") + +from _test_utils.torch_sparsity.sparse_attention_common import SimpleAttentionModel + +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestFlashSkipSoftmaxThresholdInfo: + """Test FlashSkipSoftmax.get_threshold_info() method.""" + + def test_static_threshold(self): + """Test threshold info for static threshold.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": 0.001, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + assert info["type"] == "static" + assert info["value"] == 0.001 + + def test_phased_threshold(self): + """Test threshold info for phase-specific thresholds.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": {"prefill": 0.001, "decode": 0.0001}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + assert info["type"] == "static_phased" + assert "thresholds" in info + assert info["thresholds"]["prefill"] == 0.001 + assert info["thresholds"]["decode"] == 0.0001 + assert "current" in info + + def test_dynamic_calibrated_threshold(self): + """Test threshold info for calibrated dynamic threshold.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": 0.001, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Simulate calibration setting scale factor + method.threshold_scale_factor = 437.5 + + info = method.get_threshold_info() + + assert info["type"] == "dynamic" + assert info["scale_factor"] == 437.5 + assert info["formula"] == "λ / length" + assert "example_lengths" in info + assert abs(info["example_lengths"][1024] - 437.5 / 1024) < 1e-6 + assert abs(info["example_lengths"][2048] - 437.5 / 2048) < 1e-6 + + def test_threshold_info_structure(self): + """Test that threshold info has expected structure.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": 0.001, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + # Should always have 'type' key + assert "type" in info + assert isinstance(info, dict) + + +class TestSparseAttentionModuleThresholdInfo: + """Test SparseAttentionModule.get_threshold_info() delegation.""" + + def test_module_delegates_to_method(self): + """Test that module correctly delegates to sparse method instance.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.005, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find sparse attention module + sparse_module = None + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + sparse_module = module + break + + assert sparse_module is not None + + # Test get_threshold_info + info = sparse_module.get_threshold_info() + + assert info["type"] == "static" + assert info["value"] == 0.005 + + def test_module_with_calibrated_threshold(self): + """Test module reports calibrated threshold correctly.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find module and set calibrated threshold + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + module._sparse_method_instance.threshold_scale_factor = 500.0 + break + + # Get threshold info + info = module.get_threshold_info() + + assert info["type"] == "dynamic" + assert info["scale_factor"] == 500.0 + + def test_module_without_method_instance(self): + """Test get_threshold_info when sparse method instance doesn't exist.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find module + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + # Remove sparse method instance to test fallback + delattr(module, "_sparse_method_instance") + + info = module.get_threshold_info() + + assert info["type"] == "none" + assert info["value"] is None + break + + +class TestPrintSparseAttentionSummaryIntegration: + """Test integration with print_sparse_attention_summary.""" + + def test_summary_displays_static_threshold(self, capsys): + """Test that print function displays static thresholds.""" + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + print_sparse_attention_summary, + ) + + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + print_sparse_attention_summary(sparse_model) + + captured = capsys.readouterr() + assert "Static (1.00e-03)" in captured.out + assert "flash_skip_softmax" in captured.out + + def test_summary_displays_dynamic_threshold(self, capsys): + """Test that print function displays dynamic thresholds.""" + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + print_sparse_attention_summary, + ) + + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Set calibrated threshold + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + module._sparse_method_instance.threshold_scale_factor = 437.5 + + print_sparse_attention_summary(sparse_model) + + captured = capsys.readouterr() + assert "Dynamic (λ=437.500000)" in captured.out + assert "flash_skip_softmax" in captured.out