Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 121 additions & 62 deletions modelopt/torch/quantization/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

"""Module for advanced quantization algorithms."""

import fnmatch
import gc
import hashlib
import json
import types
import warnings
from collections import defaultdict
Expand All @@ -37,10 +36,11 @@
from modelopt.torch.utils import create_param_grad_clear_hook, print_rank_0, report_memory
from modelopt.torch.utils.distributed import DistributedProcessGroup, is_master

from . import config as mtq_config
from . import model_calib
from .config import FP8_DEFAULT_CFG, NVFP4_DEFAULT_CFG, QuantizeConfig, QuantizerAttributeConfig
from .config import QuantizeConfig, QuantizerAttributeConfig
from .conversion import set_quantizer_by_cfg
from .nn import QuantLinearConvBase, SequentialQuantizer, TensorQuantizer
from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer
from .utils import is_quantized_linear, multi_context


Expand Down Expand Up @@ -82,49 +82,71 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg):


class QuantRecipe(CustomHPType):
"""A subclass of QuantizeConfig enabling auto_quantize specific configurations."""
"""A subclass of QuantizeConfig enabling auto_quantize specific configurations.

def __init__(
self, quant_cfg: dict[str, Any] | None = None, quant_format_idx: int | None = None
):
Args:
quant_cfg: str or dict or None. dict is used for custom quantization formats.
name: name for custom quantization formats. Only used if quantization format is a custom
format not available in :mod:`modelopt.torch.quantization.config`.
"""

def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | None = None):
"""Initialize the QuantRecipe with the quantization configuration."""
name = self.get_auto_name_for_config(quant_cfg) or name

if quant_cfg is None:
self.config = QuantizeConfig(quant_cfg={"*": {"enable": False}}, algorithm="max")
quant_cfg = {"quant_cfg": {"*": {"enable": False}}}
elif isinstance(quant_cfg, str):
assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}"
quant_cfg = getattr(mtq_config, quant_cfg)
else:
self.config = QuantizeConfig(**quant_cfg)
assert name is not None, "name must be provided for custom quantization formats"

self.config = mtq_config.QuantizeConfig(**quant_cfg) # type: ignore [arg-type]

# Disable KV Cache quantization
# Currently KV Cache quantization is enabled for some quantization formats and disabled for others
# This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy
self.config.quant_cfg["*output_quantizer"] = QuantizerAttributeConfig(enable=False)
self.config.quant_cfg["*output_quantizer"] = mtq_config.QuantizerAttributeConfig(
enable=False
)

self.compression = estimate_quant_compression(self.config)

self.str_repr = (
f"quantization_formats[{quant_format_idx}]:effective-bits-{self.compression * 16}"
)
self._str_repr: str = f"{name}(effective-bits: {self.compression * 16})"

@staticmethod
def get_auto_name_for_config(quant_cfg: str | dict[str, Any] | None) -> str | None:
"""Get a name for the quantization configuration."""
if quant_cfg is None:
return "NONE"
if isinstance(quant_cfg, str):
return quant_cfg
for quant_cfg_name in mtq_config.choices:
if quant_cfg == getattr(mtq_config, quant_cfg_name):
return quant_cfg_name
return None

@property
def num_bits(self) -> int:
"""Get the number of bits for the quantization format."""
return int(self.compression * 16)

def __str__(self) -> str:
return f"{self.str_repr}"
return self._str_repr

def __repr__(self) -> str:
return f"{self.config}"
return self._str_repr

def __lt__(self, other: "QuantRecipe"):
return self.compression < other.compression

def __eq__(self, other: object):
assert isinstance(other, QuantRecipe)
return self.config == other.config
return self._str_repr == other._str_repr

def __hash__(self) -> int:
sorted_json = json.dumps(json.loads(self.config.model_dump_json()), sort_keys=True)
return int(hashlib.md5(sorted_json.encode("utf-8"), usedforsecurity=False).hexdigest(), 16)
return hash(self._str_repr)

@staticmethod
def disable_folding_pqs_to_weights():
Expand Down Expand Up @@ -154,13 +176,12 @@ class QuantRecipeHparam(Hparam):

def __init__(
self,
choices: Sequence[QuantRecipe],
original: QuantRecipe | None = None,
choices: Sequence[QuantRecipe] | None = None,
nn_modules: list[nn.Module] | None = None,
) -> None:
"""Initializes Hparam with original value and choices."""
choices = sorted(set(choices) | {QuantRecipe(quant_cfg=None)})
super().__init__(choices, original)
choices = sorted({*(choices if choices else []), QuantRecipe(quant_cfg=None)})
super().__init__(choices, original=choices[0])
self.nn_modules = nn_modules if nn_modules else []

