Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
398 changes: 398 additions & 0 deletions examples/llm_sparse_attention/README.md

Large diffs are not rendered by default.

368 changes: 368 additions & 0 deletions examples/llm_sparse_attention/hf_spar_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,368 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example script for applying sparse attention to HuggingFace models."""

import argparse
import random
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

import modelopt.torch.sparsity.attention_sparsity as mtsa
from modelopt.torch.export import export_hf_checkpoint
from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig
from modelopt.torch.sparsity.attention_sparsity.config import (
SKIP_SOFTMAX_CALIB,
SKIP_SOFTMAX_DEFAULT,
)
from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule
from modelopt.torch.utils.memory_monitor import launch_memory_monitor

RAND_SEED = 1234

# You can define custom configurations or use the default
SPARSE_ATTN_CFG_CHOICES = {
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
}


def print_sparsity_stats(model: nn.Module):
"""Print sparsity statistics if available."""
module_stats = []
for name, module in model.named_modules():
if hasattr(module, "get_stats"):
stats = module.get_stats()
if stats and "average_sparsity" in stats:
module_stats.append((name, stats["average_sparsity"]))

if not module_stats:
print("No sparsity statistics available")
return

# Check if all modules have the same sparsity
sparsities = [s for _, s in module_stats]
if len(set(sparsities)) == 1:
# All identical - show summary
print(f"Average sparsity across all {len(module_stats)} modules: {sparsities[0]:.2%}")
else:
# Different sparsities - show individual values
avg_sparsity = sum(sparsities) / len(sparsities)
print(f"Average sparsity: {avg_sparsity:.2%}")
print("Per-module breakdown:")
for name, sparsity in module_stats:
print(f" {name}: {sparsity:.2%} sparse")


def get_narrativeqa_samples(num_samples=3):
"""Load samples from NarrativeQA dataset for testing.

Args:
num_samples: Number of samples to generate
"""
# Load NarrativeQA dataset
dataset = load_dataset("narrativeqa", split="test", streaming=True)

samples = []
for i, item in enumerate(dataset):
if i >= num_samples:
break

# Combine document context and question
context = item.get("document", {}).get("text", "")
question = item.get("question", {}).get("text", "")

if context and question:
# Use the full context as-is
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
samples.append(prompt)

if not samples:
raise ValueError("Could not load NarrativeQA samples")

print(f"Loaded {len(samples)} NarrativeQA samples")
return samples


def truncate_text(text: str, tokenizer, max_length: int):
"""Truncate text from the middle to preserve beginning and end.

Args:
text: Input text to truncate
tokenizer: Tokenizer to use for encoding
max_length: Maximum number of tokens

