Skip to content

Commit d979c5d

Browse files
committed
Bug fix for AutoQuantize: Add quantization skipping back; Add disable_layers to the search; cleaner names for search configs
Signed-off-by: realAsma <[email protected]>
1 parent d94fc1b commit d979c5d

File tree

3 files changed

+147
-115
lines changed

3 files changed

+147
-115
lines changed

modelopt/torch/quantization/algorithms.py

Lines changed: 120 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515

1616
"""Module for advanced quantization algorithms."""
1717

18+
import fnmatch
1819
import gc
19-
import hashlib
20-
import json
2120
import types
2221
import warnings
2322
from collections import defaultdict
@@ -37,10 +36,11 @@
3736
from modelopt.torch.utils import create_param_grad_clear_hook, print_rank_0, report_memory
3837
from modelopt.torch.utils.distributed import DistributedProcessGroup, is_master
3938

39+
from . import config as mtq_config
4040
from . import model_calib
41-
from .config import FP8_DEFAULT_CFG, NVFP4_DEFAULT_CFG, QuantizeConfig, QuantizerAttributeConfig
41+
from .config import QuantizeConfig, QuantizerAttributeConfig
4242
from .conversion import set_quantizer_by_cfg
43-
from .nn import QuantLinearConvBase, SequentialQuantizer, TensorQuantizer
43+
from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer
4444
from .utils import is_quantized_linear, multi_context
4545

4646

@@ -82,49 +82,70 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg):
8282

8383

8484
class QuantRecipe(CustomHPType):
85-
"""A subclass of QuantizeConfig enabling auto_quantize specific configurations."""
85+
"""A subclass of QuantizeConfig enabling auto_quantize specific configurations.
8686
87-
def __init__(
88-
self, quant_cfg: dict[str, Any] | None = None, quant_format_idx: int | None = None
89-
):
87+
Args:
88+
quant_cfg: str or dict or None. dict is used for custom quantization formats.
89+
name: name for custom quantization formats. Ignored if quant_cfg is None or str.
90+
"""
91+
92+
def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | None = None):
9093
"""Initialize the QuantRecipe with the quantization configuration."""
94+
name = name or self.get_auto_name_for_config(quant_cfg)
95+
9196
if quant_cfg is None:
92-
self.config = QuantizeConfig(quant_cfg={"*": {"enable": False}}, algorithm="max")
97+
quant_cfg = {"quant_cfg": {"*": {"enable": False}}}
98+
elif isinstance(quant_cfg, str):
99+
assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}"
100+
quant_cfg = getattr(mtq_config, quant_cfg)
93101
else:
94-
self.config = QuantizeConfig(**quant_cfg)
102+
assert name is not None, "name must be provided for custom quantization formats"
103+
104+
self.config = mtq_config.QuantizeConfig(**quant_cfg) # type: ignore [arg-type]
95105

96106
# Disable KV Cache quantization
97107
# Currently KV Cache quantization is enabled for some quantization formats and disabled for others
98108
# This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy
99-
self.config.quant_cfg["*output_quantizer"] = QuantizerAttributeConfig(enable=False)
109+
self.config.quant_cfg["*output_quantizer"] = mtq_config.QuantizerAttributeConfig(
110+
enable=False
111+
)
100112

101113
self.compression = estimate_quant_compression(self.config)
102114

103-
self.str_repr = (
104-
f"quantization_formats[{quant_format_idx}]:effective-bits-{self.compression * 16}"
105-
)
115+
self._str_repr: str = f"{name}(effective-bits: {self.compression * 16})"
116+
117+
@staticmethod
118+
def get_auto_name_for_config(quant_cfg: str | dict[str, Any] | None) -> str | None:
119+
"""Get a name for the quantization configuration."""
120+
if quant_cfg is None:
121+
return "NONE"
122+
if isinstance(quant_cfg, str):
123+
return quant_cfg
124+
for quant_cfg_name in mtq_config.choices:
125+
if quant_cfg == getattr(mtq_config, quant_cfg_name):
126+
return quant_cfg_name
127+
return None
106128

107129
@property
108130
def num_bits(self) -> int:
109131
"""Get the number of bits for the quantization format."""
110132
return int(self.compression * 16)
111133

