Skip to content

Commit 023159a

Browse files
committed
add hf unified ckpt export for sparse attention
Signed-off-by: Kai Xu <[email protected]>
1 parent 9c15dbc commit 023159a

File tree

10 files changed

+1099
-11
lines changed

10 files changed

+1099
-11
lines changed

examples/llm_sparse_attention/README.md

Lines changed: 398 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
accelerate
2+
datasets
3+
transformers
4+

modelopt/torch/export/unified_export_hf.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,129 @@ def _export_quantized_weight(
338338
sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale)
339339

340340

341+
def _get_sparse_attention_config(model: nn.Module) -> dict[str, Any]:
342+
"""Extract sparse attention configuration from model for export.
343+
344+
Args:
345+
model: Model with sparse attention modules
346+
347+
Returns:
348+
Dictionary with sparse attention config in format:
349+
{
350+
"config_groups": {
351+
"group_0": {
352+
"sparse_algo": "softmax_skip",
353+
"threshold": 1e-4, # only if not calibrated
354+
"targets": ["LlamaAttention"]
355+
}
356+
},
357+
"threshold_scale_factor": 0.001234, # global, if calibrated
358+
"target_sparsity": 0.5, # global, if calibrated
359+
"producer": {"name": "modelopt", "version": "..."}
360+
}
361+
"""
362+
from modelopt import __version__
363+
from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule
364+
365+
# Collect all enabled sparse attention modules
366+
sparse_modules = []
367+
for name, module in model.named_modules():
368+
if isinstance(module, SparseAttentionModule) and module.is_enabled:
369+
sparse_modules.append((name, module))
370+
371+
if not sparse_modules:
372+
return {}
373+
374+
sparse_config = {
375+
"config_groups": {},
376+
"producer": {
377+
"name": "modelopt",
378+
"version": __version__,
379+
},
380+
}
381+
382+
# Check first module for global calibration parameters
383+
# (all modules share the same calibration parameters)
384+
first_module = sparse_modules[0][1]
385+
method_instance = first_module._sparse_method_instance
386+
threshold_scale_factor = getattr(method_instance, "threshold_scale_factor", None)
387+
388+
if threshold_scale_factor is not None:
389+
# Model was calibrated: add global calibration parameters
390+
sparse_config["threshold_scale_factor"] = float(threshold_scale_factor)
391+
392+
target_sparsity = getattr(method_instance, "target_sparsity", None)
393+
if target_sparsity is not None:
394+
sparse_config["target_sparsity"] = float(target_sparsity)
395+
396+
# Group modules by configuration
397+
# Key: (sparse_algo, threshold_repr), Value: list of module class names
398+
config_to_targets = {}
399+
400+
for name, module in sparse_modules:
401+
method_instance = module._sparse_method_instance
402+
403+
# Extract sparse algorithm name from method name
404+
# e.g., "flash_softmax_skip" -> "softmax_skip"
405+
method_name = method_instance.name
406+
if method_name.startswith("flash_"):
407+
sparse_algo = method_name[6:] # Remove "flash_" prefix
408+
else:
409+
sparse_algo = method_name
410+
411+
# Get module's original class name for targets
412+
# Get the class name before SparseAttentionModule wrapping
413+
original_cls = module.get_original_cls_by_level(level=0)
414+
target_class_name = original_cls.__name__
415+
416+
# Build config key for grouping
417+
if threshold_scale_factor is None:
418+
# Not calibrated: include threshold in grouping
419+
threshold_config = getattr(method_instance, "threshold_config", None)
420+
if isinstance(threshold_config, dict):
421+
# Convert dict to tuple for hashable key
422+
threshold_repr = tuple(sorted(threshold_config.items()))
423+
else:
424+
threshold_repr = threshold_config
425+
else:
426+
# Calibrated: no threshold in per-layer config
427+
threshold_repr = None
428+
429+
config_key = (sparse_algo, threshold_repr)
430+
431+
if config_key not in config_to_targets:
432+
config_to_targets[config_key] = {
433+
"sparse_algo": sparse_algo,
434+
"threshold_config": threshold_config if threshold_scale_factor is None else None,
435+
"targets": set(),
436+
}
437+
438+
config_to_targets[config_key]["targets"].add(target_class_name)
439+
440+
# Convert grouped configs to config_groups format
441+
for group_idx, ((sparse_algo, threshold_repr), group_data) in enumerate(
442+
config_to_targets.items()
443+
):
444+
group_name = f"group_{group_idx}"
445+
group_config = {
446+
"sparse_algo": group_data["sparse_algo"],
447+
"targets": sorted(group_data["targets"]),
448+
}
449+
450+
# Add threshold only if not calibrated
451+
if group_data["threshold_config"] is not None:
452+
threshold_config = group_data["threshold_config"]
453+
if isinstance(threshold_config, dict):
454+
# Convert to JSON-serializable format
455+
group_config["threshold"] = {k: float(v) for k, v in threshold_config.items()}
456+
else:
457+
group_config["threshold"] = float(threshold_config)
458+
459+
sparse_config["config_groups"][group_name] = group_config
460+
461+
return sparse_config
462+
463+
341464
def _export_hf_checkpoint(
342465
model: nn.Module, dtype: torch.dtype | None = None
343466
) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -543,6 +666,11 @@ def export_hf_checkpoint(
543666

544667
config_data["quantization_config"] = hf_quant_config
545668

669+
# Add sparse attention config if model has sparse attention
670+
sparse_attention_config = _get_sparse_attention_config(model)
671+
if sparse_attention_config:
672+
config_data["sparse_attention_config"] = sparse_attention_config
673+
546674
with open(original_config, "w") as file:
547675
json.dump(config_data, file, indent=4)
548676

modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,5 +172,6 @@ def calibrate_sparse_attention(
172172

173173
for module_name, module in sparse_modules:
174174
module._sparse_method_instance.threshold_scale_factor = scale_factor
175+
module._sparse_method_instance.target_sparsity = calib_config.target_sparse_ratio
175176

176177
return {"calibration_results": {name: calibration_result for name, _ in sparse_modules}}

modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,9 @@ def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]:
140140

141141
print(f"Collected statistics for {len(self.sparsity_results)} samples")
142142

143-
# Stage 2: Find optimal threshold for each sample and compute 'a'
143+
# Stage 2: Find optimal threshold for each sample and compute scale factor
144144
print(
145-
f"\nStage 2: Finding 'a' parameter for target sparsity {self.target_sparse_ratio:.2f}"
145+
f"\nStage 2: Finding threshold scale factor for target sparsity {self.target_sparse_ratio:.2f}"
146146
)
147147

148148
# Find optimal threshold for each sample

modelopt/torch/sparsity/attention_sparsity/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,9 @@ def validate_num_length_bins(cls, v):
279279
"backend": "pytorch", # Only pytorch backend supported
280280
"enable": True,
281281
"calibration": {
282-
"target_sparse_ratio": 0.5,
283-
"samples": 120,
284-
"max_seqlen": 8192,
282+
"target_sparse_ratio": 0.3,
283+
"samples": 12,
284+
"max_seqlen": 1024,
285285
},
286286
},
287287
"default": {"enable": False},

