Skip to content

Commit 96c364a

Browse files
committed
feat: new cosntraint in autoquant
Signed-off-by: h-guo18 <[email protected]>
1 parent 00a7e60 commit 96c364a

File tree

1 file changed

+116
-64
lines changed

1 file changed

+116
-64
lines changed

modelopt/torch/quantization/algorithms.py

Lines changed: 116 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import gc
1919
import hashlib
2020
import json
21+
import random
2122
import types
2223
import warnings
2324
from collections import defaultdict
@@ -44,6 +45,31 @@
4445
from .utils import is_quantized_linear, multi_context
4546

4647

48+
def get_total_weight_size(modules):
49+
"""Get the total weight size of the modules."""
50+
return sum(
51+
(module.weight.numel() if AutoQuantizeSearcher._is_auto_quantize_module(module) else 0)
52+
for module in modules
53+
)
54+
55+
56+
def get_base_time_cost(modules):
57+
"""Get the base time cost of the modules."""
58+
return 42 # TODO: Implement this
59+
60+
61+
def get_total_linear_time(modules):
62+
"""Get the total linear time of the modules."""
63+
return sum(
64+
(
65+
get_base_time_cost(modules)
66+
if AutoQuantizeSearcher._is_auto_quantize_module(module)
67+
else 0
68+
)
69+
for module in modules
70+
)
71+
72+
4773
def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float:
4874
"""Estimate the compression ratio of a quantization configuration.
4975
@@ -221,6 +247,22 @@ def importance(self) -> dict:
221247
for quant_recipe, importance_dict in self._importance_dict.items()
222248
}
223249

250+
@property
251+
def weight_sizes(self) -> list[float]:
252+
"""Return the weight size of the quantization recipe."""
253+
return [
254+
get_total_weight_size(self.nn_modules) * getattr(quant_recipe, "compression")
255+
for quant_recipe in self.choices
256+
]
257+
258+
@property
259+
def time_costs(self) -> list[float]:
260+
"""Return the time cost of the quantization recipe."""
261+
return [
262+
get_base_time_cost(self.nn_modules) * (0.1 + 0.9 * random.random())
263+
for quant_recipe in self.choices
264+
]
265+
224266

225267
class AutoQuantizeSearcher(BaseSearcher):
226268
"""A searcher for AutoQuantize algorithm.
@@ -238,7 +280,7 @@ class AutoQuantizeSearcher(BaseSearcher):
238280
for other models such as ResNet.
239281
"""
240282

241-
candidate_stats: dict[str, dict[str, list[float]]]
283+
candidate_stats: dict[str, dict[str, list]]
242284
best: dict[str, Any]
243285
gradient_checkpointing_enable_contexts: list[tuple[Callable, Callable]] = []
244286

@@ -462,21 +504,19 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes):
462504
for module in modules:
463505
module._register_hparam("quant_recipe", hparam)
464506

465-
def _get_formatted_weight_compression_constraint(self):
466-
effective_bits = self.constraints["effective_bits"]
467-
assert effective_bits > 0 and effective_bits <= 16, (
468-
"effective_bits should be between 0 and 16."
469-
)
470-
weight_compression = self.constraints["effective_bits"] / 16.0
471-
472-
return weight_compression
473-
474507
def _verify_constraint(self, search_recipes):
475508
assert self.constraints["effective_bits"] >= search_recipes[0].num_bits, (
476509
f"The effective_bits {self.constraints['effective_bits']} constraint cannot be lower than the "
477510
f"num_bits of most aggressive quantization format for this search which is "
478511
f"{search_recipes[0]} whose num_bits = {search_recipes[0].num_bits}."
479512
)
513+
assert (
514+
self.constraints["effective_bits"] > 0 and self.constraints["effective_bits"] <= 16
515+
), "effective_bits should be between 0 and 16."
516+
assert len(self.constraints) == 1 and "effective_bits" in self.constraints, (
517+
f"`constraints` must contain only 'effective_bits' constraint. "
518+
f"Got {self.constraints.keys()}"
519+
)
480520

481521
def _run_func(self, func, num_iters=1, desc=""):
482522
for i, data in tqdm(
@@ -544,51 +584,15 @@ def forward_loop(model):
544584
):
545585
self._estimate_auto_quantize_scores()
546586

547-
def run_search(self):
548-
"""Search for the best per-layer quantization configuration and return the best model and configuration.
549-
550-
AutoQuantize uses Linear Programming Solver to find the optimal quantization configuration which
551-
minimizes the sum of per-layer auto_quantize scores while meeting the specified constraint.
552-
"""
553-
554-
def get_total_weight_size(modules):
555-
return sum(
556-
(module.weight.numel() if self._is_auto_quantize_module(module) else 0)
557-
for module in modules
558-
)
559-
560-
def _get_constraints_for_search(lower_bound=None):
561-
total_model_weight_size = get_total_weight_size(self.model.modules())
562-
563-
upper_bound = self._get_formatted_weight_compression_constraint()
564-
565-
if lower_bound:
566-
lower_bound = lower_bound * upper_bound
567-
568-
constraints = {
569-
"weight_size_after_compression": (
570-
lower_bound * total_model_weight_size if lower_bound else lower_bound,
571-
upper_bound * total_model_weight_size,
572-
)
573-
}
574-
return constraints, "weight_size_after_compression"
575-
576-
verbose = self.config["verbose"]
577-
assert len(self.constraints) == 1 and "effective_bits" in self.constraints, (
578-
f"`constraints` must contain only 'effective_bits' constraint. "
579-
f"Got {self.constraints.keys()}"
580-
)
581-
582-
search_recipes = self._get_search_recipes(self.config["quantization_formats"])
587+
def _get_candidate_stats(self):
588+
"""Calculate the candidate stats including formats, scores, weight sizes, and time costs."""
583589
for name, hparam in named_hparams(self.model, configurable=True):
584590
if not isinstance(hparam, QuantRecipeHparam):
585591
continue
586-
formats, scores, costs = [], [], []
592+
scores = []
587593
prev_score = float("inf")
588-
for recipe in search_recipes:
589-
formats.append(recipe)
594+
for recipe in hparam.choices:
590595
score = hparam.importance[recipe]
591-
cost = get_total_weight_size(hparam.nn_modules) * recipe.compression
592596

593597
# Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
594598
# This way we constraint the same quantization format for the same layer across the DP/TP groups
@@ -600,27 +604,71 @@ def _get_constraints_for_search(lower_bound=None):
600604
)
601605

602606
scores.append(min(score, prev_score))
603-
costs.append(cost)
604607
prev_score = score
605-
self.candidate_stats[name]["formats"] = formats
606-
self.candidate_stats[name]["scores"] = scores
607-
self.candidate_stats[name]["costs"] = costs
608+
609+
self.candidate_stats[name] = {
610+
"choices": list(hparam.choices),
611+
"scores": scores,
612+
"weight_sizes": hparam.weight_sizes,
613+
"time_costs": hparam.time_costs,
614+
}
615+
return self.candidate_stats
616+
617+
def _get_constraints_kwargs(self, lower_bound=None):
618+
"""Get the constraints and constraints to candidate costs."""
619+
constraints, constraints_to_candidate_costs = {}, {}
620+
621+
if "effective_bits" in self.constraints:
622+
upper_bound = self.constraints["effective_bits"] / 16.0
623+
if lower_bound:
624+
lower_bound = lower_bound * upper_bound
625+
constraints["total_weight_size"] = (
626+
lower_bound * get_total_weight_size(self.model.modules())
627+
if lower_bound
628+
else lower_bound,
629+
upper_bound * get_total_weight_size(self.model.modules()),
630+
)
631+
constraints_to_candidate_costs["total_weight_size"] = [
632+
candidate_stat["weight_sizes"] for candidate_stat in self.candidate_stats.values()
633+
]
634+
635+
if "linear_speedup" in self.constraints:
636+
upper_bound = self.constraints["linear_speedup"]
637+
if lower_bound:
638+
lower_bound = lower_bound * upper_bound
639+
constraints["total_linear_time"] = (
640+
(1 / lower_bound) * get_total_linear_time(self.model.modules())
641+
if lower_bound
642+
else lower_bound,
643+
(1 / upper_bound) * get_total_linear_time(self.model.modules()),
644+
)
645+
constraints_to_candidate_costs["total_linear_time"] = [
646+
candidate_stat["time_costs"] for candidate_stat in self.candidate_stats.values()
647+
]
648+
649+
return constraints, constraints_to_candidate_costs
650+
651+
def run_search(self):
652+
"""Search for the best per-layer quantization configuration and return the best model and configuration.
653+
654+
AutoQuantize uses Linear Programming Solver to find the optimal quantization configuration which
655+
minimizes the sum of per-layer auto_quantize scores while meeting the specified constraint.
656+
"""
657+
verbose = self.config["verbose"]
658+
659+
self.candidate_stats = self._get_candidate_stats()
608660

609661
for lower_bound in [None, 0.99, 0.90]:
610662
# The LP solver for auto_quantize sometimes fails to find a solution if a lower bound is not
611663
# specified. I dont know why this happens.
612664
# As a workaround, lets specify a lower bound for the weight compression if previous
613665
# search without lower bound fails.
614-
constraints, constraint_name = _get_constraints_for_search(lower_bound)
666+
constraints, constraints_to_candidate_costs = self._get_constraints_kwargs(lower_bound)
615667

616668
lps = LPS(
617669
name="AutoQuantize",
618670
constraints=constraints,
619-
constraints_to_candidate_costs={
620-
constraint_name: [
621-
candidate_stat["costs"] for candidate_stat in self.candidate_stats.values()
622-
]
623-
},
671+
constraints_to_candidate_costs=constraints_to_candidate_costs,
624672
candidate_scores=[
625673
candidate_stat["scores"] for candidate_stat in self.candidate_stats.values()
626674
],
@@ -642,9 +690,9 @@ def _get_constraints_for_search(lower_bound=None):
642690
self.best["is_satisfied"] = True
643691

644692
best_recipe = {}
645-
best_constraints, best_scores = 0, 0
693+
best_weight_size, best_linear_time, best_scores = 0, 0, 0
646694
for name, selected_idx in zip(self.candidate_stats.keys(), selections):
647-
best_recipe_for_name = self.candidate_stats[name]["formats"][selected_idx]
695+
best_recipe_for_name = self.candidate_stats[name]["choices"][selected_idx]
648696

649697
# LP solver could give different solutions for the same layer across DP/TP groups even though
650698
# the scores and costs are the same. Lets make sure the same quantization format is selected across DP/TP
@@ -657,15 +705,19 @@ def _get_constraints_for_search(lower_bound=None):
657705

658706
best_recipe[name] = best_recipe_for_name
659707
get_hparam(self.model, name).active = best_recipe_for_name
660-
best_constraints += self.candidate_stats[name]["costs"][selected_idx]
708+
best_weight_size += self.candidate_stats[name]["weight_sizes"][selected_idx]
709+
best_linear_time += self.candidate_stats[name]["time_costs"][selected_idx]
661710
best_scores += self.candidate_stats[name]["scores"][selected_idx]
662711
if verbose:
663712
print_rank_0(
664713
f"AutoQuantize best recipe for {name.replace('.quant_recipe', '')}: {best_recipe[name]}"
665714
)
666715

667716
self.best["recipe"] = best_recipe
668-
self.best["constraints"] = {constraint_name: best_constraints}
717+
self.best["constraints"] = {
718+
"total_weight_size": best_weight_size,
719+
"total_linear_time": best_linear_time,
720+
}
669721
self.best["score"] = best_scores
670722

671723
QuantRecipe.fold_pqs_to_weights(self.model)

0 commit comments

Comments
 (0)