Skip to content

Commit 5d542a2

Browse files
authored
Add Embedding quantization (#4159)
* add quant emb * support quant embeddings * remove useless log
1 parent 6f5c287 commit 5d542a2

File tree

1 file changed

+63
-11
lines changed

1 file changed

+63
-11
lines changed

paddlenlp/trainer/trainer_compress.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,36 +70,50 @@ def compress(self, custom_evaluate=None):
7070
_dynabert(self, self.model, args.output_dir)
7171
if "ptq" in args.strategy:
7272
self.args.input_filename_prefix = "pruned_model"
73+
output_dir_list = []
7374
for width_mult in args.width_mult_list:
7475
output_dir_width = os.path.join(args.output_dir, "width_mult_" + str(round(width_mult, 2)))
75-
self.quant(output_dir_width, "ptq")
76-
elif args.strategy == "ptq":
77-
# Input model is an inference model
76+
output_dir_list += self.quant(output_dir_width, "ptq")
77+
if "embeddings" in args.strategy:
78+
for output_dir in output_dir_list:
79+
self.quant(os.path.join(output_dir, args.output_filename_prefix), "embeddings")
80+
elif "ptq" in args.strategy:
81+
# When input model is an inference model
7882
if args.input_infer_model_path is not None:
7983
model_dir = os.path.dirname(args.input_infer_model_path)
8084
self.args.input_filename_prefix = os.path.basename(args.input_infer_model_path)
81-
self.quant(model_dir, args.strategy)
85+
output_dir_list = self.quant(model_dir, "ptq")
8286
# Input model is load from Trainer API in dygraph.
8387
else:
88+
# When input model is a dygraph.
89+
# exports model and then do 'ptq'
8490
# Prefix of `export_model` is 'model'
8591
self.args.input_filename_prefix = "model"
8692
input_spec = generate_input_spec(self.model, self.train_dataset)
8793
input_dir = args.output_dir
8894
export_model(model=self.model, input_spec=input_spec, path=input_dir)
89-
self.quant(input_dir, args.strategy)
90-
elif args.strategy == "qat":
95+
output_dir_list = self.quant(input_dir, "ptq")
96+
if "embeddings" in args.strategy:
97+
for output_dir in output_dir_list:
98+
self.quant(os.path.join(output_dir, args.output_filename_prefix), "embeddings")
99+
elif "qat" in args.strategy:
91100
global_try_import_slim()
92-
self.quant(args.output_dir, args.strategy)
101+
self.quant(args.output_dir, "qat")
102+
if "embeddings" in args.strategy:
103+
self.quant(os.path.join(args.output_dir, args.output_filename_prefix), "embeddings")
93104

94105

95106
def quant(self, model_dir, strategy):
96107
"""
97-
Supports Post-Training Quantization now.
108+
Supports Post-Training Quantization, Quantization Aware Training and
109+
Embedding Quantization.
98110
"""
99111
if strategy == "ptq":
100-
_post_training_quantization_grid_search(self, model_dir)
112+
return _post_training_quantization_grid_search(self, model_dir)
101113
elif strategy == "qat":
102114
_quant_aware_training_dynamic(self, model_dir)
115+
elif strategy == "embeddings":
116+
_quant_embeddings(self, model_dir)
103117

104118

105119
def generate_input_spec(model, dataset):
@@ -138,7 +152,7 @@ def _dynabert(self, model, output_dir):
138152
ofa_model = _dynabert_training(
139153
self, ofa_model, model, teacher_model, train_dataloader, eval_dataloader, args.num_train_epochs
140154
)
141-
155+
self.reset_optimizer_and_scheduler()
142156
# Each width_mult best model would be exported.
143157
_dynabert_export(self, ofa_model)
144158

@@ -540,6 +554,7 @@ def _post_training_quantization_grid_search(self, model_dir):
540554
exe = paddle.static.Executor(place)
541555

542556
args.output_filename_prefix = "int8"
557+
output_dir_list = []
543558

544559
def _post_training_quantization(algo, batch_size, batch_nums):
545560
try:
@@ -587,11 +602,13 @@ def _batch_generator_func():
587602
optimize_model=False,
588603
)
589604
post_training_quantization.quantize()
605+
save_model_path = os.path.join(model_dir, algo + "_".join([str(batch_size), str(batch_nums)]))
590606
post_training_quantization.save_quantized_model(
591-
save_model_path=os.path.join(model_dir, algo + "_".join([str(batch_size), str(batch_nums)])),
607+
save_model_path=save_model_path,
592608
model_filename=args.output_filename_prefix + ".pdmodel",
593609
params_filename=args.output_filename_prefix + ".pdiparams",
594610
)
611+
output_dir_list.append(save_model_path)
595612

596613
logger.info("Post training quantization starts.")
597614
for algo in args.algo_list:
@@ -601,6 +618,7 @@ def _batch_generator_func():
601618

602619
paddle.disable_static()
603620
logger.info("Post training quantization ends and quantized models are saved.")
621+
return output_dir_list
604622

605623

606624
def _quant_aware_training_dynamic(self, input_dir):
@@ -725,6 +743,35 @@ def _quant_aware_training_dynamic(self, input_dir):
725743
logger.info("Quant aware training ends and quantized models are saved.")
726744

727745

746+
def _quant_embeddings(self, input_prefix):
747+
import paddleslim.quant as quant
748+
749+
self.args.output_filename_prefix = "quant_emb"
750+
751+
paddle.enable_static()
752+
place = paddle.set_device(self.args.device)
753+
exe = paddle.static.Executor(place)
754+
main_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(input_prefix, exe)
755+
756+
config = {"quantize_op_types": ["lookup_table_v2"], "lookup_table_v2": {"quantize_type": "log"}}
757+
758+
quant_emb_program = quant.quant_embedding(main_program, place, config)
759+
760+
input_dir = os.path.dirname(input_prefix)
761+
762+
paddle.fluid.io.save_inference_model(
763+
input_dir,
764+
feed_target_names,
765+
fetch_targets,
766+
exe,
767+
quant_emb_program,
768+
model_filename=self.args.output_filename_prefix + ".pdmodel",
769+
params_filename=self.args.output_filename_prefix + ".pdiparams",
770+
export_for_deployment=True,
771+
program_only=False,
772+
)
773+
774+
728775
def auto_model_dynabert_forward(
729776
self,
730777
input_ids,
@@ -865,5 +912,10 @@ def soft_cross_entropy(inp, target):
865912
return -1.0 * paddle.mean(paddle.sum(inp_likelihood * target_prob, axis=-1))
866913

867914

915+
def reset_optimizer_and_scheduler(self):
916+
self.optimizer, self.lr_scheduler = None, None
917+
918+
868919
Trainer.compress = compress
869920
Trainer.quant = quant
921+
Trainer.reset_optimizer_and_scheduler = reset_optimizer_and_scheduler

0 commit comments

Comments
 (0)