From 8aefc7dfbd853160e38f5e1f40c00810e56bfae6 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 30 Sep 2025 17:59:31 -0700 Subject: [PATCH 1/6] add initial support for sparse attention Signed-off-by: Kai Xu --- examples/llm_sparse_attention/hf_spar_attn.py | 368 +++++++++++++++++ .../sparsity/attention_sparsity/__init__.py | 24 ++ .../calibration/__init__.py | 26 ++ .../sparsity/attention_sparsity/config.py | 358 ++++++++++++++++ .../sparsity/attention_sparsity/conversion.py | 387 ++++++++++++++++++ .../attention_sparsity/methods/__init__.py | 27 ++ .../methods/flash_softmax_skip.py | 289 +++++++++++++ .../attention_sparsity/methods/registry.py | 120 ++++++ .../torch/sparsity/attention_sparsity/mode.py | 85 ++++ .../attention_sparsity/model_sparsify.py | 197 +++++++++ .../attention_sparsity/nn/__init__.py | 20 + .../attention_sparsity/nn/sparse_attention.py | 205 ++++++++++ .../attention_sparsity/plugins/__init__.py | 22 + .../attention_sparsity/plugins/huggingface.py | 122 ++++++ 14 files changed, 2250 insertions(+) create mode 100644 examples/llm_sparse_attention/hf_spar_attn.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/config.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/conversion.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/methods/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/methods/registry.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/mode.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/model_sparsify.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/nn/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py diff --git a/examples/llm_sparse_attention/hf_spar_attn.py b/examples/llm_sparse_attention/hf_spar_attn.py new file mode 100644 index 000000000..461af581e --- /dev/null +++ b/examples/llm_sparse_attention/hf_spar_attn.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example script for applying sparse attention to HuggingFace models.""" + +import argparse +import random +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_CALIB, + SKIP_SOFTMAX_DEFAULT, +) +from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule +from modelopt.torch.utils.memory_monitor import launch_memory_monitor + +RAND_SEED = 1234 + +# You can define custom configurations or use the default +SPARSE_ATTN_CFG_CHOICES = { + "skip_softmax": SKIP_SOFTMAX_DEFAULT, + "skip_softmax_calib": SKIP_SOFTMAX_CALIB, +} + + +def print_sparsity_stats(model: nn.Module): + """Print sparsity statistics if available.""" + module_stats = [] + for name, module in model.named_modules(): + if hasattr(module, "get_stats"): + stats = module.get_stats() + if stats and "average_sparsity" in stats: + module_stats.append((name, stats["average_sparsity"])) + + if not module_stats: + print("No sparsity statistics available") + return + + # Check if all modules have the same sparsity + sparsities = [s for _, s in module_stats] + if len(set(sparsities)) == 1: + # All identical - show summary + print(f"Average sparsity across all {len(module_stats)} modules: {sparsities[0]:.2%}") + else: + # Different sparsities - show individual values + avg_sparsity = sum(sparsities) / len(sparsities) + print(f"Average sparsity: {avg_sparsity:.2%}") + print("Per-module breakdown:") + for name, sparsity in module_stats: + print(f" {name}: {sparsity:.2%} sparse") + + +def get_narrativeqa_samples(num_samples=3): + """Load samples from NarrativeQA dataset for testing. + + Args: + num_samples: Number of samples to generate + """ + # Load NarrativeQA dataset + dataset = load_dataset("narrativeqa", split="test", streaming=True) + + samples = [] + for i, item in enumerate(dataset): + if i >= num_samples: + break + + # Combine document context and question + context = item.get("document", {}).get("text", "") + question = item.get("question", {}).get("text", "") + + if context and question: + # Use the full context as-is + prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" + samples.append(prompt) + + if not samples: + raise ValueError("Could not load NarrativeQA samples") + + print(f"Loaded {len(samples)} NarrativeQA samples") + return samples + + +def truncate_text(text: str, tokenizer, max_length: int): + """Truncate text from the middle to preserve beginning and end. + + Args: + text: Input text to truncate + tokenizer: Tokenizer to use for encoding + max_length: Maximum number of tokens + + Returns: + Truncated text that fits within max_length tokens + """ + # First tokenize to see if truncation is needed + tokens = tokenizer.encode(text, add_special_tokens=True) + + if len(tokens) <= max_length: + return text + + # Need to truncate - preserve beginning and end + # Reserve some tokens for special tokens + available_tokens = max_length - 2 # Account for special tokens + + # Split tokens roughly in half for beginning and end + begin_tokens = available_tokens // 2 + end_tokens = available_tokens - begin_tokens + + # Decode beginning and end parts + begin_text = tokenizer.decode(tokens[:begin_tokens], skip_special_tokens=True) + end_text = tokenizer.decode(tokens[-end_tokens:], skip_special_tokens=True) + + # Combine with ellipsis marker + return begin_text + " [...] " + end_text + + +def verify_outputs(model, tokenizer, args): + """Compare outputs between baseline and sparse attention models.""" + # Update seq_len to match calibration max_seqlen if calibration was used + base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {}) + if "calibration" in base_config and "max_seqlen" in base_config["calibration"]: + calib_max_seqlen = base_config["calibration"]["max_seqlen"] + if args.seq_len != calib_max_seqlen: + print( + f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} " + f"to match calibration config" + ) + args.seq_len = calib_max_seqlen + + # Load and prepare a single test prompt + print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)") + prompts = get_narrativeqa_samples(num_samples=1) + prompt = prompts[0] + + # Prepare inputs + truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len) + display_prompt = ( + truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt + ) + + inputs = tokenizer( + truncated_prompt, + return_tensors="pt", + max_length=args.seq_len, + truncation=True, + padding=False, + ) + if torch.cuda.is_available(): + inputs = {k: v.cuda() for k, v in inputs.items()} + + print("\n" + "=" * 60) + print("BASELINE vs SPARSE ATTENTION COMPARISON") + print("=" * 60) + print(f"\nTest prompt: {display_prompt}") + print(f"Input tokens: {inputs['input_ids'].shape[1]} (max: {args.seq_len})") + if "[...]" in truncated_prompt: + print("Note: Text was middle-truncated to fit token limit") + + # Helper function to generate text + def generate_text(model, inputs, args, tokenizer): + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + temperature=args.temperature if args.do_sample else 1.0, + pad_token_id=tokenizer.pad_token_id, + ) + input_length = inputs["input_ids"].shape[1] + generated_ids = outputs[0][input_length:] + return tokenizer.decode(generated_ids, skip_special_tokens=True) + + # Find all sparse attention modules + sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] + + # Generate baseline by temporarily disabling sparse attention + print("\n" + "-" * 60) + print("Generating baseline (sparse attention disabled)...") + for module in sparse_modules: + module.disable() + baseline_text = generate_text(model, inputs, args, tokenizer) + + # Generate with sparse attention enabled + print("\nGenerating with sparse attention (calibrated thresholds)...") + for module in sparse_modules: + module.enable() + sparse_text = generate_text(model, inputs, args, tokenizer) + + # Display comparison + print("\n" + "-" * 60) + print("RESULTS:") + baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text + sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text + + print(f"\nBaseline: {baseline_display}") + print(f"With Sparse: {sparse_display}") + + if baseline_text == sparse_text: + print("\nOutputs are identical") + else: + print("\nOutputs differ") + + +def sparsify_model(model, args): + """Apply sparse attention to the model with optional calibration.""" + print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}") + base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] + + # Create modified config with selected backend + modified_sparse_cfg = {} + for pattern, cfg in base_config["sparse_cfg"].items(): + modified_cfg = cfg.copy() + modified_cfg["backend"] = args.backend + modified_sparse_cfg[pattern] = modified_cfg + + # Create new config with modified settings + sparse_config = SparseAttentionConfig( + method=base_config["method"], + sparse_cfg=modified_sparse_cfg, + collect_stats=True, # Enable stats collection for monitoring + ) + + # Sparsify with optional calibration - framework handles calibration automatically + model = mtsa.sparsify(model, config=sparse_config) + + print("Sparse attention applied successfully!") + + # Show sparsity statistics + print("\n" + "=" * 60) + print("Sparsity Statistics") + print("=" * 60) + print_sparsity_stats(model) + + return model + + +def main(args): + """Main function to run the selected mode.""" + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + launch_memory_monitor() + + print(f"Loading model: {args.pyt_ckpt_path}") + + # Load model and tokenizer + # Note: attn_implementation="eager" is required for calibration to work properly + # (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection) + model = AutoModelForCausalLM.from_pretrained( + args.pyt_ckpt_path, + attn_implementation="eager", + torch_dtype=torch.bfloat16, + ) + tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) + + # Set pad token if not set + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Move model to GPU if available + if torch.cuda.is_available(): + model = model.cuda() + print("Model moved to CUDA") + + # Apply sparse attention to the model (with calibration if configured) + model = sparsify_model(model, args) + + # Verify outputs if requested (compares baseline vs calibrated sparse model) + if args.verify_output: + verify_outputs(model, tokenizer, args) + + # Export if requested + if args.export_dir: + print(f"\nExporting model to: {args.export_dir}") + export_dir = Path(args.export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + with torch.inference_mode(): + export_hf_checkpoint(model, export_dir=export_dir) + + tokenizer.save_pretrained(export_dir) + print(f"Model exported successfully to: {export_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + + # Model arguments + parser.add_argument( + "--pyt_ckpt_path", + type=str, + required=True, + help="Specify where the PyTorch checkpoint path is", + ) + parser.add_argument( + "--sparse_attn", + type=str, + default="skip_softmax", + choices=list(SPARSE_ATTN_CFG_CHOICES.keys()), + help="Sparse attention configuration to apply.", + ) + parser.add_argument( + "--backend", + type=str, + default="pytorch", + choices=["pytorch", "triton"], + help="Backend to use for sparse attention computation (default: pytorch)", + ) + + # Sequence length arguments + parser.add_argument( + "--seq_len", + type=int, + default=2048, + help="Maximum sequence length for input prompts (will be truncated if longer)", + ) + parser.add_argument( + "--num_samples", + type=int, + default=3, + help="Number of samples to use from NarrativeQA dataset", + ) + + # Generation arguments + parser.add_argument( + "--max_new_tokens", type=int, default=50, help="Maximum new tokens to generate" + ) + parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation") + parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling") + + # Operation arguments + parser.add_argument( + "--verify_output", + action="store_true", + help="Verify that sparse attention outputs match baseline", + ) + parser.add_argument( + "--export_dir", + type=str, + default=None, + help="Directory to export the model with sparse attention applied", + ) + + args = parser.parse_args() + main(args) diff --git a/modelopt/torch/sparsity/attention_sparsity/__init__.py b/modelopt/torch/sparsity/attention_sparsity/__init__.py new file mode 100644 index 000000000..150f93a3a --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible sparse attention optimization for transformer models.""" + +# Initialize mode +from . import mode + +# Add methods to namespace +from .config import * +from .conversion import * +from .model_sparsify import * diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py new file mode 100644 index 000000000..3b616e8e3 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +from .calibrate import calibrate_sparse_attention +from .calibrator import DynamicThresholdCalibrator +from .dataset import RulerDatasetBuilder + +__all__ = [ + "DynamicThresholdCalibrator", + "RulerDatasetBuilder", + "calibrate_sparse_attention", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py new file mode 100644 index 000000000..5fdab0032 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -0,0 +1,358 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration classes for sparse attention optimization.""" + +from collections.abc import Callable +from typing import Any + +from pydantic import Field, field_validator + +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField + +# Type definitions for sparse configuration +SparseAttributeConfig = dict[str, Any] # Configuration for a specific pattern + +SparseAttentionCfgType = dict[ + str | Callable, # Pattern or callable for matching modules + SparseAttributeConfig, # Configuration dict with threshold, enable, etc. +] + + +class SparseAttentionAttributeConfig(ModeloptBaseConfig): + """Sparse attention attribute configuration for pattern-based module config.""" + + enable: bool = ModeloptField( + default=True, + title="Enable sparse attention.", + description="If True, enables sparse attention. If False, bypasses sparsity.", + ) + + method: str = ModeloptField( + default="flash_softmax_skip", + title="Sparse attention method.", + description="The sparse attention method to use (e.g., 'flash_softmax_skip').", + ) + + threshold: float | dict[str, float] = ModeloptField( + default=1e-3, + title="Sparsity threshold.", + description=( + "Threshold for determining which attention values to skip. " + "Can be a float or dict with phase-specific values." + ), + ) + + br: int = ModeloptField( + default=128, + title="Block row size.", + description="Block row size for block-wise sparsity in Flash Attention.", + ) + + bc: int = ModeloptField( + default=128, + title="Block column size.", + description="Block column size for block-wise sparsity in Flash Attention.", + ) + + collect_stats: bool = ModeloptField( + default=False, + title="Collect statistics.", + description="Whether to collect sparsity statistics during forward pass.", + ) + + backend: str = ModeloptField( + default="pytorch", + title="Backend implementation.", + description=( + "Backend to use for sparse attention computation. " + "Only 'pytorch' is supported, which uses softmax patching with F.softmax. " + "Requires model to be loaded with attn_implementation='eager'." + ), + ) + + is_causal: bool = ModeloptField( + default=True, + title="Causal attention flag.", + description=( + "Whether the model uses causal (autoregressive) attention. " + "If True, sparsity statistics are calculated over the lower triangle only. " + "Defaults to True for decoder-only models like GPT, LLaMA, etc." + ), + ) + + calibration: dict | None = ModeloptField( + default=None, + title="Calibration configuration", + description=( + "Calibration settings for this pattern. " + "If provided, enables automatic threshold calibration. " + "Only one pattern should have calibration enabled." + ), + ) + + @field_validator("method") + @classmethod + def validate_method(cls, v): + """Validate method is a string.""" + if not isinstance(v, str): + raise ValueError("method must be a string") + return v + + @field_validator("backend") + @classmethod + def validate_backend(cls, v): + """Validate backend is pytorch.""" + if v != "pytorch": + raise ValueError( + f"Invalid backend: {v}. Only 'pytorch' backend is supported. " + f"Model must be loaded with attn_implementation='eager'." + ) + return v + + @field_validator("br", "bc") + @classmethod + def validate_block_size(cls, v): + """Validate block sizes are positive integers.""" + if v <= 0: + raise ValueError(f"Block size must be positive, got {v}") + return v + + @field_validator("threshold") + @classmethod + def validate_threshold(cls, v): + """Validate threshold is in valid range (0, 1) or dict with valid phases.""" + if isinstance(v, dict): + # Validate phase keys + valid_phases = {"prefill", "decode", "default"} + invalid_keys = set(v.keys()) - valid_phases + if invalid_keys: + raise ValueError( + f"Invalid threshold phases: {invalid_keys}. Valid phases: {valid_phases}" + ) + # Validate all values are in range (0, 1) + for phase, threshold in v.items(): + if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1: + raise ValueError( + f"Threshold for phase '{phase}' must be in range (0, 1), got {threshold}" + ) + elif isinstance(v, (int, float)): + if v <= 0 or v >= 1: + raise ValueError(f"Threshold must be in range (0, 1), got {v}") + else: + raise ValueError(f"Threshold must be a number in range (0, 1) or dict, got {type(v)}") + return v + + +class CalibrationConfig(ModeloptBaseConfig): + """Configuration for automatic threshold calibration using RULER dataset. + + Calibration learns a dynamic threshold λ = scale_factor / sequence_length that + achieves target sparsity. Only supports prefill phase (seq_len > 1). + """ + + target_sparse_ratio: float = ModeloptField( + default=0.5, + title="Target sparsity ratio", + description="Target ratio of sparse attention blocks (0.0 to 1.0).", + ) + + samples: int = ModeloptField( + default=24, + title="Calibration samples", + description="Total number of RULER samples for calibration (distributed across length bins).", + ) + + max_seqlen: int = ModeloptField( + default=32768, + title="Maximum sequence length", + description="Maximum sequence length for calibration (length bins auto-generated as powers of 2).", + ) + + num_length_bins: int = ModeloptField( + default=4, + title="Number of length bins", + description="Number of length bins to generate (hidden parameter, default: 4).", + ) + + threshold_trials: list[float] | None = ModeloptField( + default=None, + title="Threshold trials", + description=( + "List of threshold values to test during calibration. " + "If None, uses default: [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]" + ), + ) + + @field_validator("threshold_trials") + @classmethod + def validate_threshold_trials(cls, v): + """Validate threshold_trials are in valid range.""" + if v is not None: + if not isinstance(v, list): + raise ValueError(f"threshold_trials must be a list, got {type(v)}") + if len(v) == 0: + raise ValueError("threshold_trials must not be empty") + for threshold in v: + if not isinstance(threshold, (int, float)): + raise ValueError(f"All threshold_trials must be numbers, got {type(threshold)}") + if threshold <= 0 or threshold >= 1: + raise ValueError( + f"All threshold_trials must be in range (0, 1), got {threshold}" + ) + return v + + @field_validator("target_sparse_ratio") + @classmethod + def validate_target_sparse_ratio(cls, v): + """Validate target sparsity ratio is between 0 and 1.""" + if not 0.0 <= v <= 1.0: + raise ValueError(f"target_sparse_ratio must be between 0.0 and 1.0, got {v}") + return v + + @field_validator("samples") + @classmethod + def validate_samples(cls, v): + """Validate samples is positive.""" + if v <= 0: + raise ValueError(f"samples must be positive, got {v}") + return v + + @field_validator("max_seqlen") + @classmethod + def validate_max_seqlen(cls, v): + """Validate max_seqlen is at least 1024.""" + if v < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {v}") + return v + + @field_validator("num_length_bins") + @classmethod + def validate_num_length_bins(cls, v): + """Validate num_length_bins is positive.""" + if v <= 0: + raise ValueError(f"num_length_bins must be positive, got {v}") + return v + + +# Pre-defined Sparse Attention Configuration +# Default configuration with block-wise sparsity optimized for Flash Attention +SKIP_SOFTMAX_DEFAULT = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": { + "prefill": 1e-3, # More aggressive during prefill + "decode": 1e-4, # Conservative during decode + }, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +# Configuration with RULER calibration +# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length +# The calibrated threshold adapts to sequence length for optimal sparsity +SKIP_SOFTMAX_CALIB = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "br": 128, + "bc": 128, + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 120, + "max_seqlen": 8192, + }, + }, + "default": {"enable": False}, + }, +} + + +class SparseAttentionConfig(ModeloptBaseConfig): + """Base configuration for sparse attention optimization. + + This base configuration provides the common structure for all sparse + attention methods and supports pattern-based layer configuration. + """ + + # Method selection + method: str = Field("flash_softmax_skip", description="Sparse attention method to use") + + # Statistics collection + collect_stats: bool = Field( + False, description="Whether to collect sparsity statistics during forward pass" + ) + + # Pattern-based sparse configuration (similar to quant_cfg in quantization) + sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={"*attention*": {"enable": True}, "default": {"enable": False}}, + title="Sparse attention configuration", + description="Pattern-based configuration for sparse attention. Keys are patterns to match module names, " + "values are configuration dicts with parameters like 'threshold', 'enable', and 'calibration'.", + validate_default=True, + ) + + # Export configuration + export_format: str | None = Field( + None, description="Export format for sparse attention (e.g., 'onnx', 'tensorrt')" + ) + + +class FlashSoftmaxSkipConfig(SparseAttentionConfig): + """Configuration for Flash Attention-aware softmax skip sparse attention.""" + + # Override method to default to flash_softmax_skip + method: str = Field( + "flash_softmax_skip", description="Sparse attention method (fixed to flash_softmax_skip)" + ) + + # Override sparse_cfg with flash_softmax_skip specific defaults + sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={ + "*attention*": { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + "default": {"enable": False}, + }, + title="Flash softmax skip sparse configuration", + description="Pattern-based configuration with flash_softmax_skip specific defaults. " + "Includes FA block sizes (br, bc) and correction factor settings.", + validate_default=True, + ) + + +__all__ = [ + "SKIP_SOFTMAX_CALIB", + "SKIP_SOFTMAX_DEFAULT", + "CalibrationConfig", + "FlashSoftmaxSkipConfig", + "SparseAttentionAttributeConfig", + "SparseAttentionCfgType", + "SparseAttentionConfig", + "SparseAttributeConfig", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py new file mode 100644 index 000000000..028e2bb67 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -0,0 +1,387 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversion and restoration utilities for sparse attention.""" + +import fnmatch +from collections.abc import Callable +from typing import Any + +import torch.nn as nn + +from modelopt.torch.opt.conversion import ModelLikeModule, ModeloptStateManager +from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict +from modelopt.torch.utils import get_unwrapped_name + +from .config import SparseAttentionConfig +from .nn.sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from .plugins.huggingface import register_sparse_attention_on_the_fly + + +def is_attn_sparsified(model: nn.Module) -> bool: + """Check if a model has sparse attention applied. + + Similar to quantization's is_quantized for API consistency. + + Args: + model: Model to check + + Returns: + True if model contains any SparseAttentionModule instances + """ + return any(isinstance(module, SparseAttentionModule) for module in model.modules()) + + +def convert_to_sparse_attention_model( + model: ModelLikeModule, config: SparseAttentionConfig +) -> ConvertReturnType: + """Convert model to use sparse attention. + + Args: + model: Model to convert + config: Sparse attention configuration + + Returns: + Tuple of (converted_model, metadata) + """ + # Initialize the true module if necessary + model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + + # Register sparse attention modules dynamically + register_sparse_attention_on_the_fly(model) + + # Replace attention modules with sparse versions + replace_sparse_attention_modules(model, version=ModeloptStateManager(model).state_version) + + # Apply configuration to sparse attention modules + sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {} + set_sparse_attention_by_cfg(model, sparse_cfg, config) + + # Create metadata + metadata = {} + update_sparse_attention_metadata(model, config, metadata) + + return model, metadata + + +def replace_sparse_attention_modules(model: nn.Module, version=None): + """Replace regular attention modules with sparse attention modules. + + Recursively replace all attention modules in the model with their sparse attention counterparts. + + Args: + model: Model to process + version: State version for tracking (optional) + """ + # Recursively replace modules + _replace_sparse_attention_modules(model, version=version) + + # Count and report replaced modules + replaced_count = sum(isinstance(m, SparseAttentionModule) for _, m in model.named_modules()) + if replaced_count > 0: + print(f"Inserted {replaced_count} sparse attention modules") + + +def _replace_sparse_attention_modules(model: nn.Module, version=None): + """Helper function for replace_sparse_attention_modules.""" + for name, child in model.named_children(): + if type(child) in SparseAttentionRegistry: + # REPLACE on the parent (model), not on child + sparse_module = SparseAttentionRegistry.convert(child) + setattr(model, name, sparse_module) + + # Now recurse into whichever module is now at `model.name` + _replace_sparse_attention_modules(getattr(model, name), version=version) + + +def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict, config: SparseAttentionConfig): + """Apply sparse attention configuration to model. + + Similar to quantization's set_quantizer_by_cfg. + + Args: + model: Model with sparse attention modules + sparse_cfg: Sparse configuration dictionary + config: Global sparse attention configuration + """ + sparse_cfg = sparse_cfg.copy() + + # Apply default first if exists + if "default" in sparse_cfg: + set_sparse_attention_attribute(model, "*", sparse_cfg["default"], config) + sparse_cfg.pop("default") + + # Apply pattern-specific configs + for pattern, cfg in sparse_cfg.items(): + set_sparse_attention_attribute(model, pattern, cfg, config) + + +def set_sparse_attention_attribute( + model: nn.Module, + wildcard_or_filter: str | Callable, + attribute_cfg: dict[str, Any], + global_config: SparseAttentionConfig, +): + """Set sparse attention attributes for modules matching pattern. + + Similar to quantization's set_quantizer_attribute. + + Args: + model: Model to configure + wildcard_or_filter: Pattern to match module names + attribute_cfg: Attributes to apply + global_config: Global sparse attention configuration + """ + # Merge global config fields with pattern config + # Filter out model-level configs that shouldn't be passed to modules + module_cfg = {k: v for k, v in attribute_cfg.items() if k != "calibration"} + + full_cfg = { + "method": global_config.method, + "collect_stats": global_config.collect_stats, + **module_cfg, + } + + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + # Check pattern match + matched = False + if isinstance(wildcard_or_filter, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter) + elif callable(wildcard_or_filter): + matched = wildcard_or_filter(name) + else: + continue + + if matched: + # Apply config using the same method as TensorQuantizer + module.set_from_attribute_config(full_cfg) + + +def restore_sparse_attention_model( + model: ModelLikeModule, config: SparseAttentionConfig, metadata: MetadataDict +) -> nn.Module: + """Restore sparse attention model from saved state. + + Args: + model: Model to restore + config: Sparse attention configuration + metadata: Saved metadata + + Returns: + Restored model + """ + # Convert to sparse attention model + model, _ = convert_to_sparse_attention_model(model, config) + + # Restore sparse attention state from metadata + if "sparse_attention_state" in metadata: + restore_sparse_attention_state(model, metadata["sparse_attention_state"]) + + return model + + +def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]): + """Restore sparse attention state from state dict. + + Args: + model: Model with sparse attention modules + state_dict: Saved state dictionary + """ + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + module_name = get_unwrapped_name(name, model) + if module_name in state_dict: + module_state = state_dict[module_name] + + # Restore method and config + if "method" in module_state: + module._method = module_state["method"] + if "method_config" in module_state: + # Restore config attributes + for key, val in module_state["method_config"].items(): + setattr(module, f"_{key}", val) + + # Re-setup with restored config + module._setup() + + +def update_sparse_attention_metadata( + model: nn.Module, config: SparseAttentionConfig, metadata: MetadataDict +) -> None: + """Update metadata with sparse attention state. + + Args: + model: Model with sparse attention + config: Configuration used + metadata: Metadata dict to update + """ + sparse_state = {} + + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + module_name = get_unwrapped_name(name, model) + + # Collect method config from module attributes + method_config = { + k[1:]: v + for k, v in module.__dict__.items() + if k.startswith("_") and k not in ("_method", "_enabled", "_sparse_method_instance") + } + + module_state = { + "method": module._sparse_method_instance.name, + "method_config": method_config, + } + + sparse_state[module_name] = module_state + + metadata["sparse_attention_state"] = sparse_state + metadata["sparse_attention_config"] = ( + config.model_dump() if hasattr(config, "model_dump") else vars(config) + ) + + +def disable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): + """Disable sparse attention for matching modules. + + Similar to mtq.disable_quantizer for API consistency. + + Args: + model: Model with sparse attention applied + wildcard_or_filter_func: Wildcard string or filter function to match module names. + For example: "*lm_head*", "*layer_0*", etc. + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> # Disable sparse attention for lm_head + >>> sparse_attn.disable_sparse_attention(model, "*lm_head*") + """ + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + matched = False + if isinstance(wildcard_or_filter_func, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter_func) + elif callable(wildcard_or_filter_func): + matched = wildcard_or_filter_func(name) + + if matched: + module.disable() + + +def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): + """Enable sparse attention for matching modules. + + Similar to mtq.enable_quantizer for API consistency. + + Args: + model: Model with sparse attention applied + wildcard_or_filter_func: Wildcard string or filter function to match module names. + For example: "*attention*", "*attn*", etc. + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> # Re-enable sparse attention for all attention modules + >>> sparse_attn.enable_sparse_attention(model, "*attention*") + """ + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + matched = False + if isinstance(wildcard_or_filter_func, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter_func) + elif callable(wildcard_or_filter_func): + matched = wildcard_or_filter_func(name) + + if matched: + module.enable() + + +def print_sparse_attention_summary(model: nn.Module): + """Print summary of sparse attention modules in the model. + + Similar to mtq.print_quant_summary for API consistency. + + Args: + model: Model with sparse attention applied + + Prints: + - Total sparse attention modules + - Enabled vs disabled count + - Method distribution + - Configuration summary by module + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> sparse_attn.print_sparse_attention_summary(model) + """ + sparse_modules = [] + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + sparse_modules.append((name, module)) + + if not sparse_modules: + print("No sparse attention modules found in model") + return + + enabled_count = sum(1 for _, m in sparse_modules if m.is_enabled) + disabled_count = len(sparse_modules) - enabled_count + + # Count methods + method_counts = {} + for _, module in sparse_modules: + method = getattr(module, "_method", "unknown") + method_counts[method] = method_counts.get(method, 0) + 1 + + print(f"\n{'=' * 70}") + print(f"{'Sparse Attention Summary':^70}") + print(f"{'=' * 70}") + print(f"Total sparse attention modules: {len(sparse_modules)}") + print(f" Enabled: {enabled_count}") + print(f" Disabled: {disabled_count}") + + if method_counts: + print("\nMethods:") + for method, count in sorted(method_counts.items()): + print(f" {method}: {count}") + + print(f"\n{'Module Details':^70}") + print(f"{'-' * 70}") + + for name, module in sparse_modules: + status = "✓" if module.is_enabled else "✗" + method = getattr(module, "_method", "unknown") + threshold = getattr(module, "_threshold", "N/A") + + # Format threshold nicely + if isinstance(threshold, dict): + threshold_str = str(threshold) + elif isinstance(threshold, float): + threshold_str = f"{threshold:.2e}" + else: + threshold_str = str(threshold) + + print(f"{status} {name}") + print(f" Method: {method}, Threshold: {threshold_str}") + + print(f"{'=' * 70}\n") diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py new file mode 100644 index 000000000..5120bd755 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse attention methods package.""" + +from .registry import SparseAttentionMethod, get_sparse_method, register_sparse_method + +__all__ = [ + "SparseAttentionMethod", + "get_sparse_method", + "register_sparse_method", +] + +# Import method implementations to trigger registration +from . import flash_softmax_skip diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py new file mode 100644 index 000000000..04b696d11 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py @@ -0,0 +1,289 @@ +"""Flash Attention-aware softmax skip method for sparse attention. + +This module implements block-wise sparsity that aligns with Flash Attention's +processing pattern for optimal performance. +""" + +import math + +import numpy as np +import torch + +from . import SparseAttentionMethod, register_sparse_method + + +@register_sparse_method("flash_softmax_skip") +class FlashSoftmaxSkipMethod(SparseAttentionMethod): + """Flash Attention-aware softmax skip sparse attention method. + + Implements row-level block-wise sparsity aligned with Flash Attention's + processing pattern for optimal performance and accuracy. + """ + + def __init__(self, method_config: dict | None = None): + """Initialize Flash softmax skip method. + + Args: + method_config: Configuration dict with threshold, br, bc, is_causal, etc. + """ + config = method_config or {} + + # Extract configuration + self.threshold_config = config.get("threshold", 1e-4) + self.br = config.get("br", 128) + self.bc = config.get("bc", 128) + self.enable_correction_factor = config.get("enable_correction_factor", True) + self.collect_stats = config.get("collect_stats", True) + self.phase = config.get("phase", None) + self.backend = config.get("backend", "pytorch") + self.is_causal = config.get("is_causal", True) + # Calibration mode: when True, prevent threshold updates to preserve calibrator's test threshold + self._calibration_mode = False + + # Initialize threshold + if isinstance(self.threshold_config, dict): + self.threshold = self.threshold_config.get( + "default", self.threshold_config.get("prefill", 1e-4) + ) + else: + self.threshold = self.threshold_config + + def _update_threshold(self, phase: str): + """Update threshold based on phase.""" + if isinstance(self.threshold_config, dict): + self.threshold = self.threshold_config.get( + phase, self.threshold_config.get("default", self.threshold) + ) + + def set_calibration_mode(self, enabled: bool): + """Set calibration mode to prevent _update_threshold from modifying the threshold.""" + self._calibration_mode = enabled + + def _infer_phase(self, attention_scores: torch.Tensor) -> str: + """Infer phase from attention scores shape.""" + return "decode" if attention_scores.shape[2] == 1 else "prefill" + + def _reshape_to_blocks( + self, tensor: torch.Tensor, br: int, bc: int + ) -> tuple[torch.Tensor, ...]: + """Reshape tensor into blocks for Flash Attention processing. + + Args: + tensor: Input tensor of shape [batch, heads, seq_q, seq_k] + br: Block row size + bc: Block column size + + Returns: + Tuple of (blocked_tensor, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k) + """ + batch_size, num_heads, seq_q, seq_k = tensor.shape + + # Calculate padding needed + padded_seq_q = math.ceil(seq_q / br) * br + padded_seq_k = math.ceil(seq_k / bc) * bc + + # Pad tensor if necessary + if padded_seq_q != seq_q or padded_seq_k != seq_k: + pad_q = padded_seq_q - seq_q + pad_k = padded_seq_k - seq_k + # Use dtype min instead of -inf for numerical stability + pad_value = torch.finfo(tensor.dtype).min + tensor = torch.nn.functional.pad(tensor, (0, pad_k, 0, pad_q), value=pad_value) + + # Reshape to blocks + num_block_rows = padded_seq_q // br + num_block_cols = padded_seq_k // bc + + # Keep natural order for row-level processing: [batch, heads, block_rows, br, block_cols, bc] + blocked = tensor.view(batch_size, num_heads, num_block_rows, br, num_block_cols, bc) + + return blocked, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k + + def calc_correction_factor_and_p( + self, attn_weights: torch.Tensor, phase: str + ) -> tuple[torch.Tensor, dict]: + """Calculate sparse mask and statistics for Flash Attention. + + Implements block-wise sparsity compatible with Flash Attention's online softmax: + 1. Reshape attention scores into 128x128 blocks + 2. Track block-wise maximum values (simulating Flash Attention's row processing) + 3. Compute cumulative maximum across blocks (for online normalization) + 4. Apply threshold: mask blocks where p = score - cummax < log(threshold) + 5. Calculate correction factor and sparsity statistics + + Args: + attn_weights: Pre-softmax attention scores [batch, heads, seq_q, seq_k] + phase: "prefill" (seq_q > 1) or "decode" (seq_q = 1) + + Returns: + element_mask: Boolean mask [batch, heads, seq_q, seq_k] + stats: Dict with sparsity, correction_factor, total_blocks, etc. + """ + batch_size, num_heads, seq_q, seq_k = attn_weights.shape + + # Calculate threshold + threshold_scale_factor = getattr(self, "threshold_scale_factor", None) + if threshold_scale_factor: + # Use calibrated dynamic threshold: λ = scale_factor / length + log_threshold = np.log(threshold_scale_factor / seq_k) + else: + # Use static threshold from config + log_threshold = np.log(self.threshold) + + if phase == "prefill": + blocked_attn, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k = ( + self._reshape_to_blocks(attn_weights, self.br, self.bc) + ) + + # Step 1: Compute maximum value in each block + # For each 128x128 block, find max across the 128 columns + # blocked_attn: [batch, heads, block_rows, br=128, block_cols, bc=128] + # block_max: [batch, heads, block_rows, br=128, block_cols] + block_max = blocked_attn.max(dim=-1)[0] + + # Step 2: Track cumulative maximum across blocks (left to right) + # This simulates Flash Attention's online softmax normalization + # block_max_cummax: [batch, heads, block_rows, br=128, block_cols] + block_max_cummax = block_max.cummax(dim=-1)[0] + + # Step 3: Calculate correction factor (how often max changes) + # Used by Flash Attention to adjust running sum when max increases + block_max_larger = torch.ones_like(block_max) + block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] + correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + + # Step 4: Normalize attention scores by cumulative max + # p represents log-space difference: log(score) - log(cummax) + p = blocked_attn - block_max_cummax[..., None] + + # Step 5: Apply threshold and create block-level mask + # Keep blocks where at least one element exceeds log(threshold) + p_larger_than_thresh = p > log_threshold + # Reduce over bc (128 cols), then br (128 rows) to get block-level decision + # Result: [batch, heads, block_rows, block_cols] + block_mask = p_larger_than_thresh.any(dim=-1).any(dim=-2) + + # Step 6: Expand block mask back to element level + # All 128x128 elements in a block share the same mask value + # [batch, heads, block_rows, block_cols] -> [batch, heads, block_rows, br=128, block_cols, bc=128] + element_mask = block_mask.unsqueeze(-2).unsqueeze(-1).expand_as(blocked_attn) + + # Step 7: Reshape to original attention shape and remove padding + element_mask = element_mask.reshape(batch_size, num_heads, padded_seq_q, padded_seq_k) + element_mask = element_mask[:, :, :seq_q, :seq_k] + + # Step 8: Calculate sparsity statistics + # Count kept blocks (averaged across batch and heads) + kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + + # Total valid blocks (lower triangle only for causal attention) + # Note: Causal mask pre-applied by attention module, so block_mask naturally + # has zeros in upper triangle. We only count lower triangle for denominator. + total_blocks = ( + num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2 + if self.is_causal + else num_block_rows * num_block_cols # Non-causal: N*N + ) + sparsity = 1 - (kept_blocks / total_blocks) + else: # decode + blocked_attn, _, num_block_cols, _, padded_seq_k = self._reshape_to_blocks( + attn_weights, 1, self.bc + ) + + # Decode: Single query row attends to all past key blocks + # blocked_attn: [batch, heads, 1, 1, num_block_cols, bc=128] + + # Step 1: Find maximum in each key block + # block_max: [batch, heads, 1, 1, num_block_cols] + block_max = blocked_attn.max(dim=-1)[0] + + # Step 2: Track cumulative maximum across key blocks (left to right) + # Simulates Flash Attention's online softmax normalization + block_max_cummax = block_max.cummax(dim=-1)[0] + + # Step 3: Calculate correction factor + # Tracks how often the maximum increases (needed for Flash Attention rescaling) + block_max_larger = torch.ones_like(block_max) + block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] + correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + + # Step 4: Normalize scores by cumulative max + # p = log(score) - log(cummax) in log-space + p = blocked_attn - block_max_cummax[..., None] + + # Step 5: Apply threshold and create block mask + # Keep blocks where at least one element exceeds threshold + p_larger_than_thresh = p > log_threshold + block_mask = p_larger_than_thresh.any(dim=-1, keepdim=False) + + # Step 6: Expand to element level and remove padding + element_mask = block_mask[..., None].expand_as(blocked_attn) + element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) + element_mask = element_mask[:, :, :seq_q, :seq_k] + + # Step 7: Calculate statistics + kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + total_blocks = num_block_cols + sparsity = 1 - (kept_blocks / total_blocks) + + # Create stats dictionary + stats = { + "correction_factor": correction_factor if self.enable_correction_factor else 1.0, + "sparsity": sparsity, + "phase": phase, + "total_blocks": total_blocks, + "sparse_blocks": int(sparsity * total_blocks), + "sample_length": seq_k, + } + + return element_mask, stats + + def apply_sparsity( + self, + query: torch.Tensor | None = None, + key: torch.Tensor | None = None, + value: torch.Tensor | None = None, + attention_scores: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Apply Flash Attention-aware block-wise sparsity. + + Args: + query: Query tensor (unused, for API compatibility) + key: Key tensor (unused, for API compatibility) + value: Value tensor (unused, for API compatibility) + attention_scores: Attention scores tensor with shape [batch, heads, seq_q, seq_k] + + Returns: + Tuple with potentially modified attention_scores + """ + # Attention scores must be provided for sparse attention + assert attention_scores is not None, "attention_scores must be provided for apply_sparsity" + + # Attention scores are always 4D: [batch, heads, seq_q, seq_k] + assert len(attention_scores.shape) == 4, ( + f"Expected 4D attention scores, got shape {attention_scores.shape}" + ) + + # Infer phase from tensor shape + phase = self._infer_phase(attention_scores) + + # Update threshold for the detected phase (skip during calibration) + if not self._calibration_mode: + self._update_threshold(phase) + + # Apply block-wise sparsity + sparse_mask, stats = self.calc_correction_factor_and_p(attention_scores, phase) + + # Store stats for module to collect (doesn't persist across calls) + self._last_stats = stats + + # Apply mask to create sparse scores + mask_value = torch.finfo(attention_scores.dtype).min + sparse_scores = attention_scores.masked_fill(~sparse_mask, mask_value) + + return query, key, value, sparse_scores + + @property + def name(self) -> str: + """Method identifier.""" + return "flash_softmax_skip" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py new file mode 100644 index 000000000..081ad9e27 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registry and base class for sparse attention methods.""" + +from abc import ABC, abstractmethod + +import torch + + +class SparseAttentionMethod(ABC): + """Base class for sparse attention methods.""" + + @abstractmethod + def apply_sparsity( + self, + query: torch.Tensor | None = None, + key: torch.Tensor | None = None, + value: torch.Tensor | None = None, + attention_scores: torch.Tensor | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + """Apply sparsity to attention computation. + + Args: + query: Query tensor + key: Key tensor + value: Value tensor + attention_scores: Pre-computed attention scores + + Returns: + Tuple of (query, key, value, attention_scores) with sparsity applied + """ + + @property + @abstractmethod + def name(self) -> str: + """Method name identifier.""" + + +# Method Registry with versioning support +_SPARSE_ATTENTION_METHODS: dict[str, dict[str, type[SparseAttentionMethod]]] = {} + + +def register_sparse_method(name: str, version: str = "v1"): + """Decorator to register sparse attention methods with version support. + + Args: + name: Method name to register + version: Version string (default: "v1") + + Example: + @register_sparse_method("my_method", version="v3") + class MyMethodV3(SparseAttentionMethod): + ... + """ + + def decorator(cls: type[SparseAttentionMethod]): + if name not in _SPARSE_ATTENTION_METHODS: + _SPARSE_ATTENTION_METHODS[name] = {} + + if version in _SPARSE_ATTENTION_METHODS[name]: + import warnings + + warnings.warn( + f"Overriding existing sparse attention method: {name}@{version}", + RuntimeWarning, + stacklevel=2, + ) + + _SPARSE_ATTENTION_METHODS[name][version] = cls + return cls + + return decorator + + +def get_sparse_method(name: str, version: str | None = None) -> type[SparseAttentionMethod]: + """Get sparse attention method by name and optional version. + + Args: + name: Method name to retrieve + version: Optional version string. If None, uses latest version. + + Returns: + Method class + + Raises: + ValueError: If method name or version is not registered + + Example: + >>> get_sparse_method("flash_softmax_skip") # Latest version + >>> get_sparse_method("flash_softmax_skip", "v1") # Specific version + """ + if name not in _SPARSE_ATTENTION_METHODS: + available = list(_SPARSE_ATTENTION_METHODS.keys()) + raise ValueError(f"Unknown sparse attention method: {name}. Available: {available}") + + method_versions = _SPARSE_ATTENTION_METHODS[name] + + if not version: + version = sorted(method_versions.keys())[-1] + + if version not in method_versions: + available_versions = list(method_versions.keys()) + raise ValueError( + f"Unknown version {version} for method {name}. Available: {available_versions}" + ) + + return method_versions[version] diff --git a/modelopt/torch/sparsity/attention_sparsity/mode.py b/modelopt/torch/sparsity/attention_sparsity/mode.py new file mode 100644 index 000000000..f389509a5 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/mode.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse attention mode descriptor for ModelOpt.""" + +from modelopt.torch.opt.config import ModeloptBaseConfig +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ModeDescriptor, + RestoreEntrypoint, + UpdateEntrypoint, + _ModeRegistryCls, +) + +from .config import SparseAttentionConfig +from .conversion import ( + convert_to_sparse_attention_model, + restore_sparse_attention_model, + update_sparse_attention_metadata, +) + +# Create registry for sparse attention modes +SparseAttentionModeRegistry = _ModeRegistryCls("sparse_attention") + + +@SparseAttentionModeRegistry.register_mode +class SparseAttentionModeDescriptor(ModeDescriptor): + """Mode descriptor for sparse attention optimization. + + This mode enables various sparse attention methods to reduce + computational complexity and memory usage in transformer models. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "sparse_attention" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return SparseAttentionConfig + + @property + def next_prohibited_modes(self) -> set[str] | None: + """Modes that should not be applied after this mode.""" + # Can work with quantization but not with weight sparsity + return {"sparsity"} + + @property + def export_mode(self) -> str | None: + """The mode that corresponds to the export mode of this mode.""" + return "export_sparse_attention" + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return convert_to_sparse_attention_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_sparse_attention_model + + @property + def update_for_save(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the model's state before saving.""" + return update_sparse_attention_metadata + + @property + def update_for_new_mode(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the model's state before new mode.""" + return update_sparse_attention_metadata diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py new file mode 100644 index 000000000..908f3ad89 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -0,0 +1,197 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main API functions for sparse attention optimization.""" + +from typing import Any + +import torch + +from modelopt.torch.opt.conversion import ModeloptStateManager, apply_mode +from modelopt.torch.opt.searcher import ForwardLoop + +from .calibration import calibrate_sparse_attention +from .config import SparseAttentionConfig +from .mode import SparseAttentionModeRegistry + +__all__ = [ + "calibrate", + "sparsify", +] + + +def sparsify( + model: torch.nn.Module, + config: dict[str, Any] | SparseAttentionConfig, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Applies sparse attention optimization to the model in-place. + + This method performs replacement of attention modules with their sparse counterparts and + optionally performs calibration as specified by ``config``. + ``forward_loop`` is used to forward data through the model and gather statistics for calibration. + + Args: + model: A pytorch model + config: A dictionary or an instance of + :class:`SparseAttentionConfig ` + specifying the values for keys ``"sparse_cfg"``, ``"method"``, and optionally ``"calibration"``. + + The ``"sparse_cfg"`` key specifies the sparse attention configurations. + The ``"method"`` key specifies the sparse attention method (e.g., "softmax_skip"). + The ``"calibration"`` key specifies calibration settings if automatic threshold tuning is desired. + + Sparse attention configurations is a dictionary mapping wildcards or filter functions + to its sparse attention attributes. The wildcards or filter functions are matched + against the module names. The sparse attention attributes include ``"threshold"``, + ``"enable"``, and method-specific parameters. + + An example ``config`` dictionary is given below: + + .. code-block::python + + config = { + "method": "softmax_skip", + "sparse_cfg": { + # Phase-aware thresholds with backend selection and calibration + "*attention*": { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + "calibration": { # Optional: enables automatic threshold calibration + "target_sparse_ratio": 0.5, + "samples": 48, + "max_seqlen": 8192, + }, + }, + # Disable for specific layers + "*layer_0*": {"enable": False}, + # Default settings + "default": {"enable": False}, + }, + } + + The ``"backend"`` parameter must be set to ``"pytorch"``: + + - ``"pytorch"``: Softmax patching approach (only supported backend) + + This requires the model to be loaded with ``attn_implementation="eager"``. + + forward_loop: A callable that forwards all calibration data through the model. This is used + to gather statistics for calibration. It should take model as the argument. It does not need + to return anything. + + This argument is only required when calibration is enabled in the config. + + Here are a few examples for correct ``forward_loop`` definitions: + + Example 1: + + .. code-block:: + + def forward_loop(model) -> None: + # iterate over the data loader and forward data through the model + for batch in data_loader: + model(batch) + + Example 2: + + .. code-block:: + + def forward_loop(model) -> float: + # evaluate the model on the task + return evaluate(model, task, ....) + + .. note:: + + Calibration does not require forwarding the entire dataset through the model. + Please subsample the dataset or reduce the number of batches if needed. + + .. important:: + + The model must always be loaded with ``attn_implementation="eager"`` + for sparse attention to work correctly: + + .. code-block:: python + + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained( + model_path, + attn_implementation="eager", # Required for sparse attention + torch_dtype=torch.bfloat16, + ) + + This is because sparse attention works by patching torch.nn.functional.softmax, + which is only called in the eager attention implementation. + + Returns: + A pytorch model which has sparse attention applied and optionally calibrated. + """ + model = apply_mode( + model, mode=[("sparse_attention", config)], registry=SparseAttentionModeRegistry + ) + + # Calibrate the sparsity ratio of the attention modules + return calibrate(model, forward_loop=forward_loop) + + +def calibrate( + model: torch.nn.Module, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Calibrates sparse attention thresholds based on target sparsity. + + This function performs calibration to find optimal thresholds that achieve + the target sparsity ratio specified in the sparse attention configuration. + + Args: + model: A pytorch model with sparse attention already applied + forward_loop: Optional callable that forwards calibration data through the model. + It should take model as the argument and can optionally return metrics. + If None, will auto-generate RULER dataset for calibration. + + Returns: + The calibrated model with optimized sparse attention thresholds. + If no calibration is configured, returns the model unchanged. + """ + # Get the sparse attention config from the model's state + if not ModeloptStateManager.is_converted(model): + return model + + manager = ModeloptStateManager(model) + + sparse_attn_config = next( + (state["config"] for name, state in manager._state if name == "sparse_attention"), None + ) + + if sparse_attn_config is None: + return model + + # Check if calibration is configured in any sparse_cfg pattern + # Note: sparse_attn_config is always a dict (stored via config.model_dump()) + sparse_cfg = sparse_attn_config.get("sparse_cfg", {}) + + has_calibration = any( + isinstance(cfg, dict) and "calibration" in cfg for cfg in sparse_cfg.values() + ) + + if not has_calibration: + return model + + # Run calibration (handles stats collection internally) + calibrate_sparse_attention(model, sparse_attn_config, forward_loop=forward_loop) + + return model diff --git a/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py b/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py new file mode 100644 index 000000000..00ff275bc --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Neural network modules for sparse attention.""" + +from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry + +__all__ = ["SparseAttentionModule", "SparseAttentionRegistry"] diff --git a/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py new file mode 100644 index 000000000..a45931224 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible sparse attention module.""" + +import torch +import torch.nn.functional as F + +from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls +from modelopt.torch.quantization.utils import replace_function + +from ..config import SparseAttentionAttributeConfig +from ..methods import get_sparse_method +from .stats_manager import SparseAttentionStatsManager + + +class SparseAttentionModule(DynamicModule): + """Generic sparse attention module wrapper for applying sparsity to attention layers. + + This module wraps existing attention implementations to add sparse attention + capabilities by patching torch.nn.functional.softmax. + + Forward Flow: + ------------- + 1. Check if sparse attention is enabled (pass-through if disabled) + 2. Create softmax patch context with sparse_softmax function + 3. Apply sparse attention by patching F.softmax: + - Patches torch.nn.functional.softmax with sparse_softmax + - sparse_softmax applies method's sparsity logic before softmax + 4. Forward through original attention with sparsity applied + + Requirements: + ------------- + - Model must be loaded with attn_implementation="eager" for proper softmax interception + - Only PyTorch backend is supported (patches F.softmax) + + Attributes: + ----------- + _enabled: bool + Whether sparse attention is enabled + _method: str + The sparse attention method to use (e.g., "flash_softmax_skip") + _method_config: dict + Configuration dictionary for the sparse method (threshold, br, bc, etc.) + _sparse_method_instance: SparseAttentionMethod + Instance of the configured sparse attention method + """ + + def set_from_attribute_config( + self, attribute_cfg: SparseAttentionAttributeConfig | dict | None = None + ): + """Set sparse attention attributes from configuration. + + Similar to TensorQuantizer.set_from_attribute_config. + + Args: + attribute_cfg: Sparse attention attribute configuration. + If None, uses default SparseAttentionAttributeConfig. + """ + # Use default config if not provided + attribute_cfg = ( + attribute_cfg if attribute_cfg is not None else SparseAttentionAttributeConfig() + ) + + # Store raw config for method initialization + self._method_config = {} + + # Define which attributes are method-specific vs module-specific + # Module-specific attributes control the SparseAttentionModule behavior + _module_attributes = {"enable", "method"} + + # Custom setters for special module attributes + _custom_setters = { + "enable": ("_enabled", lambda val: bool(val)), + "method": ("_method", lambda val: str(val)), + } + + # Process each attribute from config + for attribute, val in attribute_cfg.items(): + # Validate attribute if using config class + if hasattr(SparseAttentionAttributeConfig, "model_fields"): + assert attribute in SparseAttentionAttributeConfig.model_fields, ( + f"{attribute} is not a valid SparseAttentionModule attribute" + ) + + if attribute in _module_attributes: + # Module-level attribute: store with underscore prefix + attr_name, setter = _custom_setters.get(attribute, (f"_{attribute}", lambda v: v)) + setattr(self, attr_name, setter(val)) + else: + # Method-specific attribute: store in config dict + self._method_config[attribute] = val + + # Initialize sparse method instance + self._init_sparse_method() + + def _init_sparse_method(self): + """Initialize the sparse method instance.""" + method_class = get_sparse_method(self._method) + + # Initialize the sparse method instance + # _method_config is always initialized in set_from_attribute_config + self._sparse_method_instance = method_class(method_config=self._method_config) # type: ignore[call-arg] + + def enable(self): + """Enable sparse attention for this module.""" + self._enabled = True + + def disable(self): + """Disable sparse attention for this module.""" + self._enabled = False + + @property + def is_enabled(self) -> bool: + """Check if sparse attention is enabled.""" + return getattr(self, "_enabled", True) + + def get_stats(self) -> dict: + """Get sparsity statistics from the stats manager. + + Returns: + Dictionary with sparsity statistics including 'average_sparsity' if available. + Returns empty dict if stats manager is not enabled. + """ + if self._stats_manager is not None and self._stats_manager.enabled: + return self._stats_manager.get_summary() + return {} + + def _setup(self): + """Setup called by DynamicModule.""" + # Apply default configuration if not yet configured + if not hasattr(self, "_method"): + self.set_from_attribute_config(None) + + # Create stats manager if stats collection is enabled + if self._method_config.get("collect_stats", False): + self._stats_manager = SparseAttentionStatsManager( + module_name="sparse_attention", enabled=True + ) + else: + self._stats_manager = None + + def forward(self, *args, **kwargs): + """Forward with selected sparse attention method. + + This method dispatches to the appropriate sparse attention implementation + based on the configured method and backend. + """ + # Pass through if sparse attention is disabled + if not self.is_enabled: + return super().forward(*args, **kwargs) + + # Get the appropriate context manager for this configuration + context = self._get_sparse_context() + + # Apply sparse attention through the context + with context: + result = super().forward(*args, **kwargs) + + # Collect stats if manager is available + if self._stats_manager is not None and hasattr(self._sparse_method_instance, "_last_stats"): + self._stats_manager.collect(self._sparse_method_instance._last_stats) + + return result + + def _get_sparse_context(self): + """Get the softmax patch context for applying sparse attention.""" + return self._create_softmax_patch_context() + + def _create_softmax_patch_context(self): + """Create context manager for patching softmax function.""" + return replace_function(torch.nn.functional, "softmax", self._create_sparse_softmax()) + + def _create_sparse_softmax(self): + """Create sparse softmax function for current method.""" + original_softmax = F.softmax + + def sparse_softmax(input, dim=-1, *args, **kwargs): + # Let the method handle the sparsification + _, _, _, sparse_input = self._sparse_method_instance.apply_sparsity( + None, None, None, input + ) + + # Use sparse input if modified, otherwise use original + if sparse_input is not None: + return original_softmax(sparse_input, dim, *args, **kwargs) + return original_softmax(input, dim, *args, **kwargs) + + return sparse_softmax + + +# Create registry for sparse attention modules +SparseAttentionRegistry = _DMRegistryCls("SparseAttention", SparseAttentionModule) diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py new file mode 100644 index 000000000..ba8c8b821 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plugins for sparse attention integration with various frameworks.""" + +from .huggingface import register_sparse_attention_on_the_fly + +__all__ = [ + "register_sparse_attention_on_the_fly", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py new file mode 100644 index 000000000..0012257b6 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic sparse attention registration for HuggingFace models.""" + +import torch.nn as nn +import transformers + +from modelopt.torch.opt.dynamic import DynamicModule + +from ..nn.sparse_attention import SparseAttentionModule, SparseAttentionRegistry + + +class _GenericSparseAttention(SparseAttentionModule): + """Generic sparse attention that works with any HF attention module. + + This class provides a universal sparse attention wrapper that can + work with various transformer attention implementations. + """ + + def _setup(self): + """Setup sparse attention for any attention type. + + The base SparseAttentionModule handles detection and initialization. + """ + super()._setup() + + def get_attn_type(self, attn_module) -> type: + """Get the original attention type. + + Args: + attn_module: Attention module (possibly wrapped) + + Returns: + Original class type + """ + # If this is a DynamicModule, get the original class + if isinstance(attn_module, DynamicModule): + return attn_module.get_original_cls_by_level(level=0) + return type(attn_module) + + +def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: + """Dynamically register sparse attention for any model. + + This function automatically detects attention modules in the model + and registers them for sparse attention optimization. + + Args: + model: Model to process + + Returns: + True if any modules were registered + """ + if not _is_supported_model(model): + return False + + registered_count = 0 + attention_types = set() + + for name, module in model.named_modules(): + # Skip if already a sparse attention module + if isinstance(module, SparseAttentionModule): + continue + + # Check if this is an attention module by name + module_type = type(module) + type_name = module_type.__name__ + + # Common attention module patterns + is_attention = ( + "attention" in type_name.lower() + or type_name.endswith("Attention") + or type_name.endswith("SelfAttention") + ) + + if is_attention and module_type not in SparseAttentionRegistry: + # Register attention type + if module_type not in attention_types: + SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention) + attention_types.add(module_type) + registered_count += 1 + print(f"Registered {type_name} for sparse attention optimization") + + if registered_count > 0: + print(f"Dynamically registered {registered_count} attention module types for sparsity") + + return registered_count > 0 + + +def _is_supported_model(model: nn.Module) -> bool: + """Check if model is supported for sparse attention. + + Supports HuggingFace PreTrainedModel and any PyTorch model with attention modules. + + Args: + model: Model to check + + Returns: + True if model is supported + """ + # Check for HuggingFace PreTrainedModel + try: + if isinstance(model, transformers.PreTrainedModel): + return True + except ImportError: + pass + + # Support any PyTorch model with attention modules + return isinstance(model, nn.Module) From 1910fc609fa827694ad0bc7cd9a74fb8dd878a32 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 5 Nov 2025 17:21:28 -0800 Subject: [PATCH 2/6] Add unit and GPU tests for core sparse attention functionality Signed-off-by: Kai Xu --- .../attention_sparsity/hf_sa.py} | 75 ++--- .../{ => weight_sparsity}/.gitignore | 0 .../{ => weight_sparsity}/README.md | 0 .../{ => weight_sparsity}/data_prep.py | 0 .../{ => weight_sparsity}/eval.py | 0 .../export_trtllm_ckpt.py | 0 .../{ => weight_sparsity}/finetune.py | 15 + .../{ => weight_sparsity}/hf_pts.py | 0 .../{ => weight_sparsity}/launch_finetune.sh | 0 .../{ => weight_sparsity}/requirements.txt | 0 .../{ => weight_sparsity}/utils.py | 0 .../calibration/__init__.py | 26 -- .../sparsity/attention_sparsity/config.py | 164 ++-------- .../sparsity/attention_sparsity/conversion.py | 57 +--- .../attention_sparsity/methods/__init__.py | 2 +- ..._softmax_skip.py => flash_skip_softmax.py} | 48 +-- .../attention_sparsity/methods/registry.py | 37 ++- .../attention_sparsity/model_sparsify.py | 82 +---- .../attention_sparsity/nn/__init__.py | 20 -- .../attention_sparsity/plugins/huggingface.py | 18 +- .../{nn => }/sparse_attention.py | 36 +-- .../torch_sparsity/sparse_attention_common.py | 195 ++++++++++++ .../test_attention_sparsity.py | 52 ++++ .../test_llama_sparsify.py | 8 +- .../test_attention_sparsity_gpu.py | 144 +++++++++ .../test_integration_gpu.py | 190 ++++++++++++ .../test_flash_skip_softmax.py | 282 ++++++++++++++++++ .../test_sparse_attention_config.py | 129 ++++++++ .../test_sparse_attention_conversion.py | 208 +++++++++++++ .../test_sparse_attention_mode.py | 43 +++ 30 files changed, 1405 insertions(+), 426 deletions(-) rename examples/{llm_sparse_attention/hf_spar_attn.py => llm_sparsity/attention_sparsity/hf_sa.py} (83%) rename examples/llm_sparsity/{ => weight_sparsity}/.gitignore (100%) rename examples/llm_sparsity/{ => weight_sparsity}/README.md (100%) rename examples/llm_sparsity/{ => weight_sparsity}/data_prep.py (100%) rename examples/llm_sparsity/{ => weight_sparsity}/eval.py (100%) rename examples/llm_sparsity/{ => weight_sparsity}/export_trtllm_ckpt.py (100%) rename examples/llm_sparsity/{ => weight_sparsity}/finetune.py (95%) rename examples/llm_sparsity/{ => weight_sparsity}/hf_pts.py (100%) rename examples/llm_sparsity/{ => weight_sparsity}/launch_finetune.sh (100%) rename examples/llm_sparsity/{ => weight_sparsity}/requirements.txt (100%) rename examples/llm_sparsity/{ => weight_sparsity}/utils.py (100%) delete mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py rename modelopt/torch/sparsity/attention_sparsity/methods/{flash_softmax_skip.py => flash_skip_softmax.py} (90%) delete mode 100644 modelopt/torch/sparsity/attention_sparsity/nn/__init__.py rename modelopt/torch/sparsity/attention_sparsity/{nn => }/sparse_attention.py (83%) create mode 100644 tests/_test_utils/torch_sparsity/sparse_attention_common.py create mode 100644 tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py rename tests/examples/llm_sparsity/{ => weight_sparsity}/test_llama_sparsify.py (93%) create mode 100644 tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py create mode 100644 tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py diff --git a/examples/llm_sparse_attention/hf_spar_attn.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py similarity index 83% rename from examples/llm_sparse_attention/hf_spar_attn.py rename to examples/llm_sparsity/attention_sparsity/hf_sa.py index 461af581e..2f68cfa68 100644 --- a/examples/llm_sparse_attention/hf_spar_attn.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -22,64 +22,43 @@ import numpy as np import torch -import torch.nn as nn from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer +import modelopt.torch.opt as mto import modelopt.torch.sparsity.attention_sparsity as mtsa from modelopt.torch.export import export_hf_checkpoint from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig -from modelopt.torch.sparsity.attention_sparsity.config import ( - SKIP_SOFTMAX_CALIB, - SKIP_SOFTMAX_DEFAULT, -) -from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule from modelopt.torch.utils.memory_monitor import launch_memory_monitor RAND_SEED = 1234 +# Enable HuggingFace checkpointing support +mto.enable_huggingface_checkpointing() + # You can define custom configurations or use the default SPARSE_ATTN_CFG_CHOICES = { "skip_softmax": SKIP_SOFTMAX_DEFAULT, - "skip_softmax_calib": SKIP_SOFTMAX_CALIB, } -def print_sparsity_stats(model: nn.Module): - """Print sparsity statistics if available.""" - module_stats = [] - for name, module in model.named_modules(): - if hasattr(module, "get_stats"): - stats = module.get_stats() - if stats and "average_sparsity" in stats: - module_stats.append((name, stats["average_sparsity"])) - - if not module_stats: - print("No sparsity statistics available") - return - - # Check if all modules have the same sparsity - sparsities = [s for _, s in module_stats] - if len(set(sparsities)) == 1: - # All identical - show summary - print(f"Average sparsity across all {len(module_stats)} modules: {sparsities[0]:.2%}") - else: - # Different sparsities - show individual values - avg_sparsity = sum(sparsities) / len(sparsities) - print(f"Average sparsity: {avg_sparsity:.2%}") - print("Per-module breakdown:") - for name, sparsity in module_stats: - print(f" {name}: {sparsity:.2%} sparse") - - def get_narrativeqa_samples(num_samples=3): """Load samples from NarrativeQA dataset for testing. Args: num_samples: Number of samples to generate + + Raises: + RuntimeError: If dataset loading fails + ValueError: If no valid samples could be loaded """ - # Load NarrativeQA dataset - dataset = load_dataset("narrativeqa", split="test", streaming=True) + # Load NarrativeQA dataset with retry logic + try: + dataset = load_dataset("narrativeqa", split="test", streaming=True) + except Exception as e: + raise RuntimeError(f"Failed to load NarrativeQA dataset: {e}") samples = [] for i, item in enumerate(dataset): @@ -120,8 +99,10 @@ def truncate_text(text: str, tokenizer, max_length: int): return text # Need to truncate - preserve beginning and end - # Reserve some tokens for special tokens - available_tokens = max_length - 2 # Account for special tokens + # Calculate actual special tokens used + dummy_tokens = tokenizer.encode("", add_special_tokens=True) + special_token_count = len(dummy_tokens) + available_tokens = max_length - special_token_count # Split tokens roughly in half for beginning and end begin_tokens = available_tokens // 2 @@ -173,9 +154,7 @@ def verify_outputs(model, tokenizer, args): print("BASELINE vs SPARSE ATTENTION COMPARISON") print("=" * 60) print(f"\nTest prompt: {display_prompt}") - print(f"Input tokens: {inputs['input_ids'].shape[1]} (max: {args.seq_len})") - if "[...]" in truncated_prompt: - print("Note: Text was middle-truncated to fit token limit") + print(f"Input tokens: {inputs['input_ids'].shape[1]}") # Helper function to generate text def generate_text(model, inputs, args, tokenizer): @@ -235,23 +214,13 @@ def sparsify_model(model, args): modified_sparse_cfg[pattern] = modified_cfg # Create new config with modified settings - sparse_config = SparseAttentionConfig( - method=base_config["method"], - sparse_cfg=modified_sparse_cfg, - collect_stats=True, # Enable stats collection for monitoring - ) + sparse_config = SparseAttentionConfig(sparse_cfg=modified_sparse_cfg) - # Sparsify with optional calibration - framework handles calibration automatically + # Sparsify the model model = mtsa.sparsify(model, config=sparse_config) print("Sparse attention applied successfully!") - # Show sparsity statistics - print("\n" + "=" * 60) - print("Sparsity Statistics") - print("=" * 60) - print_sparsity_stats(model) - return model diff --git a/examples/llm_sparsity/.gitignore b/examples/llm_sparsity/weight_sparsity/.gitignore similarity index 100% rename from examples/llm_sparsity/.gitignore rename to examples/llm_sparsity/weight_sparsity/.gitignore diff --git a/examples/llm_sparsity/README.md b/examples/llm_sparsity/weight_sparsity/README.md similarity index 100% rename from examples/llm_sparsity/README.md rename to examples/llm_sparsity/weight_sparsity/README.md diff --git a/examples/llm_sparsity/data_prep.py b/examples/llm_sparsity/weight_sparsity/data_prep.py similarity index 100% rename from examples/llm_sparsity/data_prep.py rename to examples/llm_sparsity/weight_sparsity/data_prep.py diff --git a/examples/llm_sparsity/eval.py b/examples/llm_sparsity/weight_sparsity/eval.py similarity index 100% rename from examples/llm_sparsity/eval.py rename to examples/llm_sparsity/weight_sparsity/eval.py diff --git a/examples/llm_sparsity/export_trtllm_ckpt.py b/examples/llm_sparsity/weight_sparsity/export_trtllm_ckpt.py similarity index 100% rename from examples/llm_sparsity/export_trtllm_ckpt.py rename to examples/llm_sparsity/weight_sparsity/export_trtllm_ckpt.py diff --git a/examples/llm_sparsity/finetune.py b/examples/llm_sparsity/weight_sparsity/finetune.py similarity index 95% rename from examples/llm_sparsity/finetune.py rename to examples/llm_sparsity/weight_sparsity/finetune.py index 3cfc1073f..869068dbd 100644 --- a/examples/llm_sparsity/finetune.py +++ b/examples/llm_sparsity/weight_sparsity/finetune.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li diff --git a/examples/llm_sparsity/hf_pts.py b/examples/llm_sparsity/weight_sparsity/hf_pts.py similarity index 100% rename from examples/llm_sparsity/hf_pts.py rename to examples/llm_sparsity/weight_sparsity/hf_pts.py diff --git a/examples/llm_sparsity/launch_finetune.sh b/examples/llm_sparsity/weight_sparsity/launch_finetune.sh similarity index 100% rename from examples/llm_sparsity/launch_finetune.sh rename to examples/llm_sparsity/weight_sparsity/launch_finetune.sh diff --git a/examples/llm_sparsity/requirements.txt b/examples/llm_sparsity/weight_sparsity/requirements.txt similarity index 100% rename from examples/llm_sparsity/requirements.txt rename to examples/llm_sparsity/weight_sparsity/requirements.txt diff --git a/examples/llm_sparsity/utils.py b/examples/llm_sparsity/weight_sparsity/utils.py similarity index 100% rename from examples/llm_sparsity/utils.py rename to examples/llm_sparsity/weight_sparsity/utils.py diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py deleted file mode 100644 index 3b616e8e3..000000000 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Calibration framework for sparse attention methods.""" - -from .calibrate import calibrate_sparse_attention -from .calibrator import DynamicThresholdCalibrator -from .dataset import RulerDatasetBuilder - -__all__ = [ - "DynamicThresholdCalibrator", - "RulerDatasetBuilder", - "calibrate_sparse_attention", -] diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 5fdab0032..e72dacc94 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -34,18 +34,18 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): """Sparse attention attribute configuration for pattern-based module config.""" + method: str = ModeloptField( + default="flash_skip_softmax", + title="Sparse attention method.", + description="The sparse attention method to use (e.g., 'flash_skip_softmax').", + ) + enable: bool = ModeloptField( default=True, title="Enable sparse attention.", description="If True, enables sparse attention. If False, bypasses sparsity.", ) - method: str = ModeloptField( - default="flash_softmax_skip", - title="Sparse attention method.", - description="The sparse attention method to use (e.g., 'flash_softmax_skip').", - ) - threshold: float | dict[str, float] = ModeloptField( default=1e-3, title="Sparsity threshold.", @@ -67,12 +67,6 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): description="Block column size for block-wise sparsity in Flash Attention.", ) - collect_stats: bool = ModeloptField( - default=False, - title="Collect statistics.", - description="Whether to collect sparsity statistics during forward pass.", - ) - backend: str = ModeloptField( default="pytorch", title="Backend implementation.", @@ -156,103 +150,12 @@ def validate_threshold(cls, v): return v -class CalibrationConfig(ModeloptBaseConfig): - """Configuration for automatic threshold calibration using RULER dataset. - - Calibration learns a dynamic threshold λ = scale_factor / sequence_length that - achieves target sparsity. Only supports prefill phase (seq_len > 1). - """ - - target_sparse_ratio: float = ModeloptField( - default=0.5, - title="Target sparsity ratio", - description="Target ratio of sparse attention blocks (0.0 to 1.0).", - ) - - samples: int = ModeloptField( - default=24, - title="Calibration samples", - description="Total number of RULER samples for calibration (distributed across length bins).", - ) - - max_seqlen: int = ModeloptField( - default=32768, - title="Maximum sequence length", - description="Maximum sequence length for calibration (length bins auto-generated as powers of 2).", - ) - - num_length_bins: int = ModeloptField( - default=4, - title="Number of length bins", - description="Number of length bins to generate (hidden parameter, default: 4).", - ) - - threshold_trials: list[float] | None = ModeloptField( - default=None, - title="Threshold trials", - description=( - "List of threshold values to test during calibration. " - "If None, uses default: [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]" - ), - ) - - @field_validator("threshold_trials") - @classmethod - def validate_threshold_trials(cls, v): - """Validate threshold_trials are in valid range.""" - if v is not None: - if not isinstance(v, list): - raise ValueError(f"threshold_trials must be a list, got {type(v)}") - if len(v) == 0: - raise ValueError("threshold_trials must not be empty") - for threshold in v: - if not isinstance(threshold, (int, float)): - raise ValueError(f"All threshold_trials must be numbers, got {type(threshold)}") - if threshold <= 0 or threshold >= 1: - raise ValueError( - f"All threshold_trials must be in range (0, 1), got {threshold}" - ) - return v - - @field_validator("target_sparse_ratio") - @classmethod - def validate_target_sparse_ratio(cls, v): - """Validate target sparsity ratio is between 0 and 1.""" - if not 0.0 <= v <= 1.0: - raise ValueError(f"target_sparse_ratio must be between 0.0 and 1.0, got {v}") - return v - - @field_validator("samples") - @classmethod - def validate_samples(cls, v): - """Validate samples is positive.""" - if v <= 0: - raise ValueError(f"samples must be positive, got {v}") - return v - - @field_validator("max_seqlen") - @classmethod - def validate_max_seqlen(cls, v): - """Validate max_seqlen is at least 1024.""" - if v < 1024: - raise ValueError(f"max_seqlen must be >= 1024, got {v}") - return v - - @field_validator("num_length_bins") - @classmethod - def validate_num_length_bins(cls, v): - """Validate num_length_bins is positive.""" - if v <= 0: - raise ValueError(f"num_length_bins must be positive, got {v}") - return v - - # Pre-defined Sparse Attention Configuration # Default configuration with block-wise sparsity optimized for Flash Attention SKIP_SOFTMAX_DEFAULT = { - "method": "flash_softmax_skip", "sparse_cfg": { "*attn*": { + "method": "flash_skip_softmax", "threshold": { "prefill": 1e-3, # More aggressive during prefill "decode": 1e-4, # Conservative during decode @@ -267,28 +170,6 @@ def validate_num_length_bins(cls, v): } -# Configuration with RULER calibration -# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length -# The calibrated threshold adapts to sequence length for optimal sparsity -SKIP_SOFTMAX_CALIB = { - "method": "flash_softmax_skip", - "sparse_cfg": { - "*attn*": { - "br": 128, - "bc": 128, - "backend": "pytorch", # Only pytorch backend supported - "enable": True, - "calibration": { - "target_sparse_ratio": 0.5, - "samples": 120, - "max_seqlen": 8192, - }, - }, - "default": {"enable": False}, - }, -} - - class SparseAttentionConfig(ModeloptBaseConfig): """Base configuration for sparse attention optimization. @@ -296,17 +177,12 @@ class SparseAttentionConfig(ModeloptBaseConfig): attention methods and supports pattern-based layer configuration. """ - # Method selection - method: str = Field("flash_softmax_skip", description="Sparse attention method to use") - - # Statistics collection - collect_stats: bool = Field( - False, description="Whether to collect sparsity statistics during forward pass" - ) - # Pattern-based sparse configuration (similar to quant_cfg in quantization) sparse_cfg: SparseAttentionCfgType = ModeloptField( - default={"*attention*": {"enable": True}, "default": {"enable": False}}, + default={ + "*attention*": {"method": "flash_skip_softmax", "enable": True}, + "default": {"enable": False}, + }, title="Sparse attention configuration", description="Pattern-based configuration for sparse attention. Keys are patterns to match module names, " "values are configuration dicts with parameters like 'threshold', 'enable', and 'calibration'.", @@ -319,19 +195,15 @@ class SparseAttentionConfig(ModeloptBaseConfig): ) -class FlashSoftmaxSkipConfig(SparseAttentionConfig): +class FlashSkipSoftmaxConfig(SparseAttentionConfig): """Configuration for Flash Attention-aware softmax skip sparse attention.""" - # Override method to default to flash_softmax_skip - method: str = Field( - "flash_softmax_skip", description="Sparse attention method (fixed to flash_softmax_skip)" - ) - - # Override sparse_cfg with flash_softmax_skip specific defaults + # Override sparse_cfg with flash_skip_softmax specific defaults sparse_cfg: SparseAttentionCfgType = ModeloptField( default={ "*attention*": { - "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, # Flash Attention block rows "bc": 128, # Flash Attention block columns "backend": "pytorch", # Only pytorch backend supported @@ -340,17 +212,15 @@ class FlashSoftmaxSkipConfig(SparseAttentionConfig): "default": {"enable": False}, }, title="Flash softmax skip sparse configuration", - description="Pattern-based configuration with flash_softmax_skip specific defaults. " + description="Pattern-based configuration with flash_skip_softmax specific defaults. " "Includes FA block sizes (br, bc) and correction factor settings.", validate_default=True, ) __all__ = [ - "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", - "CalibrationConfig", - "FlashSoftmaxSkipConfig", + "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", "SparseAttentionCfgType", "SparseAttentionConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index 028e2bb67..25347c37f 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -26,8 +26,8 @@ from modelopt.torch.utils import get_unwrapped_name from .config import SparseAttentionConfig -from .nn.sparse_attention import SparseAttentionModule, SparseAttentionRegistry from .plugins.huggingface import register_sparse_attention_on_the_fly +from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry def is_attn_sparsified(model: nn.Module) -> bool: @@ -67,7 +67,7 @@ def convert_to_sparse_attention_model( # Apply configuration to sparse attention modules sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {} - set_sparse_attention_by_cfg(model, sparse_cfg, config) + set_sparse_attention_by_cfg(model, sparse_cfg) # Create metadata metadata = {} @@ -106,33 +106,31 @@ def _replace_sparse_attention_modules(model: nn.Module, version=None): _replace_sparse_attention_modules(getattr(model, name), version=version) -def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict, config: SparseAttentionConfig): +def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict): """Apply sparse attention configuration to model. Similar to quantization's set_quantizer_by_cfg. Args: model: Model with sparse attention modules - sparse_cfg: Sparse configuration dictionary - config: Global sparse attention configuration + sparse_cfg: Sparse configuration dictionary mapping patterns to attributes """ sparse_cfg = sparse_cfg.copy() # Apply default first if exists if "default" in sparse_cfg: - set_sparse_attention_attribute(model, "*", sparse_cfg["default"], config) + set_sparse_attention_attribute(model, "*", sparse_cfg["default"]) sparse_cfg.pop("default") # Apply pattern-specific configs for pattern, cfg in sparse_cfg.items(): - set_sparse_attention_attribute(model, pattern, cfg, config) + set_sparse_attention_attribute(model, pattern, cfg) def set_sparse_attention_attribute( model: nn.Module, wildcard_or_filter: str | Callable, attribute_cfg: dict[str, Any], - global_config: SparseAttentionConfig, ): """Set sparse attention attributes for modules matching pattern. @@ -141,19 +139,11 @@ def set_sparse_attention_attribute( Args: model: Model to configure wildcard_or_filter: Pattern to match module names - attribute_cfg: Attributes to apply - global_config: Global sparse attention configuration + attribute_cfg: Attributes to apply (must include 'method') """ - # Merge global config fields with pattern config # Filter out model-level configs that shouldn't be passed to modules module_cfg = {k: v for k, v in attribute_cfg.items() if k != "calibration"} - full_cfg = { - "method": global_config.method, - "collect_stats": global_config.collect_stats, - **module_cfg, - } - for name, module in model.named_modules(): if not isinstance(module, SparseAttentionModule): continue @@ -165,11 +155,11 @@ def set_sparse_attention_attribute( elif callable(wildcard_or_filter): matched = wildcard_or_filter(name) else: - continue + raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter)}") if matched: # Apply config using the same method as TensorQuantizer - module.set_from_attribute_config(full_cfg) + module.set_from_attribute_config(module_cfg) def restore_sparse_attention_model( @@ -236,16 +226,11 @@ def update_sparse_attention_metadata( if isinstance(module, SparseAttentionModule): module_name = get_unwrapped_name(name, model) - # Collect method config from module attributes - method_config = { - k[1:]: v - for k, v in module.__dict__.items() - if k.startswith("_") and k not in ("_method", "_enabled", "_sparse_method_instance") - } - + # Save the method configuration that was used + # _method_config already contains the validated config dict module_state = { "method": module._sparse_method_instance.name, - "method_config": method_config, + "method_config": module._method_config.copy(), } sparse_state[module_name] = module_state @@ -353,23 +338,16 @@ def print_sparse_attention_summary(model: nn.Module): method = getattr(module, "_method", "unknown") method_counts[method] = method_counts.get(method, 0) + 1 - print(f"\n{'=' * 70}") - print(f"{'Sparse Attention Summary':^70}") - print(f"{'=' * 70}") print(f"Total sparse attention modules: {len(sparse_modules)}") - print(f" Enabled: {enabled_count}") - print(f" Disabled: {disabled_count}") + print(f"Enabled: {enabled_count}") + print(f"Disabled: {disabled_count}") if method_counts: print("\nMethods:") for method, count in sorted(method_counts.items()): - print(f" {method}: {count}") - - print(f"\n{'Module Details':^70}") - print(f"{'-' * 70}") + print(f"{method}: {count}") for name, module in sparse_modules: - status = "✓" if module.is_enabled else "✗" method = getattr(module, "_method", "unknown") threshold = getattr(module, "_threshold", "N/A") @@ -381,7 +359,4 @@ def print_sparse_attention_summary(model: nn.Module): else: threshold_str = str(threshold) - print(f"{status} {name}") - print(f" Method: {method}, Threshold: {threshold_str}") - - print(f"{'=' * 70}\n") + print(f"{name}: Method: {method}, Threshold: {threshold_str}") diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py index 5120bd755..8a109fda7 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -24,4 +24,4 @@ ] # Import method implementations to trigger registration -from . import flash_softmax_skip +from . import flash_skip_softmax diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py similarity index 90% rename from modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py rename to modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index 04b696d11..8801bafb0 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Flash Attention-aware softmax skip method for sparse attention. This module implements block-wise sparsity that aligns with Flash Attention's @@ -12,8 +27,8 @@ from . import SparseAttentionMethod, register_sparse_method -@register_sparse_method("flash_softmax_skip") -class FlashSoftmaxSkipMethod(SparseAttentionMethod): +@register_sparse_method("flash_skip_softmax") +class FlashSkipSoftmax(SparseAttentionMethod): """Flash Attention-aware softmax skip sparse attention method. Implements row-level block-wise sparsity aligned with Flash Attention's @@ -25,20 +40,20 @@ def __init__(self, method_config: dict | None = None): Args: method_config: Configuration dict with threshold, br, bc, is_causal, etc. + All required fields should have defaults from SparseAttentionAttributeConfig. """ config = method_config or {} - # Extract configuration - self.threshold_config = config.get("threshold", 1e-4) - self.br = config.get("br", 128) - self.bc = config.get("bc", 128) + # Extract configuration (defaults handled by Pydantic) + self.threshold_config = config["threshold"] + self.br = config["br"] + self.bc = config["bc"] + self.backend = config["backend"] + self.is_causal = config["is_causal"] + + # Optional parameters not in Pydantic config self.enable_correction_factor = config.get("enable_correction_factor", True) - self.collect_stats = config.get("collect_stats", True) self.phase = config.get("phase", None) - self.backend = config.get("backend", "pytorch") - self.is_causal = config.get("is_causal", True) - # Calibration mode: when True, prevent threshold updates to preserve calibrator's test threshold - self._calibration_mode = False # Initialize threshold if isinstance(self.threshold_config, dict): @@ -55,10 +70,6 @@ def _update_threshold(self, phase: str): phase, self.threshold_config.get("default", self.threshold) ) - def set_calibration_mode(self, enabled: bool): - """Set calibration mode to prevent _update_threshold from modifying the threshold.""" - self._calibration_mode = enabled - def _infer_phase(self, attention_scores: torch.Tensor) -> str: """Infer phase from attention scores shape.""" return "decode" if attention_scores.shape[2] == 1 else "prefill" @@ -267,9 +278,8 @@ def apply_sparsity( # Infer phase from tensor shape phase = self._infer_phase(attention_scores) - # Update threshold for the detected phase (skip during calibration) - if not self._calibration_mode: - self._update_threshold(phase) + # Update threshold for the detected phase + self._update_threshold(phase) # Apply block-wise sparsity sparse_mask, stats = self.calc_correction_factor_and_p(attention_scores, phase) @@ -286,4 +296,4 @@ def apply_sparsity( @property def name(self) -> str: """Method identifier.""" - return "flash_softmax_skip" + return "flash_skip_softmax" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index 081ad9e27..df7b5853b 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -15,6 +15,8 @@ """Registry and base class for sparse attention methods.""" +import re +import warnings from abc import ABC, abstractmethod import torch @@ -53,6 +55,27 @@ def name(self) -> str: _SPARSE_ATTENTION_METHODS: dict[str, dict[str, type[SparseAttentionMethod]]] = {} +def _version_key(version_str: str) -> list[int]: + """Extract numeric parts for proper version sorting. + + Args: + version_str: Version string (e.g., "v1", "v2", "v10") + + Returns: + List of integers extracted from version string for sorting + + Examples: + >>> _version_key("v1") + [1] + >>> _version_key("v10") + [10] + >>> _version_key("v2.3.1") + [2, 3, 1] + """ + parts = re.findall(r"\d+", version_str) + return [int(p) for p in parts] if parts else [0] + + def register_sparse_method(name: str, version: str = "v1"): """Decorator to register sparse attention methods with version support. @@ -60,10 +83,10 @@ def register_sparse_method(name: str, version: str = "v1"): name: Method name to register version: Version string (default: "v1") - Example: + Example:: + @register_sparse_method("my_method", version="v3") - class MyMethodV3(SparseAttentionMethod): - ... + class MyMethodV3(SparseAttentionMethod): ... """ def decorator(cls: type[SparseAttentionMethod]): @@ -71,8 +94,6 @@ def decorator(cls: type[SparseAttentionMethod]): _SPARSE_ATTENTION_METHODS[name] = {} if version in _SPARSE_ATTENTION_METHODS[name]: - import warnings - warnings.warn( f"Overriding existing sparse attention method: {name}@{version}", RuntimeWarning, @@ -99,8 +120,8 @@ def get_sparse_method(name: str, version: str | None = None) -> type[SparseAtten ValueError: If method name or version is not registered Example: - >>> get_sparse_method("flash_softmax_skip") # Latest version - >>> get_sparse_method("flash_softmax_skip", "v1") # Specific version + >>> get_sparse_method("flash_skip_softmax") # Latest version + >>> get_sparse_method("flash_skip_softmax", "v1") # Specific version """ if name not in _SPARSE_ATTENTION_METHODS: available = list(_SPARSE_ATTENTION_METHODS.keys()) @@ -109,7 +130,7 @@ def get_sparse_method(name: str, version: str | None = None) -> type[SparseAtten method_versions = _SPARSE_ATTENTION_METHODS[name] if not version: - version = sorted(method_versions.keys())[-1] + version = sorted(method_versions.keys(), key=_version_key)[-1] if version not in method_versions: available_versions = list(method_versions.keys()) diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py index 908f3ad89..88434e746 100644 --- a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -19,15 +19,13 @@ import torch -from modelopt.torch.opt.conversion import ModeloptStateManager, apply_mode +from modelopt.torch.opt.conversion import apply_mode from modelopt.torch.opt.searcher import ForwardLoop -from .calibration import calibrate_sparse_attention from .config import SparseAttentionConfig from .mode import SparseAttentionModeRegistry __all__ = [ - "calibrate", "sparsify", ] @@ -39,19 +37,16 @@ def sparsify( ) -> torch.nn.Module: """Applies sparse attention optimization to the model in-place. - This method performs replacement of attention modules with their sparse counterparts and - optionally performs calibration as specified by ``config``. - ``forward_loop`` is used to forward data through the model and gather statistics for calibration. + This method performs replacement of attention modules with their sparse counterparts. Args: model: A pytorch model config: A dictionary or an instance of :class:`SparseAttentionConfig ` - specifying the values for keys ``"sparse_cfg"``, ``"method"``, and optionally ``"calibration"``. + specifying the values for keys ``"sparse_cfg"`` and ``"method"``. The ``"sparse_cfg"`` key specifies the sparse attention configurations. - The ``"method"`` key specifies the sparse attention method (e.g., "softmax_skip"). - The ``"calibration"`` key specifies calibration settings if automatic threshold tuning is desired. + The ``"method"`` key specifies the sparse attention method (e.g., "flash_skip_softmax"). Sparse attention configurations is a dictionary mapping wildcards or filter functions to its sparse attention attributes. The wildcards or filter functions are matched @@ -63,22 +58,13 @@ def sparsify( .. code-block::python config = { - "method": "softmax_skip", + "method": "flash_skip_softmax", "sparse_cfg": { - # Phase-aware thresholds with backend selection and calibration "*attention*": { "threshold": {"prefill": 1e-3, "decode": 1e-5}, - "backend": "pytorch", # Only pytorch backend supported + "backend": "pytorch", "enable": True, - "calibration": { # Optional: enables automatic threshold calibration - "target_sparse_ratio": 0.5, - "samples": 48, - "max_seqlen": 8192, - }, }, - # Disable for specific layers - "*layer_0*": {"enable": False}, - # Default settings "default": {"enable": False}, }, } @@ -89,11 +75,7 @@ def sparsify( This requires the model to be loaded with ``attn_implementation="eager"``. - forward_loop: A callable that forwards all calibration data through the model. This is used - to gather statistics for calibration. It should take model as the argument. It does not need - to return anything. - - This argument is only required when calibration is enabled in the config. + forward_loop: Reserved for future use. Here are a few examples for correct ``forward_loop`` definitions: @@ -144,54 +126,4 @@ def forward_loop(model) -> float: model, mode=[("sparse_attention", config)], registry=SparseAttentionModeRegistry ) - # Calibrate the sparsity ratio of the attention modules - return calibrate(model, forward_loop=forward_loop) - - -def calibrate( - model: torch.nn.Module, - forward_loop: ForwardLoop | None = None, -) -> torch.nn.Module: - """Calibrates sparse attention thresholds based on target sparsity. - - This function performs calibration to find optimal thresholds that achieve - the target sparsity ratio specified in the sparse attention configuration. - - Args: - model: A pytorch model with sparse attention already applied - forward_loop: Optional callable that forwards calibration data through the model. - It should take model as the argument and can optionally return metrics. - If None, will auto-generate RULER dataset for calibration. - - Returns: - The calibrated model with optimized sparse attention thresholds. - If no calibration is configured, returns the model unchanged. - """ - # Get the sparse attention config from the model's state - if not ModeloptStateManager.is_converted(model): - return model - - manager = ModeloptStateManager(model) - - sparse_attn_config = next( - (state["config"] for name, state in manager._state if name == "sparse_attention"), None - ) - - if sparse_attn_config is None: - return model - - # Check if calibration is configured in any sparse_cfg pattern - # Note: sparse_attn_config is always a dict (stored via config.model_dump()) - sparse_cfg = sparse_attn_config.get("sparse_cfg", {}) - - has_calibration = any( - isinstance(cfg, dict) and "calibration" in cfg for cfg in sparse_cfg.values() - ) - - if not has_calibration: - return model - - # Run calibration (handles stats collection internally) - calibrate_sparse_attention(model, sparse_attn_config, forward_loop=forward_loop) - return model diff --git a/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py b/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py deleted file mode 100644 index 00ff275bc..000000000 --- a/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Neural network modules for sparse attention.""" - -from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry - -__all__ = ["SparseAttentionModule", "SparseAttentionRegistry"] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index 0012257b6..b0cd1dff6 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -15,12 +15,16 @@ """Dynamic sparse attention registration for HuggingFace models.""" +import logging + import torch.nn as nn import transformers from modelopt.torch.opt.dynamic import DynamicModule -from ..nn.sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry + +logger = logging.getLogger(__name__) class _GenericSparseAttention(SparseAttentionModule): @@ -80,10 +84,8 @@ def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: type_name = module_type.__name__ # Common attention module patterns - is_attention = ( - "attention" in type_name.lower() - or type_name.endswith("Attention") - or type_name.endswith("SelfAttention") + is_attention = "attention" in type_name.lower() or type_name.endswith( + ("Attention", "SelfAttention") ) if is_attention and module_type not in SparseAttentionRegistry: @@ -92,10 +94,12 @@ def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention) attention_types.add(module_type) registered_count += 1 - print(f"Registered {type_name} for sparse attention optimization") + logger.info(f"Registered {type_name} for sparse attention optimization") if registered_count > 0: - print(f"Dynamically registered {registered_count} attention module types for sparsity") + logger.info( + f"Dynamically registered {registered_count} attention module types for sparsity" + ) return registered_count > 0 diff --git a/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py similarity index 83% rename from modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py rename to modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index a45931224..16b08bf19 100644 --- a/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -21,9 +21,8 @@ from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls from modelopt.torch.quantization.utils import replace_function -from ..config import SparseAttentionAttributeConfig -from ..methods import get_sparse_method -from .stats_manager import SparseAttentionStatsManager +from .config import SparseAttentionAttributeConfig +from .methods import get_sparse_method class SparseAttentionModule(DynamicModule): @@ -51,7 +50,7 @@ class SparseAttentionModule(DynamicModule): _enabled: bool Whether sparse attention is enabled _method: str - The sparse attention method to use (e.g., "flash_softmax_skip") + The sparse attention method to use (e.g., "flash_skip_softmax") _method_config: dict Configuration dictionary for the sparse method (threshold, br, bc, etc.) _sparse_method_instance: SparseAttentionMethod @@ -67,12 +66,10 @@ def set_from_attribute_config( Args: attribute_cfg: Sparse attention attribute configuration. - If None, uses default SparseAttentionAttributeConfig. """ - # Use default config if not provided - attribute_cfg = ( - attribute_cfg if attribute_cfg is not None else SparseAttentionAttributeConfig() - ) + # Ensure config is validated through Pydantic + if not isinstance(attribute_cfg, SparseAttentionAttributeConfig): + attribute_cfg = SparseAttentionAttributeConfig(**(attribute_cfg or {})) # Store raw config for method initialization self._method_config = {} @@ -87,8 +84,8 @@ def set_from_attribute_config( "method": ("_method", lambda val: str(val)), } - # Process each attribute from config - for attribute, val in attribute_cfg.items(): + # Process each attribute from validated config + for attribute, val in attribute_cfg.model_dump().items(): # Validate attribute if using config class if hasattr(SparseAttentionAttributeConfig, "model_fields"): assert attribute in SparseAttentionAttributeConfig.model_fields, ( @@ -132,10 +129,9 @@ def get_stats(self) -> dict: Returns: Dictionary with sparsity statistics including 'average_sparsity' if available. - Returns empty dict if stats manager is not enabled. + Returns empty dict (statistics collection will be added in calibration PR). """ - if self._stats_manager is not None and self._stats_manager.enabled: - return self._stats_manager.get_summary() + # TODO: Statistics collection will be added in calibration PR return {} def _setup(self): @@ -144,14 +140,6 @@ def _setup(self): if not hasattr(self, "_method"): self.set_from_attribute_config(None) - # Create stats manager if stats collection is enabled - if self._method_config.get("collect_stats", False): - self._stats_manager = SparseAttentionStatsManager( - module_name="sparse_attention", enabled=True - ) - else: - self._stats_manager = None - def forward(self, *args, **kwargs): """Forward with selected sparse attention method. @@ -169,10 +157,6 @@ def forward(self, *args, **kwargs): with context: result = super().forward(*args, **kwargs) - # Collect stats if manager is available - if self._stats_manager is not None and hasattr(self._sparse_method_instance, "_last_stats"): - self._stats_manager.collect(self._sparse_method_instance._last_stats) - return result def _get_sparse_context(self): diff --git a/tests/_test_utils/torch_sparsity/sparse_attention_common.py b/tests/_test_utils/torch_sparsity/sparse_attention_common.py new file mode 100644 index 000000000..7724908b0 --- /dev/null +++ b/tests/_test_utils/torch_sparsity/sparse_attention_common.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utilities for sparse attention testing.""" + +import torch +import torch.nn as nn + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +# Test models for sparse attention +class SimpleAttentionModel(nn.Module): + """Simple attention model for testing.""" + + def __init__(self, hidden_size=256, num_heads=8): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.attention = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True + ) + self.fc = nn.Linear(hidden_size, hidden_size) + + def forward(self, x): + attn_output, _ = self.attention(x, x, x, need_weights=False) + return self.fc(attn_output) + + @classmethod + def get_input(cls, hidden_size=256, seq_len=10, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, hidden_size) + + +class SimpleTransformerEncoderLayer(nn.Module): + """Simple TransformerEncoderLayer wrapper for testing.""" + + def __init__(self, d_model=128, nhead=4, dim_feedforward=256): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True, + ) + + def forward(self, x): + return self.layer(x) + + @classmethod + def get_input(cls, d_model=128, seq_len=20, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, d_model) + + +class SimpleTransformerEncoder(nn.Module): + """Simple TransformerEncoder wrapper for testing.""" + + def __init__(self, d_model=128, nhead=4, num_layers=2): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.encoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True), + num_layers=num_layers, + ) + + def forward(self, x): + return self.encoder(x) + + @classmethod + def get_input(cls, d_model=128, seq_len=10, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, d_model) + + +# Test configurations +FLASH_SKIP_SOFTMAX_DEFAULT_CFG = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 1e-4, + "br": 128, + "bc": 128, + "enable": True, + } + }, +} + +FLASH_SKIP_SOFTMAX_PHASE_AWARE_CFG = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, + "bc": 128, + "enable": True, + } + }, +} + + +def get_test_configs(): + """Get test configurations for parameterized tests. + + Note: Calibration config excluded (requires GPU and real tokenizers). + """ + return [FLASH_SKIP_SOFTMAX_DEFAULT_CFG, FLASH_SKIP_SOFTMAX_PHASE_AWARE_CFG] + + +def sparsify_model_and_forward(model, config, calib_data): + """Apply sparse attention and run forward passes. + + Args: + model: Model to sparsify + config: Sparse attention configuration + calib_data: List of calibration data tensors + + Returns: + Sparsified model + """ + + def forward_loop(model): + for batch in calib_data: + model(batch) + + # Apply sparse attention + model = sparse_attn.sparsify(model, config, forward_loop=forward_loop) + + # Verify sparse attention modules were inserted + assert any(isinstance(m, SparseAttentionModule) for m in model.modules()), ( + "No sparse attention modules found" + ) + + # Test forward passes + model.eval() + with torch.no_grad(): + for batch in calib_data: + output = model(batch) + assert not torch.isnan(output).any(), "NaN in output" + assert output is not None, "Output is None" + + return model + + +def save_restore_test(model_cls, device, sparse_config): + """Test save and restore of sparse attention state. + + Args: + model_cls: Model class to test + device: Device to run on ('cpu' or 'cuda') + sparse_config: Sparse attention configuration + """ + # Create and sparsify reference model + model_sparse = model_cls().to(device) + calib_data = [model_sparse.get_input().to(device) for _ in range(2)] + + sparsify_model_and_forward(model_sparse, sparse_config, calib_data) + + # Save state + state_dict = mto.modelopt_state(model_sparse) + + # Restore to new model + model_restored = model_cls().to(device) + mto.restore_from_modelopt_state(model_restored, state_dict) + model_restored.load_state_dict(model_sparse.state_dict()) + + # Verify outputs match + test_input = calib_data[0] + model_sparse.eval() + model_restored.eval() + + with torch.no_grad(): + output_sparse = model_sparse(test_input) + output_restored = model_restored(test_input) + + assert torch.allclose(output_sparse, output_restored, atol=1e-6), ( + "Restored model output doesn't match original" + ) diff --git a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py new file mode 100644 index 000000000..b82303990 --- /dev/null +++ b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test attention sparsity example script.""" + +import pytest +from _test_utils.examples.run_command import extend_cmd_parts, run_example_command +from _test_utils.torch.misc import minimum_gpu + + +def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", **kwargs): + """Run attention sparsity example script. + + Args: + model: Path to model + method: Sparse attention method (corresponds to --sparse_attn arg) + **kwargs: Additional arguments to pass to the script + """ + kwargs.update( + { + "pyt_ckpt_path": model, + "sparse_attn": method, + } + ) + kwargs.setdefault("seq_len", 128) + kwargs.setdefault("num_samples", 1) + kwargs.setdefault("max_new_tokens", 16) + + cmd_parts = extend_cmd_parts(["python", "hf_sa.py"], **kwargs) + run_example_command(cmd_parts, "llm_sparsity/attention_sparsity") + + +@minimum_gpu(1) +@pytest.mark.parametrize("method", ["skip_softmax"]) +def test_attention_sparsity(tiny_llama_path, tmp_path, method): + """Test sparse attention with TinyLlama.""" + run_attention_sparsity_command( + model=tiny_llama_path, + method=method, + ) diff --git a/tests/examples/llm_sparsity/test_llama_sparsify.py b/tests/examples/llm_sparsity/weight_sparsity/test_llama_sparsify.py similarity index 93% rename from tests/examples/llm_sparsity/test_llama_sparsify.py rename to tests/examples/llm_sparsity/weight_sparsity/test_llama_sparsify.py index 7f9ef929b..7094b2989 100644 --- a/tests/examples/llm_sparsity/test_llama_sparsify.py +++ b/tests/examples/llm_sparsity/weight_sparsity/test_llama_sparsify.py @@ -31,7 +31,7 @@ def run_llm_sparsity_command( kwargs.setdefault("model_max_length", 1024) cmd_parts = extend_cmd_parts(["python", "hf_pts.py"], **kwargs) - run_example_command(cmd_parts, "llm_sparsity") + run_example_command(cmd_parts, "llm_sparsity/weight_sparsity") def run_llm_sparsity_ft_command( @@ -51,13 +51,15 @@ def run_llm_sparsity_ft_command( kwargs.setdefault("eval_bs", 1) cmd_parts = extend_cmd_parts(["bash", "launch_finetune.sh"], **kwargs) - run_example_command(cmd_parts, "llm_sparsity") + run_example_command(cmd_parts, "llm_sparsity/weight_sparsity") @pytest.fixture(scope="session") def data_path(tmp_path_factory): data_path = tmp_path_factory.mktemp("data") - run_example_command(["python", "data_prep.py", "--save_path", data_path], "llm_sparsity") + run_example_command( + ["python", "data_prep.py", "--save_path", data_path], "llm_sparsity/weight_sparsity" + ) # Copy eval data to train path for faster test run_example_command( diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py new file mode 100644 index 000000000..bad077fdb --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for attention sparsity module.""" + +import pytest +import torch +from _test_utils.torch_sparsity.sparse_attention_common import ( + FLASH_SKIP_SOFTMAX_DEFAULT_CFG, + SimpleAttentionModel, + SimpleTransformerEncoder, + SimpleTransformerEncoderLayer, + get_test_configs, + save_restore_test, + sparsify_model_and_forward, +) + +import modelopt.torch.sparsity.attention_sparsity as sparse_attn + +# Skip all tests if GPU is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + + +class TestAttentionSparsityGPU: + """GPU tests for attention sparsity.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup for each test.""" + self.device = torch.device("cuda") + torch.cuda.empty_cache() + + @pytest.mark.parametrize( + "model_cls", + [SimpleAttentionModel, SimpleTransformerEncoderLayer, SimpleTransformerEncoder], + ) + @pytest.mark.parametrize("config", get_test_configs()) + def test_gpu_forward(self, model_cls, config): + """Test sparse attention forward pass on GPU.""" + model = model_cls().to(self.device) + calib_data = [model.get_input().to(self.device) for _ in range(2)] + + sparsify_model_and_forward(model, config, calib_data) + + # Additional GPU-specific checks + for batch in calib_data: + with torch.no_grad(): + output = model(batch) + assert output.device.type == "cuda" + + @pytest.mark.parametrize( + "model_cls", + [SimpleAttentionModel, SimpleTransformerEncoderLayer, SimpleTransformerEncoder], + ) + def test_save_restore(self, model_cls): + """Test save and restore on GPU.""" + save_restore_test(model_cls, "cuda", FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_different_dtypes(self, dtype): + """Test sparse attention with different dtypes.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).to(self.device).to(dtype) + calib_data = [model.get_input(d_model=256).to(self.device).to(dtype) for _ in range(2)] + + sparse_model = sparsify_model_and_forward(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG, calib_data) + + # Test forward + x = model.get_input(d_model=256).to(self.device).to(dtype) + with torch.no_grad(): + output = sparse_model(x) + + assert output.dtype == dtype + assert not torch.isnan(output).any() + if dtype != torch.bfloat16: # bfloat16 can have inf + assert not torch.isinf(output).any() + + def test_backward_pass(self): + """Test that gradients flow correctly through sparse attention.""" + model = SimpleAttentionModel(hidden_size=128, num_heads=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Enable training mode + model.train() + + x = model.get_input(hidden_size=128, seq_len=32).to(self.device) + x.requires_grad = True + + # Forward + output = model(x) + loss = output.sum() + + # Backward + loss.backward() + + # Check gradients exist + assert x.grad is not None + assert not torch.isnan(x.grad).any() + + # Check model gradients + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" + + @pytest.mark.parametrize("seq_len", [1, 1024, 2048]) + def test_various_sequence_lengths(self, seq_len): + """Test sparse attention with various sequence lengths.""" + model = SimpleAttentionModel(hidden_size=128, num_heads=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + x = model.get_input(hidden_size=128, seq_len=seq_len, batch_size=1).to(self.device) + + model.eval() + with torch.no_grad(): + output = model(x) + + assert output.shape == (1, seq_len, 128) + assert not torch.isnan(output).any() + + @pytest.mark.parametrize("batch_size", [1, 8, 16]) + def test_various_batch_sizes(self, batch_size): + """Test sparse attention with various batch sizes.""" + model = SimpleTransformerEncoderLayer(d_model=128, nhead=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + x = model.get_input(d_model=128, seq_len=64, batch_size=batch_size).to(self.device) + + model.eval() + with torch.no_grad(): + output = model(x) + + assert output.shape == (batch_size, 64, 128) + assert not torch.isnan(output).any() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py new file mode 100644 index 000000000..586cb3b9d --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration testing with locally created minimal Llama model.""" + +import pytest +import torch +from _test_utils.torch.transformers_models import create_tiny_llama_dir +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +# Skip all tests if GPU is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Create minimal Llama model locally.""" + return create_tiny_llama_dir( + tmp_path_factory.mktemp("tiny_llama"), + with_tokenizer=True, + num_hidden_layers=2, # Minimal layers for fast testing + hidden_size=512, + intermediate_size=1024, + ) + + +@pytest.fixture(scope="module") +def tinyllama_model(tiny_llama_dir): + """Load locally created tiny Llama model.""" + model = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="eager", + device_map="cuda", + ) + return model + + +@pytest.fixture(scope="module") +def tinyllama_tokenizer(tiny_llama_dir): + """Load tokenizer for tiny Llama model.""" + tokenizer = AutoTokenizer.from_pretrained(tiny_llama_dir) + return tokenizer + + +class TestTinyLlama: + """TinyLlama sparse attention tests.""" + + def test_load_and_sparsify(self, tinyllama_model): + """Load TinyLlama and apply sparse attention.""" + model = tinyllama_model + + config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Verify sparse attention modules were added + sparse_count = sum( + 1 for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) + ) + assert sparse_count > 0, "No sparse attention modules found" + + # Our tiny llama has 2 layers, so should have 2 attention modules + assert sparse_count == 2, f"Expected 2 sparse modules, got {sparse_count}" + + def test_forward_prefill(self, tinyllama_model, tinyllama_tokenizer): + """Forward pass with seq_len=64 (prefill).""" + model = tinyllama_model + tokenizer = tinyllama_tokenizer + + config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Create prefill input (seq_len > 1) + test_text = "Once upon a time in a land far away" + inputs = tokenizer(test_text, return_tensors="pt").to("cuda") + + # Forward pass + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(**inputs) + + # Verify output + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + assert outputs.logits.shape[1] == inputs.input_ids.shape[1] # seq_len preserved + + def test_forward_decode(self, tinyllama_model): + """Forward pass with seq_len=1 (decode).""" + model = tinyllama_model + + config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "threshold": 1e-5, # More conservative for decode + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Create decode input (seq_len = 1) + input_ids = torch.randint(0, 32000, (1, 1), device="cuda") + + # Forward pass + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(input_ids) + + # Verify output + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + assert outputs.logits.shape == (1, 1, 32000) # batch=1, seq=1, vocab_size + + def test_gqa_attention(self, tinyllama_model): + """Verify GQA support (num_kv_heads < num_heads).""" + model = tinyllama_model + + # Check if model uses GQA + config = model.config + has_gqa = hasattr(config, "num_key_value_heads") and ( + config.num_key_value_heads < config.num_attention_heads + ) + + if not has_gqa: + pytest.skip("Model does not use GQA") + + # Apply sparse attention + sparse_config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, sparse_config) + + # Test forward pass with GQA + input_ids = torch.randint(0, 32000, (1, 32), device="cuda") + + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(input_ids) + + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py new file mode 100644 index 000000000..b487d8639 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py @@ -0,0 +1,282 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FlashSkipSoftmax method internals.""" + +import pytest +import torch + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax + + +class TestFlashSkipSoftmaxMethod: + """Test FlashSkipSoftmax method internals.""" + + def test_phase_inference(self): + """Test phase detection from attention score shape.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Prefill: seq_q > 1 + prefill_scores = torch.randn(2, 4, 64, 64) + assert method._infer_phase(prefill_scores) == "prefill" + + # Decode: seq_q = 1 + decode_scores = torch.randn(2, 4, 1, 64) + assert method._infer_phase(decode_scores) == "decode" + + def test_threshold_update_dict_config(self): + """Test threshold updates with dict config.""" + method = FlashSkipSoftmax( + { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Initially uses prefill threshold + initial_threshold = method.threshold + + # Update to decode + method._update_threshold("decode") + assert method.threshold == 1e-5 + assert method.threshold != initial_threshold + + # Update back to prefill + method._update_threshold("prefill") + assert method.threshold == 1e-3 + + def test_threshold_update_static_config(self): + """Test threshold with static float config.""" + method = FlashSkipSoftmax( + { + "threshold": 5e-4, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + initial_threshold = method.threshold + assert initial_threshold == 5e-4 + + # Should not change for static config + method._update_threshold("decode") + assert method.threshold == 5e-4 + + def test_block_reshaping_divisible(self): + """Test block reshaping with divisible sequence lengths.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Seq lengths divisible by 128 + attn = torch.randn(2, 4, 256, 256) + blocked, num_br, num_bc, padded_q, padded_k = method._reshape_to_blocks(attn, 128, 128) + + # Verify block dimensions + assert blocked.shape == (2, 4, 2, 128, 2, 128) # 256/128 = 2 blocks + assert num_br == 2 + assert num_bc == 2 + assert padded_q == 256 # No padding + assert padded_k == 256 # No padding + + def test_block_reshaping_with_padding(self): + """Test block reshaping with non-divisible lengths.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Seq lengths NOT divisible by 128 + attn = torch.randn(2, 4, 200, 300) + blocked, num_br, num_bc, padded_q, padded_k = method._reshape_to_blocks(attn, 128, 128) + + # Verify padding applied + assert padded_q == 256 # ceil(200/128) * 128 = 2 * 128 + assert padded_k == 384 # ceil(300/128) * 128 = 3 * 128 + assert num_br == 2 + assert num_bc == 3 + assert blocked.shape == (2, 4, 2, 128, 3, 128) + + def test_correction_factor_calculation_prefill(self): + """Test correction factor for prefill phase.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Create simple attention pattern + attn = torch.randn(1, 1, 128, 256) + + mask, stats = method.calc_correction_factor_and_p(attn, "prefill") + + # Verify stats structure + assert "correction_factor" in stats + assert "sparsity" in stats + assert "phase" in stats + assert "total_blocks" in stats + assert stats["phase"] == "prefill" + assert 0 <= stats["correction_factor"] <= 1 + # Sparsity can be negative if threshold is too low (more blocks kept than expected) + assert -1 <= stats["sparsity"] <= 1 + + def test_correction_factor_calculation_decode(self): + """Test correction factor for decode phase.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-5, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Decode: single query + attn = torch.randn(1, 1, 1, 256) + + mask, stats = method.calc_correction_factor_and_p(attn, "decode") + + # Verify stats structure + assert stats["phase"] == "decode" + assert "correction_factor" in stats + assert 0 <= stats["sparsity"] <= 1 + assert mask.shape == (1, 1, 1, 256) + + def test_sparsity_statistics(self): + """Test sparsity statistics structure.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + attn = torch.randn(1, 1, 128, 256) + _, stats = method.calc_correction_factor_and_p(attn, "prefill") + + # Verify statistics are present + assert stats["total_blocks"] > 0 + assert "sparse_blocks" in stats + assert "sample_length" in stats + assert stats["sample_length"] == 256 + + def test_block_mask_correctness(self): + """Test block mask shape and type.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + attn = torch.randn(2, 4, 128, 256) + mask, _ = method.calc_correction_factor_and_p(attn, "prefill") + + # Verify mask properties + assert mask.shape == attn.shape + assert mask.dtype == torch.bool + assert mask.device == attn.device + + def test_causal_vs_noncausal(self): + """Test total_blocks calculation for causal vs non-causal.""" + config_base = { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + } + + method_causal = FlashSkipSoftmax({**config_base, "is_causal": True}) + method_noncausal = FlashSkipSoftmax({**config_base, "is_causal": False}) + + attn = torch.randn(1, 1, 256, 256) # 2x2 blocks + + _, stats_causal = method_causal.calc_correction_factor_and_p(attn, "prefill") + _, stats_noncausal = method_noncausal.calc_correction_factor_and_p(attn, "prefill") + + # Causal: 2*(2+1)/2 = 3 blocks + # Non-causal: 2*2 = 4 blocks + assert stats_causal["total_blocks"] == 3 + assert stats_noncausal["total_blocks"] == 4 + + def test_apply_sparsity_assertions(self): + """Test apply_sparsity input validation.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Test: attention_scores required + with pytest.raises(AssertionError, match="attention_scores must be provided"): + method.apply_sparsity() + + # Test: 4D shape required + with pytest.raises(AssertionError, match="Expected 4D"): + method.apply_sparsity(attention_scores=torch.randn(2, 64, 64)) # 3D + + def test_name_property(self): + """Test method name property.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + assert method.name == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py new file mode 100644 index 000000000..1824825f9 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test sparse attention configuration validation.""" + +import pytest +from pydantic import ValidationError + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_DEFAULT, + FlashSkipSoftmaxConfig, + SparseAttentionAttributeConfig, + SparseAttentionConfig, +) + + +class TestSparseAttentionAttributeConfig: + """Test SparseAttentionAttributeConfig validators.""" + + def test_valid_config(self): + """Test creating valid config.""" + config = SparseAttentionAttributeConfig( + method="flash_skip_softmax", + threshold=1e-4, + br=128, + bc=128, + enable=True, + ) + assert config.method == "flash_skip_softmax" + assert config.threshold == 1e-4 + assert config.br == 128 + assert config.bc == 128 + + def test_method_validation(self): + """Test method must be string.""" + with pytest.raises(ValidationError, match="Input should be a valid string"): + SparseAttentionAttributeConfig(method=123) + + def test_block_size_validation_negative(self): + """Test block sizes must be positive.""" + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(br=-1) + + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(bc=0) + + def test_block_size_validation_large(self): + """Test that large block sizes are accepted.""" + # Large block sizes are allowed (warning removed for simplicity) + config = SparseAttentionAttributeConfig(br=2048) + assert config.br == 2048 + + def test_threshold_validation_range(self): + """Test threshold must be in range (0, 1).""" + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=-0.1) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.5) + + def test_threshold_validation_dict(self): + """Test threshold dict validation.""" + # Valid phase-aware threshold + config = SparseAttentionAttributeConfig(threshold={"prefill": 1e-3, "decode": 1e-5}) + assert config.threshold == {"prefill": 1e-3, "decode": 1e-5} + + # Invalid phase key + with pytest.raises(ValidationError, match="Invalid threshold phases"): + SparseAttentionAttributeConfig(threshold={"invalid_phase": 1e-3}) + + # Invalid threshold value in dict (negative) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": -1e-3}) + + # Invalid threshold value in dict (>= 1.0) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.0}) + + def test_threshold_validation_type(self): + """Test threshold type validation.""" + with pytest.raises(ValidationError, match="Input should be a valid"): + SparseAttentionAttributeConfig(threshold="invalid") + + +class TestSparseAttentionConfig: + """Test SparseAttentionConfig.""" + + def test_default_config(self): + """Test default configuration.""" + config = SparseAttentionConfig() + assert "sparse_cfg" in config.model_dump() + # Check default pattern has method + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" + + def test_predefined_config(self): + """Test pre-defined configuration.""" + assert "sparse_cfg" in SKIP_SOFTMAX_DEFAULT + assert "method" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"]["*attn*"] + assert "*attn*" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"] + + +class TestFlashSkipSoftmaxConfig: + """Test FlashSkipSoftmaxConfig.""" + + def test_default_values(self): + """Test default values for flash_skip_softmax config.""" + config = FlashSkipSoftmaxConfig() + assert "*attention*" in config.sparse_cfg + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py new file mode 100644 index 000000000..8df8fe476 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sparse attention conversion and replacement.""" + +import pytest + +pytest.importorskip("transformers") + +import torch.nn as nn +from _test_utils.torch_sparsity.sparse_attention_common import ( + FLASH_SKIP_SOFTMAX_DEFAULT_CFG, + SimpleAttentionModel, + SimpleTransformerEncoderLayer, +) + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity.conversion import ( + disable_sparse_attention, + enable_sparse_attention, + print_sparse_attention_summary, +) +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestSparseAttentionReplacement: + """Test module replacement logic.""" + + def test_basic_replacement(self): + """Test that attention modules are replaced with sparse versions.""" + model = SimpleAttentionModel() + + # Count original attention modules + original_attention_count = sum( + isinstance(m, nn.MultiheadAttention) for m in model.modules() + ) + assert original_attention_count > 0 + + # Apply sparse attention + sparse_model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Count sparse attention modules + sparse_attention_count = sum( + isinstance(m, SparseAttentionModule) for m in sparse_model.modules() + ) + + # Verify replacement occurred + assert sparse_attention_count > 0 + + def test_enable_disable_toggle(self): + """Test enabling and disabling sparse attention.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Check initially enabled + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + # Disable all sparse attention modules + disable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert not module.is_enabled + + # Re-enable all sparse attention modules + enable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + def test_pattern_based_replacement(self): + """Test pattern-based selective replacement.""" + model = SimpleTransformerEncoderLayer() + + # Apply with pattern + config = { + "sparse_cfg": { + "*self_attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-4, + "br": 128, + "bc": 128, + "enable": True, + }, + "default": {"enable": False}, + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Verify sparse modules exist + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + +class TestConversionEdgeCases: + """Test edge cases and error paths in conversion.""" + + def test_callable_filter(self): + """Test using callable filter instead of wildcard.""" + model = SimpleAttentionModel() + + # Use callable filter + def filter_func(name): + return "attn" in name + + config = { + "sparse_cfg": { + filter_func: { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + }, + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + def test_no_matching_modules(self): + """Test pattern that matches nothing.""" + model = SimpleAttentionModel() + + config = { + "sparse_cfg": { + "*nonexistent*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + }, + }, + } + + # Should not error, even with no matches + sparse_attn.sparsify(model, config) + + def test_disable_enable_functions(self): + """Test disable/enable utility functions.""" + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + disable_sparse_attention, + enable_sparse_attention, + ) + + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Disable all + disable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert not module.is_enabled + + # Enable all + enable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + def test_print_sparse_attention_summary(self, capsys): + """Test print_sparse_attention_summary function.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Print summary + print_sparse_attention_summary(model) + + # Capture output + captured = capsys.readouterr() + assert "Total sparse attention modules:" in captured.out + assert "Enabled:" in captured.out + + def test_restore_sparse_attention_model(self): + """Test save/restore via modelopt_state.""" + # Create and sparsify original model + model_orig = SimpleAttentionModel() + model_orig = sparse_attn.sparsify(model_orig, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Save state + state_dict = mto.modelopt_state(model_orig) + + # Restore to new model + model_restored = SimpleAttentionModel() + mto.restore_from_modelopt_state(model_restored, state_dict) + + # Verify restoration + has_sparse = any(isinstance(m, SparseAttentionModule) for m in model_restored.modules()) + assert has_sparse + + # Verify module is configured + for module in model_restored.modules(): + if isinstance(module, SparseAttentionModule): + assert hasattr(module, "_method") + assert module._method == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py new file mode 100644 index 000000000..e7e32e153 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sparse attention mode registry.""" + +import pytest + +pytest.importorskip("transformers") + +from modelopt.torch.opt.mode import _ModeRegistryCls +from modelopt.torch.sparsity.attention_sparsity.mode import SparseAttentionModeRegistry + + +def test_sparse_attention_mode_exists(): + """Test that sparse_attention mode is registered.""" + assert "sparse_attention" in SparseAttentionModeRegistry + + +def test_sparse_attention_mode_descriptor(): + """Test sparse attention mode descriptor properties.""" + mode_descriptor = _ModeRegistryCls.get_from_any("sparse_attention") + + assert mode_descriptor is not None + assert hasattr(mode_descriptor, "config_class") + assert hasattr(mode_descriptor, "convert") + + +def test_mode_registry_get(): + """Test getting mode from registry.""" + mode = SparseAttentionModeRegistry["sparse_attention"] + assert mode is not None From 02182f8990a1e40bca5a8ccaab058a95de490ed1 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 8 Dec 2025 13:32:20 -0800 Subject: [PATCH 3/6] Address review feedback Signed-off-by: Kai Xu --- .../llm_sparsity/attention_sparsity/hf_sa.py | 4 +- .../sparsity/attention_sparsity/conversion.py | 61 --------- .../attention_sparsity/plugins/huggingface.py | 10 +- .../test_attention_sparsity.py | 2 - .../test_attention_sparsity_gpu.py | 3 - .../test_integration_gpu.py | 7 +- .../test_sparse_attention_config.py | 129 ------------------ .../test_sparse_attention_conversion.py | 14 -- 8 files changed, 6 insertions(+), 224 deletions(-) delete mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 2f68cfa68..11564a4ec 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -295,8 +295,8 @@ def main(args): "--backend", type=str, default="pytorch", - choices=["pytorch", "triton"], - help="Backend to use for sparse attention computation (default: pytorch)", + choices=["pytorch"], + help="Backend for sparse attention (default: pytorch). More backends coming soon.", ) # Sequence length arguments diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index 25347c37f..ad137e9ee 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -299,64 +299,3 @@ def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Cal if matched: module.enable() - - -def print_sparse_attention_summary(model: nn.Module): - """Print summary of sparse attention modules in the model. - - Similar to mtq.print_quant_summary for API consistency. - - Args: - model: Model with sparse attention applied - - Prints: - - Total sparse attention modules - - Enabled vs disabled count - - Method distribution - - Configuration summary by module - - Example: - >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn - >>> model = sparse_attn.sparsify(model, config) - >>> sparse_attn.print_sparse_attention_summary(model) - """ - sparse_modules = [] - for name, module in model.named_modules(): - if isinstance(module, SparseAttentionModule): - sparse_modules.append((name, module)) - - if not sparse_modules: - print("No sparse attention modules found in model") - return - - enabled_count = sum(1 for _, m in sparse_modules if m.is_enabled) - disabled_count = len(sparse_modules) - enabled_count - - # Count methods - method_counts = {} - for _, module in sparse_modules: - method = getattr(module, "_method", "unknown") - method_counts[method] = method_counts.get(method, 0) + 1 - - print(f"Total sparse attention modules: {len(sparse_modules)}") - print(f"Enabled: {enabled_count}") - print(f"Disabled: {disabled_count}") - - if method_counts: - print("\nMethods:") - for method, count in sorted(method_counts.items()): - print(f"{method}: {count}") - - for name, module in sparse_modules: - method = getattr(module, "_method", "unknown") - threshold = getattr(module, "_threshold", "N/A") - - # Format threshold nicely - if isinstance(threshold, dict): - threshold_str = str(threshold) - elif isinstance(threshold, float): - threshold_str = f"{threshold:.2e}" - else: - threshold_str = str(threshold) - - print(f"{name}: Method: {method}, Threshold: {threshold_str}") diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index b0cd1dff6..0c4a8baf9 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -15,8 +15,6 @@ """Dynamic sparse attention registration for HuggingFace models.""" -import logging - import torch.nn as nn import transformers @@ -24,8 +22,6 @@ from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry -logger = logging.getLogger(__name__) - class _GenericSparseAttention(SparseAttentionModule): """Generic sparse attention that works with any HF attention module. @@ -94,12 +90,10 @@ def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention) attention_types.add(module_type) registered_count += 1 - logger.info(f"Registered {type_name} for sparse attention optimization") + print(f"Registered {type_name} for sparse attention optimization") if registered_count > 0: - logger.info( - f"Dynamically registered {registered_count} attention module types for sparsity" - ) + print(f"Dynamically registered {registered_count} attention module types for sparsity") return registered_count > 0 diff --git a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py index b82303990..b70dfab35 100644 --- a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py +++ b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py @@ -17,7 +17,6 @@ import pytest from _test_utils.examples.run_command import extend_cmd_parts, run_example_command -from _test_utils.torch.misc import minimum_gpu def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", **kwargs): @@ -42,7 +41,6 @@ def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", run_example_command(cmd_parts, "llm_sparsity/attention_sparsity") -@minimum_gpu(1) @pytest.mark.parametrize("method", ["skip_softmax"]) def test_attention_sparsity(tiny_llama_path, tmp_path, method): """Test sparse attention with TinyLlama.""" diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py index bad077fdb..d437282d6 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py @@ -29,9 +29,6 @@ import modelopt.torch.sparsity.attention_sparsity as sparse_attn -# Skip all tests if GPU is not available -pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") - class TestAttentionSparsityGPU: """GPU tests for attention sparsity.""" diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py index 586cb3b9d..c90b99bba 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py @@ -24,9 +24,6 @@ from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule -# Skip all tests if GPU is not available -pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") - @pytest.fixture(scope="module") def tiny_llama_dir(tmp_path_factory): @@ -35,8 +32,8 @@ def tiny_llama_dir(tmp_path_factory): tmp_path_factory.mktemp("tiny_llama"), with_tokenizer=True, num_hidden_layers=2, # Minimal layers for fast testing - hidden_size=512, - intermediate_size=1024, + hidden_size=32, + intermediate_size=64, ) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py deleted file mode 100644 index 1824825f9..000000000 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py +++ /dev/null @@ -1,129 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test sparse attention configuration validation.""" - -import pytest -from pydantic import ValidationError - -pytest.importorskip("transformers") - -from modelopt.torch.sparsity.attention_sparsity.config import ( - SKIP_SOFTMAX_DEFAULT, - FlashSkipSoftmaxConfig, - SparseAttentionAttributeConfig, - SparseAttentionConfig, -) - - -class TestSparseAttentionAttributeConfig: - """Test SparseAttentionAttributeConfig validators.""" - - def test_valid_config(self): - """Test creating valid config.""" - config = SparseAttentionAttributeConfig( - method="flash_skip_softmax", - threshold=1e-4, - br=128, - bc=128, - enable=True, - ) - assert config.method == "flash_skip_softmax" - assert config.threshold == 1e-4 - assert config.br == 128 - assert config.bc == 128 - - def test_method_validation(self): - """Test method must be string.""" - with pytest.raises(ValidationError, match="Input should be a valid string"): - SparseAttentionAttributeConfig(method=123) - - def test_block_size_validation_negative(self): - """Test block sizes must be positive.""" - with pytest.raises(ValidationError, match="Block size must be positive"): - SparseAttentionAttributeConfig(br=-1) - - with pytest.raises(ValidationError, match="Block size must be positive"): - SparseAttentionAttributeConfig(bc=0) - - def test_block_size_validation_large(self): - """Test that large block sizes are accepted.""" - # Large block sizes are allowed (warning removed for simplicity) - config = SparseAttentionAttributeConfig(br=2048) - assert config.br == 2048 - - def test_threshold_validation_range(self): - """Test threshold must be in range (0, 1).""" - with pytest.raises(ValidationError, match="Threshold must be in range"): - SparseAttentionAttributeConfig(threshold=0) - - with pytest.raises(ValidationError, match="Threshold must be in range"): - SparseAttentionAttributeConfig(threshold=-0.1) - - with pytest.raises(ValidationError, match="Threshold must be in range"): - SparseAttentionAttributeConfig(threshold=1.0) - - with pytest.raises(ValidationError, match="Threshold must be in range"): - SparseAttentionAttributeConfig(threshold=1.5) - - def test_threshold_validation_dict(self): - """Test threshold dict validation.""" - # Valid phase-aware threshold - config = SparseAttentionAttributeConfig(threshold={"prefill": 1e-3, "decode": 1e-5}) - assert config.threshold == {"prefill": 1e-3, "decode": 1e-5} - - # Invalid phase key - with pytest.raises(ValidationError, match="Invalid threshold phases"): - SparseAttentionAttributeConfig(threshold={"invalid_phase": 1e-3}) - - # Invalid threshold value in dict (negative) - with pytest.raises(ValidationError, match="must be in range"): - SparseAttentionAttributeConfig(threshold={"prefill": -1e-3}) - - # Invalid threshold value in dict (>= 1.0) - with pytest.raises(ValidationError, match="must be in range"): - SparseAttentionAttributeConfig(threshold={"prefill": 1.0}) - - def test_threshold_validation_type(self): - """Test threshold type validation.""" - with pytest.raises(ValidationError, match="Input should be a valid"): - SparseAttentionAttributeConfig(threshold="invalid") - - -class TestSparseAttentionConfig: - """Test SparseAttentionConfig.""" - - def test_default_config(self): - """Test default configuration.""" - config = SparseAttentionConfig() - assert "sparse_cfg" in config.model_dump() - # Check default pattern has method - assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" - - def test_predefined_config(self): - """Test pre-defined configuration.""" - assert "sparse_cfg" in SKIP_SOFTMAX_DEFAULT - assert "method" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"]["*attn*"] - assert "*attn*" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"] - - -class TestFlashSkipSoftmaxConfig: - """Test FlashSkipSoftmaxConfig.""" - - def test_default_values(self): - """Test default values for flash_skip_softmax config.""" - config = FlashSkipSoftmaxConfig() - assert "*attention*" in config.sparse_cfg - assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index 8df8fe476..d93e929dc 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -31,7 +31,6 @@ from modelopt.torch.sparsity.attention_sparsity.conversion import ( disable_sparse_attention, enable_sparse_attention, - print_sparse_attention_summary, ) from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule @@ -171,19 +170,6 @@ def test_disable_enable_functions(self): if isinstance(module, SparseAttentionModule): assert module.is_enabled - def test_print_sparse_attention_summary(self, capsys): - """Test print_sparse_attention_summary function.""" - model = SimpleAttentionModel() - model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) - - # Print summary - print_sparse_attention_summary(model) - - # Capture output - captured = capsys.readouterr() - assert "Total sparse attention modules:" in captured.out - assert "Enabled:" in captured.out - def test_restore_sparse_attention_model(self): """Test save/restore via modelopt_state.""" # Create and sparsify original model From 0281f0e0ed5da115b461f4b9c535a2f2d4a70437 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 5 Nov 2025 17:21:28 -0800 Subject: [PATCH 4/6] Add unit and GPU tests for core sparse attention functionality Signed-off-by: Kai Xu --- .../sparsity/attention_sparsity/config.py | 17 ++- .../sparsity/attention_sparsity/conversion.py | 73 ++++++++++ .../test_sparse_attention_config.py | 129 ++++++++++++++++++ .../test_sparse_attention_conversion.py | 14 ++ 4 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index e72dacc94..bc533acb2 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -40,6 +40,12 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): description="The sparse attention method to use (e.g., 'flash_skip_softmax').", ) + method: str = ModeloptField( + default="flash_skip_softmax", + title="Sparse attention method.", + description="The sparse attention method to use (e.g., 'flash_skip_softmax').", + ) + enable: bool = ModeloptField( default=True, title="Enable sparse attention.", @@ -155,6 +161,7 @@ def validate_threshold(cls, v): SKIP_SOFTMAX_DEFAULT = { "sparse_cfg": { "*attn*": { + "method": "flash_skip_softmax", "method": "flash_skip_softmax", "threshold": { "prefill": 1e-3, # More aggressive during prefill @@ -179,6 +186,10 @@ class SparseAttentionConfig(ModeloptBaseConfig): # Pattern-based sparse configuration (similar to quant_cfg in quantization) sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={ + "*attention*": {"method": "flash_skip_softmax", "enable": True}, + "default": {"enable": False}, + }, default={ "*attention*": {"method": "flash_skip_softmax", "enable": True}, "default": {"enable": False}, @@ -195,15 +206,17 @@ class SparseAttentionConfig(ModeloptBaseConfig): ) +class FlashSkipSoftmaxConfig(SparseAttentionConfig): class FlashSkipSoftmaxConfig(SparseAttentionConfig): """Configuration for Flash Attention-aware softmax skip sparse attention.""" + # Override sparse_cfg with flash_skip_softmax specific defaults # Override sparse_cfg with flash_skip_softmax specific defaults sparse_cfg: SparseAttentionCfgType = ModeloptField( default={ "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "threshold": {"prefill": 1e-3, "decode": 1e-5}, "br": 128, # Flash Attention block rows "bc": 128, # Flash Attention block columns "backend": "pytorch", # Only pytorch backend supported @@ -213,6 +226,7 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): }, title="Flash softmax skip sparse configuration", description="Pattern-based configuration with flash_skip_softmax specific defaults. " + description="Pattern-based configuration with flash_skip_softmax specific defaults. " "Includes FA block sizes (br, bc) and correction factor settings.", validate_default=True, ) @@ -221,6 +235,7 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): __all__ = [ "SKIP_SOFTMAX_DEFAULT", "FlashSkipSoftmaxConfig", + "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", "SparseAttentionCfgType", "SparseAttentionConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index ad137e9ee..eec48c136 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -28,6 +28,7 @@ from .config import SparseAttentionConfig from .plugins.huggingface import register_sparse_attention_on_the_fly from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry def is_attn_sparsified(model: nn.Module) -> bool: @@ -68,6 +69,7 @@ def convert_to_sparse_attention_model( # Apply configuration to sparse attention modules sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {} set_sparse_attention_by_cfg(model, sparse_cfg) + set_sparse_attention_by_cfg(model, sparse_cfg) # Create metadata metadata = {} @@ -106,6 +108,7 @@ def _replace_sparse_attention_modules(model: nn.Module, version=None): _replace_sparse_attention_modules(getattr(model, name), version=version) +def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict): def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict): """Apply sparse attention configuration to model. @@ -114,17 +117,20 @@ def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict): Args: model: Model with sparse attention modules sparse_cfg: Sparse configuration dictionary mapping patterns to attributes + sparse_cfg: Sparse configuration dictionary mapping patterns to attributes """ sparse_cfg = sparse_cfg.copy() # Apply default first if exists if "default" in sparse_cfg: + set_sparse_attention_attribute(model, "*", sparse_cfg["default"]) set_sparse_attention_attribute(model, "*", sparse_cfg["default"]) sparse_cfg.pop("default") # Apply pattern-specific configs for pattern, cfg in sparse_cfg.items(): set_sparse_attention_attribute(model, pattern, cfg) + set_sparse_attention_attribute(model, pattern, cfg) def set_sparse_attention_attribute( @@ -140,6 +146,7 @@ def set_sparse_attention_attribute( model: Model to configure wildcard_or_filter: Pattern to match module names attribute_cfg: Attributes to apply (must include 'method') + attribute_cfg: Attributes to apply (must include 'method') """ # Filter out model-level configs that shouldn't be passed to modules module_cfg = {k: v for k, v in attribute_cfg.items() if k != "calibration"} @@ -156,10 +163,12 @@ def set_sparse_attention_attribute( matched = wildcard_or_filter(name) else: raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter)}") + raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter)}") if matched: # Apply config using the same method as TensorQuantizer module.set_from_attribute_config(module_cfg) + module.set_from_attribute_config(module_cfg) def restore_sparse_attention_model( @@ -226,11 +235,14 @@ def update_sparse_attention_metadata( if isinstance(module, SparseAttentionModule): module_name = get_unwrapped_name(name, model) + # Save the method configuration that was used + # _method_config already contains the validated config dict # Save the method configuration that was used # _method_config already contains the validated config dict module_state = { "method": module._sparse_method_instance.name, "method_config": module._method_config.copy(), + "method_config": module._method_config.copy(), } sparse_state[module_name] = module_state @@ -299,3 +311,64 @@ def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Cal if matched: module.enable() + + +def print_sparse_attention_summary(model: nn.Module): + """Print summary of sparse attention modules in the model. + + Similar to mtq.print_quant_summary for API consistency. + + Args: + model: Model with sparse attention applied + + Prints: + - Total sparse attention modules + - Enabled vs disabled count + - Method distribution + - Configuration summary by module + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> sparse_attn.print_sparse_attention_summary(model) + """ + sparse_modules = [] + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + sparse_modules.append((name, module)) + + if not sparse_modules: + print("No sparse attention modules found in model") + return + + enabled_count = sum(1 for _, m in sparse_modules if m.is_enabled) + disabled_count = len(sparse_modules) - enabled_count + + # Count methods + method_counts = {} + for _, module in sparse_modules: + method = getattr(module, "_method", "unknown") + method_counts[method] = method_counts.get(method, 0) + 1 + + print(f"Total sparse attention modules: {len(sparse_modules)}") + print(f"Enabled: {enabled_count}") + print(f"Disabled: {disabled_count}") + + if method_counts: + print("\nMethods:") + for method, count in sorted(method_counts.items()): + print(f"{method}: {count}") + + for name, module in sparse_modules: + method = getattr(module, "_method", "unknown") + threshold = getattr(module, "_threshold", "N/A") + + # Format threshold nicely + if isinstance(threshold, dict): + threshold_str = str(threshold) + elif isinstance(threshold, float): + threshold_str = f"{threshold:.2e}" + else: + threshold_str = str(threshold) + + print(f"Method: {method}, Threshold: {threshold_str}") diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py new file mode 100644 index 000000000..1824825f9 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test sparse attention configuration validation.""" + +import pytest +from pydantic import ValidationError + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_DEFAULT, + FlashSkipSoftmaxConfig, + SparseAttentionAttributeConfig, + SparseAttentionConfig, +) + + +class TestSparseAttentionAttributeConfig: + """Test SparseAttentionAttributeConfig validators.""" + + def test_valid_config(self): + """Test creating valid config.""" + config = SparseAttentionAttributeConfig( + method="flash_skip_softmax", + threshold=1e-4, + br=128, + bc=128, + enable=True, + ) + assert config.method == "flash_skip_softmax" + assert config.threshold == 1e-4 + assert config.br == 128 + assert config.bc == 128 + + def test_method_validation(self): + """Test method must be string.""" + with pytest.raises(ValidationError, match="Input should be a valid string"): + SparseAttentionAttributeConfig(method=123) + + def test_block_size_validation_negative(self): + """Test block sizes must be positive.""" + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(br=-1) + + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(bc=0) + + def test_block_size_validation_large(self): + """Test that large block sizes are accepted.""" + # Large block sizes are allowed (warning removed for simplicity) + config = SparseAttentionAttributeConfig(br=2048) + assert config.br == 2048 + + def test_threshold_validation_range(self): + """Test threshold must be in range (0, 1).""" + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=-0.1) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.5) + + def test_threshold_validation_dict(self): + """Test threshold dict validation.""" + # Valid phase-aware threshold + config = SparseAttentionAttributeConfig(threshold={"prefill": 1e-3, "decode": 1e-5}) + assert config.threshold == {"prefill": 1e-3, "decode": 1e-5} + + # Invalid phase key + with pytest.raises(ValidationError, match="Invalid threshold phases"): + SparseAttentionAttributeConfig(threshold={"invalid_phase": 1e-3}) + + # Invalid threshold value in dict (negative) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": -1e-3}) + + # Invalid threshold value in dict (>= 1.0) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.0}) + + def test_threshold_validation_type(self): + """Test threshold type validation.""" + with pytest.raises(ValidationError, match="Input should be a valid"): + SparseAttentionAttributeConfig(threshold="invalid") + + +class TestSparseAttentionConfig: + """Test SparseAttentionConfig.""" + + def test_default_config(self): + """Test default configuration.""" + config = SparseAttentionConfig() + assert "sparse_cfg" in config.model_dump() + # Check default pattern has method + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" + + def test_predefined_config(self): + """Test pre-defined configuration.""" + assert "sparse_cfg" in SKIP_SOFTMAX_DEFAULT + assert "method" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"]["*attn*"] + assert "*attn*" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"] + + +class TestFlashSkipSoftmaxConfig: + """Test FlashSkipSoftmaxConfig.""" + + def test_default_values(self): + """Test default values for flash_skip_softmax config.""" + config = FlashSkipSoftmaxConfig() + assert "*attention*" in config.sparse_cfg + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index d93e929dc..8df8fe476 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -31,6 +31,7 @@ from modelopt.torch.sparsity.attention_sparsity.conversion import ( disable_sparse_attention, enable_sparse_attention, + print_sparse_attention_summary, ) from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule @@ -170,6 +171,19 @@ def test_disable_enable_functions(self): if isinstance(module, SparseAttentionModule): assert module.is_enabled + def test_print_sparse_attention_summary(self, capsys): + """Test print_sparse_attention_summary function.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Print summary + print_sparse_attention_summary(model) + + # Capture output + captured = capsys.readouterr() + assert "Total sparse attention modules:" in captured.out + assert "Enabled:" in captured.out + def test_restore_sparse_attention_model(self): """Test save/restore via modelopt_state.""" # Create and sparsify original model From 08d9a9fe87c4e9524538cbf2a6ff6a0baaa65a92 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 5 Nov 2025 17:22:17 -0800 Subject: [PATCH 5/6] Add sparsity ratio calibration for skip softmax Signed-off-by: Kai Xu --- .../llm_sparsity/attention_sparsity/hf_sa.py | 8 +- .../calibration/__init__.py | 26 + .../calibration/calibrate.py | 176 +++++ .../calibration/calibrator.py | 308 +++++++++ .../attention_sparsity/calibration/dataset.py | 602 ++++++++++++++++++ .../sparsity/attention_sparsity/config.py | 115 ++++ .../methods/flash_skip_softmax.py | 12 +- .../attention_sparsity/model_sparsify.py | 81 ++- .../attention_sparsity/sparse_attention.py | 18 +- .../attention_sparsity/stats_manager.py | 125 ++++ .../test_attention_sparsity.py | 5 +- .../test_calibration_gpu.py | 388 +++++++++++ .../test_sparse_attention_calibration.py | 442 +++++++++++++ 13 files changed, 2297 insertions(+), 9 deletions(-) create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/stats_manager.py create mode 100644 tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 11564a4ec..2b5f3ade5 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -29,7 +29,10 @@ import modelopt.torch.sparsity.attention_sparsity as mtsa from modelopt.torch.export import export_hf_checkpoint from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig -from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_CALIB, + SKIP_SOFTMAX_DEFAULT, +) from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule from modelopt.torch.utils.memory_monitor import launch_memory_monitor @@ -38,9 +41,10 @@ # Enable HuggingFace checkpointing support mto.enable_huggingface_checkpointing() -# You can define custom configurations or use the default +# Sparse attention configuration choices SPARSE_ATTN_CFG_CHOICES = { "skip_softmax": SKIP_SOFTMAX_DEFAULT, + "skip_softmax_calib": SKIP_SOFTMAX_CALIB, } diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py new file mode 100644 index 000000000..3b616e8e3 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +from .calibrate import calibrate_sparse_attention +from .calibrator import DynamicThresholdCalibrator +from .dataset import RulerDatasetBuilder + +__all__ = [ + "DynamicThresholdCalibrator", + "RulerDatasetBuilder", + "calibrate_sparse_attention", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py new file mode 100644 index 000000000..b90adbc0f --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration functions for sparse attention.""" + +import warnings +from collections.abc import Callable +from typing import Any + +import torch +import torch.nn as nn +from transformers import AutoTokenizer + +from ..config import CalibrationConfig +from ..sparse_attention import SparseAttentionModule +from .calibrator import DynamicThresholdCalibrator +from .dataset import RulerDatasetBuilder + + +def _extract_tokenizer_from_model(model: nn.Module) -> str: + """Extract tokenizer name/path from model config. + + Args: + model: Model to extract tokenizer from + + Returns: + Tokenizer name or path + + Raises: + ValueError: If tokenizer path cannot be determined from model + """ + # Extract tokenizer path from model config + tokenizer_path = getattr(getattr(model, "config", None), "_name_or_path", None) + + if not tokenizer_path: + raise ValueError("Could not load tokenizer from model.") + + return tokenizer_path + + +def _extract_calibration_config(config: dict[str, Any]) -> CalibrationConfig | None: + """Extract and validate calibration config from sparse_cfg patterns. + + Args: + config: Sparse attention configuration dict + + Returns: + Validated CalibrationConfig or None if not found + """ + # Extract sparse_cfg and search for calibration + sparse_cfg = config.get("sparse_cfg", {}) + + calib_dict = next( + ( + cfg["calibration"] + for cfg in sparse_cfg.values() + if isinstance(cfg, dict) and "calibration" in cfg + ), + None, + ) + + # Create and calidate the calibration config + return CalibrationConfig(**calib_dict) if calib_dict else None + + +def create_calibration_forward_loop( + calibration_data: list[dict[str, Any]], + tokenizer_name_or_path: str, + batch_size: int = 1, +) -> Callable: + """Create forward loop for calibration. + + Args: + calibration_data: List of samples with 'input' and 'length' fields + tokenizer_name_or_path: HuggingFace tokenizer path + batch_size: Batch size (currently unused, always 1) + + Returns: + Forward loop function that takes model as argument + """ + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + def forward_loop(model: nn.Module) -> None: + device = next(model.parameters()).device + + for sample in calibration_data: + inputs = tokenizer( + sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"] + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.no_grad(): + model(**inputs) + + return forward_loop + + +def calibrate_sparse_attention( + model: nn.Module, + config: dict[str, Any], + forward_loop: Callable | None = None, +) -> dict[str, Any]: + """Calibrate sparse attention parameters for optimal sparsity. + + Args: + model: Model with sparse attention modules + config: Sparse attention configuration dict + forward_loop: Callable that forwards calibration data through model. + If None, auto-generates RULER dataset. + + Returns: + Dictionary with calibration results + """ + # Extract and validate calibration config + calib_config = _extract_calibration_config(config) + if not calib_config: + return {} + + # Generate forward_loop if not provided + if not forward_loop: + tokenizer = _extract_tokenizer_from_model(model) + builder = RulerDatasetBuilder( + samples=calib_config.samples, + max_seqlen=calib_config.max_seqlen, + tokenizer_name_or_path=tokenizer, + num_length_bins=calib_config.num_length_bins, + max_length_filter=int(calib_config.max_seqlen * 1.2), + ) + calibration_data = builder.build_calibration_dataset() + print(f"Generated {len(calibration_data)} calibration samples") + forward_loop = create_calibration_forward_loop(calibration_data, tokenizer) + + # Get sparse attention modules + sparse_modules = [ + (name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule) + ] + + if not sparse_modules: + print("No sparse attention modules found for calibration") + return {} + + print(f"Calibrating {len(sparse_modules)} sparse attention modules together...") + + # Run calibration + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=calib_config.target_sparse_ratio, + threshold_trials=calib_config.threshold_trials, + ) + calibration_result = calibrator.calibrate(model, forward_loop) + + if "scale_factor" not in calibration_result: + warnings.warn("Calibration did not produce valid results") + return {} + + # Apply calibrated scale factor to all modules + scale_factor = calibration_result["scale_factor"] + print(f"\nApplying calibrated scale factor={scale_factor:.6f} to {len(sparse_modules)} modules") + + for module_name, module in sparse_modules: + module._sparse_method_instance.threshold_scale_factor = scale_factor + + return {"calibration_results": {name: calibration_result for name, _ in sparse_modules}} diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py new file mode 100644 index 000000000..f933e50ce --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -0,0 +1,308 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +import warnings +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm + +from ..sparse_attention import SparseAttentionModule +from ..stats_manager import SparseAttentionStatsManager + + +class DynamicThresholdCalibrator: + """Dynamic threshold calibrator using length-based linear relationship. + + Implements calibration algorithm: + 1. Find hyperparameter 'a' where threshold λ = a / context_length + 2. Use dataset with different lengths and test multiple thresholds + 3. For each sample, find optimal threshold closest to target sparsity + 4. Use linear regression to fit: threshold = a * (1/length) + """ + + def __init__( + self, + target_sparse_ratio: float = 0.5, + threshold_trials: list[float] | None = None, + ): + """Initialize dynamic threshold calibrator. + + Args: + target_sparse_ratio: Target sparsity ratio (0.0 to 1.0) + threshold_trials: List of thresholds to try during calibration + + Note: + Calibration only supports prefill phase (seq_len > 1). + Decode phase uses the same calibrated threshold. + """ + self.target_sparse_ratio = target_sparse_ratio + + # Default threshold trials if not provided + self.threshold_trials = threshold_trials or [ + 1e-6, + 5e-6, + 1e-5, + 5e-5, + 1e-4, + 5e-4, + 1e-3, + 5e-3, + 1e-2, + 5e-2, + 1e-1, + 5e-1, + ] + + # Statistics tracking + self.sparsity_results = [] + + def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]: + """Find optimal 'a' parameter for length-based threshold. + + Algorithm: + 1. Test all threshold trials by running forward_loop multiple times + 2. For each sample, find optimal threshold closest to target sparsity + 3. Use regression to find 'a' in: threshold = a / length + + Args: + model: The model with sparse attention modules + forward_loop: Callable that takes model and forwards calibration data + """ + # Extract attention modules + attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] + + if not attention_modules: + warnings.warn("No sparse attention modules found for calibration") + return {} + + print("Starting dynamic threshold calibration") + print(f"Target sparsity: {self.target_sparse_ratio}") + print(f"Threshold trials: {len(self.threshold_trials)}") + + # Stage 1: Collect sparsity for all sample-threshold pairs + print("\nStage 1: Collecting sparsity data...") + self.sparsity_results = [] + + # For each threshold, run forward_loop and collect per-sample statistics + for threshold_idx, threshold in enumerate( + tqdm(self.threshold_trials, desc="Testing thresholds") + ): + # Set threshold and enable calibration mode + self._set_threshold(attention_modules, threshold) + self._enable_calibration_mode(attention_modules) + + # Run forward loop and collect stats + with torch.no_grad(): + forward_loop(model) + per_sample_stats = self._extract_calibration_stats(attention_modules) + self._disable_calibration_mode(attention_modules) + + # Store results + for sample_idx, sample_stat in enumerate(per_sample_stats): + if threshold_idx == 0: + # Initialize sample entry on first threshold + sample_length = sample_stat.get("sample_length", 0) + if sample_length > 0: + self.sparsity_results.append( + { + "sample_index": sample_idx, + "length": sample_length, + "threshold_sparsities": {}, + } + ) + + # Add sparsity for this threshold + if sample_idx < len(self.sparsity_results): + sparsity = sample_stat.get("sparsity", 0.0) + self.sparsity_results[sample_idx]["threshold_sparsities"][threshold] = sparsity + + if not self.sparsity_results: + warnings.warn("No valid sparsity measurements collected during calibration") + return {} + + print(f"Collected statistics for {len(self.sparsity_results)} samples") + + # Stage 2: Find optimal threshold for each sample and compute 'a' + print( + f"\nStage 2: Finding 'a' parameter for target sparsity {self.target_sparse_ratio:.2f}" + ) + + # Find optimal threshold for each sample + optimal_pairs = [] + for sample_result in self.sparsity_results: + # Find threshold closest to target sparsity + best_threshold, achieved_sparsity = min( + sample_result["threshold_sparsities"].items(), + key=lambda item: abs(item[1] - self.target_sparse_ratio), + ) + + optimal_pairs.append( + { + "length": sample_result["length"], + "optimal_threshold": best_threshold, + "achieved_sparsity": achieved_sparsity, + "target_sparsity": self.target_sparse_ratio, + } + ) + + if not optimal_pairs: + warnings.warn( + f"No optimal threshold pairs found for target sparsity {self.target_sparse_ratio}. " + f"Collected {len(self.sparsity_results)} samples but none achieved target sparsity." + ) + return {} + + # Linear regression: threshold = a * (1/length) + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + # X = 1/length, Y = threshold + x = 1.0 / lengths + y = thresholds + + # Least squares: scale_factor = sum(x*y) / sum(x^2) + scale_factor = np.sum(x * y) / np.sum(x**2) + + # Calculate statistics + scale_factors_per_sample = y * lengths + scale_factor_std = np.std(scale_factors_per_sample) + + # Calculate R-squared for quality metric + y_pred = scale_factor * x + ss_res = np.sum((y - y_pred) ** 2) + ss_tot = np.sum((y - np.mean(y)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + # Calculate average achieved sparsity + avg_achieved_sparsity = np.mean([p["achieved_sparsity"] for p in optimal_pairs]) + + print("\nCalibration Results:") + print(f" Threshold scale factor: {scale_factor:.6f} (std: {scale_factor_std:.6f})") + print(f" R-squared: {r_squared:.4f}") + print( + f" Average achieved sparsity: {avg_achieved_sparsity:.2%} (target: {self.target_sparse_ratio:.2%})" + ) + print(f"\nExample thresholds with λ = {scale_factor:.6f} / length:") + for length in [1024, 2048, 4096, 8192, 16384]: + print(f" Length {length:5d}: threshold = {scale_factor / length:.2e}") + + # Apply the calibrated scale factor to modules + self._apply_length_based_calibration(attention_modules, scale_factor) + + return { + "scale_factor": scale_factor, + "scale_factor_std": scale_factor_std, + "r_squared": r_squared, + "num_samples": len(optimal_pairs), + "target_sparsity": self.target_sparse_ratio, + "avg_achieved_sparsity": avg_achieved_sparsity, + "optimal_pairs": optimal_pairs, + "calibration_type": "length_based_dynamic", + } + + def _apply_length_based_calibration(self, modules: list[nn.Module], scale_factor: float): + """Apply calibrated threshold scale factor to modules. + + Args: + modules: List of attention modules + scale_factor: Calibrated scale factor for λ = scale_factor / length + """ + for module in modules: + module._sparse_method_instance.threshold_scale_factor = scale_factor + + def _enable_calibration_mode(self, modules: list[nn.Module]): + """Enable calibration mode on sparse attention modules.""" + for idx, module in enumerate(modules): + # Create stats manager if needed + if not module._stats_manager: + module._stats_manager = SparseAttentionStatsManager( + module_name=f"sparse_attn_{idx}", enabled=True + ) + else: + # Re-enable if disabled + module._stats_manager.enabled = True + + # Enable calibration mode with fresh stats + module._stats_manager.set_calibration_mode(enabled=True, reset_history=True) + module._sparse_method_instance.set_calibration_mode(True) + + def _disable_calibration_mode(self, modules: list[nn.Module]): + """Disable calibration mode (but keep stats enabled if collect_stats=True).""" + for module in modules: + if module._stats_manager: + module._stats_manager.set_calibration_mode(enabled=False) + + module._sparse_method_instance.set_calibration_mode(False) + + def _extract_calibration_stats(self, modules: list[nn.Module]) -> list[dict]: + """Extract per-sample calibration statistics from modules. + + Args: + modules: List of attention modules + + Returns: + List of per-sample statistics across all modules + """ + # Collect from all stats managers + all_per_sample_stats = [] + + for module in modules: + # Skip modules without stats manager + if not hasattr(module, "_stats_manager") or module._stats_manager is None: + continue + + manager_stats = module._stats_manager.get_calibration_stats() + if manager_stats: + all_per_sample_stats.append(manager_stats) + + if not all_per_sample_stats: + return [] + + # Aggregate across modules by sample index + num_samples = len(all_per_sample_stats[0]) + aggregated_stats = [] + + for sample_idx in range(num_samples): + sparsities = [] + sample_length = 0 + + for module_stats in all_per_sample_stats: + if sample_idx < len(module_stats): + sample_stat = module_stats[sample_idx] + sparsities.append(sample_stat.get("sparsity", 0.0)) + if not sample_length and "sample_length" in sample_stat: + sample_length = sample_stat["sample_length"] + + avg_sparsity = np.mean(sparsities) if sparsities else 0.0 + + aggregated_stats.append( + { + "sparsity": avg_sparsity, + "sample_length": sample_length, + } + ) + + return aggregated_stats + + def _set_threshold(self, modules: list[nn.Module], threshold: float): + """Set threshold on sparse attention modules.""" + for module in modules: + module._sparse_method_instance.threshold = threshold diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py new file mode 100644 index 000000000..727fb189a --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -0,0 +1,602 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RULER dataset builder for sparse attention calibration.""" + +import random +import string +import uuid +from dataclasses import dataclass +from typing import Any + +from tqdm import tqdm +from transformers import AutoTokenizer + + +def _generate_target_lengths( + max_seqlen: int, num_length_bins: int = 4, min_seqlen: int = 1024 +) -> list[int]: + """Generate target lengths as descending powers of 2. + + Args: + max_seqlen: Maximum sequence length + num_length_bins: Maximum number of length bins to generate + min_seqlen: Minimum sequence length threshold + + Returns: + List of target lengths in descending order + + Examples: + >>> _generate_target_lengths(32768, 4) + [32768, 16384, 8192, 4096] + >>> _generate_target_lengths(2048, 4) + [2048, 1024] + """ + target_lengths = [] + current = max_seqlen + + for _ in range(num_length_bins): + if current < min_seqlen: + break + target_lengths.append(current) + current = current // 2 + + return target_lengths + + +@dataclass +class RulerTask: + """Configuration for a RULER task.""" + + name: str + task_type: str # niah, variable_tracking, freq_words_extraction, qa + tokens_to_generate: int + template: str + answer_prefix: str + args: dict[str, Any] + + +# Task configurations based on RULER benchmark +RULER_TASKS = { + "niah_multikey_2": RulerTask( + name="niah_multikey_2", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "words", + "type_needle_v": "numbers", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "niah_multikey_3": RulerTask( + name="niah_multikey_3", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "uuids", + "type_needle_v": "uuids", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "vt": RulerTask( + name="vt", + task_type="variable_tracking", + tokens_to_generate=30, + template=( + "Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n" + "{context}\n" + "Question: Find all variables that are assigned the value {query} in the text above." + ), + answer_prefix=( + " Answer: According to the chain(s) of variable assignment in the text above, " + "{num_v} variables are assgined the value {query}, they are: " + ), + args={"num_chains": 1, "num_hops": 4}, + ), + "fwe": RulerTask( + name="fwe", + task_type="freq_words_extraction", + tokens_to_generate=50, + template=( + "Read the following coded text and track the frequency of each coded word. " + "Find the three most frequently appeared coded words. {context}\n" + "Question: Do not provide any explanation. Please ignore the dots '....'. " + "What are the three most frequently appeared words in the above coded text?" + ), + answer_prefix=( + " Answer: According to the coded text above, " + "the three most frequently appeared words are:" + ), + args={}, + ), + "qa_1": RulerTask( + name="qa_1", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "squad", "demo_selection": "topk_lt_ctx"}, + ), + "qa_2": RulerTask( + name="qa_2", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "hotpotqa", "demo_selection": "topk_lt_ctx"}, + ), +} + + +class RulerDatasetBuilder: + """Builder for RULER calibration datasets.""" + + def __init__( + self, + samples: int, + max_seqlen: int, + tokenizer_name_or_path: str, + seed: int = 42, + num_length_bins: int = 4, + max_length_filter: int = 65536, + ): + """Initialize RULER dataset builder. + + Args: + samples: Total number of samples to generate (distributed evenly across length bins) + max_seqlen: Maximum sequence length (length bins auto-generated as powers of 2) + tokenizer_name_or_path: HuggingFace tokenizer path + seed: Random seed for reproducibility + num_length_bins: Number of length bins to generate (default: 4) + max_length_filter: Maximum sequence length to keep (default: 65536) + + Note: + Length bins are auto-generated as descending powers of 2: + [max_seqlen, max_seqlen/2, max_seqlen/4, ...] + Generation stops when num_length_bins is reached or length < 1024. + Subtasks are set to all the difficult tasks defined in RULER_TASKS. + """ + # Validate inputs + if samples <= 0: + raise ValueError(f"samples must be positive, got {samples}") + if max_seqlen < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {max_seqlen}") + + # Store parameters + self.total_samples = samples + self.max_seqlen = max_seqlen + self.num_length_bins = num_length_bins + self.subtasks = list(RULER_TASKS.keys()) + self.tokenizer_name_or_path = tokenizer_name_or_path + self.seed = seed + self.max_length_filter = max_length_filter + + # Generate target lengths and validate + self.target_lengths = _generate_target_lengths(max_seqlen, num_length_bins, min_seqlen=1024) + if not self.target_lengths: + raise ValueError(f"No valid target lengths generated from max_seqlen={max_seqlen}") + + # Distribute samples evenly across lengths + self.samples_per_length = [samples // len(self.target_lengths)] * len(self.target_lengths) + + # Initialize tokenizer and seed + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + random.seed(seed) + + def build_calibration_dataset(self) -> list[dict[str, Any]]: + """Build the complete calibration dataset. + + Returns: + List of calibration samples with 'input' and 'length' fields + """ + all_samples = [] + + # Generate calibration samples + for num_samples, target_length in tqdm( + zip(self.samples_per_length, self.target_lengths), + desc="Generating RULER calibration samples", + total=len(self.target_lengths), + ): + samples_per_task = num_samples // len(self.subtasks) + if samples_per_task <= 0: + continue + + # Generate equal samples for each task + for task_name in self.subtasks: + for sample_idx in range(samples_per_task): + sample = self._generate_sample(task_name, target_length, sample_idx) + if sample and sample["length"] <= self.max_length_filter: + all_samples.append(sample) + + random.shuffle(all_samples) + return all_samples + + def _generate_sample( + self, task_name: str, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a single RULER sample. + + Args: + task_name: Name of the RULER task + target_length: Target sequence length in tokens + sample_idx: Index of the sample (for uniqueness) + + Returns: + Dict with 'input', 'length', and metadata fields + """ + task = RULER_TASKS[task_name] + + if task.task_type == "niah": + return self._generate_niah_sample(task, target_length, sample_idx) + elif task.task_type == "variable_tracking": + return self._generate_vt_sample(task, target_length, sample_idx) + elif task.task_type == "freq_words_extraction": + return self._generate_fwe_sample(task, target_length, sample_idx) + elif task.task_type == "qa": + return self._generate_qa_sample(task, target_length, sample_idx) + else: + raise ValueError(f"Unknown task type: {task.task_type}") + + def _generate_niah_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a needle-in-haystack sample.""" + args = task.args + + # Generate needles based on type + if args["type_needle_k"] == "words": + needle_keys = [self._generate_random_word() for _ in range(args["num_needle_k"])] + else: # uuids + needle_keys = [str(uuid.uuid4()) for _ in range(args["num_needle_k"])] + + if args["type_needle_v"] == "numbers": + needle_values = [ + str(random.randint(100000, 999999)) for _ in range(args["num_needle_v"]) + ] + else: # uuids + needle_values = [str(uuid.uuid4()) for _ in range(args["num_needle_v"])] + + # Select query needles + query_keys = random.sample(needle_keys, min(args["num_needle_q"], len(needle_keys))) + + # Generate context with needles + context = self._generate_niah_context( + target_length, needle_keys, needle_values, args["type_haystack"] + ) + + # Format template + template = task.template.format( + type_needle_v=args["type_needle_v"], + context=context, + query=query_keys[0] if query_keys else needle_keys[0], + ) + + # Add answer prefix + full_input = template + task.answer_prefix.format( + type_needle_v=args["type_needle_v"], + query=query_keys[0] if query_keys else needle_keys[0], + ) + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_vt_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a variable tracking sample.""" + args = task.args + num_chains = args["num_chains"] + num_hops = args["num_hops"] + + # Generate variable chains + variables = [] + chains = [] + for _ in range(num_chains): + chain = [self._generate_random_variable() for _ in range(num_hops + 1)] + variables.extend(chain) + chains.append(chain) + + # Generate assignments + assignments = [ + f"VAR {chain[i]} = {chain[i + 1]}" for chain in chains for i in range(len(chain) - 1) + ] + + # Create context with padding + context = self._pad_context_with_text( + "\n".join(assignments), target_length, "variable tracking context" + ) + + # Select a query value + query_value = random.choice([chain[-1] for chain in chains]) + + # Format template + template = task.template.format(context=context, query=query_value) + + # Count variables with the query value + num_v = sum(1 for chain in chains if chain[-1] == query_value) + + # Add answer prefix + full_input = template + task.answer_prefix.format(num_v=num_v, query=query_value) + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_fwe_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a frequency word extraction sample.""" + # Generate coded words with frequencies + num_unique_words = 50 + coded_words = [self._generate_coded_word() for _ in range(num_unique_words)] + + # Assign frequencies (make top 3 clearly more frequent) + frequencies = {} + for i, word in enumerate(coded_words): + if i < 3: + frequencies[word] = random.randint(20, 30) # High frequency + else: + frequencies[word] = random.randint(1, 10) # Low frequency + + # Generate the coded text + word_list = [] + for word, freq in frequencies.items(): + word_list.extend([word] * freq) + random.shuffle(word_list) + + # Add dots for separation + coded_text = " .... ".join(word_list) + + # Pad to target length + context = self._pad_context_with_text(coded_text, target_length, "coded text padding") + + # Format template + template = task.template.format(context=context) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_qa_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a QA sample.""" + # Generate synthetic documents + num_docs = 5 + documents = [] + + # Create a simple QA pair + answer = self._generate_random_phrase() + question = f"What is the special code mentioned in document {random.randint(1, num_docs)}?" + + for i in range(num_docs): + doc_text = self._generate_document_text(200) # Base document + if i == 2: # Insert answer in one document + doc_text += f" The special code is {answer}. " + documents.append(f"Document {i + 1}:\n{doc_text}\n") + + # Combine documents + context_base = "\n".join(documents) + + # Pad to target length + context = self._pad_context_with_text( + context_base, target_length, "additional document text" + ) + + # Format template + template = task.template.format(context=context, query=question) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_niah_context( + self, + target_length: int, + needle_keys: list[str], + needle_values: list[str], + haystack_type: str, + ) -> str: + """Generate context for needle-in-haystack tasks.""" + # Create needle sentences + needles = [] + for key, value in zip(needle_keys, needle_values): + needles.append(f"The magic number for {key} is {value}.") + + # Generate haystack based on type + if haystack_type == "repeat": + base_text = "The grass is green. The sky is blue. The sun is yellow. " * 100 + elif haystack_type == "essay": + base_text = self._generate_essay_text(500) + else: # needle type - more needles as distractors + base_text = self._generate_needle_haystack(500) + + # Insert needles at random positions + words = base_text.split() + for needle in needles: + insert_pos = random.randint(0, len(words)) + words.insert(insert_pos, needle) + + context = " ".join(words) + + # Adjust to target length + tokens = self.tokenizer.encode(context, add_special_tokens=False) + while len(tokens) < target_length * 0.7: # Leave room for template + context += " " + self._generate_essay_text(100) + tokens = self.tokenizer.encode(context, add_special_tokens=False) + + if len(tokens) > target_length * 0.9: + # Truncate if too long + context = self.tokenizer.decode(tokens[: int(target_length * 0.8)]) + + return context + + def _pad_context_with_text( + self, base_context: str, target_length: int, padding_type: str + ) -> str: + """Pad context to approach target length.""" + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + while len(tokens) < target_length * 0.7: # Leave room for template + if padding_type == "variable tracking context": + padding = ( + f" VAR {self._generate_random_variable()} = {self._generate_random_variable()}." + ) + elif padding_type == "coded text padding": + padding = f" .... {self._generate_coded_word()} .... " + else: + padding = " " + self._generate_essay_text(50) + + base_context += padding + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + if len(tokens) > target_length * 0.9: + # Truncate if too long + base_context = self.tokenizer.decode(tokens[: int(target_length * 0.8)]) + + return base_context + + def _generate_random_word(self) -> str: + """Generate a random word.""" + return "".join(random.choices(string.ascii_lowercase, k=random.randint(5, 10))) + + def _generate_random_variable(self) -> str: + """Generate a random variable name.""" + return "".join(random.choices(string.ascii_uppercase, k=1)) + "".join( + random.choices(string.digits, k=3) + ) + + def _generate_coded_word(self) -> str: + """Generate a coded word.""" + return "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) + + def _generate_random_phrase(self) -> str: + """Generate a random phrase.""" + words = [self._generate_random_word() for _ in range(random.randint(2, 4))] + return " ".join(words) + + def _generate_essay_text(self, num_words: int) -> str: + """Generate essay-like text.""" + topics = [ + "technology", + "science", + "nature", + "history", + "culture", + "education", + "health", + "economics", + "politics", + "philosophy", + "art", + "literature", + ] + + sentences = [] + words_generated = 0 + + while words_generated < num_words: + topic = random.choice(topics) + word1 = self._generate_random_word() + word2 = self._generate_random_word() + word3 = self._generate_random_word() + sentence = f"The {topic} of {word1} is {word2} and {word3}. " + sentences.append(sentence) + words_generated += len(sentence.split()) + + return " ".join(sentences) + + def _generate_document_text(self, num_words: int) -> str: + """Generate document-like text.""" + return self._generate_essay_text(num_words) + + def _generate_needle_haystack(self, num_items: int) -> str: + """Generate haystack with many needle-like distractors.""" + items = [] + for _ in range(num_items): + key = self._generate_random_word() + value = str(random.randint(100000, 999999)) + items.append(f"The value for {key} is {value}.") + + return " ".join(items) diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index bc533acb2..e19b74a2a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -156,6 +156,97 @@ def validate_threshold(cls, v): return v +class CalibrationConfig(ModeloptBaseConfig): + """Configuration for automatic threshold calibration using RULER dataset. + + Calibration learns a dynamic threshold λ = scale_factor / sequence_length that + achieves target sparsity. Only supports prefill phase (seq_len > 1). + """ + + target_sparse_ratio: float = ModeloptField( + default=0.5, + title="Target sparsity ratio", + description="Target ratio of sparse attention blocks (0.0 to 1.0).", + ) + + samples: int = ModeloptField( + default=24, + title="Calibration samples", + description="Total number of RULER samples for calibration (distributed across length bins).", + ) + + max_seqlen: int = ModeloptField( + default=32768, + title="Maximum sequence length", + description="Maximum sequence length for calibration (length bins auto-generated as powers of 2).", + ) + + num_length_bins: int = ModeloptField( + default=4, + title="Number of length bins", + description="Number of length bins to generate (hidden parameter, default: 4).", + ) + + threshold_trials: list[float] | None = ModeloptField( + default=None, + title="Threshold trials", + description=( + "List of threshold values to test during calibration. " + "If None, uses default: [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]" + ), + ) + + @field_validator("threshold_trials") + @classmethod + def validate_threshold_trials(cls, v): + """Validate threshold_trials are in valid range.""" + if v is not None: + if not isinstance(v, list): + raise ValueError(f"threshold_trials must be a list, got {type(v)}") + if len(v) == 0: + raise ValueError("threshold_trials must not be empty") + for threshold in v: + if not isinstance(threshold, (int, float)): + raise ValueError(f"All threshold_trials must be numbers, got {type(threshold)}") + if threshold <= 0 or threshold >= 1: + raise ValueError( + f"All threshold_trials must be in range (0, 1), got {threshold}" + ) + return v + + @field_validator("target_sparse_ratio") + @classmethod + def validate_target_sparse_ratio(cls, v): + """Validate target sparsity ratio is between 0 and 1.""" + if not 0.0 <= v <= 1.0: + raise ValueError(f"target_sparse_ratio must be between 0.0 and 1.0, got {v}") + return v + + @field_validator("samples") + @classmethod + def validate_samples(cls, v): + """Validate samples is positive.""" + if v <= 0: + raise ValueError(f"samples must be positive, got {v}") + return v + + @field_validator("max_seqlen") + @classmethod + def validate_max_seqlen(cls, v): + """Validate max_seqlen is at least 1024.""" + if v < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {v}") + return v + + @field_validator("num_length_bins") + @classmethod + def validate_num_length_bins(cls, v): + """Validate num_length_bins is positive.""" + if v <= 0: + raise ValueError(f"num_length_bins must be positive, got {v}") + return v + + # Pre-defined Sparse Attention Configuration # Default configuration with block-wise sparsity optimized for Flash Attention SKIP_SOFTMAX_DEFAULT = { @@ -232,8 +323,32 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): ) +# Configuration with RULER calibration +# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length +# The calibrated threshold adapts to sequence length for optimal sparsity +SKIP_SOFTMAX_CALIB = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "br": 128, + "bc": 128, + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + "calibration": { + "target_sparse_ratio": 0.3, + "samples": 12, + "max_seqlen": 1024, + }, + }, + "default": {"enable": False}, + }, +} + + __all__ = [ + "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "CalibrationConfig", "FlashSkipSoftmaxConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index 8801bafb0..b4a971ece 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -55,6 +55,9 @@ def __init__(self, method_config: dict | None = None): self.enable_correction_factor = config.get("enable_correction_factor", True) self.phase = config.get("phase", None) + # Calibration mode: when True, prevent threshold updates to preserve calibrator's test threshold + self._calibration_mode = False + # Initialize threshold if isinstance(self.threshold_config, dict): self.threshold = self.threshold_config.get( @@ -63,6 +66,10 @@ def __init__(self, method_config: dict | None = None): else: self.threshold = self.threshold_config + def set_calibration_mode(self, enabled: bool): + """Set calibration mode to prevent _update_threshold from modifying the threshold.""" + self._calibration_mode = enabled + def _update_threshold(self, phase: str): """Update threshold based on phase.""" if isinstance(self.threshold_config, dict): @@ -278,8 +285,9 @@ def apply_sparsity( # Infer phase from tensor shape phase = self._infer_phase(attention_scores) - # Update threshold for the detected phase - self._update_threshold(phase) + # Update threshold for the detected phase (skip during calibration) + if not self._calibration_mode: + self._update_threshold(phase) # Apply block-wise sparsity sparse_mask, stats = self.calc_correction_factor_and_p(attention_scores, phase) diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py index 88434e746..0e7d221f9 100644 --- a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -19,13 +19,15 @@ import torch -from modelopt.torch.opt.conversion import apply_mode +from modelopt.torch.opt.conversion import ModeloptStateManager, apply_mode from modelopt.torch.opt.searcher import ForwardLoop +from .calibration import calibrate_sparse_attention from .config import SparseAttentionConfig from .mode import SparseAttentionModeRegistry __all__ = [ + "calibrate", "sparsify", ] @@ -58,12 +60,36 @@ def sparsify( .. code-block::python config = { - "method": "flash_skip_softmax", "sparse_cfg": { + # Phase-aware thresholds with backend selection "*attention*": { + "method": "flash_skip_softmax", "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + # Disable for specific layers + "*layer.0*": {"enable": False}, + # Default settings + "default": {"enable": False}, + }, + } + + For automatic threshold calibration using RULER dataset: + + .. code-block::python + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", "backend": "pytorch", "enable": True, + "calibration": { # Enables automatic threshold calibration + "target_sparse_ratio": 0.5, + "samples": 48, + "max_seqlen": 8192, + }, }, "default": {"enable": False}, }, @@ -126,4 +152,55 @@ def forward_loop(model) -> float: model, mode=[("sparse_attention", config)], registry=SparseAttentionModeRegistry ) + # Calibrate the sparsity ratio of the attention modules + return calibrate(model, config, forward_loop=forward_loop) + + +def calibrate( + model: torch.nn.Module, + config: dict[str, Any] | SparseAttentionConfig | None = None, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Calibrates sparse attention thresholds based on target sparsity. + + This function performs calibration to find optimal thresholds that achieve + the target sparsity ratio specified in the config. + + Args: + model: Model with sparse attention modules + config: Sparse attention config (extracted from modelopt state if None) + forward_loop: Optional callable that forwards calibration data through the model. + If provided, uses this for calibration data. + If None, will auto-generate RULER dataset for calibration. + + Returns: + The calibrated model with optimized sparse attention thresholds. + If no calibration is configured, returns the model unchanged. + """ + # Get config from model if not provided + if config is None: + manager = ModeloptStateManager(model) + if manager.last_mode and manager.last_mode.name == "sparse_attention": + config = manager._last_config + else: + # No sparse attention applied, return model unchanged + return model + + # Extract sparse_cfg + if isinstance(config, dict): + sparse_cfg = config.get("sparse_cfg", {}) + else: + sparse_cfg = config.sparse_cfg + + # Check if calibration is configured in any sparse_cfg pattern + has_calibration = any( + isinstance(cfg, dict) and "calibration" in cfg for cfg in sparse_cfg.values() + ) + + if not has_calibration: + return model + + # Run calibration (handles stats collection internally) + calibrate_sparse_attention(model, config, forward_loop=forward_loop) + return model diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index 16b08bf19..857e15d0f 100644 --- a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -23,6 +23,7 @@ from .config import SparseAttentionAttributeConfig from .methods import get_sparse_method +from .stats_manager import SparseAttentionStatsManager class SparseAttentionModule(DynamicModule): @@ -129,9 +130,10 @@ def get_stats(self) -> dict: Returns: Dictionary with sparsity statistics including 'average_sparsity' if available. - Returns empty dict (statistics collection will be added in calibration PR). + Returns empty dict if stats manager is not enabled. """ - # TODO: Statistics collection will be added in calibration PR + if self._stats_manager is not None and self._stats_manager.enabled: + return self._stats_manager.get_summary() return {} def _setup(self): @@ -140,6 +142,14 @@ def _setup(self): if not hasattr(self, "_method"): self.set_from_attribute_config(None) + # Create stats manager if stats collection is enabled + if self._method_config.get("collect_stats", False): + self._stats_manager = SparseAttentionStatsManager( + module_name="sparse_attention", enabled=True + ) + else: + self._stats_manager = None + def forward(self, *args, **kwargs): """Forward with selected sparse attention method. @@ -157,6 +167,10 @@ def forward(self, *args, **kwargs): with context: result = super().forward(*args, **kwargs) + # Collect stats if manager is available + if self._stats_manager is not None and hasattr(self._sparse_method_instance, "_last_stats"): + self._stats_manager.collect(self._sparse_method_instance._last_stats) + return result def _get_sparse_context(self): diff --git a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py new file mode 100644 index 000000000..9862e6de4 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Statistics manager for sparse attention modules.""" + + +class SparseAttentionStatsManager: + """Centralized statistics manager for sparse attention. + + This class is the single source of truth for all statistics collection + in sparse attention modules. It handles both runtime aggregation and + per-sample calibration statistics. + + Design principles: + - Single responsibility: only stats management + - No computation: receives pre-computed stats from methods + - Optional: can be None if stats collection disabled + - Zero overhead when disabled + """ + + def __init__(self, module_name: str, enabled: bool = True): + """Initialize stats manager. + + Args: + module_name: Name of the module this manager is attached to + enabled: Whether stats collection is enabled + """ + self.module_name = module_name + self.enabled = enabled + self.calibration_mode = False + + # Aggregated stats (running totals across all forward passes) + self.aggregated_stats: dict = { + "total_calls": 0, + "total_blocks": 0, + "sparse_blocks": 0, + "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, + } + + # Per-sample stats (only populated during calibration) + self.per_sample_stats: list[dict] = [] + + def collect(self, stats: dict): + """Collect statistics from a single forward pass. + + Args: + stats: Dictionary containing statistics from method computation. + Expected keys: sparsity, phase, total_blocks, sparse_blocks, + sample_length (optional) + """ + if not self.enabled: + return + + # Update aggregated stats + self.aggregated_stats["total_calls"] += 1 + self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0) + self.aggregated_stats["sparse_blocks"] += stats.get("sparse_blocks", 0) + + phase = stats.get("phase", "unknown") + if phase in self.aggregated_stats["phase_counts"]: + self.aggregated_stats["phase_counts"][phase] += 1 + + # In calibration mode, store per-sample stats + if self.calibration_mode: + self.per_sample_stats.append( + { + "module": self.module_name, + "sparsity": stats.get("sparsity", 0.0), + "sample_length": stats.get("sample_length", 0), + "phase": phase, + } + ) + + def get_summary(self) -> dict: + """Get aggregated statistics summary. + + Returns: + Dictionary with module name, total calls, average sparsity, + and phase distribution. + """ + total_blocks = self.aggregated_stats["total_blocks"] + if total_blocks > 0: + avg_sparsity = self.aggregated_stats["sparse_blocks"] / total_blocks + else: + avg_sparsity = 0.0 + + return { + "module": self.module_name, + "total_calls": self.aggregated_stats["total_calls"], + "average_sparsity": avg_sparsity, + "phase_distribution": self.aggregated_stats["phase_counts"].copy(), + } + + def set_calibration_mode(self, enabled: bool, reset_history: bool = True): + """Enable or disable calibration mode. + + In calibration mode, per-sample statistics are stored for detailed + analysis. Otherwise, only aggregated stats are maintained. + + Args: + enabled: Whether to enable calibration mode + reset_history: Whether to clear per_sample_stats when enabling + """ + self.calibration_mode = enabled + if enabled and reset_history: + self.per_sample_stats = [] + + def reset(self): + """Reset all statistics to initial state.""" + self.aggregated_stats = { + "total_calls": 0, + "total_blocks": 0, + "sparse_blocks": 0, + "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, + } + self.per_sample_stats = [] + + def get_calibration_stats(self) -> list[dict]: + """Get per-sample calibration statistics. + + Returns: + List of per-sample statistics dictionaries. + Empty list if not in calibration mode. + """ + return self.per_sample_stats diff --git a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py index b70dfab35..6cbd5ff5b 100644 --- a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py +++ b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py @@ -43,8 +43,11 @@ def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", @pytest.mark.parametrize("method", ["skip_softmax"]) def test_attention_sparsity(tiny_llama_path, tmp_path, method): - """Test sparse attention with TinyLlama.""" + """Test sparse attention with TinyLlama (with and without calibration).""" run_attention_sparsity_command( model=tiny_llama_path, method=method, + seq_len=128, + num_samples=1, + max_new_tokens=10, ) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py new file mode 100644 index 000000000..035f2c24f --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for sparse attention calibration.""" + +import pytest +import torch +from _test_utils.torch_sparsity.sparse_attention_common import SimpleTransformerEncoderLayer + +import modelopt.torch.opt as mto +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.calibration import RulerDatasetBuilder +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +# Skip all tests if no GPU available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required") + + +class TestRulerDatasetBuilderGPU: + """Test RULER dataset generation with real tokenizers on GPU.""" + + def test_ruler_generation_with_real_tokenizer(self): + """Test RULER generation with GPT2 tokenizer.""" + builder = RulerDatasetBuilder( + samples=6, # Need at least 6 samples (1 per task) + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should generate 6 samples (1 per task) + assert len(dataset) == 6 + + # All samples should have valid structure + for sample in dataset: + assert "input" in sample + assert "length" in sample + assert sample["length"] > 0 + + def test_generated_length_accuracy(self): + """Test that generated token counts are accurate.""" + builder = RulerDatasetBuilder( + samples=3, + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Check that lengths are within reasonable range of target + for sample in dataset: + # RULER aims for 70-90% of target for context + assert 700 < sample["length"] < 1400 + + def test_multiple_subtasks(self): + """Test generation with multiple RULER subtasks.""" + builder = RulerDatasetBuilder( + samples=12, # Need at least 6 for 1 per task, use 12 for 2 per task + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Check task distribution (should have multiple tasks from RULER_TASKS) + tasks_found = {s["task"] for s in dataset} + assert len(tasks_found) >= 2 # At least 2 different tasks + + def test_large_context_lengths(self): + """Test with larger context lengths.""" + builder = RulerDatasetBuilder( + samples=24, # 4 lengths * 6 tasks = need 24 for 1 per (length, task) + max_seqlen=8192, # Generates: [8192, 4096, 2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + assert len(dataset) == 24 + + # Verify we have different lengths + lengths = [s["length"] for s in dataset] + # Should have variety of lengths across the bins + assert len(set(lengths)) > 1 # At least 2 different target lengths used + + +class TestCalibrationGPU: + """Test calibration with real models on GPU.""" + + @pytest.fixture + def simple_model(self): + """Create simple attention model for testing.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + return model + + def test_calibration_simple_model(self, simple_model): + """Test calibration with simple attention model.""" + model = simple_model + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "br": 64, + "bc": 64, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + # Simple forward loop for calibration + pass + + # Apply sparse attention with calibration + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Verify sparse attention modules exist + sparse_modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + assert len(sparse_modules) > 0 + + # Verify calibration was applied + for module in sparse_modules: + method = module._sparse_method_instance + # Check if calibrated threshold scale factor is set + if hasattr(method, "threshold_scale_factor") and method.threshold_scale_factor: + assert method.threshold_scale_factor > 0 + + def test_calibration_pytorch_backend(self, simple_model): + """Test calibration with pytorch backend.""" + model = simple_model + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Check backend is set correctly + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + method = module._sparse_method_instance + assert hasattr(method, "backend") + assert method.backend == "pytorch" + + def test_simplified_calibration(self, simple_model): + """Test simplified calibration (prefill phase only).""" + model = simple_model + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Should complete without errors + assert sparse_model is not None + + def test_calibration_persistence(self, simple_model): + """Test save and restore of calibrated model.""" + model = simple_model + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate model + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Save modelopt state + modelopt_state = mto.modelopt_state(sparse_model) + + # Create new model and restore + model_restored = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + + restored = mto.restore_from_modelopt_state(model_restored, modelopt_state) + + # Check that sparse attention is restored + has_sparse = any(isinstance(m, SparseAttentionModule) for m in restored.modules()) + assert has_sparse + + +class TestCalibrationEndToEnd: + """Integration tests with inference.""" + + @pytest.fixture + def simple_model_setup(self): + """Setup simple model.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + return model + + def test_calibrated_model_inference(self, simple_model_setup): + """Test inference with calibrated model.""" + model = simple_model_setup + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate model + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Test inference + test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=10).cuda() + + sparse_model.eval() + with torch.no_grad(): + output = sparse_model(test_input) + + # Check output is valid + assert output is not None + assert not torch.isnan(output).any() + + def test_calibrated_vs_fixed_threshold(self, simple_model_setup): + """Compare calibrated vs fixed threshold models.""" + # Config with calibration + config_calibrated = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + # Config with fixed threshold (no calibration) + config_fixed = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "enable": True, + } + }, + } + + def forward_loop(model): + pass + + # Test both can be created + model_calibrated = sparsify( + SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda(), + config_calibrated, + forward_loop=forward_loop, + ) + + model_fixed = sparsify( + SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda(), + config_fixed, + ) + + # Both should work for inference + test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=10).cuda() + + with torch.no_grad(): + output_calibrated = model_calibrated(test_input) + output_fixed = model_fixed(test_input) + + assert output_calibrated is not None + assert output_fixed is not None + + def test_memory_usage(self, simple_model_setup): + """Test that calibration doesn't cause memory issues.""" + model = simple_model_setup + + # Clear cache before test + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() + + config = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate + sparsify(model, config, forward_loop=forward_loop) + + # Check memory didn't explode + final_memory = torch.cuda.memory_allocated() + memory_increase = final_memory - initial_memory + + # Memory should be reasonable (not more than 2GB increase) + assert memory_increase < 2 * 1024**3 # 2GB + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py new file mode 100644 index 000000000..716cf54fe --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -0,0 +1,442 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sparse attention calibration.""" + +import pytest + +pytest.importorskip("transformers") + +import numpy as np +import pytest +from _test_utils.torch_sparsity.sparse_attention_common import ( + SimpleAttentionModel, + SimpleTransformerEncoder, +) + +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.calibration import ( + DynamicThresholdCalibrator, + RulerDatasetBuilder, +) +from modelopt.torch.sparsity.attention_sparsity.calibration.dataset import _generate_target_lengths +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestLengthGeneration: + """Test automatic target length generation.""" + + def test_generate_target_lengths_default(self): + """Test default 4 bins generation.""" + lengths = _generate_target_lengths(32768, num_length_bins=4) + assert lengths == [32768, 16384, 8192, 4096] + + def test_generate_target_lengths_stops_at_minimum(self): + """Test generation stops at minimum threshold.""" + lengths = _generate_target_lengths(2048, num_length_bins=4) + assert lengths == [2048, 1024] # Stops at 1024 + + def test_generate_target_lengths_fewer_bins(self): + """Test with fewer bins.""" + lengths = _generate_target_lengths(16384, num_length_bins=2) + assert lengths == [16384, 8192] + + def test_generate_target_lengths_more_bins(self): + """Test with more bins.""" + lengths = _generate_target_lengths(65536, num_length_bins=6) + assert lengths == [65536, 32768, 16384, 8192, 4096, 2048] + + def test_generate_target_lengths_exactly_minimum(self): + """Test when max_seqlen equals minimum.""" + lengths = _generate_target_lengths(1024, num_length_bins=4) + assert lengths == [1024] + + +class TestRulerDatasetBuilder: + """Test RULER dataset generation without requiring real tokenizers.""" + + def test_builder_initialization(self): + """Test that builder initializes correctly.""" + builder = RulerDatasetBuilder( + samples=12, + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + assert builder.total_samples == 12 + assert builder.max_seqlen == 2048 + assert builder.target_lengths == [2048, 1024] + assert builder.samples_per_length == [6, 6] # Evenly distributed + assert len(builder.subtasks) == 6 # All RULER_TASKS + assert builder.seed == 42 + + def test_builder_initialization_invalid_config(self): + """Test that builder raises error for invalid inputs.""" + # Test invalid samples + with pytest.raises(ValueError, match="samples must be positive"): + RulerDatasetBuilder( + samples=0, + max_seqlen=2048, + tokenizer_name_or_path="gpt2", + ) + + # Test max_seqlen below minimum + with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): + RulerDatasetBuilder( + samples=4, + max_seqlen=512, # Below minimum + tokenizer_name_or_path="gpt2", + ) + + def test_dataset_generation_minimal(self): + """Test generating small dataset.""" + builder = RulerDatasetBuilder( + samples=12, # 6 tasks x 2 lengths = need 12 for 1 per task per length + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should generate 12 samples (6 tasks x 1 sample per task x 2 lengths) + assert len(dataset) == 12 + assert all(isinstance(sample, dict) for sample in dataset) + + def test_dataset_structure(self): + """Test that dataset has correct structure.""" + builder = RulerDatasetBuilder( + samples=6, # Need at least 6 (1 per task) + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + sample = dataset[0] + + # Check required fields + assert "input" in sample + assert "length" in sample + assert "task" in sample + assert "target_length" in sample + + # Check field types + assert isinstance(sample["input"], str) + assert isinstance(sample["length"], int) + assert isinstance(sample["task"], str) + assert sample["length"] > 0 + + def test_sample_distribution(self): + """Test that samples are distributed across lengths and subtasks.""" + builder = RulerDatasetBuilder( + samples=24, # 6 tasks x 2 lengths x 2 samples = 24 + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should have 24 samples (12 per length, 2 per task) + assert len(dataset) == 24 + + # Check task distribution (should have variety from all RULER_TASKS) + tasks = [s["task"] for s in dataset] + # Verify we have all 6 tasks represented + assert len(set(tasks)) == 6 + + def test_length_targeting(self): + """Test that generated lengths are close to targets.""" + builder = RulerDatasetBuilder( + samples=6, # 1 per task + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Lengths should be within reasonable range of target + # RULER aims for 70-90% of target length for context + for sample in dataset: + assert 700 < sample["length"] < 1400 # Reasonable range around 1024 + + def test_uneven_sample_distribution(self): + """Test that samples are distributed evenly (remainder dropped).""" + builder = RulerDatasetBuilder( + samples=50, # 50 samples across 4 lengths + max_seqlen=8192, # Generates: [8192, 4096, 2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + # Even distribution: 50//4 = 12 per length + assert builder.total_samples == 50 + assert builder.target_lengths == [8192, 4096, 2048, 1024] + assert builder.samples_per_length == [12, 12, 12, 12] + assert sum(builder.samples_per_length) == 48 # 2 samples dropped (remainder) + + # Actual generated samples: 12//6=2 per task, 4 lengths, 6 tasks + # Total: 2 x 6 x 4 = 48 + dataset = builder.build_calibration_dataset() + assert len(dataset) == 48 + + +class TestDynamicThresholdCalibrator: + """Test calibration algorithm correctness.""" + + def test_calibrator_initialization(self): + """Test that calibrator initializes correctly.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + threshold_trials=[1e-4, 1e-3, 1e-2], + ) + + assert calibrator.target_sparse_ratio == 0.5 + assert len(calibrator.threshold_trials) == 3 + + def test_calibrator_default_threshold_trials(self): + """Test that calibrator has default threshold trials.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + ) + + # Should have default threshold trials + assert calibrator.threshold_trials is not None + assert len(calibrator.threshold_trials) == 12 + # Check they are positive and in valid range + trials = calibrator.threshold_trials + assert all(0 < t < 1 for t in trials) + + def test_regression_calculation_synthetic(self): + """Test 'a' parameter calculation with synthetic data.""" + # Create synthetic optimal pairs + # If threshold = a / length, then with perfect data: + # length=1000, threshold=10 => a=10000 + # length=2000, threshold=5 => a=10000 + optimal_pairs = [ + {"length": 1000, "optimal_threshold": 10.0, "achieved_sparsity": 0.5}, + {"length": 2000, "optimal_threshold": 5.0, "achieved_sparsity": 0.5}, + {"length": 4000, "optimal_threshold": 2.5, "achieved_sparsity": 0.5}, + ] + + # Manual regression calculation + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + # Calculate 'a' using least squares + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Should be close to 10000 + assert 9500 < a_parameter < 10500 + + # Test individual 'a' values + a_per_sample = y * lengths + assert np.allclose(a_per_sample, 10000, rtol=0.05) + + def test_multiple_samples_different_lengths(self): + """Test regression with varied lengths.""" + # More realistic scenario with some variance + optimal_pairs = [ + {"length": 500, "optimal_threshold": 20.0, "achieved_sparsity": 0.5}, + {"length": 1000, "optimal_threshold": 10.5, "achieved_sparsity": 0.51}, + {"length": 2000, "optimal_threshold": 5.2, "achieved_sparsity": 0.49}, + {"length": 4000, "optimal_threshold": 2.4, "achieved_sparsity": 0.50}, + ] + + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Should still be around 10000 with some tolerance for variance + assert 9000 < a_parameter < 11000 + + def test_r_squared_calculation(self): + """Test R-squared calculation for regression quality.""" + # Perfect fit data + optimal_pairs = [ + {"length": 1000, "optimal_threshold": 10.0}, + {"length": 2000, "optimal_threshold": 5.0}, + {"length": 4000, "optimal_threshold": 2.5}, + ] + + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Calculate R-squared + y_pred = a_parameter * x + ss_res = np.sum((y - y_pred) ** 2) + ss_tot = np.sum((y - np.mean(y)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + # Perfect fit should have R^2 close to 1 + assert r_squared > 0.99 + + +class TestCalibrationIntegration: + """Test end-to-end calibration without GPU.""" + + def test_calibration_disabled(self): + """Test that no calibration occurs when disabled.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + # No forward_loop needed when calibration disabled + sparse_model = sparsify(model, config) + + # Check that sparse attention is applied but not calibrated + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + # Check that no calibration is set + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + method = module._sparse_method_instance + assert not getattr(method, "threshold_scale_factor", None) + + def test_sparsify_with_calibration_requires_forward_loop(self): + """Test that calibration requires forward_loop or proper model config.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 64, + "bc": 64, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + # Without forward_loop and without model.config._name_or_path, should raise ValueError + with pytest.raises(ValueError, match="Could not load tokenizer"): + sparsify(model, config, forward_loop=None) + + def test_multiple_sparse_modules(self): + """Test that calibration handles multiple attention layers.""" + model = SimpleTransformerEncoder() + + config = { + "sparse_cfg": {"*attn*": {"threshold": 1e-3, "br": 64, "bc": 64, "enable": True}}, + } + + sparse_model = sparsify(model, config) + + # Count sparse attention modules + sparse_count = sum( + 1 for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) + ) + + # Should have 2 sparse attention modules + assert sparse_count == 2 + + def test_calibration_config_validation(self): + """Test CalibrationConfig validation.""" + from modelopt.torch.sparsity.attention_sparsity.config import CalibrationConfig + + # Valid config + config = CalibrationConfig( + target_sparse_ratio=0.5, + samples=48, + max_seqlen=32768, + ) + assert config.target_sparse_ratio == 0.5 + assert config.samples == 48 + assert config.max_seqlen == 32768 + + # Invalid target_sparse_ratio (> 1.0) + with pytest.raises(ValueError, match="target_sparse_ratio must be between"): + CalibrationConfig(target_sparse_ratio=1.5, samples=48, max_seqlen=32768) + + # Invalid target_sparse_ratio (< 0.0) + with pytest.raises(ValueError, match="target_sparse_ratio must be between"): + CalibrationConfig(target_sparse_ratio=-0.1, samples=48, max_seqlen=32768) + + # Invalid samples + with pytest.raises(ValueError, match="samples must be positive"): + CalibrationConfig(target_sparse_ratio=0.5, samples=0, max_seqlen=32768) + + # Invalid max_seqlen + with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): + CalibrationConfig(target_sparse_ratio=0.5, samples=48, max_seqlen=512) + + def test_threshold_trials_validation(self): + """Test threshold_trials validation.""" + from modelopt.torch.sparsity.attention_sparsity.config import CalibrationConfig + + # Valid custom threshold_trials + config = CalibrationConfig( + target_sparse_ratio=0.5, + threshold_trials=[1e-5, 1e-4, 1e-3, 1e-2], + ) + assert config.threshold_trials == [1e-5, 1e-4, 1e-3, 1e-2] + + # None (use defaults) + config_default = CalibrationConfig(target_sparse_ratio=0.5) + assert config_default.threshold_trials is None + + # Invalid: empty list + with pytest.raises(ValueError, match="threshold_trials must not be empty"): + CalibrationConfig(threshold_trials=[]) + + # Invalid: threshold out of range (>= 1.0) + with pytest.raises(ValueError, match="must be in range"): + CalibrationConfig(threshold_trials=[1e-4, 1.0]) + + # Invalid: threshold out of range (<= 0) + with pytest.raises(ValueError, match="must be in range"): + CalibrationConfig(threshold_trials=[1e-4, 0]) + + # Invalid: not a list (Pydantic raises ValidationError, not ValueError) + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="Input should be a valid list"): + CalibrationConfig(threshold_trials=1e-4) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From d29d64898c26483e8c659ef6ff51d48738b6ee66 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 10 Nov 2025 16:50:04 -0800 Subject: [PATCH 6/6] Add sparse attention integration to llm_eval Signed-off-by: Kai Xu --- examples/llm_eval/lm_eval_hf.py | 28 + examples/llm_eval/mmlu.py | 25 + examples/llm_eval/modeling.py | 5 + examples/llm_eval/sparse_attention_utils.py | 111 ++++ .../llm_sparsity/attention_sparsity/README.md | 161 ++++++ .../llm_sparsity/attention_sparsity/hf_sa.py | 144 ++---- .../attention_sparsity/requirements.txt | 2 + .../calibration/calibrate.py | 40 +- .../calibration/calibrator.py | 3 +- .../attention_sparsity/calibration/dataset.py | 148 ++---- .../calibration/download_ruler_data.sh | 50 ++ .../calibration/ruler_utils.py | 487 ++++++++++++++++++ .../sparsity/attention_sparsity/config.py | 67 +-- .../sparsity/attention_sparsity/conversion.py | 79 +-- .../methods/flash_skip_softmax.py | 61 ++- .../attention_sparsity/methods/registry.py | 13 + .../attention_sparsity/model_sparsify.py | 37 +- .../attention_sparsity/sparse_attention.py | 28 +- .../attention_sparsity/stats_manager.py | 12 + setup.py | 2 + .../torch_sparsity/sparse_attention_common.py | 10 +- tests/examples/llm_eval/test_llm_eval.py | 17 + .../test_calibration_gpu.py | 16 +- .../test_sparse_attention_calibration.py | 211 +++++++- .../test_sparse_attention_conversion.py | 97 ++++ .../attention_sparsity/test_stats_manager.py | 334 ++++++++++++ .../attention_sparsity/test_threshold_info.py | 270 ++++++++++ 27 files changed, 2072 insertions(+), 386 deletions(-) create mode 100644 examples/llm_eval/sparse_attention_utils.py create mode 100644 examples/llm_sparsity/attention_sparsity/README.md create mode 100644 examples/llm_sparsity/attention_sparsity/requirements.txt create mode 100755 modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py diff --git a/examples/llm_eval/lm_eval_hf.py b/examples/llm_eval/lm_eval_hf.py index 31103ff86..24dcb28f6 100755 --- a/examples/llm_eval/lm_eval_hf.py +++ b/examples/llm_eval/lm_eval_hf.py @@ -43,9 +43,11 @@ from lm_eval.api.model import T from lm_eval.models.huggingface import HFLM from quantization_utils import quantize_model +from sparse_attention_utils import sparsify_model import modelopt.torch.opt as mto from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T: @@ -60,9 +62,20 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | calib_size = arg_dict.pop("calib_size", 512) compress = arg_dict.pop("compress", False) + # Sparse attention arguments + sparse_cfg = arg_dict.pop("sparse_cfg", None) + additional_config = {} if additional_config is None else additional_config additional_config = {k: v for k, v in additional_config.items() if v is not None} + # Force eager attention if sparse attention is requested + if sparse_cfg: + additional_config["attn_implementation"] = "eager" + warnings.warn( + "Sparse attention requires attn_implementation='eager'. " + "Forcing eager attention implementation." + ) + # Enable automatic save/load of modelopt state huggingface checkpointing mto.enable_huggingface_checkpointing() @@ -91,6 +104,15 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | auto_quantize_checkpoint=auto_quantize_checkpoint, ) + if sparse_cfg: + if is_attn_sparsified(model_obj.model): + warnings.warn("Skipping sparse attention: model already has sparse attention applied.") + else: + sparsify_model( + model=model_obj, + sparse_cfg=sparse_cfg, + ) + return model_obj @@ -152,6 +174,11 @@ def setup_parser_with_modelopt_args(): action="store_true", help="Compress the model after quantization", ) + parser.add_argument( + "--sparse_cfg", + type=str, + help="Sparse attention configuration (e.g., SKIP_SOFTMAX_DEFAULT, SKIP_SOFTMAX_CALIB)", + ) return parser @@ -177,6 +204,7 @@ def setup_parser_with_modelopt_args(): "calib_batch_size": args.calib_batch_size, "calib_size": args.calib_size, "compress": args.compress, + "sparse_cfg": args.sparse_cfg, } ) diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index ca244052b..0bf47fcd3 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -48,6 +48,7 @@ from fire import Fire from modeling import EvalModel, select_model from quantization_utils import MAX_SEQ_LEN, get_tokenizer, quantize_model +from sparse_attention_utils import sparsify_model from tqdm import tqdm try: @@ -56,6 +57,7 @@ LLM = None # type: ignore[misc] import modelopt.torch.opt as mto from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -230,6 +232,7 @@ def main( auto_quantize_method: str = "gradient", auto_quantize_score_size: int = 128, auto_quantize_checkpoint: str | None = None, + sparse_cfg: str | None = None, **kwargs, ): random.seed(RAND_SEED) @@ -266,6 +269,14 @@ def main( max_batch_size=1, ) else: + # Force eager attention if sparse attention is requested + if sparse_cfg: + kwargs["attn_implementation"] = "eager" + warnings.warn( + "Sparse attention requires attn_implementation='eager'. " + "Forcing eager attention implementation." + ) + model = select_model( max_input_length=MAX_SEQ_LEN, max_output_length=2, dtype=dtype, **kwargs ) @@ -289,6 +300,20 @@ def main( auto_quantize_checkpoint=auto_quantize_checkpoint, ) + # Apply sparse attention if requested + if sparse_cfg: + model.load() + + if is_attn_sparsified(model.model): + warnings.warn( + "Skipping sparse attention: model already has sparse attention applied." + ) + else: + sparsify_model( + model=model, + sparse_cfg=sparse_cfg, + ) + for subject in tqdm(subjects): dev_df = pd.read_csv(os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None)[ :ntrain diff --git a/examples/llm_eval/modeling.py b/examples/llm_eval/modeling.py index 747b95d5b..d06d05560 100644 --- a/examples/llm_eval/modeling.py +++ b/examples/llm_eval/modeling.py @@ -179,6 +179,7 @@ class SeqToSeqModel(EvalModel): lora_path: str = "" device: str = "cuda" load_8bit: bool = False + attn_implementation: str | None = None def load(self): if self.model is None: @@ -188,6 +189,8 @@ def load(self): if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto") + if self.attn_implementation: + args["attn_implementation"] = self.attn_implementation self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path, **args) print_gpu_utilization() if self.lora_path: @@ -241,6 +244,8 @@ def load(self): if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto") + if self.attn_implementation: + args["attn_implementation"] = self.attn_implementation self.model = AutoModelForCausalLM.from_pretrained( self.model_path, trust_remote_code=True, **args ) diff --git a/examples/llm_eval/sparse_attention_utils.py b/examples/llm_eval/sparse_attention_utils.py new file mode 100644 index 000000000..dc7a1b14e --- /dev/null +++ b/examples/llm_eval/sparse_attention_utils.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for sparse attention integration with llm_eval.""" + +import modelopt.torch.sparsity.attention_sparsity as mtsa + +# Custom sparse attention configurations +CUSTOM_SPARSE_CONFIG = { + "SPARSE_CONSERVATIVE": { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 5e-4, "decode": 1e-5}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + }, + "default": {"enable": False}, + }, + }, + "SPARSE_AGGRESSIVE": { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 5e-3, "decode": 5e-4}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + }, + "default": {"enable": False}, + }, + }, +} + + +def _extract_model(model_obj): + """Extract actual model from wrapper (HFLM or EvalModel).""" + if hasattr(model_obj, "gpt2"): + return model_obj.gpt2 + elif hasattr(model_obj, "model"): + return model_obj.model + else: + return model_obj + + +def sparsify_model( + model, + sparse_cfg: str, + backend=None, +): + """Apply sparse attention to model with optional RULER calibration. + + Args: + model: Model wrapper (HFLM or EvalModel) or raw model + sparse_cfg: Sparse attention config name or dict + backend: Backend to use (optional, overrides config backend) + + Returns: + The model with sparse attention applied + + Note: + Calibration is automatically triggered if the config contains a 'calibration' field. + The calibration will auto-generate RULER dataset from the model's tokenizer. + """ + # Extract actual model + net = _extract_model(model) + + # Resolve config + if isinstance(sparse_cfg, str): + # Try custom configs first + mtsa_cfg = CUSTOM_SPARSE_CONFIG.get(sparse_cfg) + if mtsa_cfg is None: + # Try predefined configs + mtsa_cfg = getattr(mtsa, sparse_cfg, None) + if mtsa_cfg is None: + raise ValueError(f"Unknown sparse_cfg: {sparse_cfg}") + else: + mtsa_cfg = sparse_cfg + + # Override backend if specified + if backend: + if isinstance(mtsa_cfg, dict) and "sparse_cfg" in mtsa_cfg: + modified_sparse_cfg = {} + for pattern, cfg in mtsa_cfg["sparse_cfg"].items(): + modified_cfg = cfg.copy() if isinstance(cfg, dict) else cfg + if isinstance(modified_cfg, dict): + modified_cfg["backend"] = backend + modified_sparse_cfg[pattern] = modified_cfg + mtsa_cfg = {"sparse_cfg": modified_sparse_cfg} + + # Apply sparsification + print(f"\nApplying sparse attention with config: {sparse_cfg}") + mtsa.sparsify(net, mtsa_cfg) + print("Sparse attention applied successfully!") + + return model diff --git a/examples/llm_sparsity/attention_sparsity/README.md b/examples/llm_sparsity/attention_sparsity/README.md new file mode 100644 index 000000000..708947683 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/README.md @@ -0,0 +1,161 @@ +# Attention Sparsity for HuggingFace Models + +In this tutorial, we demonstrate how to use NVIDIA TensorRT Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. + +## Getting Started + +### Quick Example + +```python +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT + +# Load your model +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-8B", + attn_implementation="eager", # Required for sparse attention + torch_dtype=torch.bfloat16, +) + +# Apply sparse attention +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) +``` + +> [!Note] +> `attn_implementation="eager"` is required for sparse attention to work properly. Flash Attention 2 or SDPA would bypass the softmax patching needed for stats collection. + +## Configuration Options + +Two pre-defined configurations are available: + +### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT) + +Uses a fixed threshold value. Simple but may not be optimal for all sequence lengths. + +```python +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT + +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) +``` + +### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB) + +Uses RULER-based calibration to determine an optimal dynamic threshold that adapts to sequence length. Recommended for production use. + +```python +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_CALIB + +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB) +``` + +## Prerequisites + +### Install Requirements + +```bash +pip install -r requirements.txt +``` + +### Download RULER Calibration Data (Required for Calibration) + +If using `SKIP_SOFTMAX_CALIB`, you need to download the RULER calibration dataset first: + +```bash +bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh +``` + +This downloads the Paul Graham essays dataset used for generating calibration samples. + +## Run Sparse Attention on HuggingFace Models + +### Basic Usage (Without Calibration) + +Apply sparse attention with a fixed threshold: + +```bash +python examples/llm_sparsity/attention_sparsity/hf_sa.py \ + --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax +``` + +### With RULER Calibration + +Apply sparse attention with calibrated thresholds for optimal sparsity: + +```bash +python examples/llm_sparsity/attention_sparsity/hf_sa.py \ + --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax_calib +``` + +The calibration process: + +1. Generates RULER calibration samples +2. Collects attention statistics during forward passes +3. Determines optimal threshold scale factor for target sparsity ratio + +### Command Line Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--pyt_ckpt_path` | Required | HuggingFace model path or name | +| `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax` or `skip_softmax_calib` | +| `--backend` | `pytorch` | Backend: `pytorch` or `triton` | +| `--seq_len` | `2048` | Maximum sequence length for input prompts | +| `--export_dir` | `None` | Directory to export the sparsified model | + +## Output Comparison + +The script automatically compares outputs before and after applying sparse attention: + +1. Loads a test sample from the NarrativeQA dataset +2. Generates text before sparse attention is applied +3. Applies sparse attention (with optional calibration) +4. Generates text after sparse attention is applied +5. Compares and displays both outputs + +## Export Model + +Export the sparsified model to a HuggingFace checkpoint: + +```bash +python examples/llm_sparsity/attention_sparsity/hf_sa.py \ + --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax_calib \ + --export_dir ./exported_sparse_model +``` + +The exported model can be loaded and used with standard HuggingFace APIs. + +## Custom Configuration + +You can create custom sparse attention configurations: + +```python +custom_config = { + "sparse_cfg": { + "calibration": { # Optional: omit for fixed threshold + "target_sparse_ratio": 0.5, # Target 50% sparsity + "samples": 128, # Number of calibration samples + "max_seqlen": 8192, # Maximum sequence length + }, + "*attn*": { # Pattern to match attention modules + "method": "flash_skip_softmax", + "threshold": 1e-4, # Fixed threshold (ignored if calibration is used) + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + +model = mtsa.sparsify(model, config=custom_config) +``` + +## References + +- [TensorRT Model Optimizer Documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/) +- [RULER: What's the Real Context Size of Your Long-Context Language Models?](https://github.com/NVIDIA/RULER) diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 2b5f3ade5..29a2b53aa 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -28,12 +28,10 @@ import modelopt.torch.opt as mto import modelopt.torch.sparsity.attention_sparsity as mtsa from modelopt.torch.export import export_hf_checkpoint -from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig from modelopt.torch.sparsity.attention_sparsity.config import ( SKIP_SOFTMAX_CALIB, SKIP_SOFTMAX_DEFAULT, ) -from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule from modelopt.torch.utils.memory_monitor import launch_memory_monitor RAND_SEED = 1234 @@ -120,30 +118,23 @@ def truncate_text(text: str, tokenizer, max_length: int): return begin_text + " [...] " + end_text -def verify_outputs(model, tokenizer, args): - """Compare outputs between baseline and sparse attention models.""" - # Update seq_len to match calibration max_seqlen if calibration was used - base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {}) - if "calibration" in base_config and "max_seqlen" in base_config["calibration"]: - calib_max_seqlen = base_config["calibration"]["max_seqlen"] - if args.seq_len != calib_max_seqlen: - print( - f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} " - f"to match calibration config" - ) - args.seq_len = calib_max_seqlen +def generate_sample_output(model, tokenizer, args): + """Generate sample output for comparison. + + Args: + model: The model to generate with + tokenizer: Tokenizer for encoding/decoding + args: Command line arguments - # Load and prepare a single test prompt - print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)") + Returns: + Tuple of (generated_text, input_prompt, input_ids) + """ + # Load test sample prompts = get_narrativeqa_samples(num_samples=1) prompt = prompts[0] # Prepare inputs truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len) - display_prompt = ( - truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt - ) - inputs = tokenizer( truncated_prompt, return_tensors="pt", @@ -154,14 +145,7 @@ def verify_outputs(model, tokenizer, args): if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} - print("\n" + "=" * 60) - print("BASELINE vs SPARSE ATTENTION COMPARISON") - print("=" * 60) - print(f"\nTest prompt: {display_prompt}") - print(f"Input tokens: {inputs['input_ids'].shape[1]}") - - # Helper function to generate text - def generate_text(model, inputs, args, tokenizer): + # Generate with torch.no_grad(): outputs = model.generate( **inputs, @@ -172,60 +156,9 @@ def generate_text(model, inputs, args, tokenizer): ) input_length = inputs["input_ids"].shape[1] generated_ids = outputs[0][input_length:] - return tokenizer.decode(generated_ids, skip_special_tokens=True) - - # Find all sparse attention modules - sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] - - # Generate baseline by temporarily disabling sparse attention - print("\n" + "-" * 60) - print("Generating baseline (sparse attention disabled)...") - for module in sparse_modules: - module.disable() - baseline_text = generate_text(model, inputs, args, tokenizer) + generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) - # Generate with sparse attention enabled - print("\nGenerating with sparse attention (calibrated thresholds)...") - for module in sparse_modules: - module.enable() - sparse_text = generate_text(model, inputs, args, tokenizer) - - # Display comparison - print("\n" + "-" * 60) - print("RESULTS:") - baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text - sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text - - print(f"\nBaseline: {baseline_display}") - print(f"With Sparse: {sparse_display}") - - if baseline_text == sparse_text: - print("\nOutputs are identical") - else: - print("\nOutputs differ") - - -def sparsify_model(model, args): - """Apply sparse attention to the model with optional calibration.""" - print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}") - base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] - - # Create modified config with selected backend - modified_sparse_cfg = {} - for pattern, cfg in base_config["sparse_cfg"].items(): - modified_cfg = cfg.copy() - modified_cfg["backend"] = args.backend - modified_sparse_cfg[pattern] = modified_cfg - - # Create new config with modified settings - sparse_config = SparseAttentionConfig(sparse_cfg=modified_sparse_cfg) - - # Sparsify the model - model = mtsa.sparsify(model, config=sparse_config) - - print("Sparse attention applied successfully!") - - return model + return generated_text, truncated_prompt, inputs["input_ids"] def main(args): @@ -258,12 +191,40 @@ def main(args): model = model.cuda() print("Model moved to CUDA") - # Apply sparse attention to the model (with calibration if configured) - model = sparsify_model(model, args) + # Generate sample output BEFORE sparse attention + print("\nGenerating sample output before sparse attention...") + output_before, test_prompt, input_ids = generate_sample_output(model, tokenizer, args) + + # Apply sparse attention with optional calibration + print(f"\nApplying sparse attention: {args.sparse_attn}") + sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] + model = mtsa.sparsify(model, config=sparse_config) + print("Sparse attention applied successfully!") + + # Generate sample output AFTER sparse attention + print("\nGenerating sample output after sparse attention...") + output_after, _, _ = generate_sample_output(model, tokenizer, args) + + # Display comparison + print("\n" + "=" * 60) + print("OUTPUT COMPARISON (Before vs After Sparse Attention)") + print("=" * 60) + display_prompt = test_prompt[:150] + "..." if len(test_prompt) > 150 else test_prompt + print(f"\nTest prompt: {display_prompt}") + print(f"Input tokens: {input_ids.shape[1]}") + + output_before_display = ( + output_before[:300] + "..." if len(output_before) > 300 else output_before + ) + output_after_display = output_after[:300] + "..." if len(output_after) > 300 else output_after + + print(f"\nBefore sparse attention: {output_before_display}") + print(f"After sparse attention: {output_after_display}") - # Verify outputs if requested (compares baseline vs calibrated sparse model) - if args.verify_output: - verify_outputs(model, tokenizer, args) + if output_before == output_after: + print("\nOutputs are identical") + else: + print("\nOutputs differ") # Export if requested if args.export_dir: @@ -310,12 +271,6 @@ def main(args): default=2048, help="Maximum sequence length for input prompts (will be truncated if longer)", ) - parser.add_argument( - "--num_samples", - type=int, - default=3, - help="Number of samples to use from NarrativeQA dataset", - ) # Generation arguments parser.add_argument( @@ -325,11 +280,6 @@ def main(args): parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling") # Operation arguments - parser.add_argument( - "--verify_output", - action="store_true", - help="Verify that sparse attention outputs match baseline", - ) parser.add_argument( "--export_dir", type=str, diff --git a/examples/llm_sparsity/attention_sparsity/requirements.txt b/examples/llm_sparsity/attention_sparsity/requirements.txt new file mode 100644 index 000000000..a3e0dfa17 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/requirements.txt @@ -0,0 +1,2 @@ +nltk +wonderwords diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index b90adbc0f..f6e66ae00 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -24,6 +24,7 @@ from transformers import AutoTokenizer from ..config import CalibrationConfig +from ..conversion import print_sparse_attention_summary from ..sparse_attention import SparseAttentionModule from .calibrator import DynamicThresholdCalibrator from .dataset import RulerDatasetBuilder @@ -51,28 +52,31 @@ def _extract_tokenizer_from_model(model: nn.Module) -> str: def _extract_calibration_config(config: dict[str, Any]) -> CalibrationConfig | None: - """Extract and validate calibration config from sparse_cfg patterns. + """Extract and validate calibration config from sparse_cfg. Args: config: Sparse attention configuration dict Returns: - Validated CalibrationConfig or None if not found + Validated CalibrationConfig instance, or None if calibration is not configured + + Raises: + ValueError: If calibration config has invalid type or contains invalid values """ - # Extract sparse_cfg and search for calibration sparse_cfg = config.get("sparse_cfg", {}) - calib_dict = next( - ( - cfg["calibration"] - for cfg in sparse_cfg.values() - if isinstance(cfg, dict) and "calibration" in cfg - ), - None, - ) + # Calibration is optional + if "calibration" not in sparse_cfg: + return None + + calib_dict = sparse_cfg["calibration"] - # Create and calidate the calibration config - return CalibrationConfig(**calib_dict) if calib_dict else None + # Validate calibration is a dict + if not isinstance(calib_dict, dict): + raise ValueError(f"Calibration config must be a dict, got {type(calib_dict).__name__}. ") + + # Create and validate CalibrationConfig + return CalibrationConfig(**calib_dict) def create_calibration_forward_loop( @@ -127,7 +131,9 @@ def calibrate_sparse_attention( """ # Extract and validate calibration config calib_config = _extract_calibration_config(config) - if not calib_config: + + # Skip calibration if not configured + if calib_config is None: return {} # Generate forward_loop if not provided @@ -138,7 +144,7 @@ def calibrate_sparse_attention( max_seqlen=calib_config.max_seqlen, tokenizer_name_or_path=tokenizer, num_length_bins=calib_config.num_length_bins, - max_length_filter=int(calib_config.max_seqlen * 1.2), + max_length_filter=int(calib_config.max_seqlen * 1.5), ) calibration_data = builder.build_calibration_dataset() print(f"Generated {len(calibration_data)} calibration samples") @@ -162,6 +168,10 @@ def calibrate_sparse_attention( ) calibration_result = calibrator.calibrate(model, forward_loop) + # Print calibration statistics (regardless of success/failure for debugging) + print("\nCalibration complete!") + print_sparse_attention_summary(model) + if "scale_factor" not in calibration_result: warnings.warn("Calibration did not produce valid results") return {} diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py index f933e50ce..2914651f7 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -90,8 +90,7 @@ def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]: attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] if not attention_modules: - warnings.warn("No sparse attention modules found for calibration") - return {} + raise ValueError("No sparse attention modules found for calibration") print("Starting dynamic threshold calibration") print(f"Target sparsity: {self.target_sparse_ratio}") diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py index 727fb189a..7603b4e1d 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -17,13 +17,14 @@ import random import string -import uuid from dataclasses import dataclass from typing import Any from tqdm import tqdm from transformers import AutoTokenizer +from . import ruler_utils + def _generate_target_lengths( max_seqlen: int, num_length_bins: int = 4, min_seqlen: int = 1024 @@ -143,7 +144,7 @@ class RulerTask: " Answer: According to the coded text above, " "the three most frequently appeared words are:" ), - args={}, + args={"alpha": 2.0}, ), "qa_1": RulerTask( name="qa_1", @@ -158,7 +159,7 @@ class RulerTask: "Question: {query}" ), answer_prefix=" Answer:", - args={"dataset": "squad", "demo_selection": "topk_lt_ctx"}, + args={"dataset": "squad"}, ), "qa_2": RulerTask( name="qa_2", @@ -173,7 +174,7 @@ class RulerTask: "Question: {query}" ), answer_prefix=" Answer:", - args={"dataset": "hotpotqa", "demo_selection": "topk_lt_ctx"}, + args={"dataset": "hotpotqa"}, ), } @@ -185,17 +186,17 @@ def __init__( self, samples: int, max_seqlen: int, - tokenizer_name_or_path: str, - seed: int = 42, + tokenizer_name_or_path: str | object, num_length_bins: int = 4, max_length_filter: int = 65536, + seed: int = 42, ): """Initialize RULER dataset builder. Args: samples: Total number of samples to generate (distributed evenly across length bins) max_seqlen: Maximum sequence length (length bins auto-generated as powers of 2) - tokenizer_name_or_path: HuggingFace tokenizer path + tokenizer_name_or_path: HuggingFace tokenizer path or tokenizer object seed: Random seed for reproducibility num_length_bins: Number of length bins to generate (default: 4) max_length_filter: Maximum sequence length to keep (default: 65536) @@ -229,8 +230,11 @@ def __init__( # Distribute samples evenly across lengths self.samples_per_length = [samples // len(self.target_lengths)] * len(self.target_lengths) - # Initialize tokenizer and seed - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + # Initialize tokenizer + if isinstance(tokenizer_name_or_path, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + else: + self.tokenizer = tokenizer_name_or_path random.seed(seed) def build_calibration_dataset(self) -> list[dict[str, Any]]: @@ -247,9 +251,7 @@ def build_calibration_dataset(self) -> list[dict[str, Any]]: desc="Generating RULER calibration samples", total=len(self.target_lengths), ): - samples_per_task = num_samples // len(self.subtasks) - if samples_per_task <= 0: - continue + samples_per_task = max(num_samples // len(self.subtasks), 1) # Generate equal samples for each task for task_name in self.subtasks: @@ -293,50 +295,43 @@ def _generate_niah_sample( """Generate a needle-in-haystack sample.""" args = task.args - # Generate needles based on type - if args["type_needle_k"] == "words": - needle_keys = [self._generate_random_word() for _ in range(args["num_needle_k"])] - else: # uuids - needle_keys = [str(uuid.uuid4()) for _ in range(args["num_needle_k"])] - - if args["type_needle_v"] == "numbers": - needle_values = [ - str(random.randint(100000, 999999)) for _ in range(args["num_needle_v"]) - ] - else: # uuids - needle_values = [str(uuid.uuid4()) for _ in range(args["num_needle_v"])] - - # Select query needles - query_keys = random.sample(needle_keys, min(args["num_needle_q"], len(needle_keys))) - - # Generate context with needles - context = self._generate_niah_context( - target_length, needle_keys, needle_values, args["type_haystack"] + # Find optimal haystack size for target length + optimal_haystack = ruler_utils.find_optimal_haystack_size( + tokenizer=self.tokenizer, + max_seq_length=target_length, + template=task.template, + answer_prefix=task.answer_prefix, + tokens_to_generate=task.tokens_to_generate, + type_haystack=args.get("type_haystack", "essay"), + type_needle_k=args.get("type_needle_k", "words"), + type_needle_v=args.get("type_needle_v", "numbers"), + num_needle_k=args.get("num_needle_k", 1), + num_needle_v=args.get("num_needle_v", 1), + num_needle_q=args.get("num_needle_q", 1), ) - # Format template - template = task.template.format( - type_needle_v=args["type_needle_v"], - context=context, - query=query_keys[0] if query_keys else needle_keys[0], + # Generate sample using official RULER implementation + sample = ruler_utils.generate_niah_sample( + num_haystack=optimal_haystack, + tokenizer=self.tokenizer, + template=task.template, + answer_prefix=task.answer_prefix, + tokens_to_generate=task.tokens_to_generate, + type_haystack=args.get("type_haystack", "essay"), + type_needle_k=args.get("type_needle_k", "words"), + type_needle_v=args.get("type_needle_v", "numbers"), + num_needle_k=args.get("num_needle_k", 1), + num_needle_v=args.get("num_needle_v", 1), + num_needle_q=args.get("num_needle_q", 1), + random_seed=self.seed + sample_idx, ) - # Add answer prefix - full_input = template + task.answer_prefix.format( - type_needle_v=args["type_needle_v"], - query=query_keys[0] if query_keys else needle_keys[0], - ) + # Add task metadata + sample["task"] = task.name + sample["target_length"] = target_length + sample["sample_idx"] = sample_idx - # Tokenize to get actual length - tokens = self.tokenizer.encode(full_input, add_special_tokens=False) - - return { - "input": full_input, - "length": len(tokens), - "task": task.name, - "target_length": target_length, - "sample_idx": sample_idx, - } + return sample def _generate_vt_sample( self, task: RulerTask, target_length: int, sample_idx: int @@ -471,47 +466,6 @@ def _generate_qa_sample( "sample_idx": sample_idx, } - def _generate_niah_context( - self, - target_length: int, - needle_keys: list[str], - needle_values: list[str], - haystack_type: str, - ) -> str: - """Generate context for needle-in-haystack tasks.""" - # Create needle sentences - needles = [] - for key, value in zip(needle_keys, needle_values): - needles.append(f"The magic number for {key} is {value}.") - - # Generate haystack based on type - if haystack_type == "repeat": - base_text = "The grass is green. The sky is blue. The sun is yellow. " * 100 - elif haystack_type == "essay": - base_text = self._generate_essay_text(500) - else: # needle type - more needles as distractors - base_text = self._generate_needle_haystack(500) - - # Insert needles at random positions - words = base_text.split() - for needle in needles: - insert_pos = random.randint(0, len(words)) - words.insert(insert_pos, needle) - - context = " ".join(words) - - # Adjust to target length - tokens = self.tokenizer.encode(context, add_special_tokens=False) - while len(tokens) < target_length * 0.7: # Leave room for template - context += " " + self._generate_essay_text(100) - tokens = self.tokenizer.encode(context, add_special_tokens=False) - - if len(tokens) > target_length * 0.9: - # Truncate if too long - context = self.tokenizer.decode(tokens[: int(target_length * 0.8)]) - - return context - def _pad_context_with_text( self, base_context: str, target_length: int, padding_type: str ) -> str: @@ -590,13 +544,3 @@ def _generate_essay_text(self, num_words: int) -> str: def _generate_document_text(self, num_words: int) -> str: """Generate document-like text.""" return self._generate_essay_text(num_words) - - def _generate_needle_haystack(self, num_items: int) -> str: - """Generate haystack with many needle-like distractors.""" - items = [] - for _ in range(num_items): - key = self._generate_random_word() - value = str(random.randint(100000, 999999)) - items.append(f"The value for {key} is {value}.") - - return " ".join(items) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh b/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh new file mode 100755 index 000000000..54797f2a5 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Download RULER calibration data for attention sparsity. + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DATA_DIR="${SCRIPT_DIR}/data" +ESSAYS_DIR="${DATA_DIR}/essays" +URLS_FILE="${DATA_DIR}/PaulGrahamEssays_URLs.txt" +URLS_URL="https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt" + +mkdir -p "${ESSAYS_DIR}" + +# Download URL list if not exists +if [ ! -f "${URLS_FILE}" ]; then + echo "Downloading URL list..." + curl -fsSL "${URLS_URL}" -o "${URLS_FILE}" +fi + +# Download essays from GitHub URLs +echo -n "Downloading essays" +count=0 +while IFS= read -r url || [ -n "$url" ]; do + if [[ "${url}" == https://github.com*.txt ]]; then + filename=$(basename "${url}") + filepath="${ESSAYS_DIR}/${filename}" + if [ ! -f "${filepath}" ]; then + raw_url="${url/github.com/raw.githubusercontent.com}" + raw_url="${raw_url/\/raw\//\/}" + curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null && echo -n "." + count=$((count + 1)) + fi + fi +done < "${URLS_FILE}" +echo " done" + +echo "Downloaded ${count} essays to ${ESSAYS_DIR}" diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py new file mode 100644 index 000000000..70d4da81b --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py @@ -0,0 +1,487 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied and Adapted from https://github.com/NVIDIA/RULER +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""Official RULER dataset generation utilities adapted for Model Optimizer. + +This module contains core logic from the RULER benchmark (https://github.com/NVIDIA/RULER) +adapted to work as a library for calibration purposes. The generation logic closely follows +the official RULER implementation to ensure dataset consistency. + +Key adaptations from official RULER: +- Converted from CLI scripts to library functions +- Works with HuggingFace tokenizers directly +- Removed file I/O, returns data structures +- Simplified for calibration use case (primarily NIAH tasks) +""" + +import logging +import random +import re +import uuid +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +# Needle/Haystack template from official RULER +NEEDLE_TEMPLATE = "One of the special magic {type_needle_v} for {key} is: {value}." + +# Depth positions for needle insertion (from official RULER) +DEPTHS = [ + 0, + 2, + 5, + 7, + 10, + 12, + 15, + 18, + 20, + 23, + 25, + 28, + 30, + 33, + 35, + 38, + 40, + 43, + 45, + 48, + 50, + 53, + 55, + 58, + 60, + 62, + 65, + 67, + 70, + 72, + 75, + 77, + 80, + 82, + 85, + 87, + 90, + 92, + 95, + 97, + 100, +] + +# Data directory for RULER calibration files (downloaded via download_ruler_data.sh) +DATA_DIR = Path(__file__).parent / "data" +RULER_URLS_FILE = DATA_DIR / "PaulGrahamEssays_URLs.txt" +ESSAYS_DIR = DATA_DIR / "essays" + + +def _get_data_dir() -> Path: + """Get data directory for RULER data. + + Returns: + Path to data directory under calibration/ (created if doesn't exist) + """ + data_dir = Path(__file__).parent / "data" + data_dir.mkdir(parents=True, exist_ok=True) + return data_dir + + +def _load_paul_graham_essays_from_files() -> str: + """Load Paul Graham essays from local files. + + Reads essay .txt files from the data/essays directory. + Files must be downloaded first using download_ruler_data.sh. + + Returns: + Combined essay text + + Raises: + RuntimeError: If essays directory doesn't exist or is empty + """ + if not ESSAYS_DIR.exists(): + raise RuntimeError( + f"Essays directory not found at {ESSAYS_DIR}.\n" + "Please run the download script first:\n" + " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + ) + + essay_files = list(ESSAYS_DIR.glob("*.txt")) + if not essay_files: + raise RuntimeError( + f"No essay files found in {ESSAYS_DIR}.\n" + "Please run the download script first:\n" + " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + ) + + logger.info(f"Loading {len(essay_files)} Paul Graham essays from local files...") + + all_essays = [] + for filepath in essay_files: + text = filepath.read_text() + all_essays.append(text) + + combined_text = " ".join(all_essays) + logger.info(f"Loaded {len(all_essays)} essays successfully") + + return combined_text + + +def _load_paul_graham_essays() -> str: + """Load Paul Graham essays from local files. + + Essay files must be downloaded first using download_ruler_data.sh. + + Returns: + Essay text as string + """ + essay_text = _load_paul_graham_essays_from_files() + return re.sub(r"\s+", " ", essay_text) + + +def _load_word_lists(): + """Load word lists for random word generation. + + Returns: + List of words (adj-noun combinations) + """ + import wonderwords + + # Load wonderwords lists (same as official RULER) + nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt") + adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt") + words = [f"{adj}-{noun}" for adj in adjs for noun in nouns] + words = sorted(set(words)) + return words + + +# Global word list (loaded once) +_WORD_LIST = None + + +def generate_random_number(num_digits=7) -> str: + """Generate random number (from official RULER).""" + lower_bound = 10 ** (num_digits - 1) + upper_bound = 10**num_digits - 1 + return str(random.randint(lower_bound, upper_bound)) + + +def generate_random_word() -> str: + """Generate random word (from official RULER).""" + global _WORD_LIST + if _WORD_LIST is None: + _WORD_LIST = _load_word_lists() + return random.choice(_WORD_LIST) + + +def generate_random_uuid() -> str: + """Generate random UUID (from official RULER).""" + return str(uuid.UUID(int=random.getrandbits(128), version=4)) + + +def generate_random(type_needle: str) -> str: + """Generate random needle value based on type (from official RULER). + + Args: + type_needle: Type of needle ('numbers', 'words', 'uuids') + + Returns: + Random value as string + """ + if type_needle == "numbers": + return generate_random_number() + elif type_needle == "words": + return generate_random_word() + elif type_needle == "uuids": + return generate_random_uuid() + else: + raise ValueError(f"Unknown needle type: {type_needle}") + + +def generate_niah_sample( + num_haystack: int, + tokenizer, + template: str, + answer_prefix: str, + tokens_to_generate: int = 128, + type_haystack: str = "essay", + type_needle_k: str = "words", + type_needle_v: str = "numbers", + num_needle_k: int = 1, + num_needle_v: int = 1, + num_needle_q: int = 1, + random_seed: int = 42, +) -> dict[str, Any]: + """Generate a single NIAH (Needle in a Haystack) sample. + + This function implements the core generation logic from official RULER's niah.py, + adapted to work as a library function. + + Args: + num_haystack: Number of haystack items/words + tokenizer: HuggingFace tokenizer (AutoTokenizer instance) + template: NIAH question template + answer_prefix: Answer prefix template + tokens_to_generate: Expected number of generation tokens + type_haystack: Type of haystack ('essay', 'noise', 'needle') + type_needle_k: Type of needle keys ('numbers', 'words', 'uuids') + type_needle_v: Type of needle values ('numbers', 'words', 'uuids') + num_needle_k: Number of needle keys + num_needle_v: Number of needle values per key + num_needle_q: Number of needles to query + random_seed: Random seed for this sample + + Returns: + Dictionary with 'input', 'outputs', 'length' keys + """ + import nltk + from nltk.tokenize import sent_tokenize + + try: + nltk.data.find("tokenizers/punkt") + except LookupError: + nltk.download("punkt", quiet=True) + nltk.download("punkt_tab", quiet=True) + + if random_seed is not None: + random.seed(random_seed) + + # Ensure num_needle_k >= num_needle_q + num_needle_k = max(num_needle_k, num_needle_q) + + # Generate needles (keys and values) + keys, values, needles = [], [], [] + for _ in range(num_needle_k): + keys.append(generate_random(type_needle_k)) + value = [] + for _ in range(num_needle_v): + value.append(generate_random(type_needle_v)) + needles.append( + NEEDLE_TEMPLATE.format( + type_needle_v=type_needle_v, + key=keys[-1], + value=value[-1], + ) + ) + values.append(value) + + random.shuffle(needles) + + # Generate context based on haystack type + if type_haystack == "essay": + # Load essay corpus + essay_text = _load_paul_graham_essays() + haystack = essay_text.split(" ") + + # Create text from haystack + if num_haystack <= len(haystack): + text = " ".join(haystack[:num_haystack]) + else: + # Repeat haystack as needed + repeats = (num_haystack + len(haystack) - 1) // len(haystack) + text = " ".join((haystack * repeats)[:num_haystack]) + + # Insert needles at various depths + document_sents = sent_tokenize(text.strip()) + insertion_positions = [ + 0, + *sorted( + int(len(document_sents) * (depth / 100)) + for depth in random.sample(DEPTHS, len(needles)) + ), + len(document_sents), + ] + + document_sents_list = [] + for i in range(1, len(insertion_positions)): + last_pos = insertion_positions[i - 1] + next_pos = insertion_positions[i] + document_sents_list.append(" ".join(document_sents[last_pos:next_pos])) + if i - 1 < len(needles): + document_sents_list.append(needles[i - 1]) + + context = " ".join(document_sents_list) + + if type_haystack == "noise": + haystack_sent = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." + sentences = [haystack_sent] * num_haystack + indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) + for index, element in zip(indexes, needles): + sentences.insert(index, element) + context = "\n".join(sentences) + + elif type_haystack == "needle": + sentences = [ + NEEDLE_TEMPLATE.format( + type_needle_v=type_needle_v, + key=generate_random(type_needle_k), + value=generate_random(type_needle_v), + ) + for _ in range(num_haystack) + ] + + indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) + for index, element in zip(indexes, needles): + sentences.insert(index, element) + context = "\n".join(sentences) + + # Generate query and answer + indices = random.sample(range(num_needle_k), num_needle_q) + queries = [keys[i] for i in indices] + answers = [a for i in indices for a in values[i]] + query = ", ".join(queries[:-1]) + ", and " + queries[-1] if len(queries) > 1 else queries[0] + + # Format template (adjust for singular vs plural) + type_needle_v_display = type_needle_v + formatted_template = template + if num_needle_q * num_needle_v == 1: + formatted_template = formatted_template.replace("Some", "A") + formatted_template = formatted_template.replace("are all", "is") + formatted_template = formatted_template.replace("are", "is") + formatted_template = formatted_template.replace("answers", "answer") + type_needle_v_display = type_needle_v[:-1] # remove "s" + + input_text = formatted_template.format( + type_needle_v=type_needle_v_display, + context=context, + query=query, + ) + + # Add answer prefix + formatted_answer_prefix = answer_prefix.format( + type_needle_v=type_needle_v_display, + query=query, + ) + input_text = input_text + formatted_answer_prefix + + # Calculate actual length + if hasattr(tokenizer, "encode"): + # HuggingFace tokenizer + tokens = tokenizer.encode(input_text, add_special_tokens=False) + length = len(tokens) + tokens_to_generate + else: + # Fallback + length = len(input_text.split()) + tokens_to_generate + + return { + "input": input_text, + "outputs": answers, + "length": length, + } + + +def find_optimal_haystack_size( + tokenizer, + max_seq_length: int, + template: str, + answer_prefix: str, + tokens_to_generate: int = 128, + type_haystack: str = "essay", + **kwargs, +) -> int: + """Find optimal haystack size using binary search (from official RULER). + + Args: + tokenizer: HuggingFace tokenizer + max_seq_length: Maximum sequence length + tokens_to_generate: Expected generation tokens + type_haystack: Type of haystack + template: NIAH question template + answer_prefix: Answer prefix template + **kwargs: Additional arguments for generate_niah_sample + + Returns: + Optimal number of haystack items + """ + # Determine incremental step based on haystack type + if type_haystack == "essay": + incremental = 500 + elif type_haystack in ["noise", "needle"]: + incremental = 25 + else: + incremental = 100 + + if max_seq_length < 4096 and type_haystack != "essay": + incremental = 5 + + # Estimate tokens per haystack item + sample = generate_niah_sample( + incremental, + tokenizer, + template, + answer_prefix, + tokens_to_generate, + type_haystack=type_haystack, + **kwargs, + ) + + if hasattr(tokenizer, "encode"): + sample_tokens = len(tokenizer.encode(sample["input"], add_special_tokens=False)) + else: + sample_tokens = len(sample["input"].split()) + + tokens_per_haystack = sample_tokens / incremental + estimated_max = int((max_seq_length / tokens_per_haystack) * 3) + + # Binary search for optimal size + lower_bound = incremental + upper_bound = max(estimated_max, incremental * 2) + optimal_num_haystack = None + + logger.info(f"Estimated {tokens_per_haystack:.1f} tokens per haystack") + logger.info(f"Binary search bounds: {lower_bound} to {upper_bound}") + + while lower_bound <= upper_bound: + mid = (lower_bound + upper_bound) // 2 + sample = generate_niah_sample( + mid, + tokenizer, + template, + answer_prefix, + tokens_to_generate, + type_haystack=type_haystack, + **kwargs, + ) + total_tokens = sample["length"] + + logger.debug(f"Testing haystack size: {mid}, tokens: {total_tokens}/{max_seq_length}") + + if total_tokens <= max_seq_length: + optimal_num_haystack = mid + lower_bound = mid + 1 + else: + upper_bound = mid - 1 + + final_size = optimal_num_haystack if optimal_num_haystack is not None else incremental + logger.info(f"Optimal haystack size: {final_size}") + + return final_size diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index e19b74a2a..177c5d238 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -40,12 +40,6 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): description="The sparse attention method to use (e.g., 'flash_skip_softmax').", ) - method: str = ModeloptField( - default="flash_skip_softmax", - title="Sparse attention method.", - description="The sparse attention method to use (e.g., 'flash_skip_softmax').", - ) - enable: bool = ModeloptField( default=True, title="Enable sparse attention.", @@ -83,6 +77,12 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) + collect_stats: bool = ModeloptField( + default=False, + title="Collect statistics.", + description="Whether to collect sparsity statistics during forward pass for monitoring.", + ) + is_causal: bool = ModeloptField( default=True, title="Causal attention flag.", @@ -93,16 +93,6 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) - calibration: dict | None = ModeloptField( - default=None, - title="Calibration configuration", - description=( - "Calibration settings for this pattern. " - "If provided, enables automatic threshold calibration. " - "Only one pattern should have calibration enabled." - ), - ) - @field_validator("method") @classmethod def validate_method(cls, v): @@ -252,7 +242,6 @@ def validate_num_length_bins(cls, v): SKIP_SOFTMAX_DEFAULT = { "sparse_cfg": { "*attn*": { - "method": "flash_skip_softmax", "method": "flash_skip_softmax", "threshold": { "prefill": 1e-3, # More aggressive during prefill @@ -277,17 +266,14 @@ class SparseAttentionConfig(ModeloptBaseConfig): # Pattern-based sparse configuration (similar to quant_cfg in quantization) sparse_cfg: SparseAttentionCfgType = ModeloptField( - default={ - "*attention*": {"method": "flash_skip_softmax", "enable": True}, - "default": {"enable": False}, - }, default={ "*attention*": {"method": "flash_skip_softmax", "enable": True}, "default": {"enable": False}, }, title="Sparse attention configuration", - description="Pattern-based configuration for sparse attention. Keys are patterns to match module names, " - "values are configuration dicts with parameters like 'threshold', 'enable', and 'calibration'.", + description="Pattern-based configuration for sparse attention. Keys are patterns to match module names " + "(or 'calibration' for global calibration settings), values are configuration dicts with parameters like " + "'threshold', 'enable', etc.", validate_default=True, ) @@ -297,7 +283,6 @@ class SparseAttentionConfig(ModeloptBaseConfig): ) -class FlashSkipSoftmaxConfig(SparseAttentionConfig): class FlashSkipSoftmaxConfig(SparseAttentionConfig): """Configuration for Flash Attention-aware softmax skip sparse attention.""" @@ -311,34 +296,56 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): "br": 128, # Flash Attention block rows "bc": 128, # Flash Attention block columns "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, # Enable statistics collection "enable": True, }, "default": {"enable": False}, }, title="Flash softmax skip sparse configuration", description="Pattern-based configuration with flash_skip_softmax specific defaults. " - description="Pattern-based configuration with flash_skip_softmax specific defaults. " "Includes FA block sizes (br, bc) and correction factor settings.", validate_default=True, ) +# Pre-defined Sparse Attention Configuration +# Default configuration with block-wise sparsity optimized for Flash Attention +SKIP_SOFTMAX_DEFAULT = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": { + "prefill": 1e-3, # More aggressive during prefill + "decode": 1e-4, # Conservative during decode + }, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + # Configuration with RULER calibration # Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length # The calibrated threshold adapts to sequence length for optimal sparsity SKIP_SOFTMAX_CALIB = { "sparse_cfg": { + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 128, + "max_seqlen": 8192, + }, "*attn*": { "method": "flash_skip_softmax", "br": 128, "bc": 128, "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, "enable": True, - "calibration": { - "target_sparse_ratio": 0.3, - "samples": 12, - "max_seqlen": 1024, - }, }, "default": {"enable": False}, }, diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index eec48c136..1ca998f70 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -28,7 +28,6 @@ from .config import SparseAttentionConfig from .plugins.huggingface import register_sparse_attention_on_the_fly from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry -from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry def is_attn_sparsified(model: nn.Module) -> bool: @@ -108,7 +107,6 @@ def _replace_sparse_attention_modules(model: nn.Module, version=None): _replace_sparse_attention_modules(getattr(model, name), version=version) -def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict): def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict): """Apply sparse attention configuration to model. @@ -117,20 +115,17 @@ def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict): Args: model: Model with sparse attention modules sparse_cfg: Sparse configuration dictionary mapping patterns to attributes - sparse_cfg: Sparse configuration dictionary mapping patterns to attributes """ sparse_cfg = sparse_cfg.copy() # Apply default first if exists if "default" in sparse_cfg: - set_sparse_attention_attribute(model, "*", sparse_cfg["default"]) set_sparse_attention_attribute(model, "*", sparse_cfg["default"]) sparse_cfg.pop("default") # Apply pattern-specific configs for pattern, cfg in sparse_cfg.items(): set_sparse_attention_attribute(model, pattern, cfg) - set_sparse_attention_attribute(model, pattern, cfg) def set_sparse_attention_attribute( @@ -146,7 +141,6 @@ def set_sparse_attention_attribute( model: Model to configure wildcard_or_filter: Pattern to match module names attribute_cfg: Attributes to apply (must include 'method') - attribute_cfg: Attributes to apply (must include 'method') """ # Filter out model-level configs that shouldn't be passed to modules module_cfg = {k: v for k, v in attribute_cfg.items() if k != "calibration"} @@ -163,12 +157,10 @@ def set_sparse_attention_attribute( matched = wildcard_or_filter(name) else: raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter)}") - raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter)}") if matched: # Apply config using the same method as TensorQuantizer module.set_from_attribute_config(module_cfg) - module.set_from_attribute_config(module_cfg) def restore_sparse_attention_model( @@ -242,7 +234,6 @@ def update_sparse_attention_metadata( module_state = { "method": module._sparse_method_instance.name, "method_config": module._method_config.copy(), - "method_config": module._method_config.copy(), } sparse_state[module_name] = module_state @@ -313,62 +304,42 @@ def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Cal module.enable() +def _format_threshold(info: dict) -> str: + """Format threshold info for display.""" + t = info.get("type") + if t == "dynamic": + return f"λ={info.get('scale_factor', 0):.2f}" + if t == "static": + v = info.get("value") + if isinstance(v, dict): + return f"threshold={v}" + return f"threshold={v:.2e}" if isinstance(v, float) else f"threshold={v}" + return "threshold=N/A" + + def print_sparse_attention_summary(model: nn.Module): """Print summary of sparse attention modules in the model. - Similar to mtq.print_quant_summary for API consistency. - Args: model: Model with sparse attention applied - - Prints: - - Total sparse attention modules - - Enabled vs disabled count - - Method distribution - - Configuration summary by module - - Example: - >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn - >>> model = sparse_attn.sparsify(model, config) - >>> sparse_attn.print_sparse_attention_summary(model) """ - sparse_modules = [] - for name, module in model.named_modules(): - if isinstance(module, SparseAttentionModule): - sparse_modules.append((name, module)) + sparse_modules = [ + (name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule) + ] if not sparse_modules: - print("No sparse attention modules found in model") + print("No sparse attention modules found") return - enabled_count = sum(1 for _, m in sparse_modules if m.is_enabled) - disabled_count = len(sparse_modules) - enabled_count + enabled = sum(1 for _, m in sparse_modules if m.is_enabled) + print(f"Sparse attention: {enabled}/{len(sparse_modules)} modules enabled") - # Count methods - method_counts = {} + # Group by (method, threshold) + groups: dict[tuple[str, str], int] = {} for _, module in sparse_modules: method = getattr(module, "_method", "unknown") - method_counts[method] = method_counts.get(method, 0) + 1 - - print(f"Total sparse attention modules: {len(sparse_modules)}") - print(f"Enabled: {enabled_count}") - print(f"Disabled: {disabled_count}") - - if method_counts: - print("\nMethods:") - for method, count in sorted(method_counts.items()): - print(f"{method}: {count}") - - for name, module in sparse_modules: - method = getattr(module, "_method", "unknown") - threshold = getattr(module, "_threshold", "N/A") - - # Format threshold nicely - if isinstance(threshold, dict): - threshold_str = str(threshold) - elif isinstance(threshold, float): - threshold_str = f"{threshold:.2e}" - else: - threshold_str = str(threshold) + threshold = _format_threshold(module.get_threshold_info()) + groups[(method, threshold)] = groups.get((method, threshold), 0) + 1 - print(f"Method: {method}, Threshold: {threshold_str}") + for (method, threshold), count in sorted(groups.items()): + print(f" {method}: {count} layers, {threshold}") diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index b4a971ece..458fbeb50 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -20,6 +20,7 @@ """ import math +from typing import Any import numpy as np import torch @@ -44,7 +45,7 @@ def __init__(self, method_config: dict | None = None): """ config = method_config or {} - # Extract configuration (defaults handled by Pydantic) + # Extract configuration self.threshold_config = config["threshold"] self.br = config["br"] self.bc = config["bc"] @@ -52,7 +53,6 @@ def __init__(self, method_config: dict | None = None): self.is_causal = config["is_causal"] # Optional parameters not in Pydantic config - self.enable_correction_factor = config.get("enable_correction_factor", True) self.phase = config.get("phase", None) # Calibration mode: when True, prevent threshold updates to preserve calibrator's test threshold @@ -191,18 +191,15 @@ def calc_correction_factor_and_p( element_mask = element_mask[:, :, :seq_q, :seq_k] # Step 8: Calculate sparsity statistics - # Count kept blocks (averaged across batch and heads) - kept_blocks = block_mask.sum().item() / (batch_size * num_heads) - - # Total valid blocks (lower triangle only for causal attention) - # Note: Causal mask pre-applied by attention module, so block_mask naturally - # has zeros in upper triangle. We only count lower triangle for denominator. - total_blocks = ( - num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2 - if self.is_causal - else num_block_rows * num_block_cols # Non-causal: N*N - ) - sparsity = 1 - (kept_blocks / total_blocks) + # density = sum(mask) / numel(mask) * N / (N+1) for causal + if self.is_causal: + density = float(block_mask.sum() / block_mask.numel()) * ( + num_block_rows / (num_block_rows + 1) + ) + else: + density = float(block_mask.sum() / block_mask.numel()) + sparsity = 1 - density + total_blocks = num_block_rows * num_block_cols else: # decode blocked_attn, _, num_block_cols, _, padded_seq_k = self._reshape_to_blocks( attn_weights, 1, self.bc @@ -239,14 +236,14 @@ def calc_correction_factor_and_p( element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) element_mask = element_mask[:, :, :seq_q, :seq_k] - # Step 7: Calculate statistics - kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + # Step 7: Calculate sparsity statistics + density = float(block_mask.sum() / block_mask.numel()) + sparsity = 1 - density total_blocks = num_block_cols - sparsity = 1 - (kept_blocks / total_blocks) # Create stats dictionary stats = { - "correction_factor": correction_factor if self.enable_correction_factor else 1.0, + "correction_factor": correction_factor, "sparsity": sparsity, "phase": phase, "total_blocks": total_blocks, @@ -301,6 +298,34 @@ def apply_sparsity( return query, key, value, sparse_scores + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information for this method. + + Returns: + Dictionary with threshold configuration and calibration info. + """ + threshold_scale_factor = getattr(self, "threshold_scale_factor", None) + + if threshold_scale_factor is not None: + # Calibrated dynamic threshold + return { + "type": "dynamic", + "scale_factor": threshold_scale_factor, + "formula": "λ / length", + "example_lengths": { + 1024: threshold_scale_factor / 1024, + 2048: threshold_scale_factor / 2048, + 4096: threshold_scale_factor / 4096, + 8192: threshold_scale_factor / 8192, + }, + } + else: + # Static threshold (single value or phase-specific dict) + return { + "type": "static", + "value": self.threshold_config, + } + @property def name(self) -> str: """Method identifier.""" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index df7b5853b..2d79d9a40 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -18,6 +18,7 @@ import re import warnings from abc import ABC, abstractmethod +from typing import Any import torch @@ -45,6 +46,18 @@ def apply_sparsity( Tuple of (query, key, value, attention_scores) with sparsity applied """ + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information for display/debugging. + + Returns: + Dictionary with threshold information. Should include: + - 'type': 'static', 'dynamic', or 'none' + - 'value': threshold value (for static) + - 'scale_factor': scale factor (for dynamic) + - Other method-specific info + """ + return {"type": "none", "value": None} + @property @abstractmethod def name(self) -> str: diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py index 0e7d221f9..b6b1e809f 100644 --- a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -19,7 +19,7 @@ import torch -from modelopt.torch.opt.conversion import ModeloptStateManager, apply_mode +from modelopt.torch.opt.conversion import apply_mode from modelopt.torch.opt.searcher import ForwardLoop from .calibration import calibrate_sparse_attention @@ -136,7 +136,7 @@ def forward_loop(model) -> float: from transformers import AutoModelForCausalLM - model = AutoModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained(b model_path, attn_implementation="eager", # Required for sparse attention torch_dtype=torch.bfloat16, @@ -158,49 +158,20 @@ def forward_loop(model) -> float: def calibrate( model: torch.nn.Module, - config: dict[str, Any] | SparseAttentionConfig | None = None, + config: dict[str, Any] | SparseAttentionConfig, forward_loop: ForwardLoop | None = None, ) -> torch.nn.Module: """Calibrates sparse attention thresholds based on target sparsity. - This function performs calibration to find optimal thresholds that achieve - the target sparsity ratio specified in the config. - Args: model: Model with sparse attention modules - config: Sparse attention config (extracted from modelopt state if None) + config: Sparse attention configuration with calibration settings forward_loop: Optional callable that forwards calibration data through the model. If provided, uses this for calibration data. If None, will auto-generate RULER dataset for calibration. Returns: The calibrated model with optimized sparse attention thresholds. - If no calibration is configured, returns the model unchanged. """ - # Get config from model if not provided - if config is None: - manager = ModeloptStateManager(model) - if manager.last_mode and manager.last_mode.name == "sparse_attention": - config = manager._last_config - else: - # No sparse attention applied, return model unchanged - return model - - # Extract sparse_cfg - if isinstance(config, dict): - sparse_cfg = config.get("sparse_cfg", {}) - else: - sparse_cfg = config.sparse_cfg - - # Check if calibration is configured in any sparse_cfg pattern - has_calibration = any( - isinstance(cfg, dict) and "calibration" in cfg for cfg in sparse_cfg.values() - ) - - if not has_calibration: - return model - - # Run calibration (handles stats collection internally) calibrate_sparse_attention(model, config, forward_loop=forward_loop) - return model diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index 857e15d0f..d31a9e882 100644 --- a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -15,6 +15,8 @@ """Extensible sparse attention module.""" +from typing import Any + import torch import torch.nn.functional as F @@ -104,6 +106,14 @@ def set_from_attribute_config( # Initialize sparse method instance self._init_sparse_method() + # Create stats manager based on config + if self._method_config.get("collect_stats", False): + self._stats_manager = SparseAttentionStatsManager( + module_name="sparse_attention", enabled=True + ) + else: + self._stats_manager = None + def _init_sparse_method(self): """Initialize the sparse method instance.""" method_class = get_sparse_method(self._method) @@ -136,20 +146,22 @@ def get_stats(self) -> dict: return self._stats_manager.get_summary() return {} + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information from the sparse method instance. + + Returns: + Dictionary with threshold information from the sparse method. + """ + if hasattr(self, "_sparse_method_instance") and self._sparse_method_instance is not None: + return self._sparse_method_instance.get_threshold_info() + return {"type": "none", "value": None} + def _setup(self): """Setup called by DynamicModule.""" # Apply default configuration if not yet configured if not hasattr(self, "_method"): self.set_from_attribute_config(None) - # Create stats manager if stats collection is enabled - if self._method_config.get("collect_stats", False): - self._stats_manager = SparseAttentionStatsManager( - module_name="sparse_attention", enabled=True - ) - else: - self._stats_manager = None - def forward(self, *args, **kwargs): """Forward with selected sparse attention method. diff --git a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py index 9862e6de4..9fc57a0b1 100644 --- a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py +++ b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Statistics manager for sparse attention modules.""" diff --git a/setup.py b/setup.py index dd124e10e..cbb3e5eca 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,8 @@ "torch-geometric", "tox>4.18", "tox-current-env>=0.0.12", + "nltk", + "wonderwords", ], # docs "dev-docs": [ diff --git a/tests/_test_utils/torch_sparsity/sparse_attention_common.py b/tests/_test_utils/torch_sparsity/sparse_attention_common.py index 7724908b0..5ed079966 100644 --- a/tests/_test_utils/torch_sparsity/sparse_attention_common.py +++ b/tests/_test_utils/torch_sparsity/sparse_attention_common.py @@ -153,13 +153,15 @@ def forward_loop(model): with torch.no_grad(): for batch in calib_data: output = model(batch) - assert not torch.isnan(output).any(), "NaN in output" - assert output is not None, "Output is None" + assert not torch.isnan(output).any(), ( + f"NaN detected in output for batch shape {batch.shape}" + ) + assert output is not None, f"Output is None for batch shape {batch.shape}" return model -def save_restore_test(model_cls, device, sparse_config): +def save_restore_test(model_cls, device, sparse_config, atol=1e-6): """Test save and restore of sparse attention state. Args: @@ -190,6 +192,6 @@ def save_restore_test(model_cls, device, sparse_config): output_sparse = model_sparse(test_input) output_restored = model_restored(test_input) - assert torch.allclose(output_sparse, output_restored, atol=1e-6), ( + assert torch.allclose(output_sparse, output_restored, atol), ( "Restored model output doesn't match original" ) diff --git a/tests/examples/llm_eval/test_llm_eval.py b/tests/examples/llm_eval/test_llm_eval.py index 0abf78b53..88d29dedc 100644 --- a/tests/examples/llm_eval/test_llm_eval.py +++ b/tests/examples/llm_eval/test_llm_eval.py @@ -36,3 +36,20 @@ def test_llama_eval_fp8(): finally: # Force kill llm-serve if it's still running subprocess.run(["pkill", "-f", "llm-serve"], check=False) + + +def test_llama_eval_sparse_attention(tiny_llama_path): + """Test sparse attention with llm_eval integration.""" + try: + # Test with default sparse attention config (no quantization) + run_llm_ptq_command( + model=tiny_llama_path, + quant="none", # No quantization, only sparse attention + tasks="lm_eval", + lm_eval_tasks="hellaswag", + lm_eval_limit=0.05, # Small limit for fast test + sparse_cfg="SKIP_SOFTMAX_DEFAULT", + batch=4, + ) + finally: + subprocess.run(["pkill", "-f", "llm-serve"], check=False) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py index 035f2c24f..913dc24a0 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -115,9 +115,9 @@ def test_calibration_simple_model(self, simple_model): model = simple_model config = { - "method": "flash_softmax_skip", "sparse_cfg": { "*attn*": { + "method": "flash_skip_softmax", "threshold": 1e-3, "br": 64, "bc": 64, @@ -155,9 +155,9 @@ def test_calibration_pytorch_backend(self, simple_model): model = simple_model config = { - "method": "flash_softmax_skip", "sparse_cfg": { "*attn*": { + "method": "flash_skip_softmax", "threshold": 1e-3, "backend": "pytorch", "enable": True, @@ -187,9 +187,9 @@ def test_simplified_calibration(self, simple_model): model = simple_model config = { - "method": "flash_softmax_skip", "sparse_cfg": { "*attn*": { + "method": "flash_skip_softmax", "threshold": 1e-3, "enable": True, "calibration": { @@ -214,9 +214,9 @@ def test_calibration_persistence(self, simple_model): model = simple_model config = { - "method": "flash_softmax_skip", "sparse_cfg": { "*attn*": { + "method": "flash_skip_softmax", "threshold": 1e-3, "enable": True, "calibration": { @@ -261,9 +261,9 @@ def test_calibrated_model_inference(self, simple_model_setup): model = simple_model_setup config = { - "method": "flash_softmax_skip", "sparse_cfg": { "*attn*": { + "method": "flash_skip_softmax", "threshold": 1e-3, "backend": "pytorch", "enable": True, @@ -297,9 +297,9 @@ def test_calibrated_vs_fixed_threshold(self, simple_model_setup): """Compare calibrated vs fixed threshold models.""" # Config with calibration config_calibrated = { - "method": "flash_softmax_skip", "sparse_cfg": { "*attn*": { + "method": "flash_skip_softmax", "threshold": 1e-3, "enable": True, "calibration": { @@ -313,9 +313,9 @@ def test_calibrated_vs_fixed_threshold(self, simple_model_setup): # Config with fixed threshold (no calibration) config_fixed = { - "method": "flash_softmax_skip", "sparse_cfg": { "*attn*": { + "method": "flash_skip_softmax", "threshold": 1e-3, "enable": True, } @@ -356,9 +356,9 @@ def test_memory_usage(self, simple_model_setup): initial_memory = torch.cuda.memory_allocated() config = { - "method": "flash_softmax_skip", "sparse_cfg": { "*attn*": { + "method": "flash_skip_softmax", "threshold": 1e-3, "enable": True, "calibration": { diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py index 716cf54fe..4558ca22b 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -20,18 +20,24 @@ pytest.importorskip("transformers") import numpy as np -import pytest from _test_utils.torch_sparsity.sparse_attention_common import ( SimpleAttentionModel, SimpleTransformerEncoder, ) +from pydantic import ValidationError from modelopt.torch.sparsity.attention_sparsity import sparsify from modelopt.torch.sparsity.attention_sparsity.calibration import ( DynamicThresholdCalibrator, RulerDatasetBuilder, ) +from modelopt.torch.sparsity.attention_sparsity.calibration.calibrate import ( + _extract_calibration_config, + calibrate_sparse_attention, + create_calibration_forward_loop, +) from modelopt.torch.sparsity.attention_sparsity.calibration.dataset import _generate_target_lengths +from modelopt.torch.sparsity.attention_sparsity.config import CalibrationConfig from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule @@ -337,18 +343,18 @@ def test_sparsify_with_calibration_requires_forward_loop(self): config = { "sparse_cfg": { + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, "*attention*": { "method": "flash_skip_softmax", "threshold": 1e-3, "br": 64, "bc": 64, "enable": True, - "calibration": { - "target_sparse_ratio": 0.5, - "samples": 4, - "max_seqlen": 1024, - }, - } + }, }, } @@ -376,8 +382,6 @@ def test_multiple_sparse_modules(self): def test_calibration_config_validation(self): """Test CalibrationConfig validation.""" - from modelopt.torch.sparsity.attention_sparsity.config import CalibrationConfig - # Valid config config = CalibrationConfig( target_sparse_ratio=0.5, @@ -406,8 +410,6 @@ def test_calibration_config_validation(self): def test_threshold_trials_validation(self): """Test threshold_trials validation.""" - from modelopt.torch.sparsity.attention_sparsity.config import CalibrationConfig - # Valid custom threshold_trials config = CalibrationConfig( target_sparse_ratio=0.5, @@ -432,11 +434,190 @@ def test_threshold_trials_validation(self): CalibrationConfig(threshold_trials=[1e-4, 0]) # Invalid: not a list (Pydantic raises ValidationError, not ValueError) - from pydantic import ValidationError - with pytest.raises(ValidationError, match="Input should be a valid list"): CalibrationConfig(threshold_trials=1e-4) -if __name__ == "__main__": - pytest.main([__file__, "-v"]) +class TestDynamicThresholdCalibratorMethods: + """Test individual methods of DynamicThresholdCalibrator.""" + + def test_set_threshold(self): + """Test _set_threshold method.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + # Get sparse modules + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + assert len(modules) > 0 + + # Create calibrator and set threshold + calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + calibrator._set_threshold(modules, 0.05) + + # Verify threshold was set + for module in modules: + assert module._sparse_method_instance.threshold == 0.05 + + def test_enable_disable_calibration_mode(self): + """Test _enable_calibration_mode and _disable_calibration_mode.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + + calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + + # Enable calibration mode + calibrator._enable_calibration_mode(modules) + + for module in modules: + assert module._stats_manager is not None + assert module._stats_manager.enabled is True + assert module._stats_manager.calibration_mode is True + assert module._sparse_method_instance._calibration_mode is True + + # Disable calibration mode + calibrator._disable_calibration_mode(modules) + + for module in modules: + assert module._stats_manager.calibration_mode is False + assert module._sparse_method_instance._calibration_mode is False + + def test_extract_calibration_stats_no_stats(self): + """Test _extract_calibration_stats when no stats collected.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + + calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + + # Extract stats without running any forward passes + stats = calibrator._extract_calibration_stats(modules) + + # Should return empty list + assert stats == [] + + def test_calibrator_with_single_sample(self): + """Test calibrator edge case with only one sample.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + threshold_trials=[0.001, 0.01, 0.1], + ) + + # Even with one sample, regression should work + assert calibrator.target_sparse_ratio == 0.5 + assert len(calibrator.threshold_trials) == 3 + + +class TestCalibrateFunction: + """Test calibrate_sparse_attention function.""" + + def test_calibrate_no_config(self): + """Test calibration when config has no calibration section.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + # Config without calibration + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + # Should return empty dict when no calibration config + result = calibrate_sparse_attention(model, config) + + assert result == {} + + def test_extract_calibration_config(self): + """Test _extract_calibration_config function.""" + # Config with calibration + config = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": 0.3, + "samples": 12, + "max_seqlen": 2048, + }, + "*attn*": { + "method": "flash_skip_softmax", + }, + }, + } + + calib_config = _extract_calibration_config(config) + + assert calib_config is not None + assert calib_config.target_sparse_ratio == 0.3 + assert calib_config.samples == 12 + assert calib_config.max_seqlen == 2048 + + def test_extract_calibration_config_none(self): + """Test _extract_calibration_config when no calibration.""" + # Config without calibration + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + } + }, + } + + calib_config = _extract_calibration_config(config) + + assert calib_config is None + + def test_create_calibration_forward_loop(self): + """Test create_calibration_forward_loop function.""" + calibration_data = [ + {"input": "This is a test sample.", "length": 512}, + {"input": "Another test sample.", "length": 1024}, + ] + + forward_loop = create_calibration_forward_loop( + calibration_data=calibration_data, + tokenizer_name_or_path="gpt2", + ) + + # Should return a callable + assert callable(forward_loop) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index 8df8fe476..1ba86c143 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -206,3 +206,100 @@ def test_restore_sparse_attention_model(self): if isinstance(module, SparseAttentionModule): assert hasattr(module, "_method") assert module._method == "flash_skip_softmax" + + +class TestSparseAttentionModuleMethods: + """Test SparseAttentionModule methods.""" + + def test_get_stats_with_stats_manager(self): + """Test get_stats() when stats manager exists and is enabled.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "collect_stats": True, # Enable stats collection + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module + sparse_module = None + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + sparse_module = module + break + + assert sparse_module is not None + assert sparse_module._stats_manager is not None + + # Get stats (should return summary) + stats = sparse_module.get_stats() + + assert isinstance(stats, dict) + assert "module" in stats + assert "total_calls" in stats + assert "average_sparsity" in stats + + def test_get_stats_without_stats_manager(self): + """Test get_stats() when stats manager is None.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "collect_stats": False, # Disable stats collection + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + # Stats manager should be None + assert module._stats_manager is None + + # get_stats should return empty dict + stats = module.get_stats() + assert stats == {} + break + + def test_get_threshold_info(self): + """Test get_threshold_info() method.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.005, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module and test threshold info + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + info = module.get_threshold_info() + + assert isinstance(info, dict) + assert "type" in info + assert info["type"] == "static" + assert info["value"] == 0.005 + break diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py b/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py new file mode 100644 index 000000000..02188e97a --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py @@ -0,0 +1,334 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for SparseAttentionStatsManager.""" + +import pytest + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager + + +class TestStatsManagerInitialization: + """Test stats manager initialization.""" + + def test_initialization_defaults(self): + """Test default initialization.""" + manager = SparseAttentionStatsManager(module_name="test_module") + + assert manager.module_name == "test_module" + assert manager.enabled is True + assert manager.calibration_mode is False + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + assert manager.per_sample_stats == [] + + def test_initialization_disabled(self): + """Test initialization with disabled stats.""" + manager = SparseAttentionStatsManager(module_name="test_module", enabled=False) + + assert manager.enabled is False + assert manager.calibration_mode is False + + def test_initialization_custom_name(self): + """Test initialization with custom module name.""" + manager = SparseAttentionStatsManager(module_name="custom.attention.module") + + assert manager.module_name == "custom.attention.module" + + +class TestStatsCollection: + """Test statistics collection functionality.""" + + def test_collect_stats_enabled(self): + """Test collecting stats when enabled.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + + manager.collect(stats) + + assert manager.aggregated_stats["total_calls"] == 1 + assert manager.aggregated_stats["total_blocks"] == 100 + assert manager.aggregated_stats["sparse_blocks"] == 50 + assert manager.aggregated_stats["phase_counts"]["prefill"] == 1 + assert manager.aggregated_stats["phase_counts"]["decode"] == 0 + + def test_collect_stats_disabled(self): + """Test that collect() is no-op when disabled.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=False) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + } + + manager.collect(stats) + + # Should remain at initial values + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + + def test_collect_multiple_calls(self): + """Test accumulation over multiple collect calls.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect multiple times + for i in range(5): + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + } + manager.collect(stats) + + assert manager.aggregated_stats["total_calls"] == 5 + assert manager.aggregated_stats["total_blocks"] == 500 + assert manager.aggregated_stats["sparse_blocks"] == 250 + assert manager.aggregated_stats["phase_counts"]["prefill"] == 5 + + def test_collect_different_phases(self): + """Test phase counting.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect prefill stats + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + + # Collect decode stats + manager.collect({"phase": "decode", "total_blocks": 10, "sparse_blocks": 5}) + + assert manager.aggregated_stats["phase_counts"]["prefill"] == 2 + assert manager.aggregated_stats["phase_counts"]["decode"] == 1 + assert manager.aggregated_stats["phase_counts"]["unknown"] == 0 + + +class TestCalibrationMode: + """Test calibration mode functionality.""" + + def test_calibration_mode_per_sample_collection(self): + """Test that calibration mode stores per-sample stats.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Enable calibration mode + manager.set_calibration_mode(enabled=True) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + + manager.collect(stats) + + # Should store in per_sample_stats + assert len(manager.per_sample_stats) == 1 + assert manager.per_sample_stats[0]["module"] == "test" + assert manager.per_sample_stats[0]["sparsity"] == 0.5 + assert manager.per_sample_stats[0]["sample_length"] == 1024 + assert manager.per_sample_stats[0]["phase"] == "prefill" + + def test_calibration_mode_off(self): + """Test that per-sample stats are not collected when calibration mode is off.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + # Calibration mode is off by default + + stats = {"sparsity": 0.5, "phase": "prefill", "total_blocks": 100, "sparse_blocks": 50} + + manager.collect(stats) + + # Should NOT store in per_sample_stats + assert len(manager.per_sample_stats) == 0 + + # But should still aggregate + assert manager.aggregated_stats["total_calls"] == 1 + + def test_set_calibration_mode_with_reset(self): + """Test set_calibration_mode with reset_history=True.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect some stats in calibration mode + manager.set_calibration_mode(enabled=True) + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + assert len(manager.per_sample_stats) == 1 + + # Re-enable with reset + manager.set_calibration_mode(enabled=True, reset_history=True) + assert len(manager.per_sample_stats) == 0 # Should be cleared + + def test_set_calibration_mode_without_reset(self): + """Test set_calibration_mode with reset_history=False.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect some stats + manager.set_calibration_mode(enabled=True) + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + assert len(manager.per_sample_stats) == 1 + + # Disable without reset + manager.set_calibration_mode(enabled=False, reset_history=False) + assert len(manager.per_sample_stats) == 1 # Should be preserved + + +class TestGetSummary: + """Test get_summary() functionality.""" + + def test_get_summary_with_data(self): + """Test get_summary returns correct averages.""" + manager = SparseAttentionStatsManager(module_name="test_module", enabled=True) + + # Collect stats + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 30}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + + summary = manager.get_summary() + + assert summary["module"] == "test_module" + assert summary["total_calls"] == 2 + # Average sparsity: (30+50) / (100+100) = 80/200 = 0.4 + assert summary["average_sparsity"] == 0.4 + assert summary["phase_distribution"]["prefill"] == 2 + + def test_get_summary_no_data(self): + """Test get_summary with no collected data.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + summary = manager.get_summary() + + assert summary["module"] == "test" + assert summary["total_calls"] == 0 + assert summary["average_sparsity"] == 0.0 + assert summary["phase_distribution"]["prefill"] == 0 + + def test_get_summary_zero_blocks(self): + """Test get_summary when total_blocks is zero.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect stats with zero blocks + manager.collect({"phase": "prefill", "total_blocks": 0, "sparse_blocks": 0}) + + summary = manager.get_summary() + + assert summary["average_sparsity"] == 0.0 # Should handle division by zero + + +class TestGetCalibrationStats: + """Test get_calibration_stats() functionality.""" + + def test_get_calibration_stats(self): + """Test retrieving per-sample calibration stats.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + manager.set_calibration_mode(enabled=True) + + # Collect multiple samples + for i in range(3): + manager.collect( + { + "sparsity": 0.3 + i * 0.1, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 30, + "sample_length": 1024 + i * 512, + } + ) + + calib_stats = manager.get_calibration_stats() + + assert len(calib_stats) == 3 + assert calib_stats[0]["sparsity"] == 0.3 + assert calib_stats[1]["sparsity"] == 0.4 + assert calib_stats[2]["sparsity"] == 0.5 + + def test_get_calibration_stats_empty(self): + """Test get_calibration_stats when no calibration data.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + calib_stats = manager.get_calibration_stats() + + assert calib_stats == [] + + +class TestReset: + """Test reset functionality.""" + + def test_reset(self): + """Test reset() clears all statistics.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + manager.set_calibration_mode(enabled=True) + + # Collect some stats + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + manager.collect( + { + "sparsity": 0.3, + "phase": "decode", + "total_blocks": 10, + "sparse_blocks": 3, + "sample_length": 128, + } + ) + + # Verify stats exist + assert manager.aggregated_stats["total_calls"] == 2 + assert len(manager.per_sample_stats) == 2 + + # Reset + manager.reset() + + # All stats should be cleared + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + assert manager.per_sample_stats == [] + assert manager.aggregated_stats["phase_counts"]["prefill"] == 0 + assert manager.aggregated_stats["phase_counts"]["decode"] == 0 diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py new file mode 100644 index 000000000..ac9f46a54 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for threshold calibration functionality.""" + +import pytest + +pytest.importorskip("transformers") + +from _test_utils.torch_sparsity.sparse_attention_common import SimpleAttentionModel + +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestFlashSkipSoftmaxThresholdInfo: + """Test FlashSkipSoftmax.get_threshold_info() method.""" + + def test_static_threshold(self): + """Test threshold info for static threshold.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": 0.001, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + assert info["type"] == "static" + assert info["value"] == 0.001 + + def test_phased_threshold(self): + """Test threshold info for phase-specific thresholds.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": {"prefill": 0.001, "decode": 0.0001}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + assert info["type"] == "static_phased" + assert "thresholds" in info + assert info["thresholds"]["prefill"] == 0.001 + assert info["thresholds"]["decode"] == 0.0001 + assert "current" in info + + def test_dynamic_calibrated_threshold(self): + """Test threshold info for calibrated dynamic threshold.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": 0.001, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Simulate calibration setting scale factor + method.threshold_scale_factor = 437.5 + + info = method.get_threshold_info() + + assert info["type"] == "dynamic" + assert info["scale_factor"] == 437.5 + assert info["formula"] == "λ / length" + assert "example_lengths" in info + assert abs(info["example_lengths"][1024] - 437.5 / 1024) < 1e-6 + assert abs(info["example_lengths"][2048] - 437.5 / 2048) < 1e-6 + + def test_threshold_info_structure(self): + """Test that threshold info has expected structure.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": 0.001, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + # Should always have 'type' key + assert "type" in info + assert isinstance(info, dict) + + +class TestSparseAttentionModuleThresholdInfo: + """Test SparseAttentionModule.get_threshold_info() delegation.""" + + def test_module_delegates_to_method(self): + """Test that module correctly delegates to sparse method instance.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.005, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find sparse attention module + sparse_module = None + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + sparse_module = module + break + + assert sparse_module is not None + + # Test get_threshold_info + info = sparse_module.get_threshold_info() + + assert info["type"] == "static" + assert info["value"] == 0.005 + + def test_module_with_calibrated_threshold(self): + """Test module reports calibrated threshold correctly.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find module and set calibrated threshold + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + module._sparse_method_instance.threshold_scale_factor = 500.0 + break + + # Get threshold info + info = module.get_threshold_info() + + assert info["type"] == "dynamic" + assert info["scale_factor"] == 500.0 + + def test_module_without_method_instance(self): + """Test get_threshold_info when sparse method instance doesn't exist.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find module + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + # Remove sparse method instance to test fallback + delattr(module, "_sparse_method_instance") + + info = module.get_threshold_info() + + assert info["type"] == "none" + assert info["value"] is None + break + + +class TestPrintSparseAttentionSummaryIntegration: + """Test integration with print_sparse_attention_summary.""" + + def test_summary_displays_static_threshold(self, capsys): + """Test that print function displays static thresholds.""" + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + print_sparse_attention_summary, + ) + + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + print_sparse_attention_summary(sparse_model) + + captured = capsys.readouterr() + assert "Static (1.00e-03)" in captured.out + assert "flash_skip_softmax" in captured.out + + def test_summary_displays_dynamic_threshold(self, capsys): + """Test that print function displays dynamic thresholds.""" + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + print_sparse_attention_summary, + ) + + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Set calibrated threshold + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + module._sparse_method_instance.threshold_scale_factor = 437.5 + + print_sparse_attention_summary(sparse_model) + + captured = capsys.readouterr() + assert "Dynamic (λ=437.500000)" in captured.out + assert "flash_skip_softmax" in captured.out