15
15
16
16
"""Module for advanced quantization algorithms."""
17
17
18
+ import fnmatch
18
19
import gc
19
- import hashlib
20
- import json
21
20
import types
22
21
import warnings
23
22
from collections import defaultdict
37
36
from modelopt .torch .utils import create_param_grad_clear_hook , print_rank_0 , report_memory
38
37
from modelopt .torch .utils .distributed import DistributedProcessGroup , is_master
39
38
39
+ from . import config as mtq_config
40
40
from . import model_calib
41
- from .config import FP8_DEFAULT_CFG , NVFP4_DEFAULT_CFG , QuantizeConfig , QuantizerAttributeConfig
41
+ from .config import QuantizeConfig , QuantizerAttributeConfig
42
42
from .conversion import set_quantizer_by_cfg
43
- from .nn import QuantLinearConvBase , SequentialQuantizer , TensorQuantizer
43
+ from .nn import QuantLinearConvBase , QuantModule , SequentialQuantizer , TensorQuantizer
44
44
from .utils import is_quantized_linear , multi_context
45
45
46
46
@@ -82,49 +82,71 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg):
82
82
83
83
84
84
class QuantRecipe (CustomHPType ):
85
- """A subclass of QuantizeConfig enabling auto_quantize specific configurations."""
85
+ """A subclass of QuantizeConfig enabling auto_quantize specific configurations.
86
86
87
- def __init__ (
88
- self , quant_cfg : dict [str , Any ] | None = None , quant_format_idx : int | None = None
89
- ):
87
+ Args:
88
+ quant_cfg: str or dict or None. dict is used for custom quantization formats.
89
+ name: name for custom quantization formats. Only used if quantization format is a custom
90
+ format not available in :mod:`modelopt.torch.quantization.config`.
91
+ """
92
+
93
+ def __init__ (self , quant_cfg : str | dict [str , Any ] | None = None , name : str | None = None ):
90
94
"""Initialize the QuantRecipe with the quantization configuration."""
95
+ name = self .get_auto_name_for_config (quant_cfg ) or name
96
+
91
97
if quant_cfg is None :
92
- self .config = QuantizeConfig (quant_cfg = {"*" : {"enable" : False }}, algorithm = "max" )
98
+ quant_cfg = {"quant_cfg" : {"*" : {"enable" : False }}}
99
+ elif isinstance (quant_cfg , str ):
100
+ assert hasattr (mtq_config , quant_cfg ), f"Unknown quantization format { quant_cfg } "
101
+ quant_cfg = getattr (mtq_config , quant_cfg )
93
102
else :
94
- self .config = QuantizeConfig (** quant_cfg )
103
+ assert name is not None , "name must be provided for custom quantization formats"
104
+
105
+ self .config = mtq_config .QuantizeConfig (** quant_cfg ) # type: ignore [arg-type]
95
106
96
107
# Disable KV Cache quantization
97
108
# Currently KV Cache quantization is enabled for some quantization formats and disabled for others
98
109
# This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy
99
- self .config .quant_cfg ["*output_quantizer" ] = QuantizerAttributeConfig (enable = False )
110
+ self .config .quant_cfg ["*output_quantizer" ] = mtq_config .QuantizerAttributeConfig (
111
+ enable = False
112
+ )
100
113
101
114
self .compression = estimate_quant_compression (self .config )
102
115
103
- self .str_repr = (
104
- f"quantization_formats[{ quant_format_idx } ]:effective-bits-{ self .compression * 16 } "
105
- )
116
+ self ._str_repr : str = f"{ name } (effective-bits: { self .compression * 16 } )"
117
+
118
+ @staticmethod
119
+ def get_auto_name_for_config (quant_cfg : str | dict [str , Any ] | None ) -> str | None :
120
+ """Get a name for the quantization configuration."""
121
+ if quant_cfg is None :
122
+ return "NONE"
123
+ if isinstance (quant_cfg , str ):
124
+ return quant_cfg
125
+ for quant_cfg_name in mtq_config .choices :
126
+ if quant_cfg == getattr (mtq_config , quant_cfg_name ):
127
+ return quant_cfg_name
128
+ return None
106
129
107
130
@property
108
131
def num_bits (self ) -> int :
109
132
"""Get the number of bits for the quantization format."""
110
133
return int (self .compression * 16 )
111
134
112
135
def __str__ (self ) -> str :
113
- return f" { self .str_repr } "
136
+ return self ._str_repr
114
137
115
138
def __repr__ (self ) -> str :
116
- return f" { self .config } "
139
+ return self ._str_repr
117
140
118
141
def __lt__ (self , other : "QuantRecipe" ):
119
142
return self .compression < other .compression
120
143
121
144
def __eq__ (self , other : object ):
122
145
assert isinstance (other , QuantRecipe )
123
- return self .config == other .config
146
+ return self ._str_repr == other ._str_repr
124
147
125
148
def __hash__ (self ) -> int :
126
- sorted_json = json .dumps (json .loads (self .config .model_dump_json ()), sort_keys = True )
127
- return int (hashlib .md5 (sorted_json .encode ("utf-8" ), usedforsecurity = False ).hexdigest (), 16 )
149
+ return hash (self ._str_repr )
128
150
129
151
@staticmethod
130
152
def disable_folding_pqs_to_weights ():
@@ -154,13 +176,12 @@ class QuantRecipeHparam(Hparam):
154
176
155
177
def __init__ (
156
178
self ,
157
- choices : Sequence [QuantRecipe ],
158
- original : QuantRecipe | None = None ,
179
+ choices : Sequence [QuantRecipe ] | None = None ,
159
180
nn_modules : list [nn .Module ] | None = None ,
160
181
) -> None :
161
182
"""Initializes Hparam with original value and choices."""
162
- choices = sorted (set (choices ) | { QuantRecipe (quant_cfg = None )})
163
- super ().__init__ (choices , original )
183
+ choices = sorted ({ * (choices if choices else []), QuantRecipe (quant_cfg = None )})
184
+ super ().__init__ (choices , original = choices [ 0 ] )
164
185
self .nn_modules = nn_modules if nn_modules else []
165
186
166
187
# This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer
@@ -253,14 +274,15 @@ class AutoQuantizeSearcher(BaseSearcher):
253
274
def default_search_config (self ):
254
275
"""Get the default config for the searcher."""
255
276
return {
256
- "quantization_formats" : [NVFP4_DEFAULT_CFG , FP8_DEFAULT_CFG ],
277
+ "quantization_formats" : [" NVFP4_DEFAULT_CFG" , " FP8_DEFAULT_CFG" ],
257
278
"data_loader" : None ,
258
279
"forward_step" : None ,
259
280
"loss_func" : None ,
260
281
"forward_backward_step" : None ,
261
282
"num_calib_steps" : 512 ,
262
283
"num_score_steps" : 128 ,
263
284
"deployment" : None ,
285
+ "disabled_layers" : None ,
264
286
"verbose" : is_master (),
265
287
"checkpoint" : None ,
266
288
}
@@ -271,6 +293,7 @@ def default_state_dict(self) -> SearchStateDict:
271
293
return {
272
294
"candidate_stats" : defaultdict (dict ),
273
295
"best" : {"recipe" : {}, "constraints" : {}, "score" : float ("inf" ), "is_satisfied" : False },
296
+ "constraints" : {},
274
297
}
275
298
276
299
def sanitize_search_config (self , config : SearchConfig | None ) -> SearchConfig :
@@ -297,15 +320,19 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
297
320
298
321
@staticmethod
299
322
def _is_auto_quantize_module (module ):
300
- return is_quantized_linear (module ) or isinstance (module , QuantLinearConvBase )
323
+ return (
324
+ is_quantized_linear (module ) or isinstance (module , QuantLinearConvBase )
325
+ ) and isinstance (module , QuantModule )
301
326
302
327
@staticmethod
303
328
def _get_search_recipes (quantization_formats ):
304
329
return sorted (
305
- [
306
- QuantRecipe (quant_cfg = q , quant_format_idx = i )
307
- for i , q in enumerate (quantization_formats )
308
- ]
330
+ {
331
+ QuantRecipe (quant_cfg = q [0 ], name = q [1 ])
332
+ if isinstance (q , tuple )
333
+ else QuantRecipe (quant_cfg = q )
334
+ for q in quantization_formats
335
+ }
309
336
)
310
337
311
338
@classmethod
@@ -337,7 +364,7 @@ def forward_backward_step(model, data):
337
364
def _estimate_auto_quantize_scores (self ):
338
365
# TODO: remove the no-quant recipe
339
366
def auto_quantize_score_estimate_forward (module , input , * args , ** kwargs ):
340
- module .quant_recipe = QuantRecipe (quant_cfg = None , quant_format_idx = None )
367
+ module .quant_recipe = QuantRecipe (quant_cfg = None )
341
368
output = module ._forward_original (input , * args , ** kwargs )
342
369
343
370
# If gradient checkpointing is enabled, gradient will not be enabled in the global forward pass.
@@ -372,6 +399,9 @@ def backward_hook(module, grad_input, grad_output):
372
399
del module .output_diff_dict
373
400
374
401
def setup_params_for_score_estimation (name , param , params_metadata ):
402
+ # Let us delete the gradient as soon as they are computed to save memory
403
+ # In addition, this method enables gradient for all parameters
404
+ # This is needed to make sure the re-entrant activation checkpointing works
375
405
params_metadata [name ] = {"requires_grad" : param .requires_grad }
376
406
param .requires_grad = True
377
407
accum_grad , handle = create_param_grad_clear_hook (param )
@@ -394,15 +424,17 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
394
424
params_metadata [name ]["handle" ].remove ()
395
425
396
426
for name , module in self .model .named_modules ():
397
- if self ._is_auto_quantize_module (module ):
427
+ if (
428
+ self ._is_auto_quantize_module (module )
429
+ and module .get_hparam ("quant_recipe" ).is_configurable
430
+ ):
398
431
# Monkey patch the forward methods to cache Y(Q(W), Q(X)) - Y(W,X)
399
432
setup_module_for_score_estimation (module )
400
433
401
434
params_metadata = {}
402
435
for name , param in self .model .named_parameters ():
403
- # Let us delete the gradient as soon as they are computed to save memory
404
- # In addition, this method enables gradient for all parameters
405
- # This is needed to make sure the re-entrant activation checkpointing works
436
+ # TODO: Enabling gradient for all parameters is not needed and making backward slow
437
+ # We need to enable gradient only for the the first parameter of the module such as embedding weights
406
438
setup_params_for_score_estimation (name , param , params_metadata )
407
439
408
440
gc .collect ()
@@ -420,7 +452,10 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
420
452
report_memory ("AutoQuantize: After score estimation" )
421
453
422
454
for name , module in self .model .named_modules ():
423
- if self ._is_auto_quantize_module (module ):
455
+ if (
456
+ self ._is_auto_quantize_module (module )
457
+ and module .get_hparam ("quant_recipe" ).is_configurable
458
+ ):
424
459
cleanup_module_after_score_estimation (module )
425
460
426
461
for name , param in self .model .named_parameters ():
@@ -431,15 +466,29 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
431
466
gc .collect ()
432
467
433
468
@classmethod
434
- def insert_hparams_after_merge_rules (cls , model , quant_recipes ):
469
+ def insert_hparams_after_merge_rules (cls , model , quant_recipes , disabled_layers = None ):
435
470
"""Restrict the search space using the merge rules and insert the hparams for the model."""
436
471
# TRTLLM fuses linear layers such as q_proj, k_proj, v_proj into same layer
437
472
# Hence we need to restrict the search space so that all these layers share the same recipe
438
473
# Lets group the modules based on the rules and insert the same hparam for all the modules in the group
439
- search_map : dict [str , list [nn .Module ]] = {}
474
+
475
+ if disabled_layers is None :
476
+ disabled_layers = []
477
+ elif isinstance (disabled_layers , str ):
478
+ disabled_layers = [disabled_layers ]
479
+
480
+ search_map : dict [str , list [tuple [nn .Module , bool ]]] = {}
440
481
for name , module in model .named_modules ():
441
482
if not cls ._is_auto_quantize_module (module ):
442
483
continue
484
+
485
+ # Skip layers that match disabled_layers patterns
486
+ disabled = False
487
+ for pattern in disabled_layers :
488
+ if fnmatch .fnmatch (name , pattern ):
489
+ disabled = True
490
+ break
491
+
443
492
prefix = name
444
493
for rule in cls .rules :
445
494
pattern = re .compile (rule )
@@ -449,15 +498,17 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes):
449
498
# We support only one rule for matching per module
450
499
break
451
500
if prefix not in search_map :
452
- search_map [prefix ] = [module ]
501
+ search_map [prefix ] = [( module , disabled ) ]
453
502
else :
454
- search_map [prefix ].append (module )
455
-
456
- for prefix , modules in search_map .items ():
457
- hparam = QuantRecipeHparam (
458
- quant_recipes ,
459
- original = quant_recipes [0 ],
460
- nn_modules = modules ,
503
+ search_map [prefix ].append ((module , disabled ))
504
+
505
+ for prefix , module_info_list in search_map .items ():
506
+ modules = [module for module , _ in module_info_list ]
507
+ disabled = any (disabled for _ , disabled in module_info_list )
508
+ hparam = (
509
+ QuantRecipeHparam (None , nn_modules = modules )
510
+ if disabled
511
+ else QuantRecipeHparam (quant_recipes , nn_modules = modules )
461
512
)
462
513
for module in modules :
463
514
module ._register_hparam ("quant_recipe" , hparam )
@@ -495,7 +546,9 @@ def before_search(self):
495
546
496
547
search_recipes = self ._get_search_recipes (self .config ["quantization_formats" ])
497
548
self ._verify_constraint (search_recipes )
498
- self .insert_hparams_after_merge_rules (self .model , search_recipes )
549
+ self .insert_hparams_after_merge_rules (
550
+ self .model , search_recipes , self .config ["disabled_layers" ]
551
+ )
499
552
500
553
QuantRecipe .disable_folding_pqs_to_weights ()
501
554
@@ -557,18 +610,11 @@ def get_total_weight_size(modules):
557
610
for module in modules
558
611
)
559
612
560
- def _get_constraints_for_search (lower_bound = None ):
561
- total_model_weight_size = get_total_weight_size (self .model .modules ())
562
-
563
- upper_bound = self ._get_formatted_weight_compression_constraint ()
564
-
565
- if lower_bound :
566
- lower_bound = lower_bound * upper_bound
567
-
613
+ def _get_constraints_for_search (max_weight_size , lower_bound = None ):
568
614
constraints = {
569
615
"weight_size_after_compression" : (
570
- lower_bound * total_model_weight_size if lower_bound else lower_bound ,
571
- upper_bound * total_model_weight_size ,
616
+ lower_bound * max_weight_size if lower_bound else lower_bound ,
617
+ max_weight_size ,
572
618
)
573
619
}
574
620
return constraints , "weight_size_after_compression"
@@ -579,16 +625,20 @@ def _get_constraints_for_search(lower_bound=None):
579
625
f"Got { self .constraints .keys ()} "
580
626
)
581
627
582
- search_recipes = self ._get_search_recipes (self .config ["quantization_formats" ])
583
- for name , hparam in named_hparams (self .model , configurable = True ):
628
+ compression = self ._get_formatted_weight_compression_constraint ()
629
+ total_weight_size = get_total_weight_size (self .model .modules ())
630
+ weight_size_after_compression = total_weight_size * compression
631
+
632
+ for name , hparam in named_hparams (self .model , unique = True ):
584
633
if not isinstance (hparam , QuantRecipeHparam ):
585
634
continue
635
+
586
636
formats , scores , costs = [], [], []
587
637
prev_score = float ("inf" )
588
- for recipe in search_recipes :
638
+ for recipe in hparam . choices :
589
639
formats .append (recipe )
590
640
score = hparam .importance [recipe ]
591
- cost = get_total_weight_size (hparam .nn_modules ) * recipe .compression
641
+ cost = get_total_weight_size (hparam .nn_modules ) * recipe .compression # type: ignore [union-attr]
592
642
593
643
# Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
594
644
# This way we constraint the same quantization format for the same layer across the DP/TP groups
@@ -602,6 +652,7 @@ def _get_constraints_for_search(lower_bound=None):
602
652
scores .append (min (score , prev_score ))
603
653
costs .append (cost )
604
654
prev_score = score
655
+
605
656
self .candidate_stats [name ]["formats" ] = formats
606
657
self .candidate_stats [name ]["scores" ] = scores
607
658
self .candidate_stats [name ]["costs" ] = costs
@@ -611,7 +662,9 @@ def _get_constraints_for_search(lower_bound=None):
611
662
# specified. I dont know why this happens.
612
663
# As a workaround, lets specify a lower bound for the weight compression if previous
613
664
# search without lower bound fails.
614
- constraints , constraint_name = _get_constraints_for_search (lower_bound )
665
+ constraints , constraint_name = _get_constraints_for_search (
666
+ weight_size_after_compression , lower_bound
667
+ )
615
668
616
669
lps = LPS (
617
670
name = "AutoQuantize" ,
@@ -664,8 +717,14 @@ def _get_constraints_for_search(lower_bound=None):
664
717
f"AutoQuantize best recipe for { name .replace ('.quant_recipe' , '' )} : { best_recipe [name ]} "
665
718
)
666
719
720
+ effective_bits_from_search = (best_constraints / total_weight_size ) * 16
721
+ if verbose :
722
+ print_rank_0 (
723
+ f"AutoQuantize effective bits from search: { effective_bits_from_search : .2f} "
724
+ )
725
+
667
726
self .best ["recipe" ] = best_recipe
668
- self .best ["constraints" ] = {constraint_name : best_constraints }
727
+ self .best ["constraints" ] = {"effective_bits" : effective_bits_from_search }
669
728
self .best ["score" ] = best_scores
670
729
671
730
QuantRecipe .fold_pqs_to_weights (self .model )
0 commit comments