@@ -42,7 +42,7 @@ def linear_fc(num):
42
42
return loss
43
43
44
44
45
- def residual_block (num ):
45
+ def residual_block (num , quant_skip_pattern = None ):
46
46
def conv_bn_layer (input ,
47
47
ch_out ,
48
48
filter_size ,
@@ -67,8 +67,14 @@ def conv_bn_layer(input,
67
67
conv = conv_bn_layer (hidden , 16 , 3 , 1 , 1 , act = None , bias_attr = True )
68
68
short = conv_bn_layer (hidden , 16 , 1 , 1 , 0 , act = None )
69
69
hidden = fluid .layers .elementwise_add (x = conv , y = short , act = 'relu' )
70
- pool = fluid .layers .pool2d (
71
- input = hidden , pool_size = 2 , pool_type = 'avg' , pool_stride = 2 )
70
+
71
+ if quant_skip_pattern :
72
+ with fluid .name_scope (quant_skip_pattern ):
73
+ pool = fluid .layers .pool2d (
74
+ input = hidden , pool_size = 2 , pool_type = 'avg' , pool_stride = 2 )
75
+ else :
76
+ pool = fluid .layers .pool2d (
77
+ input = hidden , pool_size = 2 , pool_type = 'avg' , pool_stride = 2 )
72
78
fc = fluid .layers .fc (input = pool , size = 10 )
73
79
loss = fluid .layers .cross_entropy (input = fc , label = label )
74
80
loss = fluid .layers .mean (loss )
@@ -134,7 +140,10 @@ def check_program(self, program):
134
140
arg_name .endswith ('.quantized.dequantized' ))
135
141
self .assertTrue (arg_name in quantized_ops )
136
142
137
- def linear_fc_quant (self , activation_quant_type , for_ci = True ):
143
+ def linear_fc_quant (self ,
144
+ activation_quant_type ,
145
+ weight_quantize_type ,
146
+ for_ci = True ):
138
147
main = fluid .Program ()
139
148
startup = fluid .Program ()
140
149
with fluid .program_guard (main , startup ):
@@ -146,7 +155,8 @@ def linear_fc_quant(self, activation_quant_type, for_ci=True):
146
155
transform_pass = QuantizationTransformPass (
147
156
scope = fluid .global_scope (),
148
157
place = place ,
149
- activation_quantize_type = activation_quant_type )
158
+ activation_quantize_type = activation_quant_type ,
159
+ weight_quantize_type = weight_quantize_type )
150
160
transform_pass .apply (graph )
151
161
if not for_ci :
152
162
marked_nodes = set ()
@@ -167,15 +177,19 @@ def linear_fc_quant(self, activation_quant_type, for_ci=True):
167
177
val_marked_nodes )
168
178
169
179
def test_linear_fc_quant_abs_max (self ):
170
- self .linear_fc_quant ('abs_max' , for_ci = True )
180
+ self .linear_fc_quant ('abs_max' , 'abs_max' , for_ci = True )
171
181
172
182
def test_linear_fc_quant_range_abs_max (self ):
173
- self .linear_fc_quant ('range_abs_max' , for_ci = True )
183
+ self .linear_fc_quant ('range_abs_max' , 'abs_max' , for_ci = True )
174
184
175
185
def test_linear_fc_quant_moving_average_abs_max (self ):
176
- self .linear_fc_quant ('moving_average_abs_max' , for_ci = True )
186
+ self .linear_fc_quant (
187
+ 'moving_average_abs_max' , 'channel_wise_abs_max' , for_ci = True )
177
188
178
- def residual_block_quant (self , activation_quant_type , for_ci = True ):
189
+ def residual_block_quant (self ,
190
+ activation_quant_type ,
191
+ weight_quantize_type ,
192
+ for_ci = True ):
179
193
main = fluid .Program ()
180
194
startup = fluid .Program ()
181
195
with fluid .program_guard (main , startup ):
@@ -187,7 +201,8 @@ def residual_block_quant(self, activation_quant_type, for_ci=True):
187
201
transform_pass = QuantizationTransformPass (
188
202
scope = fluid .global_scope (),
189
203
place = place ,
190
- activation_quantize_type = activation_quant_type )
204
+ activation_quantize_type = activation_quant_type ,
205
+ weight_quantize_type = weight_quantize_type )
191
206
transform_pass .apply (graph )
192
207
if not for_ci :
193
208
marked_nodes = set ()
@@ -208,13 +223,14 @@ def residual_block_quant(self, activation_quant_type, for_ci=True):
208
223
val_marked_nodes )
209
224
210
225
def test_residual_block_abs_max (self ):
211
- self .residual_block_quant ('abs_max' , for_ci = True )
226
+ self .residual_block_quant ('abs_max' , 'abs_max' , for_ci = True )
212
227
213
228
def test_residual_block_range_abs_max (self ):
214
- self .residual_block_quant ('range_abs_max' , for_ci = True )
229
+ self .residual_block_quant ('range_abs_max' , 'abs_max' , for_ci = True )
215
230
216
231
def test_residual_block_moving_average_abs_max (self ):
217
- self .residual_block_quant ('moving_average_abs_max' , for_ci = True )
232
+ self .residual_block_quant (
233
+ 'moving_average_abs_max' , 'channel_wise_abs_max' , for_ci = True )
218
234
219
235
220
236
class TestQuantizationFreezePass (unittest .TestCase ):
@@ -494,11 +510,14 @@ def setUp(self):
494
510
self ._target_ops = {'elementwise_add' , 'pool2d' }
495
511
self ._target_grad_ops = {'elementwise_add_grad' , 'pool2d_grad' }
496
512
497
- def check_graph (self , graph ):
513
+ def check_graph (self , graph , skip_pattern = None ):
498
514
ops = graph .all_op_nodes ()
499
-
500
515
for op_node in ops :
501
516
if op_node .name () in self ._target_ops :
517
+ if skip_pattern and op_node .op ().has_attr ("op_namescope" ) and \
518
+ op_node .op ().attr ("op_namescope" ).find (skip_pattern ) != - 1 :
519
+ continue
520
+
502
521
in_nodes_all_not_persistable = True
503
522
for input_name in op_node .input_arg_names ():
504
523
in_node = graph ._find_node_by_name (op_node .inputs ,
@@ -508,20 +527,15 @@ def check_graph(self, graph):
508
527
not in_node .persistable ())
509
528
if not in_nodes_all_not_persistable :
510
529
continue
511
-
512
- if op_node .op ().has_attr ("pooling_type" ) and \
513
- op_node .op ().attr ("pooling_type" ) == 'max' :
514
- continue
515
-
516
530
input_names = op_node .input_arg_names ()
517
531
for input_name in input_names :
518
532
self .assertTrue (input_name .endswith ('.quant_dequant' ))
519
533
520
- def residual_block_quant (self , for_ci = True ):
534
+ def residual_block_quant (self , skip_pattern = None , for_ci = True ):
521
535
main = fluid .Program ()
522
536
startup = fluid .Program ()
523
537
with fluid .program_guard (main , startup ):
524
- loss = residual_block (1 )
538
+ loss = residual_block (2 , skip_pattern )
525
539
opt = fluid .optimizer .Adam (learning_rate = 0.001 )
526
540
opt .minimize (loss )
527
541
place = fluid .CPUPlace ()
@@ -535,7 +549,7 @@ def residual_block_quant(self, for_ci=True):
535
549
if op .name ().find ('quant' ) > - 1 :
536
550
marked_nodes .add (op )
537
551
graph .draw ('.' , 'add_quant_dequant_graph' , marked_nodes )
538
- self .check_graph (graph )
552
+ self .check_graph (graph , skip_pattern )
539
553
program = graph .to_program ()
540
554
val_graph = IrGraph (core .Graph (program .desc ), for_test = False )
541
555
if not for_ci :
@@ -546,7 +560,10 @@ def residual_block_quant(self, for_ci=True):
546
560
val_graph .draw ('.' , 'val_add_quant_dequant_graph' , val_marked_nodes )
547
561
548
562
def test_residual_block (self ):
549
- self .residual_block_quant (for_ci = True )
563
+ self .residual_block_quant (skip_pattern = None , for_ci = True )
564
+
565
+ def test_residual_block_skip_pattern (self ):
566
+ self .residual_block_quant (skip_pattern = 'skip_quant' , for_ci = True )
550
567
551
568
552
569
if __name__ == '__main__' :
0 commit comments