Skip to content

Commit 965b45e

Browse files
authored
Move pool2d to add_quant_dequant_pass, test=develop (#20586) (#20675)
* move pool2d to add_quant_dequant_pass, test=develop
1 parent e083f14 commit 965b45e

File tree

2 files changed

+57
-35
lines changed

2 files changed

+57
-35
lines changed

python/paddle/fluid/contrib/slim/quantization/quantization_pass.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
'AddQuantDequantPass'
2727
]
2828

29-
_quantizable_op_list = ['conv2d', 'depthwise_conv2d', 'mul', 'pool2d']
29+
_quantizable_op_list = ['conv2d', 'depthwise_conv2d', 'mul']
3030

3131
_fake_quant_op_list = [
3232
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
@@ -161,13 +161,11 @@ def apply(self, graph):
161161
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
162162

163163
def _quant_preprocess(op_node):
164-
pool_skipped = op_node.op().has_attr("pooling_type") and \
165-
op_node.op().attr("pooling_type") == 'avg'
166164
user_skipped = isinstance(self._skip_pattern, str) and \
167165
op_node.op().has_attr("op_namescope") and \
168166
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
169167

170-
if pool_skipped or user_skipped:
168+
if user_skipped:
171169
op_node.op()._set_attr("skip_quant", True)
172170

173171
def _transform_forward(graph, op):
@@ -1163,10 +1161,15 @@ def _scale_name(self, var_name):
11631161

11641162

11651163
class AddQuantDequantPass(object):
1166-
def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8):
1164+
def __init__(self,
1165+
scope=None,
1166+
place=None,
1167+
moving_rate=0.9,
1168+
quant_bits=8,
1169+
skip_pattern='skip_quant'):
11671170
"""
11681171
This pass is used to add quant_dequant op for some ops, such as the
1169-
'elementwise_add' and 'average pool2d' op.
1172+
'elementwise_add' and 'pool2d' op.
11701173
"""
11711174
self._scope = scope
11721175
self._place = place
@@ -1175,11 +1178,12 @@ def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8):
11751178
self._is_test = None
11761179
self._target_ops = ["elementwise_add", "pool2d"]
11771180
self._target_grad_ops = ['%s_grad' % (op) for op in self._target_ops]
1181+
self._skip_pattern = skip_pattern
11781182

11791183
def apply(self, graph):
11801184
"""
11811185
Add quant_dequant before some ops, such as the 'elementwise_add'
1182-
and 'average pool2d' op.
1186+
and 'pool2d' op.
11831187
Args:
11841188
graph(IrGraph): the target graph.
11851189
"""
@@ -1191,6 +1195,11 @@ def apply(self, graph):
11911195