# This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer
Expand Down Expand Up @@ -253,14 +274,15 @@ class AutoQuantizeSearcher(BaseSearcher):
def default_search_config(self):
"""Get the default config for the searcher."""
return {
"quantization_formats": [NVFP4_DEFAULT_CFG, FP8_DEFAULT_CFG],
"quantization_formats": ["NVFP4_DEFAULT_CFG", "FP8_DEFAULT_CFG"],
"data_loader": None,
"forward_step": None,
"loss_func": None,
"forward_backward_step": None,
"num_calib_steps": 512,
"num_score_steps": 128,
"deployment": None,
"disabled_layers": None,
"verbose": is_master(),
"checkpoint": None,
}
Expand All @@ -271,6 +293,7 @@ def default_state_dict(self) -> SearchStateDict:
return {
"candidate_stats": defaultdict(dict),
"best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False},
"constraints": {},
}

def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
Expand All @@ -297,15 +320,19 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:

@staticmethod
def _is_auto_quantize_module(module):
return is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)
return (
is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)
) and isinstance(module, QuantModule)

@staticmethod
def _get_search_recipes(quantization_formats):
return sorted(
[
QuantRecipe(quant_cfg=q, quant_format_idx=i)
for i, q in enumerate(quantization_formats)
]
{
QuantRecipe(quant_cfg=q[0], name=q[1])
if isinstance(q, tuple)
else QuantRecipe(quant_cfg=q)
for q in quantization_formats
}
)

@classmethod
Expand Down Expand Up @@ -337,7 +364,7 @@ def forward_backward_step(model, data):
def _estimate_auto_quantize_scores(self):
# TODO: remove the no-quant recipe
def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
module.quant_recipe = QuantRecipe(quant_cfg=None, quant_format_idx=None)
module.quant_recipe = QuantRecipe(quant_cfg=None)
output = module._forward_original(input, *args, **kwargs)

# If gradient checkpointing is enabled, gradient will not be enabled in the global forward pass.
Expand Down Expand Up @@ -372,6 +399,9 @@ def backward_hook(module, grad_input, grad_output):
del module.output_diff_dict

def setup_params_for_score_estimation(name, param, params_metadata):
# Let us delete the gradient as soon as they are computed to save memory
# In addition, this method enables gradient for all parameters
# This is needed to make sure the re-entrant activation checkpointing works
params_metadata[name] = {"requires_grad": param.requires_grad}
param.requires_grad = True
accum_grad, handle = create_param_grad_clear_hook(param)
Expand All @@ -394,15 +424,17 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
params_metadata[name]["handle"].remove()

for name, module in self.model.named_modules():
if self._is_auto_quantize_module(module):
if (
self._is_auto_quantize_module(module)
and module.get_hparam("quant_recipe").is_configurable
):
# Monkey patch the forward methods to cache Y(Q(W), Q(X)) - Y(W,X)
setup_module_for_score_estimation(module)

params_metadata = {}
for name, param in self.model.named_parameters():
# Let us delete the gradient as soon as they are computed to save memory
# In addition, this method enables gradient for all parameters
# This is needed to make sure the re-entrant activation checkpointing works
# TODO: Enabling gradient for all parameters is not needed and making backward slow
# We need to enable gradient only for the the first parameter of the module such as embedding weights
setup_params_for_score_estimation(name, param, params_metadata)

gc.collect()
Expand All @@ -420,7 +452,10 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
report_memory("AutoQuantize: After score estimation")

for name, module in self.model.named_modules():
if self._is_auto_quantize_module(module):
if (
self._is_auto_quantize_module(module)
and module.get_hparam("quant_recipe").is_configurable
):
cleanup_module_after_score_estimation(module)

for name, param in self.model.named_parameters():
Expand All @@ -431,15 +466,29 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
gc.collect()

@classmethod
def insert_hparams_after_merge_rules(cls, model, quant_recipes):
def insert_hparams_after_merge_rules(cls, model, quant_recipes, disabled_layers=None):
"""Restrict the search space using the merge rules and insert the hparams for the model."""
# TRTLLM fuses linear layers such as q_proj, k_proj, v_proj into same layer
# Hence we need to restrict the search space so that all these layers share the same recipe
# Lets group the modules based on the rules and insert the same hparam for all the modules in the group
search_map: dict[str, list[nn.Module]] = {}

if disabled_layers is None:
disabled_layers = []
elif isinstance(disabled_layers, str):
disabled_layers = [disabled_layers]

search_map: dict[str, list[tuple[nn.Module, bool]]] = {}
for name, module in model.named_modules():
if not cls._is_auto_quantize_module(module):
continue

