Skip to content

Commit 4c1ec41

Browse files
authored
Merge pull request #16531 from wanghaoshuang/quan_ck
[slim] Fix checkpoint of quantization strategy.
2 parents e18ab78 + d41b623 commit 4c1ec41

File tree

2 files changed

+76
-32
lines changed

2 files changed

+76
-32
lines changed

python/paddle/fluid/contrib/slim/graph/graph_wrapper.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ def __init__(self, program=None, in_nodes=[], out_nodes=[]):
204204
"""
205205
super(GraphWrapper, self).__init__()
206206
self.program = Program() if program is None else program
207+
self.persistables = {}
208+
for var in self.program.list_vars():
209+
if var.persistable:
210+
self.persistables[var.name] = var
207211
self.compiled_graph = None
208212
self.in_nodes = OrderedDict(in_nodes)
209213
self.out_nodes = OrderedDict(out_nodes)
@@ -467,7 +471,12 @@ def save_persistables(self, path, exe):
467471
path(str): The path to save the persistables.
468472
exe(framework.Executor): The executor used to save the persistables.
469473
"""
470-
io.save_persistables(exe.exe, path, main_program=self.program)
474+
# update persistables from program
475+
for var in self.program.list_vars():
476+
if var.persistable and var.name not in self.persistables:
477+
self.persistables[var.name] = var
478+
479+
io.save_vars(exe.exe, path, vars=self.persistables.values())
471480

472481
def load_persistables(self, path, exe):
473482
"""
@@ -481,7 +490,7 @@ def if_exist(var):
481490
return os.path.exists(os.path.join(path, var.name))
482491

483492
io.load_vars(
484-
exe.exe, path, main_program=self.program, predicate=if_exist)
493+
exe.exe, path, vars=self.persistables.values(), predicate=if_exist)
485494

486495
def update_param_shape(self, scope):
487496
"""

python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py

Lines changed: 65 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .... import core
2121
from ....compiler import CompiledProgram
2222
from ....compiler import BuildStrategy
23-
from ....framework import IrGraph
23+
from ....framework import IrGraph, Variable, Program
2424
from ..core.strategy import Strategy
2525
from .quantization_pass import *
2626

@@ -88,41 +88,76 @@ def __init__(self,
8888
self.save_out_nodes = save_out_nodes
8989
self.save_in_nodes = save_in_nodes
9090

91+
def on_compression_begin(self, context):
92+
"""
93+
Restore graph when the compressoin task is inited from checkpoint.
94+
"""
95+
# It is inited from checkpoint and has missed start epoch.
96+
if context.epoch_id != 0 and context.epoch_id > self.start_epoch:
97+
_logger.info("Restore quantization task from checkpoint")
98+
self._modify_graph_for_quantization(context)
99+
_logger.info("Finish restoring quantization task from checkpoint")
100+
101+
def _modify_graph_for_quantization(self, context):
102+
"""
103+
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
104+
"""
105+
train_ir_graph = IrGraph(
106+
core.Graph(context.optimize_graph.program.clone().desc),
107+
for_test=False)
108+
test_ir_graph = IrGraph(
109+
core.Graph(context.eval_graph.program.clone().desc), for_test=True)
110+
transform_pass = QuantizationTransformPass(
111+
scope=context.scope,
112+
place=context.place,
113+
weight_bits=self.weight_bits,
114+
activation_bits=self.activation_bits,
115+
activation_quantize_type=self.activation_quantize_type,
116+
weight_quantize_type=self.weight_quantize_type)
117+
transform_pass.apply(train_ir_graph)
118+
transform_pass.apply(test_ir_graph)
119+
# Put persistables created by transform_pass into context.optimize_graph.persistables
120+
# for saving checkpoint.
121+
program_persistables = set()
122+
for var in context.optimize_graph.program.list_vars():
123+
if var.persistable:
124+
program_persistables.add(var.name)
125+
126+
program = Program()
127+
for var_node in train_ir_graph.all_persistable_nodes():
128+
if var_node.name() not in program_persistables:
129+
var_desc = var_node.var()
130+
var = program.global_block().create_var(
131+
name=var_node.name(),
132+
shape=var_desc.shape(),
133+
dtype=var_desc.dtype(),
134+
type=var_desc.type(),
135+
lod_level=var_desc.lod_level())
136+
context.optimize_graph.persistables[var.name] = var
137+
138+
build_strategy = BuildStrategy()
139+
build_strategy.enable_inplace = False
140+
build_strategy.memory_optimize = False
141+
# for quantization training
142+
context.optimize_graph.compiled_graph = CompiledProgram(
143+
train_ir_graph.graph).with_data_parallel(
144+
loss_name=context.optimize_graph.out_nodes['loss'],
145+
build_strategy=build_strategy)
146+
# for evaluation. And program compiled from ir graph must be with data parallel.
147+
context.eval_graph.compiled_graph = CompiledProgram(
148+
test_ir_graph.graph).with_data_parallel(
149+
build_strategy=build_strategy)
150+
# for saving inference model after training
151+
context.put('quantization_test_ir_graph_backup', test_ir_graph)
152+
91153
def on_epoch_begin(self, context):
92154
"""
93155
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
94156
"""
95-
super(QuantizationStrategy, self).on_compression_begin(context)
157+
super(QuantizationStrategy, self).on_epoch_begin(context)
96158
if self.start_epoch == context.epoch_id:
97159
_logger.info('QuantizationStrategy::on_epoch_begin')
98-
train_ir_graph = IrGraph(
99-
core.Graph(context.optimize_graph.program.desc), for_test=False)
100-
test_ir_graph = IrGraph(
101-
core.Graph(context.eval_graph.program.desc), for_test=True)
102-
transform_pass = QuantizationTransformPass(
103-
scope=context.scope,
104-
place=context.place,
105-
weight_bits=self.weight_bits,
106-
activation_bits=self.activation_bits,
107-
activation_quantize_type=self.activation_quantize_type,
108-
weight_quantize_type=self.weight_quantize_type)
109-
transform_pass.apply(train_ir_graph)
110-
transform_pass.apply(test_ir_graph)
111-
112-
build_strategy = BuildStrategy()
113-
build_strategy.enable_inplace = False
114-
build_strategy.memory_optimize = False
115-
# for quantization training
116-
context.optimize_graph.compiled_graph = CompiledProgram(
117-
train_ir_graph.graph).with_data_parallel(
118-
loss_name=context.optimize_graph.out_nodes['loss'],
119-
build_strategy=build_strategy)
120-
# for evaluation. And program compiled from ir graph must be with data parallel.
121-
context.eval_graph.compiled_graph = CompiledProgram(
122-
test_ir_graph.graph).with_data_parallel(
123-
build_strategy=build_strategy)
124-
# for saving inference model after training
125-
context.put('quantization_test_ir_graph_backup', test_ir_graph)
160+
self._modify_graph_for_quantization(context)
126161
_logger.info('Finish QuantizationStrategy::on_epoch_begin')
127162

128163
def on_epoch_end(self, context):

0 commit comments

Comments
 (0)