112134
def __str__(self) -> str:
113-
return f"{self.str_repr}"
135+
return self._str_repr
114136

115137
def __repr__(self) -> str:
116-
return f"{self.config}"
138+
return self._str_repr
117139

118140
def __lt__(self, other: "QuantRecipe"):
119141
return self.compression < other.compression
120142

121143
def __eq__(self, other: object):
122144
assert isinstance(other, QuantRecipe)
123-
return self.config == other.config
145+
return self._str_repr == other._str_repr
124146

125147
def __hash__(self) -> int:
126-
sorted_json = json.dumps(json.loads(self.config.model_dump_json()), sort_keys=True)
127-
return int(hashlib.md5(sorted_json.encode("utf-8"), usedforsecurity=False).hexdigest(), 16)
148+
return hash(self._str_repr)
128149

129150
@staticmethod
130151
def disable_folding_pqs_to_weights():
@@ -154,13 +175,12 @@ class QuantRecipeHparam(Hparam):
154175

155176
def __init__(
156177
self,
157-
choices: Sequence[QuantRecipe],
158-
original: QuantRecipe | None = None,
178+
choices: Sequence[QuantRecipe] | None = None,
159179
nn_modules: list[nn.Module] | None = None,
160180
) -> None:
161181
"""Initializes Hparam with original value and choices."""
162-
choices = sorted(set(choices) | {QuantRecipe(quant_cfg=None)})
163-
super().__init__(choices, original)
182+
choices = sorted(set((choices if choices else []) + [QuantRecipe(quant_cfg=None)]))
183+
super().__init__(choices, original=choices[0])
164184
self.nn_modules = nn_modules if nn_modules else []
165185

166186
# This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer
@@ -253,14 +273,15 @@ class AutoQuantizeSearcher(BaseSearcher):
253273
def default_search_config(self):
254274
"""Get the default config for the searcher."""
255275
return {
256-
"quantization_formats": [NVFP4_DEFAULT_CFG, FP8_DEFAULT_CFG],
276+
"quantization_formats": ["NVFP4_DEFAULT_CFG", "FP8_DEFAULT_CFG"],
257277
"data_loader": None,
258278
"forward_step": None,
259279
"loss_func": None,
260280
"forward_backward_step": None,
261281
"num_calib_steps": 512,
262282
"num_score_steps": 128,
263283
"deployment": None,
284+
"disabled_layers": None,
264285
"verbose": is_master(),
265286
"checkpoint": None,
266287
}
@@ -271,6 +292,7 @@ def default_state_dict(self) -> SearchStateDict:
271292
return {
272293
"candidate_stats": defaultdict(dict),
273294
"best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False},
295+
"constraints": {},
274296
}
275297

276298
def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
@@ -297,15 +319,19 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
297319

298320
@staticmethod
299321
def _is_auto_quantize_module(module):
300-
return is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)
322+
return (
323+
is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)
324+
) and isinstance(module, QuantModule)
301325

302326
@staticmethod
303327
def _get_search_recipes(quantization_formats):
304328
return sorted(
305-
[
306-
QuantRecipe(quant_cfg=q, quant_format_idx=i)
307-
for i, q in enumerate(quantization_formats)
308-
]
329+
{
330+
QuantRecipe(quant_cfg=q[0], name=q[1])
331+
if isinstance(q, tuple)
332+
else QuantRecipe(quant_cfg=q)
333+
for q in quantization_formats
334+
}
309335
)
310336

311337
@classmethod
@@ -337,7 +363,7 @@ def forward_backward_step(model, data):
337363
def _estimate_auto_quantize_scores(self):
338364
# TODO: remove the no-quant recipe
339365
def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
340-
module.quant_recipe = QuantRecipe(quant_cfg=None, quant_format_idx=None)
366+
module.quant_recipe = QuantRecipe(quant_cfg=None)
341367
output = module._forward_original(input, *args, **kwargs)
342368

343369
# If gradient checkpointing is enabled, gradient will not be enabled in the global forward pass.
@@ -372,6 +398,9 @@ def backward_hook(module, grad_input, grad_output):
372398
del module.output_diff_dict
373399

