Skip to content

Commit be95a10

Browse files
authored
Bug fix for AutoQuantize: Add quantization skipping back; Add disable_layers to the search; cleaner names for search configs (#332)
Signed-off-by: realAsma <[email protected]>
1 parent 00a7e60 commit be95a10

File tree

3 files changed

+143
-113
lines changed

3 files changed

+143
-113
lines changed

modelopt/torch/quantization/algorithms.py

Lines changed: 121 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515

1616
"""Module for advanced quantization algorithms."""
1717

18+
import fnmatch
1819
import gc
19-
import hashlib
20-
import json
2120
import types
2221
import warnings
2322
from collections import defaultdict
@@ -37,10 +36,11 @@
3736
from modelopt.torch.utils import create_param_grad_clear_hook, print_rank_0, report_memory
3837
from modelopt.torch.utils.distributed import DistributedProcessGroup, is_master
3938

39+
from . import config as mtq_config
4040
from . import model_calib
41-
from .config import FP8_DEFAULT_CFG, NVFP4_DEFAULT_CFG, QuantizeConfig, QuantizerAttributeConfig
41+
from .config import QuantizeConfig, QuantizerAttributeConfig
4242
from .conversion import set_quantizer_by_cfg
43-
from .nn import QuantLinearConvBase, SequentialQuantizer, TensorQuantizer
43+
from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer
4444
from .utils import is_quantized_linear, multi_context
4545

4646

@@ -82,49 +82,71 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg):
8282

8383

8484
class QuantRecipe(CustomHPType):
85-
"""A subclass of QuantizeConfig enabling auto_quantize specific configurations."""
85+
"""A subclass of QuantizeConfig enabling auto_quantize specific configurations.
8686
87-
def __init__(
88-
self, quant_cfg: dict[str, Any] | None = None, quant_format_idx: int | None = None
89-
):
87+
Args:
88+
quant_cfg: str or dict or None. dict is used for custom quantization formats.
89+
name: name for custom quantization formats. Only used if quantization format is a custom
90+
format not available in :mod:`modelopt.torch.quantization.config`.
91+
"""
92+
93+
def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | None = None):
9094
"""Initialize the QuantRecipe with the quantization configuration."""
95+
name = self.get_auto_name_for_config(quant_cfg) or name
96+
9197
if quant_cfg is None:
92-
self.config = QuantizeConfig(quant_cfg={"*": {"enable": False}}, algorithm="max")
98+
quant_cfg = {"quant_cfg": {"*": {"enable": False}}}
99+
elif isinstance(quant_cfg, str):
100+
assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}"
101+
quant_cfg = getattr(mtq_config, quant_cfg)
93102
else:
94-
self.config = QuantizeConfig(**quant_cfg)
103+
assert name is not None, "name must be provided for custom quantization formats"
104+
105+
self.config = mtq_config.QuantizeConfig(**quant_cfg) # type: ignore [arg-type]
95106

96107
# Disable KV Cache quantization
97108
# Currently KV Cache quantization is enabled for some quantization formats and disabled for others
98109
# This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy
99-
self.config.quant_cfg["*output_quantizer"] = QuantizerAttributeConfig(enable=False)
110+
self.config.quant_cfg["*output_quantizer"] = mtq_config.QuantizerAttributeConfig(
111+
enable=False
112+
)
100113

101114
self.compression = estimate_quant_compression(self.config)
102115

103-
self.str_repr = (
104-
f"quantization_formats[{quant_format_idx}]:effective-bits-{self.compression * 16}"
105-
)
116+
self._str_repr: str = f"{name}(effective-bits: {self.compression * 16})"
117+
118+
@staticmethod
119+
def get_auto_name_for_config(quant_cfg: str | dict[str, Any] | None) -> str | None:
120+
"""Get a name for the quantization configuration."""
121+
if quant_cfg is None:
122+
return "NONE"
123+
if isinstance(quant_cfg, str):
124+
return quant_cfg
125+
for quant_cfg_name in mtq_config.choices:
126+
if quant_cfg == getattr(mtq_config, quant_cfg_name):
127+
return quant_cfg_name
128+
return None
106129

107130
@property
108131
def num_bits(self) -> int:
109132
"""Get the number of bits for the quantization format."""
110133
return int(self.compression * 16)
111134

