Skip to content

Commit 6db7c2a

Browse files
committed
Fix checkpoint of quantization.
1 parent e41d581 commit 6db7c2a

File tree

2 files changed

+75
-31
lines changed

2 files changed

+75
-31
lines changed

python/paddle/fluid/contrib/slim/graph/graph_wrapper.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ def __init__(self, program=None, in_nodes=[], out_nodes=[]):
204204
"""
205205
super(GraphWrapper, self).__init__()
206206
self.program = Program() if program is None else program
207+
self.persistables = {}
208+
for var in self.program.list_vars():
209+
if var.persistable:
210+
self.persistables[var.name] = var
207211
self.compiled_graph = None
208212
self.in_nodes = OrderedDict(in_nodes)
209213
self.out_nodes = OrderedDict(out_nodes)
@@ -467,7 +471,12 @@ def save_persistables(self, path, exe):
467471
path(str): The path to save the persistables.
468472
exe(framework.Executor): The executor used to save the persistables.
469473
"""
470-
io.save_persistables(exe.exe, path, main_program=self.program)
474+
# update persistables from program
475+
for var in self.program.list_vars():
476+
if var.persistable and var.name not in self.persistables:
477+
self.persistables[var.name] = var
478+
479+
io.save_vars(exe.exe, path, vars=self.persistables.values())
471480

472481
def load_persistables(self, path, exe):
473482
"""
@@ -481,7 +490,7 @@ def if_exist(var):
481490
return os.path.exists(os.path.join(path, var.name))
482491

483492
io.load_vars(
484-
exe.exe, path, main_program=self.program, predicate=if_exist)
493+
exe.exe, path, vars=self.persistables.values(), predicate=if_exist)
485494

486495
def update_param_shape(self, scope):
487496
"""

python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .... import core
2121
from ....compiler import CompiledProgram
2222
from ....compiler import BuildStrategy
23-
from ....framework import IrGraph
23+
from ....framework import IrGraph, Variable, Program
2424
from ..core.strategy import Strategy
2525
from .quantization_pass import *
2626

@@ -84,40 +84,75 @@ def __init__(self,
8484
self.save_out_nodes = save_out_nodes
8585
self.save_in_nodes = save_in_nodes
8686

87+
def on_compression_begin(self, context):
88+
"""
89+
Restore graph when the compressoin task is inited from checkpoint.
90+
"""
91+
# It is inited from checkpoint and has missed start epoch.
92+
if context.epoch_id != 0 and context.epoch_id > self.start_epoch:
93+
_logger.info("Restore quantization task from checkpoint")
94+
self._modify_graph_for_quantization(context)
95+
_logger.info("Finish restoring quantization task from checkpoint")
96+
97+
def _modify_graph_for_quantization(self, context):
98+
"""
99+
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
100+
"""
101+
train_ir_graph = IrGraph(
102+
core.Graph(context.optimize_graph.program.clone().desc),
103+
for_test=False)
104+
test_ir_graph = IrGraph(
105+
core.Graph(context.eval_graph.program.clone().desc), for_test=True)
106+
transform_pass = QuantizationTransformPass(
107+
scope=context.scope,
108+
place=context.place,
109+
weight_bits=self.weight_bits,
110+
activation_bits=self.activation_bits,
111+
activation_quantize_type=self.activation_quantize_type)
112+
transform_pass.apply(train_ir_graph)
113+
transform_pass.apply(test_ir_graph)
114+
# Put persistables created by transform_pass into context.optimize_graph.persistables
115+
# for saving checkpoint.
116+
program_persistables = set()
117+
for var in context.optimize_graph.program.list_vars():
118+
if var.persistable:
119+
program_persistables.add(var.name)
120+
121+
program = Program()
122+
for var_node in train_ir_graph.all_persistable_nodes():
123+
if var_node.name() not in program_persistables:
124+
var_desc = var_node.var()
125+
var = program.global_block().create_var(
126+
name=var_node.name(),
127+
shape=var_desc.shape(),
128+
dtype=var_desc.dtype(),
129+
type=var_desc.type(),
130+
lod_level=var_desc.lod_level())
131+
context.optimize_graph.persistables[var.name] = var
132+
133+
build_strategy = BuildStrategy()
134+
build_strategy.enable_inplace = False
135+
build_strategy.memory_optimize = False
136+
# for quantization training
137+
context.optimize_graph.compiled_graph = CompiledProgram(
138+
train_ir_graph.graph).with_data_parallel(
139+
loss_name=context.optimize_graph.out_nodes['loss'],
140+
build_strategy=build_strategy)
141+
# for evaluation. And program compiled from ir graph must be with data parallel.
142+
context.eval_graph.compiled_graph = CompiledProgram(
143+
test_ir_graph.graph).with_data_parallel(
144+
build_strategy=build_strategy)
145+
# for saving inference model after training
146+
context.put('quantization_test_ir_graph_backup', test_ir_graph)
147+
87148
def on_epoch_begin(self, context):
88149
"""
89150
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
90151
"""
91-
super(QuantizationStrategy, self).on_compression_begin(context)
152+
super(QuantizationStrategy, self).on_epoch_begin(context)
92153
if self.start_epoch == context.epoch_id:
93154
_logger.info('QuantizationStrategy::on_epoch_begin')
94-
train_ir_graph = IrGraph(
95-
core.Graph(context.optimize_graph.program.desc), for_test=False)
96-
test_ir_graph = IrGraph(
97-
core.Graph(context.eval_graph.program.desc), for_test=True)
98-
transform_pass = QuantizationTransformPass(
99-
scope=context.scope,
100-
place=context.place,
101-
weight_bits=self.weight_bits,
102-
activation_bits=self.activation_bits,
103-
activation_quantize_type=self.activation_quantize_type)
104-
transform_pass.apply(train_ir_graph)
105-
transform_pass.apply(test_ir_graph)
106-
107-
build_strategy = BuildStrategy()
108-
build_strategy.enable_inplace = False
109-
build_strategy.memory_optimize = False
110-
# for quantization training
111-
context.optimize_graph.compiled_graph = CompiledProgram(
112-
train_ir_graph.graph).with_data_parallel(
113-
loss_name=context.optimize_graph.out_nodes['loss'],
114-
build_strategy=build_strategy)
115-
# for evaluation. And program compiled from ir graph must be with data parallel.
116-
context.eval_graph.compiled_graph = CompiledProgram(
117-
test_ir_graph.graph).with_data_parallel(
118-
build_strategy=build_strategy)
119-
# for saving inference model after training
120-
context.put('quantization_test_ir_graph_backup', test_ir_graph)
155+
self._modify_graph_for_quantization(context)
121156
_logger.info('Finish QuantizationStrategy::on_epoch_begin')
122157

123158
def on_epoch_end(self, context):

0 commit comments

Comments
 (0)