374400
def setup_params_for_score_estimation(name, param, params_metadata):
401+
# Let us delete the gradient as soon as they are computed to save memory
402+
# In addition, this method enables gradient for all parameters
403+
# This is needed to make sure the re-entrant activation checkpointing works
375404
params_metadata[name] = {"requires_grad": param.requires_grad}
376405
param.requires_grad = True
377406
accum_grad, handle = create_param_grad_clear_hook(param)
@@ -394,15 +423,17 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
394423
params_metadata[name]["handle"].remove()
395424

396425
for name, module in self.model.named_modules():
397-
if self._is_auto_quantize_module(module):
426+
if (
427+
self._is_auto_quantize_module(module)
428+
and module.get_hparam("quant_recipe").is_configurable
429+
):
398430
# Monkey patch the forward methods to cache Y(Q(W), Q(X)) - Y(W,X)
399431
setup_module_for_score_estimation(module)
400432

401433
params_metadata = {}
402434
for name, param in self.model.named_parameters():
403-
# Let us delete the gradient as soon as they are computed to save memory
404-
# In addition, this method enables gradient for all parameters
405-
# This is needed to make sure the re-entrant activation checkpointing works
435+
# TODO: Enabling gradient for all parameters is not needed and making backward slow
436+
# We need to enable gradient only for the the first parameter of the module such as embedding weights
406437
setup_params_for_score_estimation(name, param, params_metadata)
407438

408439
gc.collect()
@@ -420,7 +451,10 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
420451
report_memory("AutoQuantize: After score estimation")
421452

422453
for name, module in self.model.named_modules():
423-
if self._is_auto_quantize_module(module):
454+
if (
455+
self._is_auto_quantize_module(module)
456+
and module.get_hparam("quant_recipe").is_configurable
457+
):
424458
cleanup_module_after_score_estimation(module)
425459

426460
for name, param in self.model.named_parameters():
@@ -431,15 +465,29 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
431465
gc.collect()
432466

433467
@classmethod
434-
def insert_hparams_after_merge_rules(cls, model, quant_recipes):
468+
def insert_hparams_after_merge_rules(cls, model, quant_recipes, disabled_layers=None):
435469
"""Restrict the search space using the merge rules and insert the hparams for the model."""
436470
# TRTLLM fuses linear layers such as q_proj, k_proj, v_proj into same layer
437471
# Hence we need to restrict the search space so that all these layers share the same recipe
438472
# Lets group the modules based on the rules and insert the same hparam for all the modules in the group
439-
search_map: dict[str, list[nn.Module]] = {}
473+
474+
if disabled_layers is None:
475+
disabled_layers = []
476+
elif isinstance(disabled_layers, str):
477+
disabled_layers = [disabled_layers]
478+
479+
search_map: dict[str, list[tuple[nn.Module, bool]]] = {}
440480
for name, module in model.named_modules():
441481
if not cls._is_auto_quantize_module(module):
442482
continue
483+
484+
# Skip layers that match disabled_layers patterns
485+
disabled = False
486+
for pattern in disabled_layers:
487+
if fnmatch.fnmatch(name, pattern):
488+
disabled = True
489+
break
490+
443491
prefix = name
444492
for rule in cls.rules:
445493
pattern = re.compile(rule)
@@ -449,15 +497,17 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes):
449497
# We support only one rule for matching per module
450498
break
451499
if prefix not in search_map:
452-
search_map[prefix] = [module]
500+
search_map[prefix] = [(module, disabled)]
453501
else:
454-
search_map[prefix].append(module)
455-
456-
for prefix, modules in search_map.items():
457-
hparam = QuantRecipeHparam(
458-
quant_recipes,
459-
original=quant_recipes[0],
460-
nn_modules=modules,
502+
search_map[prefix].append((module, disabled))
503+
504+
for prefix, module_info_list in search_map.items():
505+
modules = [module for module, _ in module_info_list]
506+
disabled = any(disabled for _, disabled in module_info_list)
507+
hparam = (
508+
QuantRecipeHparam(None, nn_modules=modules)
509+
if disabled
510+
else QuantRecipeHparam(quant_recipes, nn_modules=modules)
461511
)
462512
for module in modules:
463513
module._register_hparam("quant_recipe", hparam)
@@ -495,7 +545,9 @@ def before_search(self):
495545

496546
search_recipes = self._get_search_recipes(self.config["quantization_formats"])
497547
self._verify_constraint(search_recipes)
498-
self.insert_hparams_after_merge_rules(self.model, search_recipes)
548+
self.insert_hparams_after_merge_rules(
549+
self.model, search_recipes, self.config["disabled_layers"]
550+
)
499551

500552
QuantRecipe.disable_folding_pqs_to_weights()
501553

@@ -557,18 +609,11 @@ def get_total_weight_size(modules):
557609
for module in modules
558610
)
559611

