|
24 | 24 | import paddle.distributed.fleet as fleet
|
25 | 25 | if platform.system().lower() == 'linux':
|
26 | 26 | from ..quant import quant_post_hpo
|
27 |
| -from ..quant.quanter import convert |
| 27 | +from ..quant.quanter import convert, quant_post |
28 | 28 | from ..common.recover_program import recover_inference_program
|
29 | 29 | from ..common import get_logger
|
30 | 30 | from ..common.patterns import get_patterns
|
31 | 31 | from ..analysis import TableLatencyPredictor
|
32 | 32 | from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program, remove_unused_var_nodes
|
33 | 33 | from .strategy_config import ProgramInfo, merge_config
|
34 |
| -from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config |
| 34 | +from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config, create_train_config |
35 | 35 |
|
36 | 36 | _logger = get_logger(__name__, level=logging.INFO)
|
37 | 37 |
|
@@ -143,6 +143,11 @@ def __init__(self,
|
143 | 143 | self._strategy, self._config = self._prepare_strategy(
|
144 | 144 | self.strategy_config)
|
145 | 145 |
|
| 146 | + # If train_config is None, set default train_config |
| 147 | + if self.train_config is None: |
| 148 | + self.train_config = create_train_config(self.strategy_config, |
| 149 | + self.model_type) |
| 150 | + |
146 | 151 | def _prepare_envs(self):
|
147 | 152 | devices = paddle.device.get_device().split(':')[0]
|
148 | 153 | places = paddle.device._convert_to_place(devices)
|
@@ -248,8 +253,9 @@ def _prepare_program(self, program, feed_target_names, fetch_targets,
|
248 | 253 | feed_target_names, fetch_targets)
|
249 | 254 |
|
250 | 255 | config_dict = dict(config._asdict())
|
251 |
| - if config_dict["prune_strategy"] == "gmp" and config_dict[ |
252 |
| - 'gmp_config'] is None: |
| 256 | + if "prune_strategy" in config_dict and config_dict[ |
| 257 | + "prune_strategy"] == "gmp" and config_dict[ |
| 258 | + 'gmp_config'] is None: |
253 | 259 | _logger.info(
|
254 | 260 | "Calculating the iterations per epoch……(It will take some time)")
|
255 | 261 | # NOTE:XXX: This way of calculating the iters needs to be improved.
|
@@ -351,20 +357,57 @@ def compress(self):
|
351 | 357 | ).lower() == 'linux':
|
352 | 358 | ptq_loss = quant_post_hpo.g_min_emd_loss
|
353 | 359 |
|
354 |
| - final_quant_config = get_final_quant_config(ptq_loss) |
| 360 | + final_quant_config = get_final_quant_config( |
| 361 | + ptq_loss, mode='DistilQuant') |
355 | 362 | quant_strategy, quant_config = self._prepare_strategy(
|
356 | 363 | final_quant_config)
|
357 | 364 | self.single_strategy_compress(quant_strategy[0], quant_config[0],
|
358 | 365 | strategy_idx)
|
359 |
| - old_model_path = os.path.join( |
| 366 | + tmp_model_path = os.path.join( |
360 | 367 | self.save_dir, 'strategy_{}'.format(str(strategy_idx + 1)))
|
361 | 368 | final_model_path = os.path.join(self.final_dir)
|
362 |
| - shutil.move(old_model_path, final_model_path) |
| 369 | + if not os.path.exists(final_model_path): |
| 370 | + os.makedirs(final_model_path) |
| 371 | + tmp_model_file = os.path.join(tmp_model_path, 'model.pdmodel') |
| 372 | + tmp_params_file = os.path.join(tmp_model_path, 'model.pdiparams') |
| 373 | + final_model_file = os.path.join(final_model_path, 'model.pdmodel') |
| 374 | + final_params_file = os.path.join(final_model_path, 'model.pdiparams') |
| 375 | + shutil.move(tmp_model_file, final_model_file) |
| 376 | + shutil.move(tmp_params_file, final_params_file) |
| 377 | + _logger.info( |
| 378 | + "==> Finished the ACT process and the final model is saved in:{}". |
| 379 | + format(final_model_path)) |
363 | 380 | os._exit(0)
|
364 | 381 |
|
365 | 382 | def single_strategy_compress(self, strategy, config, strategy_idx):
|
366 |
| - ### start compress, including train/eval model |
367 |
| - if strategy == 'ptq_hpo': |
| 383 | + # start compress, including train/eval model |
| 384 | + # TODO: add the emd loss of evaluation model. |
| 385 | + if strategy == 'quant_post': |
| 386 | + quant_post( |
| 387 | + self._exe, |
| 388 | + model_dir=self.model_dir, |
| 389 | + quantize_model_path=os.path.join( |
| 390 | + self.save_dir, 'strategy_{}'.format(str(strategy_idx + 1))), |
| 391 | + data_loader=self.train_dataloader, |
| 392 | + model_filename=self.model_filename, |
| 393 | + params_filename=self.params_filename, |
| 394 | + save_model_filename=self.model_filename, |
| 395 | + save_params_filename=self.params_filename, |
| 396 | + batch_size=1, |
| 397 | + batch_nums=config.batch_num, |
| 398 | + algo=config.ptq_algo, |
| 399 | + round_type='round', |
| 400 | + bias_correct=config.bias_correct, |
| 401 | + hist_percent=config.hist_percent, |
| 402 | + quantizable_op_type=config.quantize_op_types, |
| 403 | + is_full_quantize=config.is_full_quantize, |
| 404 | + weight_bits=config.weight_bits, |
| 405 | + activation_bits=config.activation_bits, |
| 406 | + activation_quantize_type='range_abs_max', |
| 407 | + weight_quantize_type=config.weight_quantize_type, |
| 408 | + onnx_format=False) |
| 409 | + |
| 410 | + elif strategy == 'ptq_hpo': |
368 | 411 | if platform.system().lower() != 'linux':
|
369 | 412 | raise NotImplementedError(
|
370 | 413 | "post-quant-hpo is not support in system other than linux")
|
@@ -503,11 +546,12 @@ def _save_model(self, test_program_info, strategy, strategy_idx):
|
503 | 546 | test_program_info.program,
|
504 | 547 | paddle.static.CompiledProgram) else test_program_info.program
|
505 | 548 |
|
506 |
| - paddle.static.load(test_program, |
507 |
| - os.path.join(self.save_dir, 'best_model')) |
508 |
| - os.remove(os.path.join(self.save_dir, 'best_model.pdmodel')) |
509 |
| - os.remove(os.path.join(self.save_dir, 'best_model.pdopt')) |
510 |
| - os.remove(os.path.join(self.save_dir, 'best_model.pdparams')) |
| 549 | + if os.path.exists(os.path.join(self.save_dir, 'best_model.pdparams')): |
| 550 | + paddle.static.load(test_program, |
| 551 | + os.path.join(self.save_dir, 'best_model')) |
| 552 | + os.remove(os.path.join(self.save_dir, 'best_model.pdmodel')) |
| 553 | + os.remove(os.path.join(self.save_dir, 'best_model.pdopt')) |
| 554 | + os.remove(os.path.join(self.save_dir, 'best_model.pdparams')) |
511 | 555 |
|
512 | 556 | if 'qat' in strategy:
|
513 | 557 | float_program, int8_program = convert(test_program_info.program._program, self._places, self._quant_config, \
|
|
0 commit comments