18
18
import gc
19
19
import hashlib
20
20
import json
21
+ import random
21
22
import types
22
23
import warnings
23
24
from collections import defaultdict
44
45
from .utils import is_quantized_linear , multi_context
45
46
46
47
48
+ def get_total_weight_size (modules ):
49
+ """Get the total weight size of the modules."""
50
+ return sum (
51
+ (module .weight .numel () if AutoQuantizeSearcher ._is_auto_quantize_module (module ) else 0 )
52
+ for module in modules
53
+ )
54
+
55
+
56
+ def get_base_time_cost (modules ):
57
+ """Get the base time cost of the modules."""
58
+ return 42 # TODO: Implement this
59
+
60
+
61
+ def get_total_linear_time (modules ):
62
+ """Get the total linear time of the modules."""
63
+ return sum (
64
+ (
65
+ get_base_time_cost (modules )
66
+ if AutoQuantizeSearcher ._is_auto_quantize_module (module )
67
+ else 0
68
+ )
69
+ for module in modules
70
+ )
71
+
72
+
47
73
def estimate_quant_compression (quant_cfg : QuantizeConfig ) -> float :
48
74
"""Estimate the compression ratio of a quantization configuration.
49
75
@@ -221,6 +247,22 @@ def importance(self) -> dict:
221
247
for quant_recipe , importance_dict in self ._importance_dict .items ()
222
248
}
223
249
250
+ @property
251
+ def weight_sizes (self ) -> list [float ]:
252
+ """Return the weight size of the quantization recipe."""
253
+ return [
254
+ get_total_weight_size (self .nn_modules ) * getattr (quant_recipe , "compression" )
255
+ for quant_recipe in self .choices
256
+ ]
257
+
258
+ @property
259
+ def time_costs (self ) -> list [float ]:
260
+ """Return the time cost of the quantization recipe."""
261
+ return [
262
+ get_base_time_cost (self .nn_modules ) * (0.1 + 0.9 * random .random ())
263
+ for quant_recipe in self .choices
264
+ ]
265
+
224
266
225
267
class AutoQuantizeSearcher (BaseSearcher ):
226
268
"""A searcher for AutoQuantize algorithm.
@@ -238,7 +280,7 @@ class AutoQuantizeSearcher(BaseSearcher):
238
280
for other models such as ResNet.
239
281
"""
240
282
241
- candidate_stats : dict [str , dict [str , list [ float ] ]]
283
+ candidate_stats : dict [str , dict [str , list ]]
242
284
best : dict [str , Any ]
243
285
gradient_checkpointing_enable_contexts : list [tuple [Callable , Callable ]] = []
244
286
@@ -462,21 +504,19 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes):
462
504
for module in modules :
463
505
module ._register_hparam ("quant_recipe" , hparam )
464
506
465
- def _get_formatted_weight_compression_constraint (self ):
466
- effective_bits = self .constraints ["effective_bits" ]
467
- assert effective_bits > 0 and effective_bits <= 16 , (
468
- "effective_bits should be between 0 and 16."
469
- )
470
- weight_compression = self .constraints ["effective_bits" ] / 16.0
471
-
472
- return weight_compression
473
-
474
507
def _verify_constraint (self , search_recipes ):
475
508
assert self .constraints ["effective_bits" ] >= search_recipes [0 ].num_bits , (
476
509
f"The effective_bits { self .constraints ['effective_bits' ]} constraint cannot be lower than the "
477
510
f"num_bits of most aggressive quantization format for this search which is "
478
511
f"{ search_recipes [0 ]} whose num_bits = { search_recipes [0 ].num_bits } ."
479
512
)
513
+ assert (
514
+ self .constraints ["effective_bits" ] > 0 and self .constraints ["effective_bits" ] <= 16
515
+ ), "effective_bits should be between 0 and 16."
516
+ assert len (self .constraints ) == 1 and "effective_bits" in self .constraints , (
517
+ f"`constraints` must contain only 'effective_bits' constraint. "
518
+ f"Got { self .constraints .keys ()} "
519
+ )
480
520
481
521
def _run_func (self , func , num_iters = 1 , desc = "" ):
482
522
for i , data in tqdm (
@@ -544,51 +584,15 @@ def forward_loop(model):
544
584
):
545
585
self ._estimate_auto_quantize_scores ()
546
586
547
- def run_search (self ):
548
- """Search for the best per-layer quantization configuration and return the best model and configuration.
549
-
550
- AutoQuantize uses Linear Programming Solver to find the optimal quantization configuration which
551
- minimizes the sum of per-layer auto_quantize scores while meeting the specified constraint.
552
- """
553
-
554
- def get_total_weight_size (modules ):
555
- return sum (
556
- (module .weight .numel () if self ._is_auto_quantize_module (module ) else 0 )
557
- for module in modules
558
- )
559
-
560
- def _get_constraints_for_search (lower_bound = None ):
561
- total_model_weight_size = get_total_weight_size (self .model .modules ())
562
-
563
- upper_bound = self ._get_formatted_weight_compression_constraint ()
564
-
565
- if lower_bound :
566
- lower_bound = lower_bound * upper_bound
567
-
568
- constraints = {
569
- "weight_size_after_compression" : (
570
- lower_bound * total_model_weight_size if lower_bound else lower_bound ,
571
- upper_bound * total_model_weight_size ,
572
- )
573
- }
574
- return constraints , "weight_size_after_compression"
575
-
576
- verbose = self .config ["verbose" ]
577
- assert len (self .constraints ) == 1 and "effective_bits" in self .constraints , (
578
- f"`constraints` must contain only 'effective_bits' constraint. "
579
- f"Got { self .constraints .keys ()} "
580
- )
581
-
582
- search_recipes = self ._get_search_recipes (self .config ["quantization_formats" ])
587
+ def _get_candidate_stats (self ):
588
+ """Calculate the candidate stats including formats, scores, weight sizes, and time costs."""
583
589
for name , hparam in named_hparams (self .model , configurable = True ):
584
590
if not isinstance (hparam , QuantRecipeHparam ):
585
591
continue
586
- formats , scores , costs = [], [], []
592
+ scores = []
587
593
prev_score = float ("inf" )
588
- for recipe in search_recipes :
589
- formats .append (recipe )
594
+ for recipe in hparam .choices :
590
595
score = hparam .importance [recipe ]
591
- cost = get_total_weight_size (hparam .nn_modules ) * recipe .compression
592
596
593
597
# Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups
594
598
# This way we constraint the same quantization format for the same layer across the DP/TP groups
@@ -600,27 +604,71 @@ def _get_constraints_for_search(lower_bound=None):
600
604
)
601
605
602
606
scores .append (min (score , prev_score ))
603
- costs .append (cost )
604
607
prev_score = score
605
- self .candidate_stats [name ]["formats" ] = formats
606
- self .candidate_stats [name ]["scores" ] = scores
607
- self .candidate_stats [name ]["costs" ] = costs
608
+
609
+ self .candidate_stats [name ] = {
610
+ "choices" : list (hparam .choices ),
611
+ "scores" : scores ,
612
+ "weight_sizes" : hparam .weight_sizes ,
613
+ "time_costs" : hparam .time_costs ,
614
+ }
615
+ return self .candidate_stats
616
+
617
+ def _get_constraints_kwargs (self , lower_bound = None ):
618
+ """Get the constraints and constraints to candidate costs."""
619
+ constraints , constraints_to_candidate_costs = {}, {}
620
+
621
+ if "effective_bits" in self .constraints :
622
+ upper_bound = self .constraints ["effective_bits" ] / 16.0
623
+ if lower_bound :
624
+ lower_bound = lower_bound * upper_bound
625
+ constraints ["total_weight_size" ] = (
626
+ lower_bound * get_total_weight_size (self .model .modules ())
627
+ if lower_bound
628
+ else lower_bound ,
629
+ upper_bound * get_total_weight_size (self .model .modules ()),
630
+ )
631
+ constraints_to_candidate_costs ["total_weight_size" ] = [
632
+ candidate_stat ["weight_sizes" ] for candidate_stat in self .candidate_stats .values ()
633
+ ]
634
+
635
+ if "linear_speedup" in self .constraints :
636
+ upper_bound = self .constraints ["linear_speedup" ]
637
+ if lower_bound :
638
+ lower_bound = lower_bound * upper_bound
639
+ constraints ["total_linear_time" ] = (
640
+ (1 / lower_bound ) * get_total_linear_time (self .model .modules ())
641
+ if lower_bound
642
+ else lower_bound ,
643
+ (1 / upper_bound ) * get_total_linear_time (self .model .modules ()),
644
+ )
645
+ constraints_to_candidate_costs ["total_linear_time" ] = [
646
+ candidate_stat ["time_costs" ] for candidate_stat in self .candidate_stats .values ()
647
+ ]
648
+
649
+ return constraints , constraints_to_candidate_costs
650
+
651
+ def run_search (self ):
652
+ """Search for the best per-layer quantization configuration and return the best model and configuration.
653
+
654
+ AutoQuantize uses Linear Programming Solver to find the optimal quantization configuration which
655
+ minimizes the sum of per-layer auto_quantize scores while meeting the specified constraint.
656
+ """
657
+ verbose = self .config ["verbose" ]
658
+
659
+ self .candidate_stats = self ._get_candidate_stats ()
608
660
609
661
for lower_bound in [None , 0.99 , 0.90 ]:
610
662
# The LP solver for auto_quantize sometimes fails to find a solution if a lower bound is not
611
663
# specified. I dont know why this happens.
612
664
# As a workaround, lets specify a lower bound for the weight compression if previous
613
665
# search without lower bound fails.
614
- constraints , constraint_name = _get_constraints_for_search (lower_bound )
666
+ constraints , constraints_to_candidate_costs = self . _get_constraints_kwargs (lower_bound )
615
667
616
668
lps = LPS (
617
669
name = "AutoQuantize" ,
618
670
constraints = constraints ,
619
- constraints_to_candidate_costs = {
620
- constraint_name : [
621
- candidate_stat ["costs" ] for candidate_stat in self .candidate_stats .values ()
622
- ]
623
- },
671
+ constraints_to_candidate_costs = constraints_to_candidate_costs ,
624
672
candidate_scores = [
625
673
candidate_stat ["scores" ] for candidate_stat in self .candidate_stats .values ()
626
674
],
@@ -642,9 +690,9 @@ def _get_constraints_for_search(lower_bound=None):
642
690
self .best ["is_satisfied" ] = True
643
691
644
692
best_recipe = {}
645
- best_constraints , best_scores = 0 , 0
693
+ best_weight_size , best_linear_time , best_scores = 0 , 0 , 0
646
694
for name , selected_idx in zip (self .candidate_stats .keys (), selections ):
647
- best_recipe_for_name = self .candidate_stats [name ]["formats " ][selected_idx ]
695
+ best_recipe_for_name = self .candidate_stats [name ]["choices " ][selected_idx ]
648
696
649
697
# LP solver could give different solutions for the same layer across DP/TP groups even though
650
698
# the scores and costs are the same. Lets make sure the same quantization format is selected across DP/TP
@@ -657,15 +705,19 @@ def _get_constraints_for_search(lower_bound=None):
657
705
658
706
best_recipe [name ] = best_recipe_for_name
659
707
get_hparam (self .model , name ).active = best_recipe_for_name
660
- best_constraints += self .candidate_stats [name ]["costs" ][selected_idx ]
708
+ best_weight_size += self .candidate_stats [name ]["weight_sizes" ][selected_idx ]
709
+ best_linear_time += self .candidate_stats [name ]["time_costs" ][selected_idx ]
661
710
best_scores += self .candidate_stats [name ]["scores" ][selected_idx ]
662
711
if verbose :
663
712
print_rank_0 (
664
713
f"AutoQuantize best recipe for { name .replace ('.quant_recipe' , '' )} : { best_recipe [name ]} "
665
714
)
666
715
667
716
self .best ["recipe" ] = best_recipe
668
- self .best ["constraints" ] = {constraint_name : best_constraints }
717
+ self .best ["constraints" ] = {
718
+ "total_weight_size" : best_weight_size ,
719
+ "total_linear_time" : best_linear_time ,
720
+ }
669
721
self .best ["score" ] = best_scores
670
722
671
723
QuantRecipe .fold_pqs_to_weights (self .model )
0 commit comments