Skip to content

Commit ff3c0c9

Browse files
authored
Support dynamic export for dynabert (#3549)
* support dynamic export for dynabert * save diff * add error message for input check * update input model name prefix * solve conficts and update compress strategy
1 parent 240f817 commit ff3c0c9

File tree

1 file changed

+110
-50
lines changed

1 file changed

+110
-50
lines changed

paddlenlp/trainer/trainer_compress.py

Lines changed: 110 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,27 @@ def compress(self, custom_evaluate=None):
6868
self.custom_evaluate is not None
6969
), "Custom model using DynaBERT strategy needs to pass in parameters `custom_evaluate`."
7070
_dynabert(self, self.model, args.output_dir)
71-
if "ptq" in args.strategy:
72-
self.args.input_filename_prefix = "pruned_model"
71+
if "ptq" in args.strategy or "qat" in args.strategy:
7372
output_dir_list = []
7473
for width_mult in args.width_mult_list:
7574
output_dir_width = os.path.join(args.output_dir, "width_mult_" + str(round(width_mult, 2)))
76-
output_dir_list += self.quant(output_dir_width, "ptq")
77-
if "embeddings" in args.strategy:
78-
for output_dir in output_dir_list:
79-
self.quant(os.path.join(output_dir, args.output_filename_prefix), "embeddings")
75+
if "ptq" in args.strategy:
76+
output_dir_list += self.quant(output_dir_width, "ptq")
77+
elif "qat" in args.strategy:
78+
self.quant(output_dir_width, "qat")
79+
output_dir_list.append(output_dir_width)
80+
if "embeddings" in args.strategy:
81+
if "ptq" not in args.strategy and "qat" not in args.strategy:
82+
output_dir_list = []
83+
for width_mult in args.width_mult_list:
84+
output_dir_width = os.path.join(
85+
args.output_dir, "width_mult_" + str(round(width_mult, 2)), args.input_filename_prefix
86+
)
87+
self.quant(output_dir_width, "embeddings")
88+
else:
89+
for output_dir in output_dir_list:
90+
self.quant(os.path.join(output_dir, args.output_filename_prefix), "embeddings")
91+
8092
elif "ptq" in args.strategy:
8193
# When input model is an inference model
8294
if args.input_infer_model_path is not None:
@@ -153,8 +165,9 @@ def _dynabert(self, model, output_dir):
153165
self, ofa_model, model, teacher_model, train_dataloader, eval_dataloader, args.num_train_epochs
154166
)
155167
self.reset_optimizer_and_scheduler()
168+
156169
# Each width_mult best model would be exported.
157-
_dynabert_export(self, ofa_model)
170+
_dynabert_export(self)
158171