112135
def __str__(self) -> str:
113-
return f"{self.str_repr}"
136+
return self._str_repr
114137

115138
def __repr__(self) -> str:
116-
return f"{self.config}"
139+
return self._str_repr
117140

118141
def __lt__(self, other: "QuantRecipe"):
119142
return self.compression < other.compression
120143

121144
def __eq__(self, other: object):
122145
assert isinstance(other, QuantRecipe)
123-
return self.config == other.config
146+
return self._str_repr == other._str_repr
124147

125148
def __hash__(self) -> int:
126-
sorted_json = json.dumps(json.loads(self.config.model_dump_json()), sort_keys=True)
127-
return int(hashlib.md5(sorted_json.encode("utf-8"), usedforsecurity=False).hexdigest(), 16)
149+
return hash(self._str_repr)
128150

129151
@staticmethod
130152
def disable_folding_pqs_to_weights():
@@ -154,13 +176,12 @@ class QuantRecipeHparam(Hparam):
154176

155177
def __init__(
156178
self,
157-
choices: Sequence[QuantRecipe],
158-
original: QuantRecipe | None = None,
179+
choices: Sequence[QuantRecipe] | None = None,
159180
nn_modules: list[nn.Module] | None = None,
160181
) -> None:
161182
"""Initializes Hparam with original value and choices."""
162-
choices = sorted(set(choices) | {QuantRecipe(quant_cfg=None)})
163-
super().__init__(choices, original)
183+
choices = sorted({*(choices if choices else []), QuantRecipe(quant_cfg=None)})
184+
super().__init__(choices, original=choices[0])
164185
self.nn_modules = nn_modules if nn_modules else []
165186

166187
# This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer
@@ -253,14 +274,15 @@ class AutoQuantizeSearcher(BaseSearcher):
253274
def default_search_config(self):
254275
"""Get the default config for the searcher."""
255276
return {
256-
"quantization_formats": [NVFP4_DEFAULT_CFG, FP8_DEFAULT_CFG],
277+
"quantization_formats": ["NVFP4_DEFAULT_CFG", "FP8_DEFAULT_CFG"],
257278
"data_loader": None,
258279
"forward_step": None,
259280
"loss_func": None,
260281
"forward_backward_step": None,
261282
"num_calib_steps": 512,
262283
"num_score_steps": 128,
263284
"deployment": None,
285+
"disabled_layers": None,
264286
"verbose": is_master(),
265287
"checkpoint": None,
266288
}
@@ -271,6 +293,7 @@ def default_state_dict(self) -> SearchStateDict:
271293
return {
272294
"candidate_stats": defaultdict(dict),
273295
"best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False},
296+
"constraints": {},
274297
}
275298

276299
def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
@@ -297,15 +320,19 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
297320

298321
@staticmethod
299322
def _is_auto_quantize_module(module):
300-
return is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)
323+
return (
324+
is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)
325+
) and isinstance(module, QuantModule)
301326

302327
@staticmethod
303328
def _get_search_recipes(quantization_formats):
304329
return sorted(
305-
[
306-
QuantRecipe(quant_cfg=q, quant_format_idx=i)
307-
for i, q in enumerate(quantization_formats)
308-
]
330+
{
331+
QuantRecipe(quant_cfg=q[0], name=q[1])
332+
if isinstance(q, tuple)
333+
else QuantRecipe(quant_cfg=q)
334+
for q in quantization_formats
335+
}
309336
)
310337

311338
@classmethod
@@ -337,7 +364,7 @@ def forward_backward_step(model, data):
337364
def _estimate_auto_quantize_scores(self):
338365
# TODO: remove the no-quant recipe
339366
def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
340-
module.quant_recipe = QuantRecipe(quant_cfg=None, quant_format_idx=None)
367+
module.quant_recipe = QuantRecipe(quant_cfg=None)
341368
output = module._forward_original(input, *args, **kwargs)
342369

343370
# If gradient checkpointing is enabled, gradient will not be enabled in the global forward pass.
@@ -372,6 +399,9 @@ def backward_hook(module, grad_input, grad_output):
372399
del module.output_diff_dict
373400

