Skip to content

Commit 5993dde

Browse files
authored
[Cherry-pick] Fix some bugs for quantization (#24852)
* Update sigmoid output from Y to out, test=develop (#24765) * Collecting concat output threshold, test=develop (#24742) * Add output threshold for ops that have several output activations, test=develop (#24726) * [Fix bug] Init scale node in OutScaleForTrainingPass and enable test_quantization_scale_pass UT (#24393) * Init scale node in OutScaleForTrainingPass, test=develop * Enable test_quantization_scale, test=develop
1 parent 8c40ebd commit 5993dde

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

python/paddle/fluid/contrib/slim/quantization/quantization_pass.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
_out_scale_op_list = [
4444
"conv2d", "depthwise_conv2d", "mul", "matmul", "relu", "leaky_relu",
4545
"relu6", "sigmoid", "tanh", "prelu", "swish", "softmax", "batch_norm",
46-
"elementwise_add", "pool2d", "reshape2", "transpose2"
46+
"elementwise_add", "pool2d", "reshape2", "transpose2", "concat"
4747
]
4848

4949
# list op real input and output names, to avoid processing input such as AxisTensor.
@@ -83,7 +83,7 @@
8383
"swish": [["X"], ["Out"]],
8484
"dropout": [["X"], ["Out"]],
8585
"batch_norm": [["X"], ["Y"]],
86-
"sigmoid": [["X"], ["Y"]],
86+
"sigmoid": [["X"], ["Out"]],
8787
}
8888

8989

@@ -1156,20 +1156,27 @@ def apply(self, graph):
11561156
assert isinstance(graph,
11571157
IrGraph), 'graph must be the instance of IrGraph.'
11581158
self._is_test = graph.is_test()
1159-
ops = graph.all_op_nodes()
1160-
for op_node in ops:
1161-
name = op_node.name()
1162-
if name in self._teller_set:
1163-
if len(op_node.output_arg_names()) != 1:
1164-
continue
1165-
in_node = graph._find_node_by_name(
1166-
op_node.outputs, op_node.output_arg_names()[0])
1159+
target_ops = []
1160+
for op in graph.all_op_nodes():
1161+
if op.name() in self._teller_set:
1162+
target_ops.append(op)
1163+
for op in target_ops:
1164+
for output_var_name in _get_op_output_var_names(op):
1165+
in_node = graph._find_node_by_name(op.outputs, output_var_name)
11671166
out_node = graph.create_var_node_from_desc(in_node.var())
11681167
scale_node = graph.create_persistable_node(
11691168
name=self._scale_name(in_node.name()),
11701169
var_type=core.VarDesc.VarType.LOD_TENSOR,
11711170
shape=[1],
11721171
var_dtype=in_node.dtype())
1172+
data_type = 'float64' if in_node.dtype() \
1173+
== core.VarDesc.VarType.FP64 else 'float32'
1174+
_init_var_node(
1175+
scale_node,
1176+
np.ones(
1177+
[1], dtype=data_type),
1178+
self._scope,
1179+
self._place)
11731180
ins = {'X': in_node}
11741181
outs = {'Out': out_node, 'OutScale': scale_node}
11751182
if not self._is_test:
@@ -1178,8 +1185,6 @@ def apply(self, graph):
11781185
var_type=core.VarDesc.VarType.LOD_TENSOR,
11791186
var_dtype=in_node.dtype(),
11801187
shape=[1])
1181-
data_type = 'float64' if in_node.dtype(
1182-
) == core.VarDesc.VarType.FP64 else 'float32'
11831188
_init_var_node(
11841189
state_in_node,
11851190
np.ones(
@@ -1257,13 +1262,13 @@ def apply(self, graph):
12571262
"""
12581263
assert isinstance(graph,
12591264
IrGraph), 'graph must be the instance of IrGraph.'
1260-
ops = graph.all_op_nodes()
1261-
for op_node in ops:
1262-
name = op_node.name()
1263-
if name in self._teller_set:
1264-
if len(op_node.output_arg_names()) != 1:
1265-
continue
1266-
scale_name = self._scale_name(op_node.output_arg_names()[0])
1265+
op_nodes = graph.all_op_nodes()
1266+
for op_node in op_nodes:
1267+
if op_node.name() in self._teller_set:
1268+
output_var_name = _get_op_output_var_names(op_node)
1269+
assert len(output_var_name) == 1, "Only support collecting " \
1270+
"output for op that only has an activation output for now."
1271+
scale_name = self._scale_name(output_var_name[0])
12671272
scale_v = np.array(
12681273
self._scope.find_var(scale_name).get_tensor())[0]
12691274
op_node.op()._set_attr("out_threshold", float(scale_v))

python/paddle/fluid/contrib/slim/tests/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,6 @@ if(WIN32)
114114
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
115115
endif()
116116

117-
# Disable unittest for random error temporary
118-
list(REMOVE_ITEM TEST_OPS test_quantization_scale_pass)
119-
120117
if(LINUX AND WITH_MKLDNN)
121118

122119
#### Image classification dataset: ImageNet (small)

0 commit comments

Comments
 (0)