diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 4adac26c..75a245a8 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -44,6 +44,19 @@ from .utils import is_quantized_linear, multi_context +def _get_total_weight_size(modules): + """Helper function to get the total weight size of the modules.""" + return sum( + (module.weight.numel() if AutoQuantizeSearcher._is_auto_quantize_module(module) else 0) + for module in modules + ) + + +def _get_estimate_latency(modules): + """Helper function to get the estimated latency of the modules.""" + return 42 # TODO: implement this + + def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float: """Estimate the compression ratio of a quantization configuration. @@ -270,6 +283,33 @@ class AutoQuantizeSearcher(BaseSearcher): r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts ] + # Registry that maps user input constraint(e.g. effective_bits) to actual serach costs(e.g. weight_size) + # Each entry defines: + # - cost_name: The name of the cost metric for the constraint. + # - cost_fn: Function to compute the cost for the acutal search. + # - cost_upper_bound_fn: Mapping from user-specified constraint value to cost upper bound during search. + # - cost_to_contraint_value_fn: Reverse mapping from cost to user-specified constraint value. + # Add new constraints here as needed for additional search objectives. + constraint_registry: dict[str, dict[str, str | Callable]] = { + "effective_bits": { + "cost_name": "weight_size_after_compression", + "cost_fn": lambda modules, recipe: _get_total_weight_size(modules) * recipe.compression, + "cost_upper_bound_fn": lambda modules, effective_bits: _get_total_weight_size(modules) + * (effective_bits / 16.0), + "cost_to_contraint_value_fn": lambda modules, weight_size: weight_size + / _get_total_weight_size(modules) + * 16.0, + }, + "linear_speedup": { + "cost_name": "latency_after_compression", + "cost_fn": lambda modules, recipe: _get_estimate_latency(modules), + "cost_upper_bound_fn": lambda modules, linear_speedup: _get_estimate_latency(modules) + / linear_speedup, + "cost_to_contraint_value_fn": lambda modules, latency: _get_estimate_latency(modules) + / latency, + }, + } + @property def default_search_config(self): """Get the default config for the searcher.""" @@ -513,21 +553,21 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes, disabled_layers= for module in modules: module._register_hparam("quant_recipe", hparam) - def _get_formatted_weight_compression_constraint(self): - effective_bits = self.constraints["effective_bits"] - assert effective_bits > 0 and effective_bits <= 16, ( - "effective_bits should be between 0 and 16." - ) - weight_compression = self.constraints["effective_bits"] / 16.0 - - return weight_compression - def _verify_constraint(self, search_recipes): - assert self.constraints["effective_bits"] >= search_recipes[0].num_bits, ( - f"The effective_bits {self.constraints['effective_bits']} constraint cannot be lower than the " - f"num_bits of most aggressive quantization format for this search which is " - f"{search_recipes[0]} whose num_bits = {search_recipes[0].num_bits}." - ) + for constraint_name in self.constraints: + assert constraint_name in self.constraint_registry, ( + f"Constraint {constraint_name} is not supported. " + f"Supported constraints are {self.constraint_registry.keys()}" + ) + if "effective_bits" in self.constraints: + assert self.constraints["effective_bits"] >= search_recipes[0].num_bits, ( + f"The effective_bits {self.constraints['effective_bits']} constraint cannot be lower than the " + f"num_bits of most aggressive quantization format for this search which is " + f"{search_recipes[0]} whose num_bits = {search_recipes[0].num_bits}." + ) + assert ( + self.constraints["effective_bits"] > 0 and self.constraints["effective_bits"] <= 16 + ), "effective_bits should be between 0 and 16." def _run_func(self, func, num_iters=1, desc=""): for i, data in tqdm( @@ -537,6 +577,66 @@ def _run_func(self, func, num_iters=1, desc=""): ): func(self.model, data) + def _make_scores_monotonic(self, scores): + """Ensure that the scores are monotonically decreasing for the correctness of LPS solver.""" + monotonic_scores = [] + prev_score = float("inf") + for score in scores: + score = min(score, prev_score) + monotonic_scores.append(score) + prev_score = score + return monotonic_scores + + def _populate_candidate_stats(self): + """Populate the self.candidate_stats with scores, costs, etc. for the candidate quantization recipes.""" + for name, hparam in named_hparams(self.model, unique=True): + if not isinstance(hparam, QuantRecipeHparam): + continue + + scores = [] + for recipe in hparam.choices: + score = hparam.importance[recipe] + # Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups + # This way we constraint the same quantization format for the same layer across the DP/TP groups + # The cost we use here is weight size. They are the same across DP/TP groups. + _ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state + # The score is the sum of the scores across DP and TP groups + scores.append( + DistributedProcessGroup.get_dist_syncd_obj( + score, [_ps.data_parallel_group, _ps.tensor_parallel_group], sum + ) + ) + + self.candidate_stats[name]["recipes"] = hparam.choices + self.candidate_stats[name]["scores"] = self._make_scores_monotonic(scores) + for constraint_name in self.constraint_registry: + cost_name: str + cost_name, cost_fn = ( + self.constraint_registry[constraint_name]["cost_name"], + self.constraint_registry[constraint_name]["cost_fn"], + ) + self.candidate_stats[name][cost_name] = [ + cost_fn(hparam.nn_modules, recipe) for recipe in hparam.choices + ] + + def _get_search_constraints(self, user_constraints, lower_bound=None): + """Convert user constraints (e.g. effective_bits) to search constraints (e.g. weight_size_after_compression).""" + search_constraints, constraints_to_candidate_costs = {}, {} + for constraint_name, constraint_value in user_constraints.items(): + cost_name: str + cost_name = self.constraint_registry[constraint_name]["cost_name"] + cost_upper_bound = self.constraint_registry[constraint_name]["cost_upper_bound_fn"]( + self.model.modules(), constraint_value + ) + search_constraints[cost_name] = ( + lower_bound * cost_upper_bound if lower_bound else lower_bound, + cost_upper_bound, + ) + constraints_to_candidate_costs[cost_name] = [ + candidate_stat[cost_name] for candidate_stat in self.candidate_stats.values() + ] + return search_constraints, constraints_to_candidate_costs + def before_search(self): """Prepare the model for search by calibrating the quantizers and collecting ``AutoQuantize`` score.""" # Import here to avoid circular import @@ -544,6 +644,7 @@ def before_search(self): super().before_search() + self.verbose = self.config["verbose"] search_recipes = self._get_search_recipes(self.config["quantization_formats"]) self._verify_constraint(search_recipes) self.insert_hparams_after_merge_rules( @@ -597,93 +698,43 @@ def forward_loop(model): ): self._estimate_auto_quantize_scores() + # Populate self.candidate_stats with scores, costs, etc. for the search + self._populate_candidate_stats() + def run_search(self): - """Search for the best per-layer quantization configuration and return the best model and configuration. + """Search for the best per-layer quantization configuration and produce selections of recipes for each layer. AutoQuantize uses Linear Programming Solver to find the optimal quantization configuration which minimizes the sum of per-layer auto_quantize scores while meeting the specified constraint. """ - - def get_total_weight_size(modules): - return sum( - (module.weight.numel() if self._is_auto_quantize_module(module) else 0) - for module in modules - ) - - def _get_constraints_for_search(max_weight_size, lower_bound=None): - constraints = { - "weight_size_after_compression": ( - lower_bound * max_weight_size if lower_bound else lower_bound, - max_weight_size, - ) - } - return constraints, "weight_size_after_compression" - - verbose = self.config["verbose"] - assert len(self.constraints) == 1 and "effective_bits" in self.constraints, ( - f"`constraints` must contain only 'effective_bits' constraint. " - f"Got {self.constraints.keys()}" - ) - - compression = self._get_formatted_weight_compression_constraint() - total_weight_size = get_total_weight_size(self.model.modules()) - weight_size_after_compression = total_weight_size * compression - - for name, hparam in named_hparams(self.model, unique=True): - if not isinstance(hparam, QuantRecipeHparam): - continue - - formats, scores, costs = [], [], [] - prev_score = float("inf") - for recipe in hparam.choices: - formats.append(recipe) - score = hparam.importance[recipe] - cost = get_total_weight_size(hparam.nn_modules) * recipe.compression # type: ignore [union-attr] - - # Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups - # This way we constraint the same quantization format for the same layer across the DP/TP groups - # The cost we use here is weight size. They are the same across DP/TP groups. - _ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state - # The score is the sum of the scores across DP and TP groups - score = DistributedProcessGroup.get_dist_syncd_obj( - score, [_ps.data_parallel_group, _ps.tensor_parallel_group], sum - ) - - scores.append(min(score, prev_score)) - costs.append(cost) - prev_score = score - - self.candidate_stats[name]["formats"] = formats - self.candidate_stats[name]["scores"] = scores - self.candidate_stats[name]["costs"] = costs - for lower_bound in [None, 0.99, 0.90]: # The LP solver for auto_quantize sometimes fails to find a solution if a lower bound is not # specified. I dont know why this happens. # As a workaround, lets specify a lower bound for the weight compression if previous # search without lower bound fails. - constraints, constraint_name = _get_constraints_for_search( - weight_size_after_compression, lower_bound + + # Convert user-specified constraints(e.g. effective_bits) to + # acutal search constraints(e.g. weight_size_after_compression) and corresponding bounds. + search_constraints, constraints_to_candidate_costs = self._get_search_constraints( + self.constraints, lower_bound ) lps = LPS( name="AutoQuantize", - constraints=constraints, - constraints_to_candidate_costs={ - constraint_name: [ - candidate_stat["costs"] for candidate_stat in self.candidate_stats.values() - ] - }, + constraints=search_constraints, + constraints_to_candidate_costs=constraints_to_candidate_costs, candidate_scores=[ candidate_stat["scores"] for candidate_stat in self.candidate_stats.values() ], objective_type="minimize", - verbose=verbose, + verbose=self.verbose, ) - selections, self.status = lps() + self.selections, self.status = lps() if self.status == "Optimal": break + def after_search(self): + """Post-process the searched selections and produce the best model and configuration.""" self.best = {} if self.status != "Optimal": @@ -695,9 +746,10 @@ def _get_constraints_for_search(max_weight_size, lower_bound=None): self.best["is_satisfied"] = True best_recipe = {} - best_constraints, best_scores = 0, 0 - for name, selected_idx in zip(self.candidate_stats.keys(), selections): - best_recipe_for_name = self.candidate_stats[name]["formats"][selected_idx] + best_scores = 0 + best_constraints = dict.fromkeys(self.constraint_registry, 0) + for name, selected_idx in zip(self.candidate_stats.keys(), self.selections): + best_recipe_for_name = self.candidate_stats[name]["recipes"][selected_idx] # LP solver could give different solutions for the same layer across DP/TP groups even though # the scores and costs are the same. Lets make sure the same quantization format is selected across DP/TP @@ -710,21 +762,31 @@ def _get_constraints_for_search(max_weight_size, lower_bound=None): best_recipe[name] = best_recipe_for_name get_hparam(self.model, name).active = best_recipe_for_name - best_constraints += self.candidate_stats[name]["costs"][selected_idx] + for constraint_name in self.constraint_registry: + cost_name: str + cost_name = self.constraint_registry[constraint_name]["cost_name"] + best_constraints[constraint_name] += self.candidate_stats[name][cost_name][ + selected_idx + ] best_scores += self.candidate_stats[name]["scores"][selected_idx] - if verbose: + if self.verbose: print_rank_0( f"AutoQuantize best recipe for {name.replace('.quant_recipe', '')}: {best_recipe[name]}" ) - effective_bits_from_search = (best_constraints / total_weight_size) * 16 - if verbose: - print_rank_0( - f"AutoQuantize effective bits from search: {effective_bits_from_search: .2f}" - ) + # Map the search constraints(e.g. weight_size_after_compression) back + # to user-specified constraints(e.g. effective_bits) + for constraint_name in self.constraint_registry: + best_constraints[constraint_name] = self.constraint_registry[constraint_name][ + "cost_to_contraint_value_fn" + ](self.model.modules(), best_constraints[constraint_name]) + if self.verbose: + print_rank_0( + f"AutoQuantize {constraint_name} from search: {best_constraints[constraint_name]}" + ) self.best["recipe"] = best_recipe - self.best["constraints"] = {"effective_bits": effective_bits_from_search} + self.best["constraints"] = best_constraints self.best["score"] = best_scores QuantRecipe.fold_pqs_to_weights(self.model)