374401
def setup_params_for_score_estimation(name, param, params_metadata):
402+
# Let us delete the gradient as soon as they are computed to save memory
403+
# In addition, this method enables gradient for all parameters
404+
# This is needed to make sure the re-entrant activation checkpointing works
375405
params_metadata[name] = {"requires_grad": param.requires_grad}
376406
param.requires_grad = True
377407
accum_grad, handle = create_param_grad_clear_hook(param)
@@ -394,15 +424,17 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
394424
params_metadata[name]["handle"].remove()
395425

396426
for name, module in self.model.named_modules():
397-
if self._is_auto_quantize_module(module):
427+
if (
428+
self._is_auto_quantize_module(module)
429+
and module.get_hparam("quant_recipe").is_configurable
430+
):
398431
# Monkey patch the forward methods to cache Y(Q(W), Q(X)) - Y(W,X)
399432
setup_module_for_score_estimation(module)
400433

401434
params_metadata = {}
402435
for name, param in self.model.named_parameters():
403-
# Let us delete the gradient as soon as they are computed to save memory
404-
# In addition, this method enables gradient for all parameters
405-
# This is needed to make sure the re-entrant activation checkpointing works
436+
# TODO: Enabling gradient for all parameters is not needed and making backward slow
437+
# We need to enable gradient only for the the first parameter of the module such as embedding weights
406438
setup_params_for_score_estimation(name, param, params_metadata)
407439

408440
gc.collect()
@@ -420,7 +452,10 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
420452
report_memory("AutoQuantize: After score estimation")
421453

422454
for name, module in self.model.named_modules():
423-
if self._is_auto_quantize_module(module):
455+
if (
456+
self._is_auto_quantize_module(module)
457+
and module.get_hparam("quant_recipe").is_configurable
458+
):
424459
cleanup_module_after_score_estimation(module)
425460

426461
for name, param in self.model.named_parameters():
@@ -431,15 +466,29 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
431466
gc.collect()
432467

433468
@classmethod
434-
def insert_hparams_after_merge_rules(cls, model, quant_recipes):
469+
def insert_hparams_after_merge_rules(cls, model, quant_recipes, disabled_layers=None):
435470
"""Restrict the search space using the merge rules and insert the hparams for the model."""
436471
# TRTLLM fuses linear layers such as q_proj, k_proj, v_proj into same layer
437472
# Hence we need to restrict the search space so that all these layers share the same recipe
438473
# Lets group the modules based on the rules and insert the same hparam for all the modules in the group
439-
search_map: dict[str, list[nn.Module]] = {}
474+
475+
if disabled_layers is None:
476+
disabled_layers = []
477+
elif isinstance(disabled_layers, str):
478+
disabled_layers = [disabled_layers]
479+
480+
search_map: dict[str, list[tuple[nn.Module, bool]]] = {}
440481
for name, module in model.named_modules():
441482
if not cls._is_auto_quantize_module(module):
442483
continue
484+
485+
# Skip layers that match disabled_layers patterns
486+
disabled = False
487+
for pattern in disabled_layers:
488+
if fnmatch.fnmatch(name, pattern):
489+
disabled = True
490+
break
491+
443492
prefix = name
444493
for rule in cls.rules:
445494
pattern = re.compile(rule)
@@ -449,15 +498,17 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes):
449498
# We support only one rule for matching per module
450499
break
451500
if prefix not in search_map:
452-
search_map[prefix] = [module]
501+
search_map[prefix] = [(module, disabled)]
453502
else:
454-
search_map[prefix].append(module)
455-
456-
for prefix, modules in search_map.items():
457-
hparam = QuantRecipeHparam(
458-
quant_recipes,
459-
original=quant_recipes[0],
460-
nn_modules=modules,
503+
search_map[prefix].append((module, disabled))
504+
505+
for prefix, module_info_list in search_map.items():
506+
modules = [module for module, _ in module_info_list]
507+
disabled = any(disabled for _, disabled in module_info_list)
508+
hparam = (
509+
QuantRecipeHparam(None, nn_modules=modules)
510+
if disabled
511+
else QuantRecipeHparam(quant_recipes, nn_modules=modules)
461512
)
462513
for module in modules:
463514
module._register_hparam("quant_recipe", hparam)
@@ -495,7 +546,9 @@ def before_search(self):
495546

496547
search_recipes = self._get_search_recipes(self.config["quantization_formats"])
497548
self._verify_constraint(search_recipes)
498-
self.insert_hparams_after_merge_rules(self.model, search_recipes)
549+
self.insert_hparams_after_merge_rules(
550+
self.model, search_recipes, self.config["disabled_layers"]
551+
)
499552

500553
QuantRecipe.disable_folding_pqs_to_weights()
501554

@@ -557,18 +610,11 @@ def get_total_weight_size(modules):
557610
for module in modules
558611
)
559612

