|
29 | 29 | from ..common import get_logger
|
30 | 30 | from ..common.patterns import get_patterns
|
31 | 31 | from ..analysis import TableLatencyPredictor
|
32 |
| -from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program |
| 32 | +from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program, remove_unused_var_nodes |
33 | 33 | from .strategy_config import ProgramInfo, merge_config
|
34 | 34 | from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config
|
35 | 35 |
|
@@ -274,6 +274,15 @@ def _prepare_program(self, program, feed_target_names, fetch_targets,
|
274 | 274 | train_program_info, test_program_info, self._quant_config = build_quant_program(
|
275 | 275 | self._exe, self._places, config_dict, train_program_info,
|
276 | 276 | test_program_info)
|
| 277 | + if self.train_config.sparse_model: |
| 278 | + from ..prune.unstructured_pruner import UnstructuredPruner |
| 279 | + self._pruner = UnstructuredPruner( |
| 280 | + train_program_info.program, |
| 281 | + mode='ratio', |
| 282 | + ratio=0.75, |
| 283 | + prune_params_type='conv1x1_only', |
| 284 | + place=self._places) |
| 285 | + self._pruner.set_static_masks() |
277 | 286 |
|
278 | 287 | self._exe.run(train_program_info.startup_program)
|
279 | 288 |
|
@@ -402,7 +411,9 @@ def single_strategy_compress(self, strategy, config, strategy_idx):
|
402 | 411 | train_program_info, test_program_info = self._prepare_program(
|
403 | 412 | inference_program, feed_target_names, fetch_targets, patterns,
|
404 | 413 | default_distill_node_pair, strategy, config)
|
405 |
| - |
| 414 | + if 'unstructure' in self._strategy: |
| 415 | + test_program_info.program._program = remove_unused_var_nodes( |
| 416 | + test_program_info.program._program) |
406 | 417 | test_program_info = self._start_train(train_program_info,
|
407 | 418 | test_program_info, strategy)
|
408 | 419 | self._save_model(test_program_info, strategy, strategy_idx)
|
@@ -462,6 +473,9 @@ def _start_train(self, train_program_info, test_program_info, strategy):
|
462 | 473 | "Not set eval function, so unable to test accuracy performance."
|
463 | 474 | )
|
464 | 475 |
|
| 476 | + if 'unstructure' in self._strategy or self.train_config.sparse_model: |
| 477 | + self._pruner.update_params() |
| 478 | + |
465 | 479 | return test_program_info
|
466 | 480 |
|
467 | 481 | def _save_model(self, test_program_info, strategy, strategy_idx):
|
|
0 commit comments