Skip to content

Commit 3e5837c

Browse files
committed
add the calibration algorighm for skip softmax
Signed-off-by: Kai Xu <[email protected]>
1 parent 9fa8991 commit 3e5837c

File tree

16 files changed

+1551
-1585
lines changed

16 files changed

+1551
-1585
lines changed

examples/llm_sparse_attention/hf_spar_attn.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,17 @@ def truncate_text(text: str, tokenizer, max_length: int):
137137

138138
def verify_outputs(model, tokenizer, args):
139139
"""Compare outputs between baseline and sparse attention models."""
140+
# Update seq_len to match calibration max_seqlen if calibration was used
141+
base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {})
142+
if "calibration" in base_config and "max_seqlen" in base_config["calibration"]:
143+
calib_max_seqlen = base_config["calibration"]["max_seqlen"]
144+
if args.seq_len != calib_max_seqlen:
145+
print(
146+
f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} "
147+
f"to match calibration config"
148+
)
149+
args.seq_len = calib_max_seqlen
150+
140151
# Load and prepare a single test prompt
141152
print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)")
142153
prompts = get_narrativeqa_samples(num_samples=1)
@@ -225,36 +236,13 @@ def sparsify_model(model, args):
225236

226237
# Create new config with modified settings
227238
sparse_config = SparseAttentionConfig(
228-
method=base_config["method"], sparse_cfg=modified_sparse_cfg
239+
method=base_config["method"],
240+
sparse_cfg=modified_sparse_cfg,
241+
collect_stats=True, # Enable stats collection for monitoring
229242
)
230243

231-
# Check if calibration is present in config
232-
has_calibration = any(
233-
"calibration" in cfg for cfg in modified_sparse_cfg.values() if isinstance(cfg, dict)
234-
)
235-
236-
if has_calibration:
237-
print("\n" + "=" * 60)
238-
print("CALIBRATION")
239-
print("=" * 60)
240-
print("Config includes calibration - running automatic threshold calibration...")
241-
242-
# Display calibration settings
243-
for cfg in modified_sparse_cfg.values():
244-
if isinstance(cfg, dict) and "calibration" in cfg:
245-
calib = cfg["calibration"]
246-
print(f" Target sparsity: {calib.get('target_sparse_ratio', 0.5)}")
247-
print(f" Samples: {calib.get('samples', 48)}")
248-
print(f" Max sequence length: {calib.get('max_seqlen', 32768)}")
249-
print(" Tokenizer: Auto-extracted from model")
250-
print(" Dataset: RULER (6 default tasks)")
251-
break
252-
253-
# Sparsify with calibration - framework will auto-generate RULER dataset
254-
model = mtsa.sparsify(model, config=sparse_config)
255-
print("\nCalibration complete! Model now uses dynamic threshold: λ = a / context_length")
256-
else:
257-
model = mtsa.sparsify(model, config=sparse_config)
244+
# Sparsify with optional calibration - framework handles calibration automatically
245+
model = mtsa.sparsify(model, config=sparse_config)
258246

