Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 152 additions & 90 deletions modelopt/torch/quantization/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@
from .utils import is_quantized_linear, multi_context


def _get_total_weight_size(modules):
"""Helper function to get the total weight size of the modules."""
return sum(
(module.weight.numel() if AutoQuantizeSearcher._is_auto_quantize_module(module) else 0)
for module in modules
)


def _get_estimate_latency(modules):
"""Helper function to get the estimated latency of the modules."""
return 42 # TODO: implement this


def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float:
"""Estimate the compression ratio of a quantization configuration.

Expand Down Expand Up @@ -270,6 +283,33 @@ class AutoQuantizeSearcher(BaseSearcher):
r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts
]

# Registry that maps user input constraint(e.g. effective_bits) to actual serach costs(e.g. weight_size)
# Each entry defines:
# - cost_name: The name of the cost metric for the constraint.
# - cost_fn: Function to compute the cost for the acutal search.
# - cost_upper_bound_fn: Mapping from user-specified constraint value to cost upper bound during search.
# - cost_to_contraint_value_fn: Reverse mapping from cost to user-specified constraint value.
# Add new constraints here as needed for additional search objectives.
constraint_registry: dict[str, dict[str, str | Callable]] = {
"effective_bits": {
"cost_name": "weight_size_after_compression",
"cost_fn": lambda modules, recipe: _get_total_weight_size(modules) * recipe.compression,
"cost_upper_bound_fn": lambda modules, effective_bits: _get_total_weight_size(modules)
* (effective_bits / 16.0),
"cost_to_contraint_value_fn": lambda modules, weight_size: weight_size
/ _get_total_weight_size(modules)
* 16.0,
},
"linear_speedup": {
"cost_name": "latency_after_compression",
"cost_fn": lambda modules, recipe: _get_estimate_latency(modules),
"cost_upper_bound_fn": lambda modules, linear_speedup: _get_estimate_latency(modules)
/ linear_speedup,
"cost_to_contraint_value_fn": lambda modules, latency: _get_estimate_latency(modules)
/ latency,
},
}

@property
def default_search_config(self):
"""Get the default config for the searcher."""
Expand Down Expand Up @@ -513,21 +553,21 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes, disabled_layers=
for module in modules:
module._register_hparam("quant_recipe", hparam)

def _get_formatted_weight_compression_constraint(self):
effective_bits = self.constraints["effective_bits"]
assert effective_bits > 0 and effective_bits <= 16, (
"effective_bits should be between 0 and 16."
)
weight_compression = self.constraints["effective_bits"] / 16.0

return weight_compression

def _verify_constraint(self, search_recipes):
assert self.constraints["effective_bits"] >= search_recipes[0].num_bits, (
f"The effective_bits {self.constraints['effective_bits']} constraint cannot be lower than the "
f"num_bits of most aggressive quantization format for this search which is "
f"{search_recipes[0]} whose num_bits = {search_recipes[0].num_bits}."
)
for constraint_name in self.constraints:
assert constraint_name in self.constraint_registry, (
f"Constraint {constraint_name} is not supported. "
f"Supported constraints are {self.constraint_registry.keys()}"
)
if "effective_bits" in self.constraints:
assert self.constraints["effective_bits"] >= search_recipes[0].num_bits, (
f"The effective_bits {self.constraints['effective_bits']} constraint cannot be lower than the "
f"num_bits of most aggressive quantization format for this search which is "
f"{search_recipes[0]} whose num_bits = {search_recipes[0].num_bits}."
)
assert (
self.constraints["effective_bits"] > 0 and self.constraints["effective_bits"] <= 16
), "effective_bits should be between 0 and 16."

def _run_func(self, func, num_iters=1, desc=""):
for i, data in tqdm(
Expand All @@ -537,13 +577,74 @@ def _run_func(self, func, num_iters=1, desc=""):
):
func(self.model, data)

def _make_scores_monotonic(self, scores):
"""Ensure that the scores are monotonically decreasing for the correctness of LPS solver."""
monotonic_scores = []
prev_score = float("inf")
for score in scores:
score = min(score, prev_score)
monotonic_scores.append(score)
prev_score = score
return monotonic_scores

def _populate_candidate_stats(self):
"""Populate the self.candidate_stats with scores, costs, etc. for the candidate quantization recipes."""
for name, hparam in named_hparams(self.model, unique=True):
if not isinstance(hparam, QuantRecipeHparam):
continue

scores = []
for recipe in hparam.choices:
score = hparam.importance[recipe]
# Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
# This way we constraint the same quantization format for the same layer across the DP/TP groups
# The cost we use here is weight size. They are the same across DP/TP groups.
_ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state
# The score is the sum of the scores across DP and TP groups
scores.append(
DistributedProcessGroup.get_dist_syncd_obj(
score, [_ps.data_parallel_group, _ps.tensor_parallel_group], sum
)
)

self.candidate_stats[name]["recipes"] = hparam.choices
self.candidate_stats[name]["scores"] = self._make_scores_monotonic(scores)
for constraint_name in self.constraint_registry:
cost_name: str
cost_name, cost_fn = (
self.constraint_registry[constraint_name]["cost_name"],
self.constraint_registry[constraint_name]["cost_fn"],
)
self.candidate_stats[name][cost_name] = [
cost_fn(hparam.nn_modules, recipe) for recipe in hparam.choices
]

def _get_search_constraints(self, user_constraints, lower_bound=None):
"""Convert user constraints (e.g. effective_bits) to search constraints (e.g. weight_size_after_compression)."""
search_constraints, constraints_to_candidate_costs = {}, {}
for constraint_name, constraint_value in user_constraints.items():
cost_name: str
cost_name = self.constraint_registry[constraint_name]["cost_name"]
cost_upper_bound = self.constraint_registry[constraint_name]["cost_upper_bound_fn"](
self.model.modules(), constraint_value
)
search_constraints[cost_name] = (
lower_bound * cost_upper_bound if lower_bound else lower_bound,
cost_upper_bound,
)
constraints_to_candidate_costs[cost_name] = [
candidate_stat[cost_name] for candidate_stat in self.candidate_stats.values()
]
return search_constraints, constraints_to_candidate_costs

def before_search(self):
"""Prepare the model for search by calibrating the quantizers and collecting ``AutoQuantize`` score."""
# Import here to avoid circular import
from modelopt.torch.quantization.model_quant import calibrate

super().before_search()

self.verbose = self.config["verbose"]
search_recipes = self._get_search_recipes(self.config["quantization_formats"])
self._verify_constraint(search_recipes)
self.insert_hparams_after_merge_rules(
Expand Down Expand Up @@ -597,93 +698,43 @@ def forward_loop(model):
):
self._estimate_auto_quantize_scores()

# Populate self.candidate_stats with scores, costs, etc. for the search
self._populate_candidate_stats()

def run_search(self):
"""Search for the best per-layer quantization configuration and return the best model and configuration.
"""Search for the best per-layer quantization configuration and produce selections of recipes for each layer.

AutoQuantize uses Linear Programming Solver to find the optimal quantization configuration which
minimizes the sum of per-layer auto_quantize scores while meeting the specified constraint.
"""