Returns:
Truncated text that fits within max_length tokens
"""
# First tokenize to see if truncation is needed
tokens = tokenizer.encode(text, add_special_tokens=True)

if len(tokens) <= max_length:
return text

# Need to truncate - preserve beginning and end
# Reserve some tokens for special tokens
available_tokens = max_length - 2 # Account for special tokens

# Split tokens roughly in half for beginning and end
begin_tokens = available_tokens // 2
end_tokens = available_tokens - begin_tokens

# Decode beginning and end parts
begin_text = tokenizer.decode(tokens[:begin_tokens], skip_special_tokens=True)
end_text = tokenizer.decode(tokens[-end_tokens:], skip_special_tokens=True)

# Combine with ellipsis marker
return begin_text + " [...] " + end_text


def verify_outputs(model, tokenizer, args):
"""Compare outputs between baseline and sparse attention models."""
# Update seq_len to match calibration max_seqlen if calibration was used
base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {})
if "calibration" in base_config and "max_seqlen" in base_config["calibration"]:
calib_max_seqlen = base_config["calibration"]["max_seqlen"]
if args.seq_len != calib_max_seqlen:
print(
f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} "
f"to match calibration config"
)
args.seq_len = calib_max_seqlen

# Load and prepare a single test prompt
print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)")
prompts = get_narrativeqa_samples(num_samples=1)
prompt = prompts[0]

# Prepare inputs
truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len)
display_prompt = (
truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt
)

inputs = tokenizer(
truncated_prompt,
return_tensors="pt",
max_length=args.seq_len,
truncation=True,
padding=False,
)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}

print("\n" + "=" * 60)
print("BASELINE vs SPARSE ATTENTION COMPARISON")
print("=" * 60)
print(f"\nTest prompt: {display_prompt}")
print(f"Input tokens: {inputs['input_ids'].shape[1]} (max: {args.seq_len})")
if "[...]" in truncated_prompt:
print("Note: Text was middle-truncated to fit token limit")

# Helper function to generate text
def generate_text(model, inputs, args, tokenizer):
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature if args.do_sample else 1.0,
pad_token_id=tokenizer.pad_token_id,
)
input_length = inputs["input_ids"].shape[1]
generated_ids = outputs[0][input_length:]
return tokenizer.decode(generated_ids, skip_special_tokens=True)

# Find all sparse attention modules
sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)]

# Generate baseline by temporarily disabling sparse attention
print("\n" + "-" * 60)
print("Generating baseline (sparse attention disabled)...")
for module in sparse_modules:
module.disable()
baseline_text = generate_text(model, inputs, args, tokenizer)

# Generate with sparse attention enabled
print("\nGenerating with sparse attention (calibrated thresholds)...")
for module in sparse_modules:
module.enable()
sparse_text = generate_text(model, inputs, args, tokenizer)

# Display comparison
print("\n" + "-" * 60)
print("RESULTS:")
baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text
sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text

print(f"\nBaseline: {baseline_display}")
print(f"With Sparse: {sparse_display}")

if baseline_text == sparse_text:
print("\nOutputs are identical")
else:
print("\nOutputs differ")


def sparsify_model(model, args):
"""Apply sparse attention to the model with optional calibration."""
print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}")
base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]

# Create modified config with selected backend
modified_sparse_cfg = {}
for pattern, cfg in base_config["sparse_cfg"].items():
modified_cfg = cfg.copy()
modified_cfg["backend"] = args.backend
modified_sparse_cfg[pattern] = modified_cfg

# Create new config with modified settings
sparse_config = SparseAttentionConfig(
method=base_config["method"],
sparse_cfg=modified_sparse_cfg,
collect_stats=True, # Enable stats collection for monitoring
)

# Sparsify with optional calibration - framework handles calibration automatically
model = mtsa.sparsify(model, config=sparse_config)

print("Sparse attention applied successfully!")

# Show sparsity statistics
print("\n" + "=" * 60)
print("Sparsity Statistics")
print("=" * 60)
print_sparsity_stats(model)

return model


def main(args):
"""Main function to run the selected mode."""
if not torch.cuda.is_available():
raise OSError("GPU is required for inference.")

random.seed(RAND_SEED)
np.random.seed(RAND_SEED)
launch_memory_monitor()

print(f"Loading model: {args.pyt_ckpt_path}")

# Load model and tokenizer
# Note: attn_implementation="eager" is required for calibration to work properly
# (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection)
model = AutoModelForCausalLM.from_pretrained(
args.pyt_ckpt_path,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)

# Set pad token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# Move model to GPU if available
if torch.cuda.is_available():
model = model.cuda()
print("Model moved to CUDA")

# Apply sparse attention to the model (with calibration if configured)
model = sparsify_model(model, args)

Comment on lines +273 to +290
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Call model.eval() before running inference

The model stays in training mode after loading (dropout, layernorm stats, etc.), so the sparse-attention calibration and the generation comparison run with stochastic behavior. That breaks verification and produces inconsistent results. Move the model to eval mode before sparsifying/exporting.

     model = AutoModelForCausalLM.from_pretrained(
         args.pyt_ckpt_path,
         attn_implementation="eager",
         torch_dtype=torch.bfloat16,
     )
     tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
 
     # Set pad token if not set
     if tokenizer.pad_token is None:
         tokenizer.pad_token = tokenizer.eos_token
 
     # Move model to GPU if available
     if torch.cuda.is_available():
         model = model.cuda()
         print("Model moved to CUDA")
 
+    model.eval()
+
     # Apply sparse attention to the model (with calibration if configured)
     model = sparsify_model(model, args)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
args.pyt_ckpt_path,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
# Set pad token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Move model to GPU if available
if torch.cuda.is_available():
model = model.cuda()
print("Model moved to CUDA")
# Apply sparse attention to the model (with calibration if configured)
model = sparsify_model(model, args)
model = AutoModelForCausalLM.from_pretrained(
args.pyt_ckpt_path,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
# Set pad token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Move model to GPU if available
if torch.cuda.is_available():
model = model.cuda()
print("Model moved to CUDA")
model.eval()
# Apply sparse attention to the model (with calibration if configured)
model = sparsify_model(model, args)
🤖 Prompt for AI Agents
In examples/llm_sparse_attention/hf_spar_attn.py around lines 273 to 290, the
model remains in training mode which allows dropout and other stochastic
behaviors during sparse-attention calibration and generation; call model.eval()
after loading (and after moving to CUDA if applicable) and before calling
sparsify_model or running any calibration/generation to ensure deterministic
behavior during verification.

# Verify outputs if requested (compares baseline vs calibrated sparse model)
if args.verify_output:
verify_outputs(model, tokenizer, args)

# Export if requested
if args.export_dir:
print(f"\nExporting model to: {args.export_dir}")
export_dir = Path(args.export_dir)
export_dir.mkdir(parents=True, exist_ok=True)

with torch.inference_mode():
export_hf_checkpoint(model, export_dir=export_dir)

tokenizer.save_pretrained(export_dir)
print(f"Model exported successfully to: {export_dir}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)

# Model arguments
parser.add_argument(
"--pyt_ckpt_path",
type=str,
required=True,
help="Specify where the PyTorch checkpoint path is",
)
parser.add_argument(
"--sparse_attn",
type=str,
default="skip_softmax",
choices=list(SPARSE_ATTN_CFG_CHOICES.keys()),
help="Sparse attention configuration to apply.",
)
parser.add_argument(
"--backend",
type=str,
default="pytorch",
choices=["pytorch", "triton"],
help="Backend to use for sparse attention computation (default: pytorch)",
Comment on lines +326 to +330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Remove the unsupported triton backend option.

SparseAttentionAttributeConfig.validate_backend hard-rejects any backend other than "pytorch", so choosing "triton" causes an immediate crash. Until Triton support lands, constrain the CLI to the actually supported backend (and mirror the README update):

-    parser.add_argument(
-        "--backend",
-        type=str,
-        default="pytorch",
-        choices=["pytorch", "triton"],
-        help="Backend to use for sparse attention computation (default: pytorch)",
-    )
+    parser.add_argument(
+        "--backend",
+        type=str,
+        default="pytorch",
+        choices=["pytorch"],
+        help="Backend to use for sparse attention computation (currently only 'pytorch').",
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"--backend",
type=str,
default="pytorch",
choices=["pytorch", "triton"],
help="Backend to use for sparse attention computation (default: pytorch)",
parser.add_argument(
"--backend",
type=str,
default="pytorch",
choices=["pytorch"],
help="Backend to use for sparse attention computation (currently only 'pytorch').",
)
🤖 Prompt for AI Agents
In examples/llm_sparse_attention/hf_spar_attn.py around lines 326 to 330, the
CLI parser currently includes "triton" as a choice for the --backend option even
though SparseAttentionAttributeConfig.validate_backend only accepts "pytorch";
remove "triton" from the choices and update the default/choices to only include
"pytorch" so the CLI cannot select an unsupported backend and crash at
validation time.

)

# Sequence length arguments
parser.add_argument(
"--seq_len",
type=int,
default=2048,
help="Maximum sequence length for input prompts (will be truncated if longer)",
)
parser.add_argument(
"--num_samples",
type=int,
default=3,
help="Number of samples to use from NarrativeQA dataset",
)

# Generation arguments
parser.add_argument(
"--max_new_tokens", type=int, default=50, help="Maximum new tokens to generate"
)
parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation")
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling")

# Operation arguments
parser.add_argument(
"--verify_output",
action="store_true",
help="Verify that sparse attention outputs match baseline",
)
parser.add_argument(
"--export_dir",
type=str,
default=None,
help="Directory to export the model with sparse attention applied",
)

args = parser.parse_args()
main(args)
4 changes: 4 additions & 0 deletions examples/llm_sparse_attention/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
accelerate
datasets
transformers

Loading
Loading