11921196
for op_node in ops:
11931197
if op_node.name() in self._target_ops:
1198+
if isinstance(self._skip_pattern, str) and \
1199+
op_node.op().has_attr("op_namescope") and \
1200+
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1:
1201+
continue
1202+
11941203
in_nodes_all_not_persistable = True
11951204
for input_name in op_node.input_arg_names():
11961205
in_node = graph._find_node_by_name(op_node.inputs,
@@ -1201,10 +1210,6 @@ def apply(self, graph):
12011210
if not in_nodes_all_not_persistable:
12021211
continue
12031212

1204-
if op_node.op().has_attr("pooling_type") and \
1205-
op_node.op().attr("pooling_type") == 'max':
1206-
continue
1207-
12081213
input_names = op_node.input_arg_names()
12091214
for input_name in input_names:
12101215
in_node = graph._find_node_by_name(op_node.inputs,

python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def linear_fc(num):
4242
return loss
4343

4444

45-
def residual_block(num):
45+
def residual_block(num, quant_skip_pattern=None):
4646
def conv_bn_layer(input,
4747
ch_out,
4848
filter_size,
@@ -67,8 +67,14 @@ def conv_bn_layer(input,
6767
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
6868
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
6969
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
70-
pool = fluid.layers.pool2d(
71-
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
70+
71+
if quant_skip_pattern:
72+
with fluid.name_scope(quant_skip_pattern):
73+
pool = fluid.layers.pool2d(
74+
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
75+
else:
76+
pool = fluid.layers.pool2d(
77+
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
7278
fc = fluid.layers.fc(input=pool, size=10)
7379
loss = fluid.layers.cross_entropy(input=fc, label=label)
7480
loss = fluid.layers.mean(loss)
@@ -134,7 +140,10 @@ def check_program(self, program):
134140
arg_name.endswith('.quantized.dequantized'))
135141
self.assertTrue(arg_name in quantized_ops)
136142

137-
def linear_fc_quant(self, activation_quant_type, for_ci=True):
143+
def linear_fc_quant(self,
144+
activation_quant_type,
145+
weight_quantize_type,
146+
for_ci=True):
138147
main = fluid.Program()
139148
startup = fluid.Program()
140149
with fluid.program_guard(main, startup):
@@ -146,7 +155,8 @@ def linear_fc_quant(self, activation_quant_type, for_ci=True):
146155
transform_pass = QuantizationTransformPass(
147156
scope=fluid.global_scope(),
148157
place=place,
149-
activation_quantize_type=activation_quant_type)
158+
activation_quantize_type=activation_quant_type,
159+
weight_quantize_type=weight_quantize_type)
150160
transform_pass.apply(graph)
151161
if not for_ci:
152162
marked_nodes = set()
@@ -167,15 +177,19 @@ def linear_fc_quant(self, activation_quant_type, for_ci=True):
167177
val_marked_nodes)
168178

169179
def test_linear_fc_quant_abs_max(self):
170-
self.linear_fc_quant('abs_max', for_ci=True)
180+
self.linear_fc_quant('abs_max', 'abs_max', for_ci=True)
171181

172182
def test_linear_fc_quant_range_abs_max(self):
173-
self.linear_fc_quant('range_abs_max', for_ci=True)
183+
self.linear_fc_quant('range_abs_max', 'abs_max', for_ci=True)
174184

175185
def test_linear_fc_quant_moving_average_abs_max(self):
176-
self.linear_fc_quant('moving_average_abs_max', for_ci=True)
186+
self.linear_fc_quant(
187+
'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True)
177188

178-
def residual_block_quant(self, activation_quant_type, for_ci=True):
189+
def residual_block_quant(self,
190+
activation_quant_type,
191+
weight_quantize_type,
192+
for_ci=True):
179193
main = fluid.Program()
180194
startup = fluid.Program()
181195
with fluid.program_guard(main, startup):
@@ -187,7 +201,8 @@ def residual_block_quant(self, activation_quant_type, for_ci=True):
187201
transform_pass = QuantizationTransformPass(
188202
scope=fluid.global_scope(),
189203
place=place,
190-
activation_quantize_type=activation_quant_type)
204+
activation_quantize_type=activation_quant_type,
205+
weight_quantize_type=weight_quantize_type)
191206
transform_pass.apply(graph)
192207
if not for_ci:
193208
marked_nodes = set()
@@ -208,13 +223,14 @@ def residual_block_quant(self, activation_quant_type, for_ci=True):
208223
val_marked_nodes)
209224

210225
def test_residual_block_abs_max(self):
211-
self.residual_block_quant('abs_max', for_ci=True)
226+
self.residual_block_quant('abs_max', 'abs_max', for_ci=True)
212227

213228
def test_residual_block_range_abs_max(self):
214-
self.residual_block_quant('range_abs_max', for_ci=True)
229+
self.residual_block_quant('range_abs_max', 'abs_max', for_ci=True)
215230

216231
def test_residual_block_moving_average_abs_max(self):
217-
self.residual_block_quant('moving_average_abs_max', for_ci=True)
232+
self.residual_block_quant(
233+
'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True)
218234

219235

220236
class TestQuantizationFreezePass(unittest.TestCase):
@@ -494,11 +510,14 @@ def setUp(self):
494510
self._target_ops = {'elementwise_add', 'pool2d'}
495511
self._target_grad_ops = {'elementwise_add_grad', 'pool2d_grad'}
496512

497-
def check_graph(self, graph):
513+
def check_graph(self, graph, skip_pattern=None):
498514
ops = graph.all_op_nodes()
499-
500515
for op_node in ops:
501516
if op_node.name() in self._target_ops:
517+
if skip_pattern and op_node.op().has_attr("op_namescope") and \
518+
op_node.op().attr("op_namescope").find(skip_pattern) != -1:
519+
continue
520+
502521
in_nodes_all_not_persistable = True
503522
for input_name in op_node.input_arg_names():
504523
in_node = graph._find_node_by_name(op_node.inputs,
@@ -508,20 +527,15 @@ def check_graph(self, graph):
508527
not in_node.persistable())
509528
if not in_nodes_all_not_persistable:
510529
continue
511-
512-
if op_node.op().has_attr("pooling_type") and \
513-
op_node.op().attr("pooling_type") == 'max':
514-
continue
515-
516530
input_names = op_node.input_arg_names()
517531
for input_name in input_names:
518532
self.assertTrue(input_name.endswith('.quant_dequant'))
519533

520-
def residual_block_quant(self, for_ci=True):
534+
def residual_block_quant(self, skip_pattern=None, for_ci=True):
521535
main = fluid.Program()
522536
startup = fluid.Program()
523537
with fluid.program_guard(main, startup):
524-
loss = residual_block(1)
538+
loss = residual_block(2, skip_pattern)
525539
opt = fluid.optimizer.Adam(learning_rate=0.001)
526540
opt.minimize(loss)
527541
place = fluid.CPUPlace()
@@ -535,7 +549,7 @@ def residual_block_quant(self, for_ci=True):
535549
if op.name().find('quant') > -1:
536550
marked_nodes.add(op)
537551
graph.draw('.', 'add_quant_dequant_graph', marked_nodes)
538-
self.check_graph(graph)
552+
self.check_graph(graph, skip_pattern)
539553
program = graph.to_program()
540554
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
541555
if not for_ci:
@@ -546,7 +560,10 @@ def residual_block_quant(self, for_ci=True):
546560
val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes)
547561

548562
def test_residual_block(self):
549-
self.residual_block_quant(for_ci=True)
563+
self.residual_block_quant(skip_pattern=None, for_ci=True)
564+
565+
def test_residual_block_skip_pattern(self):
566+
self.residual_block_quant(skip_pattern='skip_quant', for_ci=True)
550567

551568

552569
if __name__ == '__main__':

0 commit comments

Comments
 (0)