Skip to content

Commit 4fbfc01

Browse files
Restore quantization and distillation stategy before loading persistables. (#16959)
test=develop
1 parent 5a2d6d6 commit 4fbfc01

File tree

4 files changed

+11
-6
lines changed

4 files changed

+11
-6
lines changed

python/paddle/fluid/contrib/slim/core/compressor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,9 @@ def _load_checkpoint(self, context):
363363
strategies = pickle.load(
364364
strategy_file, encoding='bytes')
365365

366+
for strategy in strategies:
367+
strategy.restore_from_checkpoint(context)
368+
366369
if os.path.exists(model_path):
367370
exe = SlimGraphExecutor(context.place)
368371
with scope_guard(context.scope):

python/paddle/fluid/contrib/slim/core/strategy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ def on_batch_end(self, context):
4646

4747
def on_compression_end(self, context):
4848
pass
49+
50+
def restore_from_checkpoint(self, context):
51+
pass

python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, distillers=None, start_epoch=0, end_epoch=0):
3838
super(DistillationStrategy, self).__init__(start_epoch, end_epoch)
3939
self.distillers = distillers
4040

41-
def on_compression_begin(self, context):
41+
def restore_from_checkpoint(self, context):
4242
# load from checkpoint
4343
if context.epoch_id > 0:
4444
if context.epoch_id > self.start_epoch and context.epoch_id < self.end_epoch:

python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(self,
8888
self.save_out_nodes = save_out_nodes
8989
self.save_in_nodes = save_in_nodes
9090

91-
def on_compression_begin(self, context):
91+
def restore_from_checkpoint(self, context):
9292
"""
9393
Restore graph when the compressoin task is inited from checkpoint.
9494
"""
@@ -143,10 +143,9 @@ def _modify_graph_for_quantization(self, context):
143143
train_ir_graph.graph).with_data_parallel(
144144
loss_name=context.optimize_graph.out_nodes['loss'],
145145
build_strategy=build_strategy)
146-
# for evaluation. And program compiled from ir graph must be with data parallel.
147-
context.eval_graph.compiled_graph = CompiledProgram(
148-
test_ir_graph.graph).with_data_parallel(
149-
build_strategy=build_strategy)
146+
147+
context.eval_graph.program = test_ir_graph.to_program()
148+
150149
# for saving inference model after training
151150
context.put('quantization_test_ir_graph_backup', test_ir_graph)
152151

0 commit comments

Comments
 (0)