159172
ofa_model, ofa_model.model = _recover_transformer_func(ofa_model, True), _recover_transformer_func(
160173
ofa_model.model, True
@@ -500,44 +513,89 @@ def _dynabert_training(self, ofa_model, model, teacher_model, train_dataloader,
500513
return ofa_model
501514

502515

503-
def _dynabert_export(self, ofa_model):
504-
from paddleslim.nas.ofa import utils
516+
def _get_dynabert_model(model, width_mult):
517+
for layer in model.base_model.encoder.layers:
518+
# Multi-Head Attention
519+
layer.self_attn.num_heads = int(layer.self_attn.num_heads * width_mult)
520+
layer.self_attn.q_proj = nn.Linear(
521+
layer.self_attn.q_proj.weight.shape[0],
522+
int(layer.self_attn.q_proj.weight.shape[1] * width_mult),
523+
layer.self_attn.q_proj._weight_attr,
524+
layer.self_attn.q_proj._bias_attr,
525+
)
526+
layer.self_attn.k_proj = nn.Linear(
527+
layer.self_attn.k_proj.weight.shape[0],
528+
int(layer.self_attn.k_proj.weight.shape[1] * width_mult),
529+
layer.self_attn.k_proj._weight_attr,
530+
layer.self_attn.k_proj._bias_attr,
531+
)
532+
layer.self_attn.v_proj = nn.Linear(
533+
layer.self_attn.v_proj.weight.shape[0],
534+
int(layer.self_attn.v_proj.weight.shape[1] * width_mult),
535+
layer.self_attn.v_proj._weight_attr,
536+
layer.self_attn.v_proj._bias_attr,
537+
)
538+
layer.self_attn.out_proj = nn.Linear(
539+
int(layer.self_attn.out_proj.weight.shape[0] * width_mult),
540+
layer.self_attn.out_proj.weight.shape[1],
541+
layer.self_attn.out_proj._weight_attr,
542+
layer.self_attn.out_proj._bias_attr,
543+
)
505544

506-
ofa_model._add_teacher = False
507-
ofa_model, ofa_model.model = _recover_transformer_func(ofa_model), _recover_transformer_func(ofa_model.model)
508-
if isinstance(ofa_model.model, paddle.DataParallel):
509-
ori_num_heads = ofa_model.model._layers.base_model.encoder.layers[0].self_attn.num_heads
510-
else:
511-
ori_num_heads = ofa_model.model.base_model.encoder.layers[0].self_attn.num_heads
512-
for width_mult in self.args.width_mult_list:
513-
model_dir = os.path.join(self.args.output_dir, "width_mult_" + str(round(width_mult, 2)))
514-
state_dict = paddle.load(os.path.join(model_dir, "model_state.pdparams"))
515-
origin_model = self.model.__class__.from_pretrained(model_dir)
516-
ofa_model.model.set_state_dict(state_dict)
517-
best_config = utils.dynabert_config(ofa_model, width_mult)
518-
best_config = check_dynabert_config(best_config, width_mult)
519-
input_spec = generate_input_spec(self.model, self.train_dataset)
520-
origin_model_new = ofa_model.export(
521-
best_config,
522-
input_shapes=[[1, 1]] * len(input_spec),
523-
input_dtypes=["int64"] * len(input_spec),
524-
origin_model=origin_model,
545+
# Feed Forward
546+
layer.linear1 = nn.Linear(
547+
layer.linear1.weight.shape[0],
548+
int(layer.linear1.weight.shape[1] * width_mult),
549+
layer.linear1._weight_attr,
550+
layer.linear1._bias_attr,
551+
)
552+
layer.linear2 = nn.Linear(
553+
int(layer.linear2.weight.shape[0] * width_mult),
554+
layer.linear2.weight.shape[1],
555+
layer.linear2._weight_attr,
556+
layer.linear2._bias_attr,
525557
)
526-
for name, sublayer in origin_model_new.named_sublayers():
527-
if isinstance(sublayer, paddle.nn.MultiHeadAttention):
528-
sublayer.num_heads = int(width_mult * sublayer.num_heads)
529-
530-
pruned_infer_model_dir = os.path.join(model_dir, "pruned_model")
531-
net = paddle.jit.to_static(origin_model_new, input_spec=input_spec)
532-
paddle.jit.save(net, pruned_infer_model_dir)
533-
# Recover num_heads of ofa_model.model
534-
if isinstance(ofa_model.model, paddle.DataParallel):
535-
for layer in ofa_model.model._layers.base_model.encoder.layers:
536-
layer.self_attn.num_heads = ori_num_heads
558+
return model
559+
560+
561+
def _load_parameters(dynabert_model, ori_state_dict):
562+
dynabert_state_dict = dynabert_model.state_dict()
563+
for key in ori_state_dict.keys():
564+
# Removes '.fn' from ofa model parameters
565+
dynabert_key = key.replace(".fn", "")
566+
if dynabert_key not in dynabert_state_dict.keys():
567+
logger.warning("Failed to export parameter %s" % key)
537568
else:
538-
for layer in ofa_model.model.base_model.encoder.layers:
539-
layer.self_attn.num_heads = ori_num_heads
540-
logger.info("Pruned models have been exported.")
569+
dynabert_shape = dynabert_state_dict[dynabert_key].shape
570+
if len(dynabert_shape) == 2:
571+
dynabert_state_dict[dynabert_key] = ori_state_dict[key][: dynabert_shape[0], : dynabert_shape[1]]
572+
elif len(dynabert_shape) == 1:
573+
dynabert_state_dict[dynabert_key] = ori_state_dict[key][: dynabert_shape[0]]
574+
else:
575+
raise ValueError("Please check input model. Length of shape should be 1 or 2 for any parameter.")
576+
dynabert_model.set_state_dict(dynabert_state_dict)
577+
return dynabert_model
578+
579+
580+
def _export_dynamic_dynabert_model(self, width_mult):
581+
model_dir = os.path.join(self.args.output_dir, "width_mult_" + str(round(width_mult, 2)))
582+
state_dict = paddle.load(os.path.join(model_dir, "model_state.pdparams"))
583+
origin_model = self.model.__class__.from_pretrained(model_dir)
584+
dynabert_model = _get_dynabert_model(origin_model, width_mult)
585+
dynabert_model = _load_parameters(dynabert_model, state_dict)
586+
return dynabert_model
587+
588+
589+
def _dynabert_export(self):
590+
for width_mult in self.args.width_mult_list:
591+
dynabert_model = _export_dynamic_dynabert_model(self, width_mult)
592+
self.model = dynabert_model
593+
if "qat" not in self.args.strategy:
594+
input_spec = generate_input_spec(self.model, self.train_dataset)
595+
pruned_infer_model_dir = os.path.join(self.args.output_dir, "width_mult_" + str(round(width_mult, 2)))
596+
export_model(model=dynabert_model, input_spec=input_spec, path=pruned_infer_model_dir)
597+
self.args.input_filename_prefix = "model"
598+
logger.info("Pruned models have been exported.")
541599

542600

543601
def _post_training_quantization_grid_search(self, model_dir):
@@ -649,10 +707,10 @@ def _quant_aware_training_dynamic(self, input_dir):
649707
"onnx_format": args.onnx_format,
650708
}
651709

652-
if not os.path.exists(args.output_dir):
653-
os.makedirs(args.output_dir)
710+
if not os.path.exists(input_dir):
711+
os.makedirs(input_dir)
654712

655-
output_param_path = os.path.join(args.output_dir, "best_quant.pdparams")
713+
output_param_path = os.path.join(input_dir, "best_quant.pdparams")
656714

657715
train_dataloader = self.get_train_dataloader()
658716
eval_dataloader = self.get_eval_dataloader(self.eval_dataset)
@@ -667,9 +725,10 @@ def _quant_aware_training_dynamic(self, input_dir):
667725

668726
self.create_optimizer_and_scheduler(num_training_steps=args.num_training_steps)
669727

670-
logger.info("FP32 model's evaluation starts.")
728+
logger.info("Evaluating FP32 model before quantization aware training.")
671729

672730
tic_eval = time.time()
731+
673732
acc = evaluate(self, self.model, eval_dataloader)
674733
logger.info("eval done total: %s s" % (time.time() - tic_eval))
675734

@@ -680,6 +739,7 @@ def _quant_aware_training_dynamic(self, input_dir):
680739
global_step = 0
681740
tic_train = time.time()
682741
best_acc, acc = 0.0, 0.0
742+
683743
logger.info("Quant aware training starts.")
684744
# Train self.model
685745
for epoch in range(args.num_train_epochs):
@@ -701,7 +761,6 @@ def _quant_aware_training_dynamic(self, input_dir):
701761
for key in batch:
702762
if key in model_para_keys:
703763
inputs[key] = batch[key]
704-
705764
logits = self.model(**inputs)
706765
loss = self.criterion(logits, labels)
707766
loss.backward()
@@ -736,11 +795,12 @@ def _quant_aware_training_dynamic(self, input_dir):
736795
quanter.save_quantized_model(
737796
self.model, os.path.join(input_dir, args.output_filename_prefix), input_spec=input_spec
738797
)
739-
if os.path.exists(output_param_path):
740-
os.remove(output_param_path)
741798

742799
self.model = _recover_auto_model_forward(self.model)
743-
logger.info("Quant aware training ends and quantized models are saved.")
800+
logger.info(
801+
"Quant aware training ends and quantized models are saved to %s."
802+
% os.path.join(input_dir, args.output_filename_prefix)
803+
)
744804

745805

746806
def _quant_embeddings(self, input_prefix):

0 commit comments

Comments
 (0)