1515
1616"""Module for advanced quantization algorithms."""
1717
18+ import fnmatch
1819import gc
19- import hashlib
20- import json
2120import types
2221import warnings
2322from collections import defaultdict
3736from modelopt .torch .utils import create_param_grad_clear_hook , print_rank_0 , report_memory
3837from modelopt .torch .utils .distributed import DistributedProcessGroup , is_master
3938
39+ from . import config as mtq_config
4040from . import model_calib
41- from .config import FP8_DEFAULT_CFG , NVFP4_DEFAULT_CFG , QuantizeConfig , QuantizerAttributeConfig
41+ from .config import QuantizeConfig , QuantizerAttributeConfig
4242from .conversion import set_quantizer_by_cfg
43- from .nn import QuantLinearConvBase , SequentialQuantizer , TensorQuantizer
43+ from .nn import QuantLinearConvBase , QuantModule , SequentialQuantizer , TensorQuantizer
4444from .utils import is_quantized_linear , multi_context
4545
4646
@@ -82,49 +82,71 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg):
8282
8383
8484class QuantRecipe (CustomHPType ):
85- """A subclass of QuantizeConfig enabling auto_quantize specific configurations."""
85+ """A subclass of QuantizeConfig enabling auto_quantize specific configurations.
8686
87- def __init__ (
88- self , quant_cfg : dict [str , Any ] | None = None , quant_format_idx : int | None = None
89- ):
87+ Args:
88+ quant_cfg: str or dict or None. dict is used for custom quantization formats.
89+ name: name for custom quantization formats. Only used if quantization format is a custom
90+ format not available in :mod:`modelopt.torch.quantization.config`.
91+ """
92+
93+ def __init__ (self , quant_cfg : str | dict [str , Any ] | None = None , name : str | None = None ):
9094 """Initialize the QuantRecipe with the quantization configuration."""
95+ name = self .get_auto_name_for_config (quant_cfg ) or name
96+
9197 if quant_cfg is None :
92- self .config = QuantizeConfig (quant_cfg = {"*" : {"enable" : False }}, algorithm = "max" )
98+ quant_cfg = {"quant_cfg" : {"*" : {"enable" : False }}}
99+ elif isinstance (quant_cfg , str ):
100+ assert hasattr (mtq_config , quant_cfg ), f"Unknown quantization format { quant_cfg } "
101+ quant_cfg = getattr (mtq_config , quant_cfg )
93102 else :
94- self .config = QuantizeConfig (** quant_cfg )
103+ assert name is not None , "name must be provided for custom quantization formats"
104+
105+ self .config = mtq_config .QuantizeConfig (** quant_cfg ) # type: ignore [arg-type]
95106
96107 # Disable KV Cache quantization
97108 # Currently KV Cache quantization is enabled for some quantization formats and disabled for others
98109 # This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy
99- self .config .quant_cfg ["*output_quantizer" ] = QuantizerAttributeConfig (enable = False )
110+ self .config .quant_cfg ["*output_quantizer" ] = mtq_config .QuantizerAttributeConfig (
111+ enable = False
112+ )
100113
101114 self .compression = estimate_quant_compression (self .config )
102115
103- self .str_repr = (
104- f"quantization_formats[{ quant_format_idx } ]:effective-bits-{ self .compression * 16 } "
105- )
116+ self ._str_repr : str = f"{ name } (effective-bits: { self .compression * 16 } )"
117+
118+ @staticmethod
119+ def get_auto_name_for_config (quant_cfg : str | dict [str , Any ] | None ) -> str | None :
120+ """Get a name for the quantization configuration."""
121+ if quant_cfg is None :
122+ return "NONE"
123+ if isinstance (quant_cfg , str ):
124+ return quant_cfg
125+ for quant_cfg_name in mtq_config .choices :
126+ if quant_cfg == getattr (mtq_config , quant_cfg_name ):
127+ return quant_cfg_name
128+ return None
106129
107130 @property
108131 def num_bits (self ) -> int :
109132 """Get the number of bits for the quantization format."""
110133 return int (self .compression * 16 )
111134
112135 def __str__ (self ) -> str :
113- return f" { self .str_repr } "
136+ return self ._str_repr
114137
115138 def __repr__ (self ) -> str :
116- return f" { self .config } "
139+ return self ._str_repr
117140
118141 def __lt__ (self , other : "QuantRecipe" ):
119142 return self .compression < other .compression
120143
121144 def __eq__ (self , other : object ):
122145 assert isinstance (other , QuantRecipe )
123- return self .config == other .config
146+ return self ._str_repr == other ._str_repr
124147
125148 def __hash__ (self ) -> int :
126- sorted_json = json .dumps (json .loads (self .config .model_dump_json ()), sort_keys = True )
127- return int (hashlib .md5 (sorted_json .encode ("utf-8" ), usedforsecurity = False ).hexdigest (), 16 )
149+ return hash (self ._str_repr )
128150
129151 @staticmethod
130152 def disable_folding_pqs_to_weights ():
@@ -154,13 +176,12 @@ class QuantRecipeHparam(Hparam):
154176
155177 def __init__ (
156178 self ,
157- choices : Sequence [QuantRecipe ],
158- original : QuantRecipe | None = None ,
179+ choices : Sequence [QuantRecipe ] | None = None ,
159180 nn_modules : list [nn .Module ] | None = None ,
160181 ) -> None :
161182 """Initializes Hparam with original value and choices."""
162- choices = sorted (set (choices ) | { QuantRecipe (quant_cfg = None )})
163- super ().__init__ (choices , original )
183+ choices = sorted ({ * (choices if choices else []), QuantRecipe (quant_cfg = None )})
184+ super ().__init__ (choices , original = choices [ 0 ] )
164185 self .nn_modules = nn_modules if nn_modules else []
165186
166187 # This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer
@@ -253,14 +274,15 @@ class AutoQuantizeSearcher(BaseSearcher):
253274 def default_search_config (self ):
254275 """Get the default config for the searcher."""
255276 return {
256- "quantization_formats" : [NVFP4_DEFAULT_CFG , FP8_DEFAULT_CFG ],
277+ "quantization_formats" : [" NVFP4_DEFAULT_CFG" , " FP8_DEFAULT_CFG" ],
257278 "data_loader" : None ,
258279 "forward_step" : None ,
259280 "loss_func" : None ,
260281 "forward_backward_step" : None ,
261282 "num_calib_steps" : 512 ,
262283 "num_score_steps" : 128 ,
263284 "deployment" : None ,
285+ "disabled_layers" : None ,
264286 "verbose" : is_master (),
265287 "checkpoint" : None ,
266288 }
@@ -271,6 +293,7 @@ def default_state_dict(self) -> SearchStateDict:
271293 return {
272294 "candidate_stats" : defaultdict (dict ),
273295 "best" : {"recipe" : {}, "constraints" : {}, "score" : float ("inf" ), "is_satisfied" : False },
296+ "constraints" : {},
274297 }
275298
276299 def sanitize_search_config (self , config : SearchConfig | None ) -> SearchConfig :
@@ -297,15 +320,19 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
297320
298321 @staticmethod
299322 def _is_auto_quantize_module (module ):
300- return is_quantized_linear (module ) or isinstance (module , QuantLinearConvBase )
323+ return (
324+ is_quantized_linear (module ) or isinstance (module , QuantLinearConvBase )
325+ ) and isinstance (module , QuantModule )
301326
302327 @staticmethod
303328 def _get_search_recipes (quantization_formats ):
304329 return sorted (
305- [
306- QuantRecipe (quant_cfg = q , quant_format_idx = i )
307- for i , q in enumerate (quantization_formats )
308- ]
330+ {
331+ QuantRecipe (quant_cfg = q [0 ], name = q [1 ])
332+ if isinstance (q , tuple )
333+ else QuantRecipe (quant_cfg = q )
334+ for q in quantization_formats
335+ }
309336 )
310337
311338 @classmethod
@@ -337,7 +364,7 @@ def forward_backward_step(model, data):
337364 def _estimate_auto_quantize_scores (self ):
338365 # TODO: remove the no-quant recipe
339366 def auto_quantize_score_estimate_forward (module , input , * args , ** kwargs ):
340- module .quant_recipe = QuantRecipe (quant_cfg = None , quant_format_idx = None )
367+ module .quant_recipe = QuantRecipe (quant_cfg = None )
341368 output = module ._forward_original (input , * args , ** kwargs )
342369
343370 # If gradient checkpointing is enabled, gradient will not be enabled in the global forward pass.
@@ -372,6 +399,9 @@ def backward_hook(module, grad_input, grad_output):
372399 del module .output_diff_dict
373400
374401 def setup_params_for_score_estimation (name , param , params_metadata ):
402+ # Let us delete the gradient as soon as they are computed to save memory
403+ # In addition, this method enables gradient for all parameters
404+ # This is needed to make sure the re-entrant activation checkpointing works
375405 params_metadata [name ] = {"requires_grad" : param .requires_grad }
376406 param .requires_grad = True
377407 accum_grad , handle = create_param_grad_clear_hook (param )
@@ -394,15 +424,17 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
394424 params_metadata [name ]["handle" ].remove ()
395425
396426 for name , module in self .model .named_modules ():
397- if self ._is_auto_quantize_module (module ):
427+ if (
428+ self ._is_auto_quantize_module (module )
429+ and module .get_hparam ("quant_recipe" ).is_configurable
430+ ):
398431 # Monkey patch the forward methods to cache Y(Q(W), Q(X)) - Y(W,X)
399432 setup_module_for_score_estimation (module )
400433
401434 params_metadata = {}
402435 for name , param in self .model .named_parameters ():
403- # Let us delete the gradient as soon as they are computed to save memory
404- # In addition, this method enables gradient for all parameters
405- # This is needed to make sure the re-entrant activation checkpointing works
436+ # TODO: Enabling gradient for all parameters is not needed and making backward slow
437+ # We need to enable gradient only for the the first parameter of the module such as embedding weights
406438 setup_params_for_score_estimation (name , param , params_metadata )
407439
408440 gc .collect ()
@@ -420,7 +452,10 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
420452 report_memory ("AutoQuantize: After score estimation" )
421453
422454 for name , module in self .model .named_modules ():
423- if self ._is_auto_quantize_module (module ):
455+ if (
456+ self ._is_auto_quantize_module (module )
457+ and module .get_hparam ("quant_recipe" ).is_configurable
458+ ):
424459 cleanup_module_after_score_estimation (module )
425460
426461 for name , param in self .model .named_parameters ():
@@ -431,15 +466,29 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
431466 gc .collect ()
432467
433468 @classmethod
434- def insert_hparams_after_merge_rules (cls , model , quant_recipes ):
469+ def insert_hparams_after_merge_rules (cls , model , quant_recipes , disabled_layers = None ):
435470 """Restrict the search space using the merge rules and insert the hparams for the model."""
436471 # TRTLLM fuses linear layers such as q_proj, k_proj, v_proj into same layer
437472 # Hence we need to restrict the search space so that all these layers share the same recipe
438473 # Lets group the modules based on the rules and insert the same hparam for all the modules in the group
439- search_map : dict [str , list [nn .Module ]] = {}
474+
475+ if disabled_layers is None :
476+ disabled_layers = []
477+ elif isinstance (disabled_layers , str ):
478+ disabled_layers = [disabled_layers ]
479+
480+ search_map : dict [str , list [tuple [nn .Module , bool ]]] = {}
440481 for name , module in model .named_modules ():
441482 if not cls ._is_auto_quantize_module (module ):
442483 continue
484+
485+ # Skip layers that match disabled_layers patterns
486+ disabled = False
487+ for pattern in disabled_layers :
488+ if fnmatch .fnmatch (name , pattern ):
489+ disabled = True
490+ break
491+
443492 prefix = name
444493 for rule in cls .rules :
445494 pattern = re .compile (rule )
@@ -449,15 +498,17 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes):
449498 # We support only one rule for matching per module
450499 break
451500 if prefix not in search_map :
452- search_map [prefix ] = [module ]
501+ search_map [prefix ] = [( module , disabled ) ]
453502 else :
454- search_map [prefix ].append (module )
455-
456- for prefix , modules in search_map .items ():
457- hparam = QuantRecipeHparam (
458- quant_recipes ,
459- original = quant_recipes [0 ],
460- nn_modules = modules ,
503+ search_map [prefix ].append ((module , disabled ))
504+
505+ for prefix , module_info_list in search_map .items ():
506+ modules = [module for module , _ in module_info_list ]
507+ disabled = any (disabled for _ , disabled in module_info_list )
508+ hparam = (
509+ QuantRecipeHparam (None , nn_modules = modules )
510+ if disabled
511+ else QuantRecipeHparam (quant_recipes , nn_modules = modules )
461512 )
462513 for module in modules :
463514 module ._register_hparam ("quant_recipe" , hparam )
@@ -495,7 +546,9 @@ def before_search(self):
495546
496547 search_recipes = self ._get_search_recipes (self .config ["quantization_formats" ])
497548 self ._verify_constraint (search_recipes )
498- self .insert_hparams_after_merge_rules (self .model , search_recipes )
549+ self .insert_hparams_after_merge_rules (
550+ self .model , search_recipes , self .config ["disabled_layers" ]
551+ )
499552
500553 QuantRecipe .disable_folding_pqs_to_weights ()
501554
@@ -557,18 +610,11 @@ def get_total_weight_size(modules):
557610 for module in modules
558611 )
559612
560- def _get_constraints_for_search (lower_bound = None ):
561- total_model_weight_size = get_total_weight_size (self .model .modules ())
562-
563- upper_bound = self ._get_formatted_weight_compression_constraint ()
564-
565- if lower_bound :
566- lower_bound = lower_bound * upper_bound
567-
613+ def _get_constraints_for_search (max_weight_size , lower_bound = None ):
568614 constraints = {
569615 "weight_size_after_compression" : (
570- lower_bound * total_model_weight_size if lower_bound else lower_bound ,
571- upper_bound * total_model_weight_size ,
616+ lower_bound * max_weight_size if lower_bound else lower_bound ,
617+ max_weight_size ,
572618 )
573619 }
574620 return constraints , "weight_size_after_compression"
@@ -579,16 +625,20 @@ def _get_constraints_for_search(lower_bound=None):
579625 f"Got { self .constraints .keys ()} "
580626 )
581627
582- search_recipes = self ._get_search_recipes (self .config ["quantization_formats" ])
583- for name , hparam in named_hparams (self .model , configurable = True ):
628+ compression = self ._get_formatted_weight_compression_constraint ()
629+ total_weight_size = get_total_weight_size (self .model .modules ())
630+ weight_size_after_compression = total_weight_size * compression
631+
632+ for name , hparam in named_hparams (self .model , unique = True ):
584633 if not isinstance (hparam , QuantRecipeHparam ):
585634 continue
635+
586636 formats , scores , costs = [], [], []
587637 prev_score = float ("inf" )
588- for recipe in search_recipes :
638+ for recipe in hparam . choices :
589639 formats .append (recipe )
590640 score = hparam .importance [recipe ]
591- cost = get_total_weight_size (hparam .nn_modules ) * recipe .compression
641+ cost = get_total_weight_size (hparam .nn_modules ) * recipe .compression # type: ignore [union-attr]
592642
593643 # Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
594644 # This way we constraint the same quantization format for the same layer across the DP/TP groups
@@ -602,6 +652,7 @@ def _get_constraints_for_search(lower_bound=None):
602652 scores .append (min (score , prev_score ))
603653 costs .append (cost )
604654 prev_score = score
655+
605656 self .candidate_stats [name ]["formats" ] = formats
606657 self .candidate_stats [name ]["scores" ] = scores
607658 self .candidate_stats [name ]["costs" ] = costs
@@ -611,7 +662,9 @@ def _get_constraints_for_search(lower_bound=None):
611662 # specified. I dont know why this happens.
612663 # As a workaround, lets specify a lower bound for the weight compression if previous
613664 # search without lower bound fails.
614- constraints , constraint_name = _get_constraints_for_search (lower_bound )
665+ constraints , constraint_name = _get_constraints_for_search (
666+ weight_size_after_compression , lower_bound
667+ )
615668
616669 lps = LPS (
617670 name = "AutoQuantize" ,
@@ -664,8 +717,14 @@ def _get_constraints_for_search(lower_bound=None):
664717 f"AutoQuantize best recipe for { name .replace ('.quant_recipe' , '' )} : { best_recipe [name ]} "
665718 )
666719
720+ effective_bits_from_search = (best_constraints / total_weight_size ) * 16
721+ if verbose :
722+ print_rank_0 (
723+ f"AutoQuantize effective bits from search: { effective_bits_from_search : .2f} "
724+ )
725+
667726 self .best ["recipe" ] = best_recipe
668- self .best ["constraints" ] = {constraint_name : best_constraints }
727+ self .best ["constraints" ] = {"effective_bits" : effective_bits_from_search }
669728 self .best ["score" ] = best_scores
670729
671730 QuantRecipe .fold_pqs_to_weights (self .model )
0 commit comments