Skip to content

Commit 79a7b16

Browse files
zzjjayceci3
andauthored
Fix sparse quant (PaddlePaddle#1076)
* Remove 'mask' nodes in sparse model * Fixed sparsity of compressed model. * Fixed sparsity of compressed model. Co-authored-by: ceci3 <[email protected]>
1 parent 2350af8 commit 79a7b16

File tree

3 files changed

+43
-15
lines changed

3 files changed

+43
-15
lines changed

paddleslim/auto_compression/compressor.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ..common import get_logger
3030
from ..common.patterns import get_patterns
3131
from ..analysis import TableLatencyPredictor
32-
from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program
32+
from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program, remove_unused_var_nodes
3333
from .strategy_config import ProgramInfo, merge_config
3434
from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config
3535

@@ -274,6 +274,15 @@ def _prepare_program(self, program, feed_target_names, fetch_targets,
274274
train_program_info, test_program_info, self._quant_config = build_quant_program(
275275
self._exe, self._places, config_dict, train_program_info,
276276
test_program_info)
277+
if self.train_config.sparse_model:
278+
from ..prune.unstructured_pruner import UnstructuredPruner
279+
self._pruner = UnstructuredPruner(
280+
train_program_info.program,
281+
mode='ratio',
282+
ratio=0.75,
283+
prune_params_type='conv1x1_only',
284+
place=self._places)
285+
self._pruner.set_static_masks()
277286

278287
self._exe.run(train_program_info.startup_program)
279288

@@ -402,7 +411,9 @@ def single_strategy_compress(self, strategy, config, strategy_idx):
402411
train_program_info, test_program_info = self._prepare_program(
403412
inference_program, feed_target_names, fetch_targets, patterns,
404413
default_distill_node_pair, strategy, config)
405-
414+
if 'unstructure' in self._strategy:
415+
test_program_info.program._program = remove_unused_var_nodes(
416+
test_program_info.program._program)
406417
test_program_info = self._start_train(train_program_info,
407418
test_program_info, strategy)
408419
self._save_model(test_program_info, strategy, strategy_idx)
@@ -462,6 +473,9 @@ def _start_train(self, train_program_info, test_program_info, strategy):
462473
"Not set eval function, so unable to test accuracy performance."
463474
)
464475

476+
if 'unstructure' in self._strategy or self.train_config.sparse_model:
477+
self._pruner.update_params()
478+
465479
return test_program_info
466480

467481
def _save_model(self, test_program_info, strategy, strategy_idx):

paddleslim/auto_compression/create_compressed_program.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525

2626
_logger = get_logger(__name__, level=logging.INFO)
2727
__all__ = [
28-
'build_distill_program', 'build_quant_program', 'build_prune_program'
28+
'build_distill_program', 'build_quant_program', 'build_prune_program',
29+
'remove_unused_var_nodes'
2930
]
3031

3132

@@ -425,3 +426,25 @@ def build_prune_program(executor,
425426
format(config['prune_algo']))
426427

427428
return pruner, train_program_info
429+
430+
431+
def remove_unused_var_nodes(program):
432+
'''
433+
This function is called before saving the sparse model to remove redundant nodes.
434+
Args:
435+
program(paddle.static.Program): The sparse model to be saved.
436+
Returns:
437+
program(paddle.static.Program): The sparse model.
438+
'''
439+
from paddle.fluid import core
440+
from paddle.fluid.framework import IrGraph
441+
graph = IrGraph(core.Graph(program.desc), for_test=True)
442+
removed_nodes = set()
443+
ops = graph.all_op_nodes()
444+
for op_node in ops:
445+
for input_node in op_node.inputs:
446+
if '_mask' in input_node.name():
447+
removed_nodes.add(op_node)
448+
graph.safe_remove_nodes(removed_nodes)
449+
program = graph.to_program()
450+
return program

paddleslim/auto_compression/strategy_config.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,9 @@
103103

104104
### Train
105105
TrainConfig = namedtuple("Train", [
106-
"epochs",
107-
"learning_rate",
108-
"optimizer",
109-
"optim_args",
110-
"eval_iter",
111-
"logging_iter",
112-
"origin_metric",
113-
"target_metric",
114-
"use_fleet",
115-
"amp_config",
116-
"recompute_config",
117-
"sharding_config",
106+
"epochs", "learning_rate", "optimizer", "optim_args", "eval_iter",
107+
"logging_iter", "origin_metric", "target_metric", "use_fleet", "amp_config",
108+
"recompute_config", "sharding_config", "sparse_model"
118109
])
119110

120111
TrainConfig.__new__.__defaults__ = (None, ) * len(TrainConfig._fields)

0 commit comments

Comments
 (0)