560-
def _get_constraints_for_search(lower_bound=None):
561-
total_model_weight_size = get_total_weight_size(self.model.modules())
562-
563-
upper_bound = self._get_formatted_weight_compression_constraint()
564-
565-
if lower_bound:
566-
lower_bound = lower_bound * upper_bound
567-
612+
def _get_constraints_for_search(max_weight_size, lower_bound=None):
568613
constraints = {
569614
"weight_size_after_compression": (
570-
lower_bound * total_model_weight_size if lower_bound else lower_bound,
571-
upper_bound * total_model_weight_size,
615+
lower_bound * max_weight_size if lower_bound else lower_bound,
616+
max_weight_size,
572617
)
573618
}
574619
return constraints, "weight_size_after_compression"
@@ -579,16 +624,20 @@ def _get_constraints_for_search(lower_bound=None):
579624
f"Got {self.constraints.keys()}"
580625
)
581626

582-
search_recipes = self._get_search_recipes(self.config["quantization_formats"])
583-
for name, hparam in named_hparams(self.model, configurable=True):
627+
compression = self._get_formatted_weight_compression_constraint()
628+
total_weight_size = get_total_weight_size(self.model.modules())
629+
weight_size_after_compression = total_weight_size * compression
630+
631+
for name, hparam in named_hparams(self.model, unique=True):
584632
if not isinstance(hparam, QuantRecipeHparam):
585633
continue
634+
586635
formats, scores, costs = [], [], []
587636
prev_score = float("inf")
588-
for recipe in search_recipes:
637+
for recipe in hparam.choices:
589638
formats.append(recipe)
590639
score = hparam.importance[recipe]
591-
cost = get_total_weight_size(hparam.nn_modules) * recipe.compression
640+
cost = get_total_weight_size(hparam.nn_modules) * recipe.compression # type: ignore [union-attr]
592641

593642
# Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
594643
# This way we constraint the same quantization format for the same layer across the DP/TP groups
@@ -602,6 +651,7 @@ def _get_constraints_for_search(lower_bound=None):
602651
scores.append(min(score, prev_score))
603652
costs.append(cost)
604653
prev_score = score
654+
605655
self.candidate_stats[name]["formats"] = formats
606656
self.candidate_stats[name]["scores"] = scores
607657
self.candidate_stats[name]["costs"] = costs
@@ -611,7 +661,9 @@ def _get_constraints_for_search(lower_bound=None):
611661
# specified. I dont know why this happens.
612662
# As a workaround, lets specify a lower bound for the weight compression if previous
613663
# search without lower bound fails.
614-
constraints, constraint_name = _get_constraints_for_search(lower_bound)
664+
constraints, constraint_name = _get_constraints_for_search(
665+
weight_size_after_compression, lower_bound
666+
)
615667

616668
lps = LPS(
617669
name="AutoQuantize",
@@ -664,8 +716,14 @@ def _get_constraints_for_search(lower_bound=None):
664716
f"AutoQuantize best recipe for {name.replace('.quant_recipe', '')}: {best_recipe[name]}"
665717
)
666718

719+
effective_bits_from_search = (best_constraints / total_weight_size) * 16
720+
if verbose:
721+
print_rank_0(
722+
f"AutoQuantize effective bits from search: {effective_bits_from_search: .2f}"
723+
)
724+
667725
self.best["recipe"] = best_recipe
668-
self.best["constraints"] = {constraint_name: best_constraints}
726+
self.best["constraints"] = {"effective_bits": effective_bits_from_search}
669727
self.best["score"] = best_scores
670728

671729
QuantRecipe.fold_pqs_to_weights(self.model)

0 commit comments

Comments
 (0)