@@ -62,9 +62,8 @@ def __init__(self,
62
62
self ._ops_to_quantize = _ops_to_quantize
63
63
self ._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set (
64
64
[- 1 ])
65
- self ._scale_immutable_ops = [
66
- 'transpose2' , 'reshape2' , 'pool2d' , 'scale'
67
- ]
65
+ self ._scale_immutable_ops = ['transpose2' , 'reshape2' , 'pool2d' ]
66
+ self ._scale_ops = ['scale' ]
68
67
self ._conv_ops = ['conv2d' , 'depthwise_conv2d' ]
69
68
self ._pool_ops = ['pool2d' ]
70
69
self ._mul_ops = ['mul' ]
@@ -87,8 +86,8 @@ def apply(self, graph):
87
86
self ._reset_pass_idx_and_group ('int8' )
88
87
graph = self ._label_skip_quantized_op (graph )
89
88
graph = self ._gather_weight_thresholds_from_fake (graph )
90
- graph = self ._gather_output_scales_from_attr (graph )
91
89
graph = self ._gather_input_scales_from_fake (graph )
90
+ graph = self ._gather_output_scales_from_attr (graph )
92
91
graph = self ._remove_fake_ops (graph )
93
92
graph = self ._dequantize_weights (graph )
94
93
graph = self ._optimize_fp32_graph (graph )
@@ -160,12 +159,16 @@ def _label_skip_quantized_op(self, graph):
160
159
op_node .op ()._set_attr ("skip_quant" , True )
161
160
return graph
162
161
163
- def _gather_input_scales_from_fake (self , graph ):
164
- def _add_scale_for_vars (var_names , use_unsigned_int , lod_tensor ):
165
- scales = self ._var_quant_scales
166
- for var_name in var_names :
162
+ def _add_scale_for_vars (self , var_names , use_unsigned_int , lod_tensor ):
163
+ """
164
+ Save quantization scales for variables. Do not overwrite.
165
+ """
166
+ scales = self ._var_quant_scales
167
+ for var_name in var_names :
168
+ if var_name not in scales :
167
169
scales [var_name ] = (use_unsigned_int , lod_tensor )
168
170
171
+ def _gather_input_scales_from_fake (self , graph ):
169
172
# fake_quantize_dequantize_abs_max doesn't have scale value
170
173
fake_ops = ['fake_quantize_dequantize_moving_average_abs_max' ]
171
174
fake_ops .extend (self ._fake_quantize_types )
@@ -185,8 +188,8 @@ def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
185
188
scale [scale == np .Inf ] = 0.0
186
189
lod_tensor = self ._convert_scale2tensor (scale )
187
190
use_unsigned_int = False
188
- _add_scale_for_vars ([input_name , output_name ], use_unsigned_int ,
189
- lod_tensor )
191
+ self . _add_scale_for_vars ([input_name , output_name ],
192
+ use_unsigned_int , lod_tensor )
190
193
191
194
return graph
192
195
@@ -219,8 +222,8 @@ def _gather_output_scales_from_attr(self, graph):
219
222
use_unsigned_int = False
220
223
for output_name in op .op ().outputs ():
221
224
for out_var_name in op .op ().output (output_name ):
222
- self ._var_quant_scales [ out_var_name ] = (
223
- use_unsigned_int , scale_lod_tensor )
225
+ self ._add_scale_for_vars (
226
+ [ out_var_name ], use_unsigned_int , scale_lod_tensor )
224
227
225
228
return graph
226
229
@@ -239,24 +242,21 @@ def _update_scales(graph):
239
242
output_name = op .output ("Out" )[0 ]
240
243
tensor_names = [input_name , output_name ]
241
244
242
- # Scale is not quantized, so if it doesn't have any scales
243
- # to propagate, its tensors won't be added to the waiting list.
244
- if all (name not in self ._var_quant_scales for name in tensor_names ) \
245
- and op .name () != 'scale' :
245
+ if all (name not in self ._var_quant_scales
246
+ for name in tensor_names ):
246
247
waiting_for_scale .update (tensor_names )
247
248
continue
248
-
249
- if input_name in self ._var_quant_scales :
249
+ elif input_name in self ._var_quant_scales :
250
250
self ._var_quant_scales [
251
251
output_name ] = self ._var_quant_scales [input_name ]
252
252
elif output_name in self ._var_quant_scales :
253
- if op . name () == 'scale' :
254
- _update_scale_op_in_scale ( op , input_name ,
255
- output_name )
256
- else :
257
- self . _var_quant_scales [
258
- input_name ] = self ._var_quant_scales [
259
- output_name ]
253
+ self . _var_quant_scales [
254
+ input_name ] = self . _var_quant_scales [ output_name ]
255
+ elif op . name () in self . _scale_ops :
256
+ input_name = op . input ( "X" )[ 0 ]
257
+ output_name = op . output ( "Out" )[ 0 ]
258
+ if output_name in self ._var_quant_scales :
259
+ _update_scale_op_in_scale ( op , input_name , output_name )
260
260
return waiting_for_scale
261
261
262
262
waiting_for_scale = _update_scales (graph )
0 commit comments