@@ -68,15 +68,27 @@ def compress(self, custom_evaluate=None):
68
68
self .custom_evaluate is not None
69
69
), "Custom model using DynaBERT strategy needs to pass in parameters `custom_evaluate`."
70
70
_dynabert (self , self .model , args .output_dir )
71
- if "ptq" in args .strategy :
72
- self .args .input_filename_prefix = "pruned_model"
71
+ if "ptq" in args .strategy or "qat" in args .strategy :
73
72
output_dir_list = []
74
73
for width_mult in args .width_mult_list :
75
74
output_dir_width = os .path .join (args .output_dir , "width_mult_" + str (round (width_mult , 2 )))
76
- output_dir_list += self .quant (output_dir_width , "ptq" )
77
- if "embeddings" in args .strategy :
78
- for output_dir in output_dir_list :
79
- self .quant (os .path .join (output_dir , args .output_filename_prefix ), "embeddings" )
75
+ if "ptq" in args .strategy :
76
+ output_dir_list += self .quant (output_dir_width , "ptq" )
77
+ elif "qat" in args .strategy :
78
+ self .quant (output_dir_width , "qat" )
79
+ output_dir_list .append (output_dir_width )
80
+ if "embeddings" in args .strategy :
81
+ if "ptq" not in args .strategy and "qat" not in args .strategy :
82
+ output_dir_list = []
83
+ for width_mult in args .width_mult_list :
84
+ output_dir_width = os .path .join (
85
+ args .output_dir , "width_mult_" + str (round (width_mult , 2 )), args .input_filename_prefix
86
+ )
87
+ self .quant (output_dir_width , "embeddings" )
88
+ else :
89
+ for output_dir in output_dir_list :
90
+ self .quant (os .path .join (output_dir , args .output_filename_prefix ), "embeddings" )
91
+
80
92
elif "ptq" in args .strategy :
81
93
# When input model is an inference model
82
94
if args .input_infer_model_path is not None :
@@ -153,8 +165,9 @@ def _dynabert(self, model, output_dir):
153
165
self , ofa_model , model , teacher_model , train_dataloader , eval_dataloader , args .num_train_epochs
154
166
)
155
167
self .reset_optimizer_and_scheduler ()
168
+
156
169
# Each width_mult best model would be exported.
157
- _dynabert_export (self , ofa_model )
170
+ _dynabert_export (self )
158
171
159
172
ofa_model , ofa_model .model = _recover_transformer_func (ofa_model , True ), _recover_transformer_func (
160
173
ofa_model .model , True
@@ -500,44 +513,89 @@ def _dynabert_training(self, ofa_model, model, teacher_model, train_dataloader,
500
513
return ofa_model
501
514
502
515
503
- def _dynabert_export (self , ofa_model ):
504
- from paddleslim .nas .ofa import utils
516
+ def _get_dynabert_model (model , width_mult ):
517
+ for layer in model .base_model .encoder .layers :
518
+ # Multi-Head Attention
519
+ layer .self_attn .num_heads = int (layer .self_attn .num_heads * width_mult )
520
+ layer .self_attn .q_proj = nn .Linear (
521
+ layer .self_attn .q_proj .weight .shape [0 ],
522
+ int (layer .self_attn .q_proj .weight .shape [1 ] * width_mult ),
523
+ layer .self_attn .q_proj ._weight_attr ,
524
+ layer .self_attn .q_proj ._bias_attr ,
525
+ )
526
+ layer .self_attn .k_proj = nn .Linear (
527
+ layer .self_attn .k_proj .weight .shape [0 ],
528
+ int (layer .self_attn .k_proj .weight .shape [1 ] * width_mult ),
529
+ layer .self_attn .k_proj ._weight_attr ,
530
+ layer .self_attn .k_proj ._bias_attr ,
531
+ )
532
+ layer .self_attn .v_proj = nn .Linear (
533
+ layer .self_attn .v_proj .weight .shape [0 ],
534
+ int (layer .self_attn .v_proj .weight .shape [1 ] * width_mult ),
535
+ layer .self_attn .v_proj ._weight_attr ,
536
+ layer .self_attn .v_proj ._bias_attr ,
537
+ )
538
+ layer .self_attn .out_proj = nn .Linear (
539
+ int (layer .self_attn .out_proj .weight .shape [0 ] * width_mult ),
540
+ layer .self_attn .out_proj .weight .shape [1 ],
541
+ layer .self_attn .out_proj ._weight_attr ,
542
+ layer .self_attn .out_proj ._bias_attr ,
543
+ )
505
544
506
- ofa_model ._add_teacher = False
507
- ofa_model , ofa_model .model = _recover_transformer_func (ofa_model ), _recover_transformer_func (ofa_model .model )
508
- if isinstance (ofa_model .model , paddle .DataParallel ):
509
- ori_num_heads = ofa_model .model ._layers .base_model .encoder .layers [0 ].self_attn .num_heads
510
- else :
511
- ori_num_heads = ofa_model .model .base_model .encoder .layers [0 ].self_attn .num_heads
512
- for width_mult in self .args .width_mult_list :
513
- model_dir = os .path .join (self .args .output_dir , "width_mult_" + str (round (width_mult , 2 )))
514
- state_dict = paddle .load (os .path .join (model_dir , "model_state.pdparams" ))
515
- origin_model = self .model .__class__ .from_pretrained (model_dir )
516
- ofa_model .model .set_state_dict (state_dict )
517
- best_config = utils .dynabert_config (ofa_model , width_mult )
518
- best_config = check_dynabert_config (best_config , width_mult )
519
- input_spec = generate_input_spec (self .model , self .train_dataset )
520
- origin_model_new = ofa_model .export (
521
- best_config ,
522
- input_shapes = [[1 , 1 ]] * len (input_spec ),
523
- input_dtypes = ["int64" ] * len (input_spec ),
524
- origin_model = origin_model ,
545
+ # Feed Forward
546
+ layer .linear1 = nn .Linear (
547
+ layer .linear1 .weight .shape [0 ],
548
+ int (layer .linear1 .weight .shape [1 ] * width_mult ),
549
+ layer .linear1 ._weight_attr ,
550
+ layer .linear1 ._bias_attr ,
551
+ )
552
+ layer .linear2 = nn .Linear (
553
+ int (layer .linear2 .weight .shape [0 ] * width_mult ),
554
+ layer .linear2 .weight .shape [1 ],
555
+ layer .linear2 ._weight_attr ,
556
+ layer .linear2 ._bias_attr ,
525
557
)
526
- for name , sublayer in origin_model_new .named_sublayers ():
527
- if isinstance (sublayer , paddle .nn .MultiHeadAttention ):
528
- sublayer .num_heads = int (width_mult * sublayer .num_heads )
529
-
530
- pruned_infer_model_dir = os .path .join (model_dir , "pruned_model" )
531
- net = paddle .jit .to_static (origin_model_new , input_spec = input_spec )
532
- paddle .jit .save (net , pruned_infer_model_dir )
533
- # Recover num_heads of ofa_model.model
534
- if isinstance (ofa_model .model , paddle .DataParallel ):
535
- for layer in ofa_model .model ._layers .base_model .encoder .layers :
536
- layer .self_attn .num_heads = ori_num_heads
558
+ return model
559
+
560
+
561
+ def _load_parameters (dynabert_model , ori_state_dict ):
562
+ dynabert_state_dict = dynabert_model .state_dict ()
563
+ for key in ori_state_dict .keys ():
564
+ # Removes '.fn' from ofa model parameters
565
+ dynabert_key = key .replace (".fn" , "" )
566
+ if dynabert_key not in dynabert_state_dict .keys ():
567
+ logger .warning ("Failed to export parameter %s" % key )
537
568
else :
538
- for layer in ofa_model .model .base_model .encoder .layers :
539
- layer .self_attn .num_heads = ori_num_heads
540
- logger .info ("Pruned models have been exported." )
569
+ dynabert_shape = dynabert_state_dict [dynabert_key ].shape
570
+ if len (dynabert_shape ) == 2 :
571
+ dynabert_state_dict [dynabert_key ] = ori_state_dict [key ][: dynabert_shape [0 ], : dynabert_shape [1 ]]
572
+ elif len (dynabert_shape ) == 1 :
573
+ dynabert_state_dict [dynabert_key ] = ori_state_dict [key ][: dynabert_shape [0 ]]
574
+ else :
575
+ raise ValueError ("Please check input model. Length of shape should be 1 or 2 for any parameter." )
576
+ dynabert_model .set_state_dict (dynabert_state_dict )
577
+ return dynabert_model
578
+
579
+
580
+ def _export_dynamic_dynabert_model (self , width_mult ):
581
+ model_dir = os .path .join (self .args .output_dir , "width_mult_" + str (round (width_mult , 2 )))
582
+ state_dict = paddle .load (os .path .join (model_dir , "model_state.pdparams" ))
583
+ origin_model = self .model .__class__ .from_pretrained (model_dir )
584
+ dynabert_model = _get_dynabert_model (origin_model , width_mult )
585
+ dynabert_model = _load_parameters (dynabert_model , state_dict )
586
+ return dynabert_model
587
+
588
+
589
+ def _dynabert_export (self ):
590
+ for width_mult in self .args .width_mult_list :
591
+ dynabert_model = _export_dynamic_dynabert_model (self , width_mult )
592
+ self .model = dynabert_model
593
+ if "qat" not in self .args .strategy :
594
+ input_spec = generate_input_spec (self .model , self .train_dataset )
595
+ pruned_infer_model_dir = os .path .join (self .args .output_dir , "width_mult_" + str (round (width_mult , 2 )))
596
+ export_model (model = dynabert_model , input_spec = input_spec , path = pruned_infer_model_dir )
597
+ self .args .input_filename_prefix = "model"
598
+ logger .info ("Pruned models have been exported." )
541
599
542
600
543
601
def _post_training_quantization_grid_search (self , model_dir ):
@@ -649,10 +707,10 @@ def _quant_aware_training_dynamic(self, input_dir):
649
707
"onnx_format" : args .onnx_format ,
650
708
}
651
709
652
- if not os .path .exists (args . output_dir ):
653
- os .makedirs (args . output_dir )
710
+ if not os .path .exists (input_dir ):
711
+ os .makedirs (input_dir )
654
712
655
- output_param_path = os .path .join (args . output_dir , "best_quant.pdparams" )
713
+ output_param_path = os .path .join (input_dir , "best_quant.pdparams" )
656
714
657
715
train_dataloader = self .get_train_dataloader ()
658
716
eval_dataloader = self .get_eval_dataloader (self .eval_dataset )
@@ -667,9 +725,10 @@ def _quant_aware_training_dynamic(self, input_dir):
667
725
668
726
self .create_optimizer_and_scheduler (num_training_steps = args .num_training_steps )
669
727
670
- logger .info ("FP32 model's evaluation starts ." )
728
+ logger .info ("Evaluating FP32 model before quantization aware training ." )
671
729
672
730
tic_eval = time .time ()
731
+
673
732
acc = evaluate (self , self .model , eval_dataloader )
674
733
logger .info ("eval done total: %s s" % (time .time () - tic_eval ))
675
734
@@ -680,6 +739,7 @@ def _quant_aware_training_dynamic(self, input_dir):
680
739
global_step = 0
681
740
tic_train = time .time ()
682
741
best_acc , acc = 0.0 , 0.0
742
+
683
743
logger .info ("Quant aware training starts." )
684
744
# Train self.model
685
745
for epoch in range (args .num_train_epochs ):
@@ -701,7 +761,6 @@ def _quant_aware_training_dynamic(self, input_dir):
701
761
for key in batch :
702
762
if key in model_para_keys :
703
763
inputs [key ] = batch [key ]
704
-
705
764
logits = self .model (** inputs )
706
765
loss = self .criterion (logits , labels )
707
766
loss .backward ()
@@ -736,11 +795,12 @@ def _quant_aware_training_dynamic(self, input_dir):
736
795
quanter .save_quantized_model (
737
796
self .model , os .path .join (input_dir , args .output_filename_prefix ), input_spec = input_spec
738
797
)
739
- if os .path .exists (output_param_path ):
740
- os .remove (output_param_path )
741
798
742
799
self .model = _recover_auto_model_forward (self .model )
743
- logger .info ("Quant aware training ends and quantized models are saved." )
800
+ logger .info (
801
+ "Quant aware training ends and quantized models are saved to %s."
802
+ % os .path .join (input_dir , args .output_filename_prefix )
803
+ )
744
804
745
805
746
806
def _quant_embeddings (self , input_prefix ):
0 commit comments