Skip to content

Commit 5e68f15

Browse files
committed
refactor autoquant to support multiple constraints
Signed-off-by: h-guo18 <[email protected]>
1 parent be95a10 commit 5e68f15

File tree

1 file changed

+152
-90
lines changed

1 file changed

+152
-90
lines changed

modelopt/torch/quantization/algorithms.py

Lines changed: 152 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@
4444
from .utils import is_quantized_linear, multi_context
4545

4646

47+
def _get_total_weight_size(modules):
48+
"""Helper function to get the total weight size of the modules."""
49+
return sum(
50+
(module.weight.numel() if AutoQuantizeSearcher._is_auto_quantize_module(module) else 0)
51+
for module in modules
52+
)
53+
54+
55+
def _get_estimate_latency(modules):
56+
"""Helper function to get the estimated latency of the modules."""
57+
return 42 # TODO: implement this
58+
59+
4760
def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float:
4861
"""Estimate the compression ratio of a quantization configuration.
4962
@@ -270,6 +283,33 @@ class AutoQuantizeSearcher(BaseSearcher):
270283
r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts
271284
]
272285

286+
# Registry that maps user input constraint(e.g. effective_bits) to actual serach costs(e.g. weight_size)
287+
# Each entry defines:
288+
# - cost_name: The name of the cost metric for the constraint.
289+
# - cost_fn: Function to compute the cost for the acutal search.
290+
# - cost_upper_bound_fn: Mapping from user-specified constraint value to cost upper bound during search.
291+
# - cost_to_contraint_value_fn: Reverse mapping from cost to user-specified constraint value.
292+
# Add new constraints here as needed for additional search objectives.
293+
constraint_registry: dict[str, dict[str, str | Callable]] = {
294+
"effective_bits": {
295+
"cost_name": "weight_size_after_compression",
296+
"cost_fn": lambda modules, recipe: _get_total_weight_size(modules) * recipe.compression,
297+
"cost_upper_bound_fn": lambda modules, effective_bits: _get_total_weight_size(modules)
298+
* (effective_bits / 16.0),
299+
"cost_to_contraint_value_fn": lambda modules, weight_size: weight_size
300+
/ _get_total_weight_size(modules)
301+
* 16.0,
302+
},
303+
"linear_speedup": {
304+
"cost_name": "latency_after_compression",
305+
"cost_fn": lambda modules, recipe: _get_estimate_latency(modules),
306+
"cost_upper_bound_fn": lambda modules, linear_speedup: _get_estimate_latency(modules)
307+
/ linear_speedup,
308+
"cost_to_contraint_value_fn": lambda modules, latency: _get_estimate_latency(modules)
309+
/ latency,
310+
},
311+
}
312+
273313
@property
274314
def default_search_config(self):
275315
"""Get the default config for the searcher."""
@@ -513,21 +553,21 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes, disabled_layers=
513553
for module in modules:
514554
module._register_hparam("quant_recipe", hparam)
515555

516-
def _get_formatted_weight_compression_constraint(self):
517-
effective_bits = self.constraints["effective_bits"]
518-
assert effective_bits > 0 and effective_bits <= 16, (
519-
"effective_bits should be between 0 and 16."
520-
)
521-
weight_compression = self.constraints["effective_bits"] / 16.0
522-
523-
return weight_compression
524-
525556
def _verify_constraint(self, search_recipes):
526-
assert self.constraints["effective_bits"] >= search_recipes[0].num_bits, (
527-
f"The effective_bits {self.constraints['effective_bits']} constraint cannot be lower than the "
528-
f"num_bits of most aggressive quantization format for this search which is "
529-
f"{search_recipes[0]} whose num_bits = {search_recipes[0].num_bits}."
530-
)
557+
for constraint_name in self.constraints:
558+
assert constraint_name in self.constraint_registry, (
559+
f"Constraint {constraint_name} is not supported. "
560+
f"Supported constraints are {self.constraint_registry.keys()}"
561+
)
562+
if "effective_bits" in self.constraints:
563+
assert self.constraints["effective_bits"] >= search_recipes[0].num_bits, (
564+
f"The effective_bits {self.constraints['effective_bits']} constraint cannot be lower than the "
565+
f"num_bits of most aggressive quantization format for this search which is "
566+
f"{search_recipes[0]} whose num_bits = {search_recipes[0].num_bits}."
567+
)
568+
assert (
569+
self.constraints["effective_bits"] > 0 and self.constraints["effective_bits"] <= 16
570+
), "effective_bits should be between 0 and 16."
531571

532572
def _run_func(self, func, num_iters=1, desc=""):
533573
for i, data in tqdm(
@@ -537,13 +577,74 @@ def _run_func(self, func, num_iters=1, desc=""):
537577
):
538578
func(self.model, data)
539579

580+
def _make_scores_monotonic(self, scores):
581+
"""Ensure that the scores are monotonically decreasing for the correctness of LPS solver."""
582+
monotonic_scores = []
583+
prev_score = float("inf")
584+
for score in scores:
585+
score = min(score, prev_score)
586+
monotonic_scores.append(score)
587+
prev_score = score
588+
return monotonic_scores
589+
590+
def _populate_candidate_stats(self):
591+
"""Populate the self.candidate_stats with scores, costs, etc. for the candidate quantization recipes."""
592+
for name, hparam in named_hparams(self.model, unique=True):
593+
if not isinstance(hparam, QuantRecipeHparam):
594+
continue
595+
596+
scores = []
597+
for recipe in hparam.choices:
598+
score = hparam.importance[recipe]
599+
# Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
600+
# This way we constraint the same quantization format for the same layer across the DP/TP groups
601+
# The cost we use here is weight size. They are the same across DP/TP groups.
602+
_ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state
603+
# The score is the sum of the scores across DP and TP groups
604+
scores.append(
605+
DistributedProcessGroup.get_dist_syncd_obj(
606+
score, [_ps.data_parallel_group, _ps.tensor_parallel_group], sum
607+
)
608+
)
609+
610+
self.candidate_stats[name]["recipes"] = hparam.choices
611+
self.candidate_stats[name]["scores"] = self._make_scores_monotonic(scores)
612+
for constraint_name in self.constraint_registry:
613+
cost_name: str
614+
cost_name, cost_fn = (
615+
self.constraint_registry[constraint_name]["cost_name"],
616+
self.constraint_registry[constraint_name]["cost_fn"],
617+
)
618+
self.candidate_stats[name][cost_name] = [
619+
cost_fn(hparam.nn_modules, recipe) for recipe in hparam.choices
620+
]
621+
622+
def _get_search_constraints(self, user_constraints, lower_bound=None):
623+
"""Convert user constraints (e.g. effective_bits) to search constraints (e.g. weight_size_after_compression)."""
624+
search_constraints, constraints_to_candidate_costs = {}, {}
625+
for constraint_name, constraint_value in user_constraints.items():
626+
cost_name: str
627+
cost_name = self.constraint_registry[constraint_name]["cost_name"]
628+
cost_upper_bound = self.constraint_registry[constraint_name]["cost_upper_bound_fn"](
629+
self.model.modules(), constraint_value
630+
)
631+
search_constraints[cost_name] = (
632+
lower_bound * cost_upper_bound if lower_bound else lower_bound,
633+
cost_upper_bound,
634+
)
635+
constraints_to_candidate_costs[cost_name] = [
636+
candidate_stat[cost_name] for candidate_stat in self.candidate_stats.values()
637+
]
638+
return search_constraints, constraints_to_candidate_costs
639+
540640
def before_search(self):
541641
"""Prepare the model for search by calibrating the quantizers and collecting ``AutoQuantize`` score."""
542642
# Import here to avoid circular import
543643
from modelopt.torch.quantization.model_quant import calibrate
544644

545645
super().before_search()
546646

647+
self.verbose = self.config["verbose"]
547648
search_recipes = self._get_search_recipes(self.config["quantization_formats"])
548649
self._verify_constraint(search_recipes)
549650
self.insert_hparams_after_merge_rules(
@@ -597,93 +698,43 @@ def forward_loop(model):
597698
):
598699
self._estimate_auto_quantize_scores()
599700

701+
# Populate self.candidate_stats with scores, costs, etc. for the search
702+
self._populate_candidate_stats()
703+
600704
def run_search(self):
601-
"""Search for the best per-layer quantization configuration and return the best model and configuration.
705+
"""Search for the best per-layer quantization configuration and produce selections of recipes for each layer.
602706
603707
AutoQuantize uses Linear Programming Solver to find the optimal quantization configuration which
604708
minimizes the sum of per-layer auto_quantize scores while meeting the specified constraint.
605709
"""
606-
607-
def get_total_weight_size(modules):
608-
return sum(
609-
(module.weight.numel() if self._is_auto_quantize_module(module) else 0)
610-
for module in modules
611-
)
612-
613-
def _get_constraints_for_search(max_weight_size, lower_bound=None):
614-
constraints = {
615-
"weight_size_after_compression": (
616-
lower_bound * max_weight_size if lower_bound else lower_bound,
617-
max_weight_size,
618-
)
619-
}
620-
return constraints, "weight_size_after_compression"
621-
622-
verbose = self.config["verbose"]
623-
assert len(self.constraints) == 1 and "effective_bits" in self.constraints, (
624-
f"`constraints` must contain only 'effective_bits' constraint. "
625-
f"Got {self.constraints.keys()}"
626-
)
627-
628-
compression = self._get_formatted_weight_compression_constraint()
629-
total_weight_size = get_total_weight_size(self.model.modules())
630-
weight_size_after_compression = total_weight_size * compression
631-
632-
for name, hparam in named_hparams(self.model, unique=True):
633-
if not isinstance(hparam, QuantRecipeHparam):
634-
continue
635-
636-
formats, scores, costs = [], [], []
637-
prev_score = float("inf")
638-
for recipe in hparam.choices:
639-
formats.append(recipe)
640-
score = hparam.importance[recipe]
641-
cost = get_total_weight_size(hparam.nn_modules) * recipe.compression # type: ignore [union-attr]
642-
643-
# Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
644-
# This way we constraint the same quantization format for the same layer across the DP/TP groups
645-
# The cost we use here is weight size. They are the same across DP/TP groups.
646-
_ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state
647-
# The score is the sum of the scores across DP and TP groups
648-
score = DistributedProcessGroup.get_dist_syncd_obj(
649-
score, [_ps.data_parallel_group, _ps.tensor_parallel_group], sum
650-
)
651-
652-
scores.append(min(score, prev_score))
653-
costs.append(cost)
654-
prev_score = score
655-
656-
self.candidate_stats[name]["formats"] = formats
657-
self.candidate_stats[name]["scores"] = scores
658-
self.candidate_stats[name]["costs"] = costs
659-
660710
for lower_bound in [None, 0.99, 0.90]:
661711
# The LP solver for auto_quantize sometimes fails to find a solution if a lower bound is not
662712
# specified. I dont know why this happens.
663713
# As a workaround, lets specify a lower bound for the weight compression if previous
664714
# search without lower bound fails.
665-
constraints, constraint_name = _get_constraints_for_search(
666-
weight_size_after_compression, lower_bound
715+
716+
# Convert user-specified constraints(e.g. effective_bits) to
717+
# acutal search constraints(e.g. weight_size_after_compression) and corresponding bounds.
718+
search_constraints, constraints_to_candidate_costs = self._get_search_constraints(
719+
self.constraints, lower_bound
667720
)
668721

669722
lps = LPS(
670723
name="AutoQuantize",
671-
constraints=constraints,
672-
constraints_to_candidate_costs={
673-
constraint_name: [
674-
candidate_stat["costs"] for candidate_stat in self.candidate_stats.values()
675-
]
676-
},
724+
constraints=search_constraints,
725+
constraints_to_candidate_costs=constraints_to_candidate_costs,
677726
candidate_scores=[
678727
candidate_stat["scores"] for candidate_stat in self.candidate_stats.values()
679728
],
680729
objective_type="minimize",
681-
verbose=verbose,
730+
verbose=self.verbose,
682731
)
683-
selections, self.status = lps()
732+
self.selections, self.status = lps()
684733
if self.status == "Optimal":
685734
break
686735

736+
def after_search(self):
737+
"""Post-process the searched selections and produce the best model and configuration."""
687738
self.best = {}
688739

689740
if self.status != "Optimal":
@@ -695,9 +746,10 @@ def _get_constraints_for_search(max_weight_size, lower_bound=None):
695746
self.best["is_satisfied"] = True
696747

697748
best_recipe = {}
698-
best_constraints, best_scores = 0, 0
699-
for name, selected_idx in zip(self.candidate_stats.keys(), selections):
700-
best_recipe_for_name = self.candidate_stats[name]["formats"][selected_idx]
749+
best_scores = 0
750+
best_constraints = dict.fromkeys(self.constraint_registry, 0)
751+
for name, selected_idx in zip(self.candidate_stats.keys(), self.selections):
752+
best_recipe_for_name = self.candidate_stats[name]["recipes"][selected_idx]
701753

702754
# LP solver could give different solutions for the same layer across DP/TP groups even though
703755
# the scores and costs are the same. Lets make sure the same quantization format is selected across DP/TP
@@ -710,21 +762,31 @@ def _get_constraints_for_search(max_weight_size, lower_bound=None):
710762

711763
best_recipe[name] = best_recipe_for_name
712764
get_hparam(self.model, name).active = best_recipe_for_name
713-
best_constraints += self.candidate_stats[name]["costs"][selected_idx]
765+
for constraint_name in self.constraint_registry:
766+
cost_name: str
767+
cost_name = self.constraint_registry[constraint_name]["cost_name"]
768+
best_constraints[constraint_name] += self.candidate_stats[name][cost_name][
769+
selected_idx
770+
]
714771
best_scores += self.candidate_stats[name]["scores"][selected_idx]
715-
if verbose:
772+
if self.verbose:
716773
print_rank_0(
717774
f"AutoQuantize best recipe for {name.replace('.quant_recipe', '')}: {best_recipe[name]}"
718775
)
719776

720-
effective_bits_from_search = (best_constraints / total_weight_size) * 16
721-
if verbose:
722-
print_rank_0(
723-
f"AutoQuantize effective bits from search: {effective_bits_from_search: .2f}"
724-
)
777+
# Map the search constraints(e.g. weight_size_after_compression) back
778+
# to user-specified constraints(e.g. effective_bits)
779+
for constraint_name in self.constraint_registry:
780+
best_constraints[constraint_name] = self.constraint_registry[constraint_name][
781+
"cost_to_contraint_value_fn"
782+
](self.model.modules(), best_constraints[constraint_name])
783+
if self.verbose:
784+
print_rank_0(
785+
f"AutoQuantize {constraint_name} from search: {best_constraints[constraint_name]}"
786+
)
725787

726788
self.best["recipe"] = best_recipe
727-
self.best["constraints"] = {"effective_bits": effective_bits_from_search}
789+
self.best["constraints"] = best_constraints
728790
self.best["score"] = best_scores
729791

730792
QuantRecipe.fold_pqs_to_weights(self.model)

0 commit comments

Comments
 (0)