def get_total_weight_size(modules):
return sum(
(module.weight.numel() if self._is_auto_quantize_module(module) else 0)
for module in modules
)

def _get_constraints_for_search(max_weight_size, lower_bound=None):
constraints = {
"weight_size_after_compression": (
lower_bound * max_weight_size if lower_bound else lower_bound,
max_weight_size,
)
}
return constraints, "weight_size_after_compression"

verbose = self.config["verbose"]
assert len(self.constraints) == 1 and "effective_bits" in self.constraints, (
f"`constraints` must contain only 'effective_bits' constraint. "
f"Got {self.constraints.keys()}"
)

compression = self._get_formatted_weight_compression_constraint()
total_weight_size = get_total_weight_size(self.model.modules())
weight_size_after_compression = total_weight_size * compression

for name, hparam in named_hparams(self.model, unique=True):
if not isinstance(hparam, QuantRecipeHparam):
continue

formats, scores, costs = [], [], []
prev_score = float("inf")
for recipe in hparam.choices:
formats.append(recipe)
score = hparam.importance[recipe]
cost = get_total_weight_size(hparam.nn_modules) * recipe.compression # type: ignore [union-attr]

# Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
# This way we constraint the same quantization format for the same layer across the DP/TP groups
# The cost we use here is weight size. They are the same across DP/TP groups.
_ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state
# The score is the sum of the scores across DP and TP groups
score = DistributedProcessGroup.get_dist_syncd_obj(
score, [_ps.data_parallel_group, _ps.tensor_parallel_group], sum
)

