@@ -338,6 +338,129 @@ def _export_quantized_weight(
338338 sub_module .register_buffer (quantizer_attrs .weight_scale , weight_scale )
339339
340340
341+ def _get_sparse_attention_config (model : nn .Module ) -> dict [str , Any ]:
342+ """Extract sparse attention configuration from model for export.
343+
344+ Args:
345+ model: Model with sparse attention modules
346+
347+ Returns:
348+ Dictionary with sparse attention config in format:
349+ {
350+ "config_groups": {
351+ "group_0": {
352+ "sparse_algo": "softmax_skip",
353+ "threshold": 1e-4, # only if not calibrated
354+ "targets": ["LlamaAttention"]
355+ }
356+ },
357+ "threshold_scale_factor": 0.001234, # global, if calibrated
358+ "target_sparsity": 0.5, # global, if calibrated
359+ "producer": {"name": "modelopt", "version": "..."}
360+ }
361+ """
362+ from modelopt import __version__
363+ from modelopt .torch .sparsity .attention_sparsity .nn .sparse_attention import SparseAttentionModule
364+
365+ # Collect all enabled sparse attention modules
366+ sparse_modules = []
367+ for name , module in model .named_modules ():
368+ if isinstance (module , SparseAttentionModule ) and module .is_enabled :
369+ sparse_modules .append ((name , module ))
370+
371+ if not sparse_modules :
372+ return {}
373+
374+ sparse_config = {
375+ "config_groups" : {},
376+ "producer" : {
377+ "name" : "modelopt" ,
378+ "version" : __version__ ,
379+ },
380+ }
381+
382+ # Check first module for global calibration parameters
383+ # (all modules share the same calibration parameters)
384+ first_module = sparse_modules [0 ][1 ]
385+ method_instance = first_module ._sparse_method_instance
386+ threshold_scale_factor = getattr (method_instance , "threshold_scale_factor" , None )
387+
388+ if threshold_scale_factor is not None :
389+ # Model was calibrated: add global calibration parameters
390+ sparse_config ["threshold_scale_factor" ] = float (threshold_scale_factor )
391+
392+ target_sparsity = getattr (method_instance , "target_sparsity" , None )
393+ if target_sparsity is not None :
394+ sparse_config ["target_sparsity" ] = float (target_sparsity )
395+
396+ # Group modules by configuration
397+ # Key: (sparse_algo, threshold_repr), Value: list of module class names
398+ config_to_targets = {}
399+
400+ for name , module in sparse_modules :
401+ method_instance = module ._sparse_method_instance
402+
403+ # Extract sparse algorithm name from method name
404+ # e.g., "flash_softmax_skip" -> "softmax_skip"
405+ method_name = method_instance .name
406+ if method_name .startswith ("flash_" ):
407+ sparse_algo = method_name [6 :] # Remove "flash_" prefix
408+ else :
409+ sparse_algo = method_name
410+
411+ # Get module's original class name for targets
412+ # Get the class name before SparseAttentionModule wrapping
413+ original_cls = module .get_original_cls_by_level (level = 0 )
414+ target_class_name = original_cls .__name__
415+
416+ # Build config key for grouping
417+ if threshold_scale_factor is None :
418+ # Not calibrated: include threshold in grouping
419+ threshold_config = getattr (method_instance , "threshold_config" , None )
420+ if isinstance (threshold_config , dict ):
421+ # Convert dict to tuple for hashable key
422+ threshold_repr = tuple (sorted (threshold_config .items ()))
423+ else :
424+ threshold_repr = threshold_config
425+ else :
426+ # Calibrated: no threshold in per-layer config
427+ threshold_repr = None
428+
429+ config_key = (sparse_algo , threshold_repr )
430+
431+ if config_key not in config_to_targets :
432+ config_to_targets [config_key ] = {
433+ "sparse_algo" : sparse_algo ,
434+ "threshold_config" : threshold_config if threshold_scale_factor is None else None ,
435+ "targets" : set (),
436+ }
437+
438+ config_to_targets [config_key ]["targets" ].add (target_class_name )
439+
440+ # Convert grouped configs to config_groups format
441+ for group_idx , ((sparse_algo , threshold_repr ), group_data ) in enumerate (
442+ config_to_targets .items ()
443+ ):
444+ group_name = f"group_{ group_idx } "
445+ group_config = {
446+ "sparse_algo" : group_data ["sparse_algo" ],
447+ "targets" : sorted (group_data ["targets" ]),
448+ }
449+
450+ # Add threshold only if not calibrated
451+ if group_data ["threshold_config" ] is not None :
452+ threshold_config = group_data ["threshold_config" ]
453+ if isinstance (threshold_config , dict ):
454+ # Convert to JSON-serializable format
455+ group_config ["threshold" ] = {k : float (v ) for k , v in threshold_config .items ()}
456+ else :
457+ group_config ["threshold" ] = float (threshold_config )
458+
459+ sparse_config ["config_groups" ][group_name ] = group_config
460+
461+ return sparse_config
462+
463+
341464def _export_hf_checkpoint (
342465 model : nn .Module , dtype : torch .dtype | None = None
343466) -> tuple [dict [str , Any ], dict [str , Any ]]:
@@ -543,6 +666,11 @@ def export_hf_checkpoint(
543666
544667 config_data ["quantization_config" ] = hf_quant_config
545668
669+ # Add sparse attention config if model has sparse attention
670+ sparse_attention_config = _get_sparse_attention_config (model )
671+ if sparse_attention_config :
672+ config_data ["sparse_attention_config" ] = sparse_attention_config
673+
546674 with open (original_config , "w" ) as file :
547675 json .dump (config_data , file , indent = 4 )
548676
0 commit comments