44
44
from .utils import is_quantized_linear , multi_context
45
45
46
46
47
+ def _get_total_weight_size (modules ):
48
+ """Helper function to get the total weight size of the modules."""
49
+ return sum (
50
+ (module .weight .numel () if AutoQuantizeSearcher ._is_auto_quantize_module (module ) else 0 )
51
+ for module in modules
52
+ )
53
+
54
+
55
+ def _get_estimate_latency (modules ):
56
+ """Helper function to get the estimated latency of the modules."""
57
+ return 42 # TODO: implement this
58
+
59
+
47
60
def estimate_quant_compression (quant_cfg : QuantizeConfig ) -> float :
48
61
"""Estimate the compression ratio of a quantization configuration.
49
62
@@ -270,6 +283,33 @@ class AutoQuantizeSearcher(BaseSearcher):
270
283
r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$" , # dbrx experts
271
284
]
272
285
286
+ # Registry that maps user input constraint(e.g. effective_bits) to actual serach costs(e.g. weight_size)
287
+ # Each entry defines:
288
+ # - cost_name: The name of the cost metric for the constraint.
289
+ # - cost_fn: Function to compute the cost for the acutal search.
290
+ # - cost_upper_bound_fn: Mapping from user-specified constraint value to cost upper bound during search.
291
+ # - cost_to_contraint_value_fn: Reverse mapping from cost to user-specified constraint value.
292
+ # Add new constraints here as needed for additional search objectives.
293
+ constraint_registry : dict [str , dict [str , str | Callable ]] = {
294
+ "effective_bits" : {
295
+ "cost_name" : "weight_size_after_compression" ,
296
+ "cost_fn" : lambda modules , recipe : _get_total_weight_size (modules ) * recipe .compression ,
297
+ "cost_upper_bound_fn" : lambda modules , effective_bits : _get_total_weight_size (modules )
298
+ * (effective_bits / 16.0 ),
299
+ "cost_to_contraint_value_fn" : lambda modules , weight_size : weight_size
300
+ / _get_total_weight_size (modules )
301
+ * 16.0 ,
302
+ },
303
+ "linear_speedup" : {
304
+ "cost_name" : "latency_after_compression" ,
305
+ "cost_fn" : lambda modules , recipe : _get_estimate_latency (modules ),
306
+ "cost_upper_bound_fn" : lambda modules , linear_speedup : _get_estimate_latency (modules )
307
+ / linear_speedup ,
308
+ "cost_to_contraint_value_fn" : lambda modules , latency : _get_estimate_latency (modules )
309
+ / latency ,
310
+ },
311
+ }
312
+
273
313
@property
274
314
def default_search_config (self ):
275
315
"""Get the default config for the searcher."""
@@ -513,21 +553,21 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes, disabled_layers=
513
553
for module in modules :
514
554
module ._register_hparam ("quant_recipe" , hparam )
515
555
516
- def _get_formatted_weight_compression_constraint (self ):
517
- effective_bits = self .constraints ["effective_bits" ]
518
- assert effective_bits > 0 and effective_bits <= 16 , (
519
- "effective_bits should be between 0 and 16."
520
- )
521
- weight_compression = self .constraints ["effective_bits" ] / 16.0
522
-
523
- return weight_compression
524
-
525
556
def _verify_constraint (self , search_recipes ):
526
- assert self .constraints ["effective_bits" ] >= search_recipes [0 ].num_bits , (
527
- f"The effective_bits { self .constraints ['effective_bits' ]} constraint cannot be lower than the "
528
- f"num_bits of most aggressive quantization format for this search which is "
529
- f"{ search_recipes [0 ]} whose num_bits = { search_recipes [0 ].num_bits } ."
530
- )
557
+ for constraint_name in self .constraints :
558
+ assert constraint_name in self .constraint_registry , (
559
+ f"Constraint { constraint_name } is not supported. "
560
+ f"Supported constraints are { self .constraint_registry .keys ()} "
561
+ )
562
+ if "effective_bits" in self .constraints :
563
+ assert self .constraints ["effective_bits" ] >= search_recipes [0 ].num_bits , (
564
+ f"The effective_bits { self .constraints ['effective_bits' ]} constraint cannot be lower than the "
565
+ f"num_bits of most aggressive quantization format for this search which is "
566
+ f"{ search_recipes [0 ]} whose num_bits = { search_recipes [0 ].num_bits } ."
567
+ )
568
+ assert (
569
+ self .constraints ["effective_bits" ] > 0 and self .constraints ["effective_bits" ] <= 16
570
+ ), "effective_bits should be between 0 and 16."
531
571
532
572
def _run_func (self , func , num_iters = 1 , desc = "" ):
533
573
for i , data in tqdm (
@@ -537,13 +577,74 @@ def _run_func(self, func, num_iters=1, desc=""):
537
577
):
538
578
func (self .model , data )
539
579
580
+ def _make_scores_monotonic (self , scores ):
581
+ """Ensure that the scores are monotonically decreasing for the correctness of LPS solver."""
582
+ monotonic_scores = []
583
+ prev_score = float ("inf" )
584
+ for score in scores :
585
+ score = min (score , prev_score )
586
+ monotonic_scores .append (score )
587
+ prev_score = score
588
+ return monotonic_scores
589
+
590
+ def _populate_candidate_stats (self ):
591
+ """Populate the self.candidate_stats with scores, costs, etc. for the candidate quantization recipes."""
592
+ for name , hparam in named_hparams (self .model , unique = True ):
593
+ if not isinstance (hparam , QuantRecipeHparam ):
594
+ continue
595
+
596
+ scores = []
597
+ for recipe in hparam .choices :
598
+ score = hparam .importance [recipe ]
599
+ # Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
600
+ # This way we constraint the same quantization format for the same layer across the DP/TP groups
601
+ # The cost we use here is weight size. They are the same across DP/TP groups.
602
+ _ps = self .model .get_submodule (name .split (".quant_recipe" )[0 ]).parallel_state
603
+ # The score is the sum of the scores across DP and TP groups
604
+ scores .append (
605
+ DistributedProcessGroup .get_dist_syncd_obj (
606
+ score , [_ps .data_parallel_group , _ps .tensor_parallel_group ], sum
607
+ )
608
+ )
609
+
610
+ self .candidate_stats [name ]["recipes" ] = hparam .choices
611
+ self .candidate_stats [name ]["scores" ] = self ._make_scores_monotonic (scores )
612
+ for constraint_name in self .constraint_registry :
613
+ cost_name : str
614
+ cost_name , cost_fn = (
615
+ self .constraint_registry [constraint_name ]["cost_name" ],
616
+ self .constraint_registry [constraint_name ]["cost_fn" ],
617
+ )
618
+ self .candidate_stats [name ][cost_name ] = [
619
+ cost_fn (hparam .nn_modules , recipe ) for recipe in hparam .choices
620
+ ]
621
+
622
+ def _get_search_constraints (self , user_constraints , lower_bound = None ):
623
+ """Convert user constraints (e.g. effective_bits) to search constraints (e.g. weight_size_after_compression)."""
624
+ search_constraints , constraints_to_candidate_costs = {}, {}
625
+ for constraint_name , constraint_value in user_constraints .items ():
626
+ cost_name : str
627
+ cost_name = self .constraint_registry [constraint_name ]["cost_name" ]
628
+ cost_upper_bound = self .constraint_registry [constraint_name ]["cost_upper_bound_fn" ](
629
+ self .model .modules (), constraint_value
630
+ )
631
+ search_constraints [cost_name ] = (
632
+ lower_bound * cost_upper_bound if lower_bound else lower_bound ,
633
+ cost_upper_bound ,
634
+ )
635
+ constraints_to_candidate_costs [cost_name ] = [
636
+ candidate_stat [cost_name ] for candidate_stat in self .candidate_stats .values ()
637
+ ]
638
+ return search_constraints , constraints_to_candidate_costs
639
+
540
640
def before_search (self ):
541
641
"""Prepare the model for search by calibrating the quantizers and collecting ``AutoQuantize`` score."""
542
642
# Import here to avoid circular import
543
643
from modelopt .torch .quantization .model_quant import calibrate
544
644
545
645
super ().before_search ()
546
646
647
+ self .verbose = self .config ["verbose" ]
547
648
search_recipes = self ._get_search_recipes (self .config ["quantization_formats" ])
548
649
self ._verify_constraint (search_recipes )
549
650
self .insert_hparams_after_merge_rules (
@@ -597,93 +698,43 @@ def forward_loop(model):
597
698
):
598
699
self ._estimate_auto_quantize_scores ()
599
700
701
+ # Populate self.candidate_stats with scores, costs, etc. for the search
702
+ self ._populate_candidate_stats ()
703
+
600
704
def run_search (self ):
601
- """Search for the best per-layer quantization configuration and return the best model and configuration .
705
+ """Search for the best per-layer quantization configuration and produce selections of recipes for each layer .
602
706
603
707
AutoQuantize uses Linear Programming Solver to find the optimal quantization configuration which
604
708
minimizes the sum of per-layer auto_quantize scores while meeting the specified constraint.
605
709
"""
606
-
607
- def get_total_weight_size (modules ):
608
- return sum (
609
- (module .weight .numel () if self ._is_auto_quantize_module (module ) else 0 )
610
- for module in modules
611
- )
612
-
613
- def _get_constraints_for_search (max_weight_size , lower_bound = None ):
614
- constraints = {
615
- "weight_size_after_compression" : (
616
- lower_bound * max_weight_size if lower_bound else lower_bound ,
617
- max_weight_size ,
618
- )
619
- }
620
- return constraints , "weight_size_after_compression"
621
-
622
- verbose = self .config ["verbose" ]
623
- assert len (self .constraints ) == 1 and "effective_bits" in self .constraints , (
624
- f"`constraints` must contain only 'effective_bits' constraint. "
625
- f"Got { self .constraints .keys ()} "
626
- )
627
-
628
- compression = self ._get_formatted_weight_compression_constraint ()
629
- total_weight_size = get_total_weight_size (self .model .modules ())
630
- weight_size_after_compression = total_weight_size * compression
631
-
632
- for name , hparam in named_hparams (self .model , unique = True ):
633
- if not isinstance (hparam , QuantRecipeHparam ):
634
- continue
635
-
636
- formats , scores , costs = [], [], []
637
- prev_score = float ("inf" )
638
- for recipe in hparam .choices :
639
- formats .append (recipe )
640
- score = hparam .importance [recipe ]
641
- cost = get_total_weight_size (hparam .nn_modules ) * recipe .compression # type: ignore [union-attr]
642
-
643
- # Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
644
- # This way we constraint the same quantization format for the same layer across the DP/TP groups
645
- # The cost we use here is weight size. They are the same across DP/TP groups.
646
- _ps = self .model .get_submodule (name .split (".quant_recipe" )[0 ]).parallel_state
647
- # The score is the sum of the scores across DP and TP groups
648
- score = DistributedProcessGroup .get_dist_syncd_obj (
649
- score , [_ps .data_parallel_group , _ps .tensor_parallel_group ], sum
650
- )
651
-
652
- scores .append (min (score , prev_score ))
653
- costs .append (cost )
654
- prev_score = score
655
-
656
- self .candidate_stats [name ]["formats" ] = formats
657
- self .candidate_stats [name ]["scores" ] = scores
658
- self .candidate_stats [name ]["costs" ] = costs
659
-
660
710
for lower_bound in [None , 0.99 , 0.90 ]:
661
711
# The LP solver for auto_quantize sometimes fails to find a solution if a lower bound is not
662
712
# specified. I dont know why this happens.
663
713
# As a workaround, lets specify a lower bound for the weight compression if previous
664
714
# search without lower bound fails.
665
- constraints , constraint_name = _get_constraints_for_search (
666
- weight_size_after_compression , lower_bound
715
+
716
+ # Convert user-specified constraints(e.g. effective_bits) to
717
+ # acutal search constraints(e.g. weight_size_after_compression) and corresponding bounds.
718
+ search_constraints , constraints_to_candidate_costs = self ._get_search_constraints (
719
+ self .constraints , lower_bound
667
720
)
668
721
669
722
lps = LPS (
670
723
name = "AutoQuantize" ,
671
- constraints = constraints ,
672
- constraints_to_candidate_costs = {
673
- constraint_name : [
674
- candidate_stat ["costs" ] for candidate_stat in self .candidate_stats .values ()
675
- ]
676
- },
724
+ constraints = search_constraints ,
725
+ constraints_to_candidate_costs = constraints_to_candidate_costs ,
677
726
candidate_scores = [
678
727
candidate_stat ["scores" ] for candidate_stat in self .candidate_stats .values ()
679
728
],
680
729
objective_type = "minimize" ,
681
- verbose = verbose ,
730
+ verbose = self . verbose ,
682
731
)
683
- selections , self .status = lps ()
732
+ self . selections , self .status = lps ()
684
733
if self .status == "Optimal" :
685
734
break
686
735
736
+ def after_search (self ):
737
+ """Post-process the searched selections and produce the best model and configuration."""
687
738
self .best = {}
688
739
689
740
if self .status != "Optimal" :
@@ -695,9 +746,10 @@ def _get_constraints_for_search(max_weight_size, lower_bound=None):
695
746
self .best ["is_satisfied" ] = True
696
747
697
748
best_recipe = {}
698
- best_constraints , best_scores = 0 , 0
699
- for name , selected_idx in zip (self .candidate_stats .keys (), selections ):
700
- best_recipe_for_name = self .candidate_stats [name ]["formats" ][selected_idx ]
749
+ best_scores = 0
750
+ best_constraints = dict .fromkeys (self .constraint_registry , 0 )
751
+ for name , selected_idx in zip (self .candidate_stats .keys (), self .selections ):
752
+ best_recipe_for_name = self .candidate_stats [name ]["recipes" ][selected_idx ]
701
753
702
754
# LP solver could give different solutions for the same layer across DP/TP groups even though
703
755
# the scores and costs are the same. Lets make sure the same quantization format is selected across DP/TP
@@ -710,21 +762,31 @@ def _get_constraints_for_search(max_weight_size, lower_bound=None):
710
762
711
763
best_recipe [name ] = best_recipe_for_name
712
764
get_hparam (self .model , name ).active = best_recipe_for_name
713
- best_constraints += self .candidate_stats [name ]["costs" ][selected_idx ]
765
+ for constraint_name in self .constraint_registry :
766
+ cost_name : str
767
+ cost_name = self .constraint_registry [constraint_name ]["cost_name" ]
768
+ best_constraints [constraint_name ] += self .candidate_stats [name ][cost_name ][
769
+ selected_idx
770
+ ]
714
771
best_scores += self .candidate_stats [name ]["scores" ][selected_idx ]
715
- if verbose :
772
+ if self . verbose :
716
773
print_rank_0 (
717
774
f"AutoQuantize best recipe for { name .replace ('.quant_recipe' , '' )} : { best_recipe [name ]} "
718
775
)
719
776
720
- effective_bits_from_search = (best_constraints / total_weight_size ) * 16
721
- if verbose :
722
- print_rank_0 (
723
- f"AutoQuantize effective bits from search: { effective_bits_from_search : .2f} "
724
- )
777
+ # Map the search constraints(e.g. weight_size_after_compression) back
778
+ # to user-specified constraints(e.g. effective_bits)
779
+ for constraint_name in self .constraint_registry :
780
+ best_constraints [constraint_name ] = self .constraint_registry [constraint_name ][
781
+ "cost_to_contraint_value_fn"
782
+ ](self .model .modules (), best_constraints [constraint_name ])
783
+ if self .verbose :
784
+ print_rank_0 (
785
+ f"AutoQuantize { constraint_name } from search: { best_constraints [constraint_name ]} "
786
+ )
725
787
726
788
self .best ["recipe" ] = best_recipe
727
- self .best ["constraints" ] = { "effective_bits" : effective_bits_from_search }
789
+ self .best ["constraints" ] = best_constraints
728
790
self .best ["score" ] = best_scores
729
791
730
792
QuantRecipe .fold_pqs_to_weights (self .model )
0 commit comments