@@ -70,36 +70,50 @@ def compress(self, custom_evaluate=None):
70
70
_dynabert (self , self .model , args .output_dir )
71
71
if "ptq" in args .strategy :
72
72
self .args .input_filename_prefix = "pruned_model"
73
+ output_dir_list = []
73
74
for width_mult in args .width_mult_list :
74
75
output_dir_width = os .path .join (args .output_dir , "width_mult_" + str (round (width_mult , 2 )))
75
- self .quant (output_dir_width , "ptq" )
76
- elif args .strategy == "ptq" :
77
- # Input model is an inference model
76
+ output_dir_list += self .quant (output_dir_width , "ptq" )
77
+ if "embeddings" in args .strategy :
78
+ for output_dir in output_dir_list :
79
+ self .quant (os .path .join (output_dir , args .output_filename_prefix ), "embeddings" )
80
+ elif "ptq" in args .strategy :
81
+ # When input model is an inference model
78
82
if args .input_infer_model_path is not None :
79
83
model_dir = os .path .dirname (args .input_infer_model_path )
80
84
self .args .input_filename_prefix = os .path .basename (args .input_infer_model_path )
81
- self .quant (model_dir , args . strategy )
85
+ output_dir_list = self .quant (model_dir , "ptq" )
82
86
# Input model is load from Trainer API in dygraph.
83
87
else :
88
+ # When input model is a dygraph.
89
+ # exports model and then do 'ptq'
84
90
# Prefix of `export_model` is 'model'
85
91
self .args .input_filename_prefix = "model"
86
92
input_spec = generate_input_spec (self .model , self .train_dataset )
87
93
input_dir = args .output_dir
88
94
export_model (model = self .model , input_spec = input_spec , path = input_dir )
89
- self .quant (input_dir , args .strategy )
90
- elif args .strategy == "qat" :
95
+ output_dir_list = self .quant (input_dir , "ptq" )
96
+ if "embeddings" in args .strategy :
97
+ for output_dir in output_dir_list :
98
+ self .quant (os .path .join (output_dir , args .output_filename_prefix ), "embeddings" )
99
+ elif "qat" in args .strategy :
91
100
global_try_import_slim ()
92
- self .quant (args .output_dir , args .strategy )
101
+ self .quant (args .output_dir , "qat" )
102
+ if "embeddings" in args .strategy :
103
+ self .quant (os .path .join (args .output_dir , args .output_filename_prefix ), "embeddings" )
93
104
94
105
95
106
def quant (self , model_dir , strategy ):
96
107
"""
97
- Supports Post-Training Quantization now.
108
+ Supports Post-Training Quantization, Quantization Aware Training and
109
+ Embedding Quantization.
98
110
"""
99
111
if strategy == "ptq" :
100
- _post_training_quantization_grid_search (self , model_dir )
112
+ return _post_training_quantization_grid_search (self , model_dir )
101
113
elif strategy == "qat" :
102
114
_quant_aware_training_dynamic (self , model_dir )
115
+ elif strategy == "embeddings" :
116
+ _quant_embeddings (self , model_dir )
103
117
104
118
105
119
def generate_input_spec (model , dataset ):
@@ -138,7 +152,7 @@ def _dynabert(self, model, output_dir):
138
152
ofa_model = _dynabert_training (
139
153
self , ofa_model , model , teacher_model , train_dataloader , eval_dataloader , args .num_train_epochs
140
154
)
141
-
155
+ self . reset_optimizer_and_scheduler ()
142
156
# Each width_mult best model would be exported.
143
157
_dynabert_export (self , ofa_model )
144
158
@@ -540,6 +554,7 @@ def _post_training_quantization_grid_search(self, model_dir):
540
554
exe = paddle .static .Executor (place )
541
555
542
556
args .output_filename_prefix = "int8"
557
+ output_dir_list = []
543
558
544
559
def _post_training_quantization (algo , batch_size , batch_nums ):
545
560
try :
@@ -587,11 +602,13 @@ def _batch_generator_func():
587
602
optimize_model = False ,
588
603
)
589
604
post_training_quantization .quantize ()
605
+ save_model_path = os .path .join (model_dir , algo + "_" .join ([str (batch_size ), str (batch_nums )]))
590
606
post_training_quantization .save_quantized_model (
591
- save_model_path = os . path . join ( model_dir , algo + "_" . join ([ str ( batch_size ), str ( batch_nums )])) ,
607
+ save_model_path = save_model_path ,
592
608
model_filename = args .output_filename_prefix + ".pdmodel" ,
593
609
params_filename = args .output_filename_prefix + ".pdiparams" ,
594
610
)
611
+ output_dir_list .append (save_model_path )
595
612
596
613
logger .info ("Post training quantization starts." )
597
614
for algo in args .algo_list :
@@ -601,6 +618,7 @@ def _batch_generator_func():
601
618
602
619
paddle .disable_static ()
603
620
logger .info ("Post training quantization ends and quantized models are saved." )
621
+ return output_dir_list
604
622
605
623
606
624
def _quant_aware_training_dynamic (self , input_dir ):
@@ -725,6 +743,35 @@ def _quant_aware_training_dynamic(self, input_dir):
725
743
logger .info ("Quant aware training ends and quantized models are saved." )
726
744
727
745
746
+ def _quant_embeddings (self , input_prefix ):
747
+ import paddleslim .quant as quant
748
+
749
+ self .args .output_filename_prefix = "quant_emb"
750
+
751
+ paddle .enable_static ()
752
+ place = paddle .set_device (self .args .device )
753
+ exe = paddle .static .Executor (place )
754
+ main_program , feed_target_names , fetch_targets = paddle .static .load_inference_model (input_prefix , exe )
755
+
756
+ config = {"quantize_op_types" : ["lookup_table_v2" ], "lookup_table_v2" : {"quantize_type" : "log" }}
757
+
758
+ quant_emb_program = quant .quant_embedding (main_program , place , config )
759
+
760
+ input_dir = os .path .dirname (input_prefix )
761
+
762
+ paddle .fluid .io .save_inference_model (
763
+ input_dir ,
764
+ feed_target_names ,
765
+ fetch_targets ,
766
+ exe ,
767
+ quant_emb_program ,
768
+ model_filename = self .args .output_filename_prefix + ".pdmodel" ,
769
+ params_filename = self .args .output_filename_prefix + ".pdiparams" ,
770
+ export_for_deployment = True ,
771
+ program_only = False ,
772
+ )
773
+
774
+
728
775
def auto_model_dynabert_forward (
729
776
self ,
730
777
input_ids ,
@@ -865,5 +912,10 @@ def soft_cross_entropy(inp, target):
865
912
return - 1.0 * paddle .mean (paddle .sum (inp_likelihood * target_prob , axis = - 1 ))
866
913
867
914
915
+ def reset_optimizer_and_scheduler (self ):
916
+ self .optimizer , self .lr_scheduler = None , None
917
+
918
+
868
919
Trainer .compress = compress
869
920
Trainer .quant = quant
921
+ Trainer .reset_optimizer_and_scheduler = reset_optimizer_and_scheduler
0 commit comments