Skip to content

Commit d68a02a

Browse files
authored
Merge pull request #16456 from wzzju/fix_quan_hang
Fix quantization hang bugs.
2 parents f0070d9 + 27d0520 commit d68a02a

File tree

2 files changed

+159
-82
lines changed

2 files changed

+159
-82
lines changed

python/paddle/fluid/contrib/slim/quantization/quantization_pass.py

Lines changed: 88 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,10 @@
1414

1515
import collections
1616
import numpy as np
17-
import six
1817
from ..... import compat as cpt
1918
from .... import core
20-
from .... import Executor
2119
from ....framework import IrGraph
2220
from ....framework import IrNode
23-
from ....framework import Program
24-
from ....initializer import Constant
25-
from ....initializer import NumpyArrayInitializer
2621
from .... import unique_name
2722

2823
__all__ = [
@@ -107,7 +102,6 @@ def __init__(self,
107102
self._window_size = window_size
108103
self._moving_rate = moving_rate
109104

110-
self._need_initialized = collections.OrderedDict()
111105
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
112106
self._conv_ops = ['conv2d', 'depthwise_conv2d']
113107
self._quantizable_grad_ops = [
@@ -127,14 +121,17 @@ def apply(self, graph):
127121
"""
128122
assert isinstance(graph,
129123
IrGraph), 'graph must be the instance of IrGraph.'
130-
self._need_initialized.clear()
124+
#sequential_execution = core.get_pass('sequential_execution_pass')
125+
#sequential_execution.apply(graph.graph)
131126
self._is_test = graph.is_test()
132127
# marked the variable which has been dequantized.
133128
dequantized_vars = collections.OrderedDict()
134129
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
135130

136131
def _transform_forward(graph, op):
137132
for var_node in op.inputs:
133+
if var_node.name() not in op.input_arg_names():
134+
continue
138135
if var_node.name() in dequantized_vars:
139136
dequant_var_node = dequantized_vars[var_node.name()]
140137
else:
@@ -168,6 +165,8 @@ def _transform_forward(graph, op):
168165
def _transform_backward(graph, op):
169166
no_dequanted_input_vars = True
170167
for var_node in op.inputs:
168+
if var_node.name() not in op.input_arg_names():
169+
continue
171170
if var_node.name() in dequantized_vars:
172171
dequant_var_node = dequantized_vars[var_node.name()]
173172
graph.update_input_link(var_node, dequant_var_node, op)
@@ -188,25 +187,7 @@ def _transform_backward(graph, op):
188187
for op in ops:
189188
if op.name() in self._quantizable_grad_ops:
190189
_transform_backward(graph, op)
191-
192-
if len(self._need_initialized) > 0:
193-
assert self._scope is not None, \
194-
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
195-
assert self._place is not None, \
196-
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
197-
init_program = Program()
198-
for var_desc, initializer in six.iteritems(self._need_initialized):
199-
var = init_program.global_block().create_var(
200-
name=var_desc.name(),
201-
shape=var_desc.shape(),
202-
dtype=var_desc.dtype(),
203-
type=var_desc.type(),
204-
lod_level=var_desc.lod_level(),
205-
persistable=var_desc.persistable())
206-
initializer(var, init_program.global_block())
207-
exe = Executor(self._place)
208-
exe.run(program=init_program, scope=self._scope)
209-
190+
graph.resolve_hazard()
210191
return graph
211192

212193
def _create_global_step(self, graph):
@@ -222,8 +203,9 @@ def _create_global_step(self, graph):
222203
var_type=core.VarDesc.VarType.LOD_TENSOR,
223204
shape=[1],
224205
var_dtype=core.VarDesc.VarType.INT64)
225-
self._need_initialized[global_step_in.var()] = \
226-
Constant(value=0, force_cpu=True)
206+
self._init_var_node(
207+
global_step_in, np.zeros(
208+
[1], dtype='int64'))
227209
global_step_out = graph.create_var_node_from_desc(
228210
global_step_in.var())
229211
# The attribute of `op_role` is needed by ParallelExecutor.
@@ -300,7 +282,9 @@ def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
300282
var_type=core.VarDesc.VarType.LOD_TENSOR,
301283
shape=[1],
302284
var_dtype=var_node.dtype())
303-
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
285+
data_type = 'float64' if var_node.dtype(
286+
) == core.VarDesc.VarType.FP64 else 'float32'
287+
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type))
304288

305289
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
306290
inputs = {'X': var_node, 'InScale': scale_in_node}
@@ -313,7 +297,11 @@ def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
313297
var_type=core.VarDesc.VarType.LOD_TENSOR,
314298
shape=[self._window_size],
315299
var_dtype=var_node.dtype())
316-
self._need_initialized[scales_node.var()] = Constant(value=0)
300+
data_type = 'float64' if var_node.dtype(
301+
) == core.VarDesc.VarType.FP64 else 'float32'
302+
self._init_var_node(
303+
scales_node, np.zeros(
304+
[self._window_size], dtype=data_type))
317305
inputs['Iter'] = self._global_step
318306
outputs['OutScales'] = scales_node
319307
attrs = {
@@ -353,7 +341,9 @@ def _insert_quant_moving_average_abs_max_op(self, graph, var_node,
353341
var_type=core.VarDesc.VarType.LOD_TENSOR,
354342
shape=[1],
355343
var_dtype=var_node.dtype())
356-
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
344+
data_type = 'float64' if var_node.dtype(
345+
) == core.VarDesc.VarType.FP64 else 'float32'
346+
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type))
357347

358348
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
359349
ins = {'X': var_node, 'InScale': scale_in_node}
@@ -364,13 +354,15 @@ def _insert_quant_moving_average_abs_max_op(self, graph, var_node,
364354
var_type=core.VarDesc.VarType.LOD_TENSOR,
365355
var_dtype=var_node.dtype(),
366356
shape=[1])
367-
self._need_initialized[state_in_node.var()] = Constant(value=1)
357+
data_type = 'float64' if var_node.dtype(
358+
) == core.VarDesc.VarType.FP64 else 'float32'
359+
self._init_var_node(scale_in_node, np.ones([1], dtype=data_type))
368360
accum_in_node = graph.create_persistable_node(
369361
name=unique_name.generate('accum'),
370362
var_type=core.VarDesc.VarType.LOD_TENSOR,
371363
var_dtype=var_node.dtype(),
372364
shape=[1])
373-
self._need_initialized[accum_in_node.var()] = Constant(value=1)
365+
self._init_var_node(accum_in_node, np.ones([1], dtype=data_type))
374366
state_out_node = graph.create_var_node_from_desc(state_in_node.var(
375367
))
376368
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
@@ -490,6 +482,16 @@ def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes,
490482
graph.link_to(dequant_op_node, dequant_var_node)
491483
return dequant_var_node
492484

485+
def _init_var_node(self, var_node, value):
486+
assert isinstance(
487+
value, np.ndarray), 'The type of value should be numpy array.'
488+
assert self._scope is not None, \
489+
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
490+
assert self._place is not None, \
491+
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
492+
tensor = self._scope.var(var_node.name()).get_tensor()
493+
tensor.set(value, self._place)
494+
493495
def _quantized_var_name(self, var_name):
494496
"""
495497
Return quantized variable name for the input `var_name`.
@@ -592,7 +594,8 @@ def apply(self, graph):
592594
self._weight_bits)
593595
self._restore_var(input_arg_name, quantized_param_v)
594596
else:
595-
scale_v = graph.var_node(op_node.output('OutScale')[0])
597+
scale_v = self._to_node(op_node.outputs,
598+
op_node.output('OutScale')[0])
596599
self._var_scale_map[input_arg_name] = scale_v
597600

598601
ops = graph.all_op_nodes()
@@ -613,32 +616,35 @@ def apply(self, graph):
613616
for op_node in ops:
614617
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
615618
for var_node in op_node.inputs:
616-
name = var_node.name()
617-
if name in self._op_output_rename_map:
618-
old_in = graph.var_node(name)
619-
new_in = self._op_output_rename_map[name]
619+
if var_node.node in self._op_output_rename_map:
620+
old_in = var_node
621+
new_in = self._op_output_rename_map[var_node.node]
620622
graph.update_input_link(old_in, new_in, op_node)
621623

622624
# remove the unused var node in the graph
623625
self._remove_unused_var_nodes(graph)
626+
graph.resolve_hazard()
624627
return graph
625628

626629
def _remove_fake_quant_and_dequant_op(self, graph, op_node):
627-
k = op_node.output('Out')[0]
628-
v = op_node.input('X')[0]
629-
if v not in self._op_input_rename_map:
630-
self._op_input_rename_map[k] = v
630+
k = self._to_node(op_node.outputs, op_node.output('Out')[0])
631+
v = self._to_node(op_node.inputs, op_node.input('X')[0])
632+
if v.node not in self._op_input_rename_map:
633+
self._op_input_rename_map[k.node] = v
631634
else:
632-
self._op_input_rename_map[k] = self._op_input_rename_map[v]
635+
self._op_input_rename_map[k.node] = self._op_input_rename_map[
636+
v.node]
633637
graph.safe_remove_nodes(op_node)
634638

635639
def _insert_post_channel_dequant_op(self, graph, op_node):
636640
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
637641
for var_node in op_node.inputs:
638642
name = var_node.name()
639-
if name in self._op_input_rename_map:
640-
old_in = graph.var_node(name)
641-
new_in = graph.var_node(self._op_input_rename_map[name])
643+
if name not in op_node.input_arg_names():
644+
continue
645+
if var_node.node in self._op_input_rename_map:
646+
old_in = var_node
647+
new_in = self._op_input_rename_map[var_node.node]
642648
new_in.clear_outputs()
643649
graph.update_input_link(old_in, new_in, op_node)
644650
original_var_name = self._original_var_name(name)
@@ -653,28 +659,20 @@ def _insert_post_channel_dequant_op(self, graph, op_node):
653659
assert isinstance(scale_v, IrNode)
654660
scale_var_node = self._var_scale_map[original_var_name]
655661

656-
if len(op_node.outputs) != 1:
662+
if len(op_node.output_arg_names()) != 1:
657663
raise ValueError("Only support one output, but op %s has"
658664
" more than one output." % (op_node.name()))
659665

660-
output_var_node = op_node.outputs[0]
666+
output_var_node = self._to_node(op_node.outputs,
667+
op_node.output_arg_names()[0])
661668
weight_scale_node = graph.create_persistable_node(
662669
name=unique_name.generate('channel_scale'),
663670
var_type=core.VarDesc.VarType.LOD_TENSOR,
664671
shape=[channel_scale.shape[0]],
665672
var_dtype=output_var_node.dtype())
666-
init_program = Program()
667-
weight_scale_var = init_program.global_block().create_var(
668-
name=weight_scale_node.name(),
669-
shape=weight_scale_node.shape(),
670-
dtype=weight_scale_node.dtype(),
671-
type=weight_scale_node.type(),
672-
lod_level=weight_scale_node.var().lod_level(),
673-
persistable=weight_scale_node.persistable())
674-
initializer = NumpyArrayInitializer(value=channel_scale)
675-
initializer(weight_scale_var, init_program.global_block())
676-
exe = Executor(self._place)
677-
exe.run(program=init_program, scope=self._scope)
673+
data_type = 'float64' if output_var_node.dtype(
674+
) == core.VarDesc.VarType.FP64 else 'float32'
675+
self._init_var_node(weight_scale_node, channel_scale.astype(data_type))
678676
dequant_var_node = graph.create_var_node(
679677
name=self._dequantized_var_name(output_var_node.name()),
680678
var_type=output_var_node.type(),
@@ -695,16 +693,18 @@ def _insert_post_channel_dequant_op(self, graph, op_node):
695693
graph.link_to(scale_var_node, dequant_op_node)
696694
graph.link_to(weight_scale_node, dequant_op_node)
697695
graph.link_to(dequant_op_node, dequant_var_node)
698-
self._op_output_rename_map[output_var_node.name()] = dequant_var_node
696+
self._op_output_rename_map[output_var_node.node] = dequant_var_node
699697
return dequant_var_node
700698

701699
def _insert_post_dequant_op(self, graph, op_node):
702700
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
703701
for var_node in op_node.inputs:
704702
name = var_node.name()
705-
if name in self._op_input_rename_map:
706-
old_in = graph.var_node(name)
707-
new_in = graph.var_node(self._op_input_rename_map[name])
703+
if name not in op_node.input_arg_names():
704+
continue
705+
if var_node.node in self._op_input_rename_map:
706+
old_in = var_node
707+
new_in = self._op_input_rename_map[var_node.node]
708708
new_in.clear_outputs()
709709
graph.update_input_link(old_in, new_in, op_node)
710710
original_var_name = self._original_var_name(name)
@@ -720,11 +720,12 @@ def _insert_post_dequant_op(self, graph, op_node):
720720
assert isinstance(scale_v, IrNode)
721721
scale_var_node = self._var_scale_map[original_var_name]
722722

723-
if len(op_node.outputs) != 1:
723+
if len(op_node.output_arg_names()) != 1:
724724
raise ValueError("Only support one output, but op %s has"
725725
" more than one output." % (op_node.name()))
726726

727-
output_var_node = op_node.outputs[0]
727+
output_var_node = self._to_node(op_node.outputs,
728+
op_node.output_arg_names()[0])
728729
dequant_var_node = graph.create_var_node(
729730
name=self._dequantized_var_name(output_var_node.name()),
730731
var_type=output_var_node.type(),
@@ -742,9 +743,27 @@ def _insert_post_dequant_op(self, graph, op_node):
742743
graph.link_to(output_var_node, dequant_op_node)
743744
graph.link_to(scale_var_node, dequant_op_node)
744745
graph.link_to(dequant_op_node, dequant_var_node)
745-
self._op_output_rename_map[output_var_node.name()] = dequant_var_node
746+
self._op_output_rename_map[output_var_node.node] = dequant_var_node
746747
return dequant_var_node
747748

749+
def _init_var_node(self, var_node, value):
750+
assert isinstance(
751+
value, np.ndarray), 'The type of value should be numpy array.'
752+
assert self._scope is not None, \
753+
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
754+
assert self._place is not None, \
755+
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
756+
tensor = self._scope.var(var_node.name()).get_tensor()
757+
tensor.set(value, self._place)
758+
759+
def _to_node(self, nodes, node_name):
760+
target_node = None
761+
for n in nodes:
762+
if n.name() == node_name:
763+
target_node = n
764+
assert target_node is not None, "Cannot find the target node in the giving set."
765+
return target_node
766+
748767
def _load_var(self, name):
749768
return np.array(self._scope.find_var(name).get_tensor())
750769

@@ -848,6 +867,7 @@ def apply(self, graph):
848867

849868
# remove the unused var node in the graph
850869
self._remove_unused_var_nodes(graph)
870+
graph.resolve_hazard()
851871
return graph
852872

853873
def _convert_to_int8(self, graph, var_node):
@@ -930,5 +950,5 @@ def apply(self, graph):
930950
for output_node in op_node.outputs:
931951
graph.link_to(dequant_node, output_node)
932952
graph.safe_remove_nodes(op_node)
933-
953+
graph.resolve_hazard()
934954
return graph

0 commit comments

Comments
 (0)