diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index b99311e6..4adac26c 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -15,9 +15,8 @@ """Module for advanced quantization algorithms.""" +import fnmatch import gc -import hashlib -import json import types import warnings from collections import defaultdict @@ -37,10 +36,11 @@ from modelopt.torch.utils import create_param_grad_clear_hook, print_rank_0, report_memory from modelopt.torch.utils.distributed import DistributedProcessGroup, is_master +from . import config as mtq_config from . import model_calib -from .config import FP8_DEFAULT_CFG, NVFP4_DEFAULT_CFG, QuantizeConfig, QuantizerAttributeConfig +from .config import QuantizeConfig, QuantizerAttributeConfig from .conversion import set_quantizer_by_cfg -from .nn import QuantLinearConvBase, SequentialQuantizer, TensorQuantizer +from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import is_quantized_linear, multi_context @@ -82,27 +82,50 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg): class QuantRecipe(CustomHPType): - """A subclass of QuantizeConfig enabling auto_quantize specific configurations.""" + """A subclass of QuantizeConfig enabling auto_quantize specific configurations. - def __init__( - self, quant_cfg: dict[str, Any] | None = None, quant_format_idx: int | None = None - ): + Args: + quant_cfg: str or dict or None. dict is used for custom quantization formats. + name: name for custom quantization formats. Only used if quantization format is a custom + format not available in :mod:`modelopt.torch.quantization.config`. + """ + + def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | None = None): """Initialize the QuantRecipe with the quantization configuration.""" + name = self.get_auto_name_for_config(quant_cfg) or name + if quant_cfg is None: - self.config = QuantizeConfig(quant_cfg={"*": {"enable": False}}, algorithm="max") + quant_cfg = {"quant_cfg": {"*": {"enable": False}}} + elif isinstance(quant_cfg, str): + assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}" + quant_cfg = getattr(mtq_config, quant_cfg) else: - self.config = QuantizeConfig(**quant_cfg) + assert name is not None, "name must be provided for custom quantization formats" + + self.config = mtq_config.QuantizeConfig(**quant_cfg) # type: ignore [arg-type] # Disable KV Cache quantization # Currently KV Cache quantization is enabled for some quantization formats and disabled for others # This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy - self.config.quant_cfg["*output_quantizer"] = QuantizerAttributeConfig(enable=False) + self.config.quant_cfg["*output_quantizer"] = mtq_config.QuantizerAttributeConfig( + enable=False + ) self.compression = estimate_quant_compression(self.config) - self.str_repr = ( - f"quantization_formats[{quant_format_idx}]:effective-bits-{self.compression * 16}" - ) + self._str_repr: str = f"{name}(effective-bits: {self.compression * 16})" + + @staticmethod + def get_auto_name_for_config(quant_cfg: str | dict[str, Any] | None) -> str | None: + """Get a name for the quantization configuration.""" + if quant_cfg is None: + return "NONE" + if isinstance(quant_cfg, str): + return quant_cfg + for quant_cfg_name in mtq_config.choices: + if quant_cfg == getattr(mtq_config, quant_cfg_name): + return quant_cfg_name + return None @property def num_bits(self) -> int: @@ -110,21 +133,20 @@ def num_bits(self) -> int: return int(self.compression * 16) def __str__(self) -> str: - return f"{self.str_repr}" + return self._str_repr def __repr__(self) -> str: - return f"{self.config}" + return self._str_repr def __lt__(self, other: "QuantRecipe"): return self.compression < other.compression def __eq__(self, other: object): assert isinstance(other, QuantRecipe) - return self.config == other.config + return self._str_repr == other._str_repr def __hash__(self) -> int: - sorted_json = json.dumps(json.loads(self.config.model_dump_json()), sort_keys=True) - return int(hashlib.md5(sorted_json.encode("utf-8"), usedforsecurity=False).hexdigest(), 16) + return hash(self._str_repr) @staticmethod def disable_folding_pqs_to_weights(): @@ -154,13 +176,12 @@ class QuantRecipeHparam(Hparam): def __init__( self, - choices: Sequence[QuantRecipe], - original: QuantRecipe | None = None, + choices: Sequence[QuantRecipe] | None = None, nn_modules: list[nn.Module] | None = None, ) -> None: """Initializes Hparam with original value and choices.""" - choices = sorted(set(choices) | {QuantRecipe(quant_cfg=None)}) - super().__init__(choices, original) + choices = sorted({*(choices if choices else []), QuantRecipe(quant_cfg=None)}) + super().__init__(choices, original=choices[0]) self.nn_modules = nn_modules if nn_modules else [] # This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer @@ -253,7 +274,7 @@ class AutoQuantizeSearcher(BaseSearcher): def default_search_config(self): """Get the default config for the searcher.""" return { - "quantization_formats": [NVFP4_DEFAULT_CFG, FP8_DEFAULT_CFG], + "quantization_formats": ["NVFP4_DEFAULT_CFG", "FP8_DEFAULT_CFG"], "data_loader": None, "forward_step": None, "loss_func": None, @@ -261,6 +282,7 @@ def default_search_config(self): "num_calib_steps": 512, "num_score_steps": 128, "deployment": None, + "disabled_layers": None, "verbose": is_master(), "checkpoint": None, } @@ -271,6 +293,7 @@ def default_state_dict(self) -> SearchStateDict: return { "candidate_stats": defaultdict(dict), "best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False}, + "constraints": {}, } def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: @@ -297,15 +320,19 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: @staticmethod def _is_auto_quantize_module(module): - return is_quantized_linear(module) or isinstance(module, QuantLinearConvBase) + return ( + is_quantized_linear(module) or isinstance(module, QuantLinearConvBase) + ) and isinstance(module, QuantModule) @staticmethod def _get_search_recipes(quantization_formats): return sorted( - [ - QuantRecipe(quant_cfg=q, quant_format_idx=i) - for i, q in enumerate(quantization_formats) - ] + { + QuantRecipe(quant_cfg=q[0], name=q[1]) + if isinstance(q, tuple) + else QuantRecipe(quant_cfg=q) + for q in quantization_formats + } ) @classmethod @@ -337,7 +364,7 @@ def forward_backward_step(model, data): def _estimate_auto_quantize_scores(self): # TODO: remove the no-quant recipe def auto_quantize_score_estimate_forward(module, input, *args, **kwargs): - module.quant_recipe = QuantRecipe(quant_cfg=None, quant_format_idx=None) + module.quant_recipe = QuantRecipe(quant_cfg=None) output = module._forward_original(input, *args, **kwargs) # If gradient checkpointing is enabled, gradient will not be enabled in the global forward pass. @@ -372,6 +399,9 @@ def backward_hook(module, grad_input, grad_output): del module.output_diff_dict def setup_params_for_score_estimation(name, param, params_metadata): + # Let us delete the gradient as soon as they are computed to save memory + # In addition, this method enables gradient for all parameters + # This is needed to make sure the re-entrant activation checkpointing works params_metadata[name] = {"requires_grad": param.requires_grad} param.requires_grad = True accum_grad, handle = create_param_grad_clear_hook(param) @@ -394,15 +424,17 @@ def cleanup_params_after_score_estimation(name, param, params_metadata): params_metadata[name]["handle"].remove() for name, module in self.model.named_modules(): - if self._is_auto_quantize_module(module): + if ( + self._is_auto_quantize_module(module) + and module.get_hparam("quant_recipe").is_configurable + ): # Monkey patch the forward methods to cache Y(Q(W), Q(X)) - Y(W,X) setup_module_for_score_estimation(module) params_metadata = {} for name, param in self.model.named_parameters(): - # Let us delete the gradient as soon as they are computed to save memory - # In addition, this method enables gradient for all parameters - # This is needed to make sure the re-entrant activation checkpointing works + # TODO: Enabling gradient for all parameters is not needed and making backward slow + # We need to enable gradient only for the the first parameter of the module such as embedding weights setup_params_for_score_estimation(name, param, params_metadata) gc.collect() @@ -420,7 +452,10 @@ def cleanup_params_after_score_estimation(name, param, params_metadata): report_memory("AutoQuantize: After score estimation") for name, module in self.model.named_modules(): - if self._is_auto_quantize_module(module): + if ( + self._is_auto_quantize_module(module) + and module.get_hparam("quant_recipe").is_configurable + ): cleanup_module_after_score_estimation(module) for name, param in self.model.named_parameters(): @@ -431,15 +466,29 @@ def cleanup_params_after_score_estimation(name, param, params_metadata): gc.collect() @classmethod - def insert_hparams_after_merge_rules(cls, model, quant_recipes): + def insert_hparams_after_merge_rules(cls, model, quant_recipes, disabled_layers=None): """Restrict the search space using the merge rules and insert the hparams for the model.""" # TRTLLM fuses linear layers such as q_proj, k_proj, v_proj into same layer # Hence we need to restrict the search space so that all these layers share the same recipe # Lets group the modules based on the rules and insert the same hparam for all the modules in the group - search_map: dict[str, list[nn.Module]] = {} + + if disabled_layers is None: + disabled_layers = [] + elif isinstance(disabled_layers, str): + disabled_layers = [disabled_layers] + + search_map: dict[str, list[tuple[nn.Module, bool]]] = {} for name, module in model.named_modules(): if not cls._is_auto_quantize_module(module): continue + + # Skip layers that match disabled_layers patterns + disabled = False + for pattern in disabled_layers: + if fnmatch.fnmatch(name, pattern): + disabled = True + break + prefix = name for rule in cls.rules: pattern = re.compile(rule) @@ -449,15 +498,17 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes): # We support only one rule for matching per module break if prefix not in search_map: - search_map[prefix] = [module] + search_map[prefix] = [(module, disabled)] else: - search_map[prefix].append(module) - - for prefix, modules in search_map.items(): - hparam = QuantRecipeHparam( - quant_recipes, - original=quant_recipes[0], - nn_modules=modules, + search_map[prefix].append((module, disabled)) + + for prefix, module_info_list in search_map.items(): + modules = [module for module, _ in module_info_list] + disabled = any(disabled for _, disabled in module_info_list) + hparam = ( + QuantRecipeHparam(None, nn_modules=modules) + if disabled + else QuantRecipeHparam(quant_recipes, nn_modules=modules) ) for module in modules: module._register_hparam("quant_recipe", hparam) @@ -495,7 +546,9 @@ def before_search(self): search_recipes = self._get_search_recipes(self.config["quantization_formats"]) self._verify_constraint(search_recipes) - self.insert_hparams_after_merge_rules(self.model, search_recipes) + self.insert_hparams_after_merge_rules( + self.model, search_recipes, self.config["disabled_layers"] + ) QuantRecipe.disable_folding_pqs_to_weights() @@ -557,18 +610,11 @@ def get_total_weight_size(modules): for module in modules ) - def _get_constraints_for_search(lower_bound=None): - total_model_weight_size = get_total_weight_size(self.model.modules()) - - upper_bound = self._get_formatted_weight_compression_constraint() - - if lower_bound: - lower_bound = lower_bound * upper_bound - + def _get_constraints_for_search(max_weight_size, lower_bound=None): constraints = { "weight_size_after_compression": ( - lower_bound * total_model_weight_size if lower_bound else lower_bound, - upper_bound * total_model_weight_size, + lower_bound * max_weight_size if lower_bound else lower_bound, + max_weight_size, ) } return constraints, "weight_size_after_compression" @@ -579,16 +625,20 @@ def _get_constraints_for_search(lower_bound=None): f"Got {self.constraints.keys()}" ) - search_recipes = self._get_search_recipes(self.config["quantization_formats"]) - for name, hparam in named_hparams(self.model, configurable=True): + compression = self._get_formatted_weight_compression_constraint() + total_weight_size = get_total_weight_size(self.model.modules()) + weight_size_after_compression = total_weight_size * compression + + for name, hparam in named_hparams(self.model, unique=True): if not isinstance(hparam, QuantRecipeHparam): continue + formats, scores, costs = [], [], [] prev_score = float("inf") - for recipe in search_recipes: + for recipe in hparam.choices: formats.append(recipe) score = hparam.importance[recipe] - cost = get_total_weight_size(hparam.nn_modules) * recipe.compression + cost = get_total_weight_size(hparam.nn_modules) * recipe.compression # type: ignore [union-attr] # Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups # This way we constraint the same quantization format for the same layer across the DP/TP groups @@ -602,6 +652,7 @@ def _get_constraints_for_search(lower_bound=None): scores.append(min(score, prev_score)) costs.append(cost) prev_score = score + self.candidate_stats[name]["formats"] = formats self.candidate_stats[name]["scores"] = scores self.candidate_stats[name]["costs"] = costs @@ -611,7 +662,9 @@ def _get_constraints_for_search(lower_bound=None): # specified. I dont know why this happens. # As a workaround, lets specify a lower bound for the weight compression if previous # search without lower bound fails. - constraints, constraint_name = _get_constraints_for_search(lower_bound) + constraints, constraint_name = _get_constraints_for_search( + weight_size_after_compression, lower_bound + ) lps = LPS( name="AutoQuantize", @@ -664,8 +717,14 @@ def _get_constraints_for_search(lower_bound=None): f"AutoQuantize best recipe for {name.replace('.quant_recipe', '')}: {best_recipe[name]}" ) + effective_bits_from_search = (best_constraints / total_weight_size) * 16 + if verbose: + print_rank_0( + f"AutoQuantize effective bits from search: {effective_bits_from_search: .2f}" + ) + self.best["recipe"] = best_recipe - self.best["constraints"] = {constraint_name: best_constraints} + self.best["constraints"] = {"effective_bits": effective_bits_from_search} self.best["score"] = best_scores QuantRecipe.fold_pqs_to_weights(self.model) diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index b8e02726..e2936796 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -31,8 +31,7 @@ from modelopt.torch.quantization.config import QuantizeConfig from modelopt.torch.quantization.conversion import set_quantizer_by_cfg -from . import config -from .algorithms import AutoQuantizeSearcher +from .algorithms import AutoQuantizeSearcher, QuantRecipe from .config import QuantizeAlgoCfgType from .conversion import set_quantizer_attribute from .mode import QuantizeModeRegistry, get_modelike_from_algo_cfg @@ -411,14 +410,15 @@ def forward_backward_step(model, batch) -> None: for i, quant_cfg in enumerate(quantization_formats): if quant_cfg is None: continue - if isinstance(quant_cfg, str): - assert quant_cfg in config.choices, f"Invalid quantization format: {quant_cfg}" - quant_cfg = getattr(config, quant_cfg) - elif not any(quant_cfg is getattr(config, choice) for choice in config.choices): + + name = QuantRecipe.get_auto_name_for_config(quant_cfg) + if name is None: + name = f"CUSTOM_{i}" warnings.warn( - "Received custom quantization formats for search, auto_quantize results may not be optimal." + f"Received custom quantization formats for search, auto_quantize results may not be optimal. " + f"This config will be displayed as {name}" ) - processed_quantization_formats.append(quant_cfg) + processed_quantization_formats.append((quant_cfg, name)) assert len(processed_quantization_formats) > 0, "`quantization_formats` should not be empty" model = apply_mode( @@ -435,18 +435,13 @@ def forward_backward_step(model, batch) -> None: "forward_backward_step": forward_backward_step, "num_calib_steps": num_calib_steps, "num_score_steps": num_score_steps, + "disabled_layers": disabled_layers, "verbose": verbose, } # Disable all quantizers; AutoQuantize will enable the needed ones set_quantizer_by_cfg(model, {"*": {"enable": False}}) searcher.search(model, constraints, config=search_config) # type: ignore[arg-type] - if disabled_layers: - if isinstance(disabled_layers, str): - disabled_layers = [disabled_layers] - for layer_pattern in disabled_layers: - disable_quantizer(model, layer_pattern) - return model, searcher.state_dict() diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index c0ca4e3f..a673729e 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -92,11 +92,10 @@ def test_quant_recipe_hparam(): ] hparam = QuantRecipeHparam( search_recipes, - original=search_recipes[0], nn_modules=[model_test], ) model_test._register_hparam("quant_recipe", hparam) - assert model_test.quant_recipe == QuantRecipe(mtq.INT8_DEFAULT_CFG) + assert model_test.quant_recipe == QuantRecipe(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG) assert model_test.get_hparam("quant_recipe").choices == sorted( [*search_recipes, QuantRecipe(quant_cfg=None)] ) @@ -125,14 +124,16 @@ def test_quant_recipe_hparam(): [SimpleConv, SimpleConvLinear, SimpleLinear, TransformerBlock], ) @pytest.mark.parametrize( - "search_formats", + ("search_formats", "min_bits", "search_bits"), [ - [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], - [mtq.INT4_AWQ_CFG, mtq.INT8_SMOOTHQUANT_CFG], - [mtq.INT4_AWQ_CFG, INT8_CUSTOM_QUANT_TEST_CFG], + ([mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], 4.0, 6.0), + ([mtq.INT4_AWQ_CFG, mtq.INT8_SMOOTHQUANT_CFG], 4.0, 6.0), + ([mtq.INT4_AWQ_CFG, INT8_CUSTOM_QUANT_TEST_CFG], 4.0, 6.0), + ([mtq.INT8_SMOOTHQUANT_CFG], 8.0, 11.0), + ([None, mtq.INT8_SMOOTHQUANT_CFG], 8.0, 11.0), ], ) -def test_auto_quantize(model_cls, search_formats): +def test_auto_quantize(model_cls, search_formats, min_bits, search_bits): model = model_cls() def loss_func(output): @@ -140,7 +141,7 @@ def loss_func(output): best_model, search_history = mtq.auto_quantize( model, - constraints={"effective_bits": 11.0}, + constraints={"effective_bits": search_bits}, quantization_formats=search_formats, data_loader=[model.get_input() for _ in range(2)], forward_step=lambda model, batch: model(batch), @@ -151,6 +152,10 @@ def loss_func(output): ) assert isinstance(search_history, dict) assert search_history["best"]["is_satisfied"] + effective_bits_from_search = search_history["best"]["constraints"]["effective_bits"] + assert effective_bits_from_search <= search_bits and effective_bits_from_search >= min_bits, ( + "Search failed!" + ) if model_cls == TransformerBlock: hparam = model.attn.q_proj.get_hparam("quant_recipe") @@ -198,35 +203,6 @@ def loss_func(output): assert not best_model.mlp.input_quantizer.is_enabled -def test_auto_quantize_vs_quantize(): - model_ref = SimpleLinear() - state_dict = copy.deepcopy(model_ref.state_dict()) - dataloader = [model_ref.get_input() for _ in range(2)] - - def calibrate(model): - for input in dataloader: - model(input) - - mtq.quantize(model_ref, mtq.INT8_SMOOTHQUANT_CFG, calibrate) - - model_test = SimpleLinear() - model_test.load_state_dict(state_dict) - - best_model, search_history = mtq.auto_quantize( - model_test, - constraints={"effective_bits": 11.0}, - quantization_formats=[mtq.INT8_SMOOTHQUANT_CFG], - data_loader=dataloader, - forward_step=lambda model, batch: model(batch), - loss_func=lambda output, data: output.sum(), - num_calib_steps=2, - num_score_steps=2, - verbose=True, - ) - - assert torch.allclose(best_model(dataloader[0]), model_ref(dataloader[0])) - - INT4INT8_AWQ_CFG = { "quant_cfg": { "*weight_quantizer": [