# Skip layers that match disabled_layers patterns
disabled = False
for pattern in disabled_layers:
if fnmatch.fnmatch(name, pattern):
disabled = True
break

prefix = name
for rule in cls.rules:
pattern = re.compile(rule)
Expand All @@ -449,15 +498,17 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes):
# We support only one rule for matching per module
break
if prefix not in search_map:
search_map[prefix] = [module]
search_map[prefix] = [(module, disabled)]
else:
search_map[prefix].append(module)

for prefix, modules in search_map.items():
hparam = QuantRecipeHparam(
quant_recipes,
original=quant_recipes[0],
nn_modules=modules,
search_map[prefix].append((module, disabled))

for prefix, module_info_list in search_map.items():
modules = [module for module, _ in module_info_list]
disabled = any(disabled for _, disabled in module_info_list)
hparam = (
QuantRecipeHparam(None, nn_modules=modules)
if disabled
else QuantRecipeHparam(quant_recipes, nn_modules=modules)
)
for module in modules:
module._register_hparam("quant_recipe", hparam)
Expand Down Expand Up @@ -495,7 +546,9 @@ def before_search(self):

search_recipes = self._get_search_recipes(self.config["quantization_formats"])
self._verify_constraint(search_recipes)
self.insert_hparams_after_merge_rules(self.model, search_recipes)
self.insert_hparams_after_merge_rules(
self.model, search_recipes, self.config["disabled_layers"]
)

QuantRecipe.disable_folding_pqs_to_weights()

Expand Down Expand Up @@ -557,18 +610,11 @@ def get_total_weight_size(modules):
for module in modules
)

def _get_constraints_for_search(lower_bound=None):
total_model_weight_size = get_total_weight_size(self.model.modules())

upper_bound = self._get_formatted_weight_compression_constraint()

if lower_bound:
lower_bound = lower_bound * upper_bound

def _get_constraints_for_search(max_weight_size, lower_bound=None):
constraints = {
"weight_size_after_compression": (
lower_bound * total_model_weight_size if lower_bound else lower_bound,
upper_bound * total_model_weight_size,
lower_bound * max_weight_size if lower_bound else lower_bound,
max_weight_size,
)
}
return constraints, "weight_size_after_compression"
Expand All @@ -579,16 +625,20 @@ def _get_constraints_for_search(lower_bound=None):
f"Got {self.constraints.keys()}"
)

search_recipes = self._get_search_recipes(self.config["quantization_formats"])
for name, hparam in named_hparams(self.model, configurable=True):
compression = self._get_formatted_weight_compression_constraint()
total_weight_size = get_total_weight_size(self.model.modules())
weight_size_after_compression = total_weight_size * compression

for name, hparam in named_hparams(self.model, unique=True):
if not isinstance(hparam, QuantRecipeHparam):
continue

formats, scores, costs = [], [], []
prev_score = float("inf")
for recipe in search_recipes:
for recipe in hparam.choices:
formats.append(recipe)
score = hparam.importance[recipe]
cost = get_total_weight_size(hparam.nn_modules) * recipe.compression
cost = get_total_weight_size(hparam.nn_modules) * recipe.compression # type: ignore [union-attr]

# Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
# This way we constraint the same quantization format for the same layer across the DP/TP groups
Expand All @@ -602,6 +652,7 @@ def _get_constraints_for_search(lower_bound=None):
scores.append(min(score, prev_score))
costs.append(cost)
prev_score = score

self.candidate_stats[name]["formats"] = formats
self.candidate_stats[name]["scores"] = scores
self.candidate_stats[name]["costs"] = costs
Expand All @@ -611,7 +662,9 @@ def _get_constraints_for_search(lower_bound=None):
# specified. I dont know why this happens.
# As a workaround, lets specify a lower bound for the weight compression if previous
# search without lower bound fails.
constraints, constraint_name = _get_constraints_for_search(lower_bound)
constraints, constraint_name = _get_constraints_for_search(
weight_size_after_compression, lower_bound
)

lps = LPS(
name="AutoQuantize",
Expand Down Expand Up @@ -664,8 +717,14 @@ def _get_constraints_for_search(lower_bound=None):
f"AutoQuantize best recipe for {name.replace('.quant_recipe', '')}: {best_recipe[name]}"
)

effective_bits_from_search = (best_constraints / total_weight_size) * 16
if verbose:
print_rank_0(
f"AutoQuantize effective bits from search: {effective_bits_from_search: .2f}"
)

self.best["recipe"] = best_recipe
self.best["constraints"] = {constraint_name: best_constraints}
self.best["constraints"] = {"effective_bits": effective_bits_from_search}
self.best["score"] = best_scores

QuantRecipe.fold_pqs_to_weights(self.model)
Loading
Loading