560-
def _get_constraints_for_search(lower_bound=None):
561-
total_model_weight_size = get_total_weight_size(self.model.modules())
562-
563-
upper_bound = self._get_formatted_weight_compression_constraint()
564-
565-
if lower_bound:
566-
lower_bound = lower_bound * upper_bound
567-
613+
def _get_constraints_for_search(max_weight_size, lower_bound=None):
568614
constraints = {
569615
"weight_size_after_compression": (
570-
lower_bound * total_model_weight_size if lower_bound else lower_bound,
571-
upper_bound * total_model_weight_size,
616+
lower_bound * max_weight_size if lower_bound else lower_bound,
617+
max_weight_size,
572618
)
573619
}
574620
return constraints, "weight_size_after_compression"
@@ -579,16 +625,20 @@ def _get_constraints_for_search(lower_bound=None):
579625
f"Got {self.constraints.keys()}"
580626
)
581627

582-
search_recipes = self._get_search_recipes(self.config["quantization_formats"])
583-
for name, hparam in named_hparams(self.model, configurable=True):
628+
compression = self._get_formatted_weight_compression_constraint()
629+
total_weight_size = get_total_weight_size(self.model.modules())
630+
weight_size_after_compression = total_weight_size * compression
631+
632+
for name, hparam in named_hparams(self.model, unique=True):
584633
if not isinstance(hparam, QuantRecipeHparam):
585634
continue
635+
586636
formats, scores, costs = [], [], []
587637
prev_score = float("inf")
588-
for recipe in search_recipes:
638+
for recipe in hparam.choices:
589639
formats.append(recipe)
590640
score = hparam.importance[recipe]
591-
cost = get_total_weight_size(hparam.nn_modules) * recipe.compression
641+
cost = get_total_weight_size(hparam.nn_modules) * recipe.compression # type: ignore [union-attr]
592642

593643
# Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
594644
# This way we constraint the same quantization format for the same layer across the DP/TP groups
@@ -602,6 +652,7 @@ def _get_constraints_for_search(lower_bound=None):
602652
scores.append(min(score, prev_score))
603653
costs.append(cost)
604654
prev_score = score
655+
605656
self.candidate_stats[name]["formats"] = formats
606657
self.candidate_stats[name]["scores"] = scores
607658
self.candidate_stats[name]["costs"] = costs
@@ -611,7 +662,9 @@ def _get_constraints_for_search(lower_bound=None):
611662
# specified. I dont know why this happens.
612663
# As a workaround, lets specify a lower bound for the weight compression if previous
613664
# search without lower bound fails.
614-
constraints, constraint_name = _get_constraints_for_search(lower_bound)
665+
constraints, constraint_name = _get_constraints_for_search(
666+
weight_size_after_compression, lower_bound
667+
)
615668

616669
lps = LPS(
617670
name="AutoQuantize",
@@ -664,8 +717,14 @@ def _get_constraints_for_search(lower_bound=None):
664717
f"AutoQuantize best recipe for {name.replace('.quant_recipe', '')}: {best_recipe[name]}"
665718
)
666719

720+
effective_bits_from_search = (best_constraints / total_weight_size) * 16
721+
if verbose:
722+
print_rank_0(
723+
f"AutoQuantize effective bits from search: {effective_bits_from_search: .2f}"
724+
)
725+
667726
self.best["recipe"] = best_recipe
668-
self.best["constraints"] = {constraint_name: best_constraints}
727+
self.best["constraints"] = {"effective_bits": effective_bits_from_search}
669728
self.best["score"] = best_scores
670729

671730
QuantRecipe.fold_pqs_to_weights(self.model)

0 commit comments

Comments
 (0)