259247
print("Sparse attention applied successfully!")
260248

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Calibration functions for sparse attention."""
17+
18+
import warnings
19+
from collections.abc import Callable
20+
from typing import Any
21+
22+
import torch
23+
import torch.nn as nn
24+
from transformers import AutoTokenizer
25+
26+
from ..config import CalibrationConfig
27+
from ..nn.sparse_attention import SparseAttentionModule
28+
from .calibrator import DynamicThresholdCalibrator
29+
from .dataset import RulerDatasetBuilder
30+
31+
32+
def _extract_tokenizer_from_model(model: nn.Module) -> str:
33+
"""Extract tokenizer name/path from model config.
34+
35+
Args:
36+
model: Model to extract tokenizer from
37+
38+
Returns:
39+
Tokenizer name or path
40+
41+
Raises:
42+
ValueError: If tokenizer path cannot be determined from model
43+
"""
44+
# Extract tokenizer path from model config
45+
tokenizer_path = getattr(getattr(model, "config", None), "_name_or_path", None)
46+
47+
if not tokenizer_path:
48+
raise ValueError("Could not load tokenizer from model.")
49+
50+
return tokenizer_path
51+
52+
53+
def _extract_calibration_config(config: dict[str, Any]) -> CalibrationConfig | None:
54+
"""Extract and validate calibration config from sparse_cfg patterns.
55+
56+
Args:
57+
config: Sparse attention configuration dict
58+
59+
Returns:
60+
Validated CalibrationConfig or None if not found
61+
"""
62+
# Extract sparse_cfg and search for calibration
63+
sparse_cfg = config.get("sparse_cfg", {})
64+
65+
calib_dict = next(
66+
(
67+
cfg["calibration"]
68+
for cfg in sparse_cfg.values()
69+
if isinstance(cfg, dict) and "calibration" in cfg
70+
),
71+
None,
72+
)
73+
74+
# Create and calidate the calibration config
75+
return CalibrationConfig(**calib_dict) if calib_dict else None
76+
77+
78+
def create_calibration_forward_loop(
79+
calibration_data: list[dict[str, Any]],
80+
tokenizer_name_or_path: str,
81+
batch_size: int = 1,
82+
) -> Callable:
83+
"""Create forward loop for calibration.
84+
85+
Args:
86+
calibration_data: List of samples with 'input' and 'length' fields
87+
tokenizer_name_or_path: HuggingFace tokenizer path
88+
batch_size: Batch size (currently unused, always 1)
89+
90+
Returns:
91+
Forward loop function that takes model as argument
92+
"""
93+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
94+
if not tokenizer.pad_token:
95+
tokenizer.pad_token = tokenizer.eos_token
96+
97+
def forward_loop(model: nn.Module) -> None:
98+
device = next(model.parameters()).device
99+
100+
for sample in calibration_data:
101+
inputs = tokenizer(
102+
sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"]
103+
)
104+
inputs = {k: v.to(device) for k, v in inputs.items()}
105+
106+
with torch.no_grad():
107+
model(**inputs)
108+
109+
return forward_loop
110+
111+
112+
def calibrate_sparse_attention(
113+
model: nn.Module,
114+
config: dict[str, Any],
115+
forward_loop: Callable | None = None,
116+
) -> dict[str, Any]:
117+
"""Calibrate sparse attention parameters for optimal sparsity.
118+
119+
Args:
120+
model: Model with sparse attention modules
121+
config: Sparse attention configuration dict
122+
forward_loop: Callable that forwards calibration data through model.
123+
If None, auto-generates RULER dataset.
124+
125+
Returns:
126+
Dictionary with calibration results
127+
"""
128+
# Extract and validate calibration config
129+
calib_config = _extract_calibration_config(config)
130+
if not calib_config:
131+
return {}
132+
133+
# Generate forward_loop if not provided
134+
if not forward_loop:
135+
tokenizer = _extract_tokenizer_from_model(model)
136+
builder = RulerDatasetBuilder(
137+
samples=calib_config.samples,
138+
max_seqlen=calib_config.max_seqlen,
139+
tokenizer_name_or_path=tokenizer,
140+
num_length_bins=calib_config.num_length_bins,
141+
max_length_filter=int(calib_config.max_seqlen * 1.2),
142+
)
143+
calibration_data = builder.build_calibration_dataset()
144+
print(f"Generated {len(calibration_data)} calibration samples")
145+
forward_loop = create_calibration_forward_loop(calibration_data, tokenizer)
146+
147+
# Get sparse attention modules
148+
sparse_modules = [
149+
(name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule)
150+
]
151+
152+
if not sparse_modules:
153+
print("No sparse attention modules found for calibration")
154+
return {}
155+
156+
print(f"Calibrating {len(sparse_modules)} sparse attention modules together...")
157+
158+
# Run calibration
159+
calibrator = DynamicThresholdCalibrator(
160+
target_sparse_ratio=calib_config.target_sparse_ratio,
161+
threshold_trials=calib_config.threshold_trials,
162+
)
163+
calibration_result = calibrator.calibrate(model, forward_loop)
164+
165+
if "scale_factor" not in calibration_result:
166+
warnings.warn("Calibration did not produce valid results")
167+
return {}
168+
169+
# Apply calibrated scale factor to all modules
170+
scale_factor = calibration_result["scale_factor"]
171+
print(f"\nApplying calibrated scale factor={scale_factor:.6f} to {len(sparse_modules)} modules")
172+
173+
for module_name, module in sparse_modules:
174+
module._sparse_method_instance.threshold_scale_factor = scale_factor
175+
176+
return {"calibration_results": {name: calibration_result for name, _ in sparse_modules}}

0 commit comments

Comments
 (0)