scores.append(min(score, prev_score))
costs.append(cost)
prev_score = score

self.candidate_stats[name]["formats"] = formats
self.candidate_stats[name]["scores"] = scores
self.candidate_stats[name]["costs"] = costs

for lower_bound in [None, 0.99, 0.90]:
# The LP solver for auto_quantize sometimes fails to find a solution if a lower bound is not
# specified. I dont know why this happens.
# As a workaround, lets specify a lower bound for the weight compression if previous
# search without lower bound fails.
constraints, constraint_name = _get_constraints_for_search(
weight_size_after_compression, lower_bound

# Convert user-specified constraints(e.g. effective_bits) to
# acutal search constraints(e.g. weight_size_after_compression) and corresponding bounds.
search_constraints, constraints_to_candidate_costs = self._get_search_constraints(
self.constraints, lower_bound
)

lps = LPS(
name="AutoQuantize",
constraints=constraints,
constraints_to_candidate_costs={
constraint_name: [
candidate_stat["costs"] for candidate_stat in self.candidate_stats.values()
]
},
constraints=search_constraints,
constraints_to_candidate_costs=constraints_to_candidate_costs,
candidate_scores=[
candidate_stat["scores"] for candidate_stat in self.candidate_stats.values()
],
objective_type="minimize",
verbose=verbose,
verbose=self.verbose,
)
selections, self.status = lps()
self.selections, self.status = lps()
if self.status == "Optimal":
break

def after_search(self):
"""Post-process the searched selections and produce the best model and configuration."""
self.best = {}

if self.status != "Optimal":
Expand All @@ -695,9 +746,10 @@ def _get_constraints_for_search(max_weight_size, lower_bound=None):
self.best["is_satisfied"] = True

best_recipe = {}
best_constraints, best_scores = 0, 0
for name, selected_idx in zip(self.candidate_stats.keys(), selections):
best_recipe_for_name = self.candidate_stats[name]["formats"][selected_idx]
best_scores = 0
best_constraints = dict.fromkeys(self.constraint_registry, 0)
for name, selected_idx in zip(self.candidate_stats.keys(), self.selections):
best_recipe_for_name = self.candidate_stats[name]["recipes"][selected_idx]

# LP solver could give different solutions for the same layer across DP/TP groups even though
# the scores and costs are the same. Lets make sure the same quantization format is selected across DP/TP
Expand All @@ -710,21 +762,31 @@ def _get_constraints_for_search(max_weight_size, lower_bound=None):

best_recipe[name] = best_recipe_for_name
get_hparam(self.model, name).active = best_recipe_for_name
best_constraints += self.candidate_stats[name]["costs"][selected_idx]
for constraint_name in self.constraint_registry:
cost_name: str
cost_name = self.constraint_registry[constraint_name]["cost_name"]
best_constraints[constraint_name] += self.candidate_stats[name][cost_name][
selected_idx
]
best_scores += self.candidate_stats[name]["scores"][selected_idx]
if verbose:
if self.verbose:
print_rank_0(
f"AutoQuantize best recipe for {name.replace('.quant_recipe', '')}: {best_recipe[name]}"
)

effective_bits_from_search = (best_constraints / total_weight_size) * 16
if verbose:
print_rank_0(
f"AutoQuantize effective bits from search: {effective_bits_from_search: .2f}"
)
# Map the search constraints(e.g. weight_size_after_compression) back
# to user-specified constraints(e.g. effective_bits)
for constraint_name in self.constraint_registry:
best_constraints[constraint_name] = self.constraint_registry[constraint_name][
"cost_to_contraint_value_fn"
](self.model.modules(), best_constraints[constraint_name])
if self.verbose:
print_rank_0(
f"AutoQuantize {constraint_name} from search: {best_constraints[constraint_name]}"
)

self.best["recipe"] = best_recipe
self.best["constraints"] = {"effective_bits": effective_bits_from_search}
self.best["constraints"] = best_constraints
self.best["score"] = best_scores

QuantRecipe.fold_pqs_to_weights(self.model)
Loading