modelopt/torch/sparsity/attention_sparsity/conversion.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,23 @@ def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any])
213213
module._method = module_state["method"]
214214
if "method_config" in module_state:
215215
# Restore config attributes
216+
# Separate method instance attributes from module attributes
217+
method_instance_attrs = {"threshold_scale_factor", "target_sparsity"}
218+
216219
for key, val in module_state["method_config"].items():
217-
setattr(module, f"_{key}", val)
220+
if key not in method_instance_attrs:
221+
# Set on module
222+
setattr(module, f"_{key}", val)
218223

219224
# Re-setup with restored config
220225
module._setup()
221226

227+
# Restore method instance attributes after _setup
228+
if "method_config" in module_state:
229+
for key, val in module_state["method_config"].items():
230+
if key in {"threshold_scale_factor", "target_sparsity"}:
231+
setattr(module._sparse_method_instance, key, val)
232+
222233

223234
def update_sparse_attention_metadata(
224235
model: nn.Module, config: SparseAttentionConfig, metadata: MetadataDict
@@ -243,8 +254,16 @@ def update_sparse_attention_metadata(
243254
if k.startswith("_") and k not in ("_method", "_enabled", "_sparse_method_instance")
244255
}
245256

257+
# Also collect calibration-related attributes from method instance
258+
method_instance = module._sparse_method_instance
259+
for attr in ["threshold_scale_factor", "target_sparsity"]:
260+
if hasattr(method_instance, attr):
261+
val = getattr(method_instance, attr)
262+
if val is not None:
263+
method_config[attr] = val
264+
246265
module_state = {
247-
"method": module._sparse_method_instance.name,
266+
"method": method_instance.name,
248267
"method_config": method_config,
249268
}
250269

tests/_test_utils/torch_sparsity/sparse_attention_common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,13 @@ def get_input(cls, d_model=128, seq_len=10, batch_size=2):
9393
# Test configurations
9494
FLASH_SOFTMAX_SKIP_DEFAULT_CFG = {
9595
"method": "flash_softmax_skip",
96-
"sparse_cfg": {"*attention*": {"threshold": 1e-4, "br": 128, "bc": 128, "enable": True}},
96+
"sparse_cfg": {"*attn*": {"threshold": 1e-4, "br": 128, "bc": 128, "enable": True}},
9797
}
9898

9999
FLASH_SOFTMAX_SKIP_PHASE_AWARE_CFG = {
100100
"method": "flash_softmax_skip",
101101
"sparse_cfg": {
102-
"*attention*": {
102+
"*attn*": {
103103
"threshold": {"prefill": 1e-3, "decode": 1e-5},
104104
"br": 128,
105105
"bc": 128,
@@ -112,7 +112,7 @@ def get_input(cls, d_model=128, seq_len=10, batch_size=2):
112112
"method": "flash_softmax_skip",
113113
"collect_stats": True,
114114
"sparse_cfg": {
115-
"*attention*": {
115+
"*attn*": {
116116
"threshold": 1e-4,
117117
"br": 128,
118118
"bc": 128,
@@ -125,7 +125,7 @@ def get_input(cls, d_model=128, seq_len=10, batch_size=2):
125125
FLASH_SOFTMAX_SKIP_CALIBRATION_CFG = {
126126
"method": "flash_softmax_skip",
127127
"sparse_cfg": {
128-
"*attention*": {
128+
"*attn*": {
129129
"br": 128,
130130
"bc": 128,
131131
"enable": True,

0 commit comments

Comments
 (0)