Skip to content

Commit aa731e6

Browse files
author
Wojciech Uss
authored
update scale collection and propagation algorithm (#31783) (#31810)
1 parent f3b0f8d commit aa731e6

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,8 @@ def __init__(self,
6262
self._ops_to_quantize = _ops_to_quantize
6363
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
6464
[-1])
65-
self._scale_immutable_ops = [
66-
'transpose2', 'reshape2', 'pool2d', 'scale'
67-
]
65+
self._scale_immutable_ops = ['transpose2', 'reshape2', 'pool2d']
66+
self._scale_ops = ['scale']
6867
self._conv_ops = ['conv2d', 'depthwise_conv2d']
6968
self._pool_ops = ['pool2d']
7069
self._mul_ops = ['mul']
@@ -87,8 +86,8 @@ def apply(self, graph):
8786
self._reset_pass_idx_and_group('int8')
8887
graph = self._label_skip_quantized_op(graph)
8988
graph = self._gather_weight_thresholds_from_fake(graph)
90-
graph = self._gather_output_scales_from_attr(graph)
9189
graph = self._gather_input_scales_from_fake(graph)
90+
graph = self._gather_output_scales_from_attr(graph)
9291
graph = self._remove_fake_ops(graph)
9392
graph = self._dequantize_weights(graph)
9493
graph = self._optimize_fp32_graph(graph)
@@ -160,12 +159,16 @@ def _label_skip_quantized_op(self, graph):
160159
op_node.op()._set_attr("skip_quant", True)
161160
return graph
162161

163-
def _gather_input_scales_from_fake(self, graph):
164-
def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
165-
scales = self._var_quant_scales
166-
for var_name in var_names:
162+
def _add_scale_for_vars(self, var_names, use_unsigned_int, lod_tensor):
163+
"""
164+
Save quantization scales for variables. Do not overwrite.
165+
"""
166+
scales = self._var_quant_scales
167+
for var_name in var_names:
168+
if var_name not in scales:
167169
scales[var_name] = (use_unsigned_int, lod_tensor)
168170

171+
def _gather_input_scales_from_fake(self, graph):
169172
# fake_quantize_dequantize_abs_max doesn't have scale value
170173
fake_ops = ['fake_quantize_dequantize_moving_average_abs_max']
171174
fake_ops.extend(self._fake_quantize_types)
@@ -185,8 +188,8 @@ def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
185188
scale[scale == np.Inf] = 0.0
186189
lod_tensor = self._convert_scale2tensor(scale)
187190
use_unsigned_int = False
188-
_add_scale_for_vars([input_name, output_name], use_unsigned_int,
189-
lod_tensor)
191+
self._add_scale_for_vars([input_name, output_name],
192+
use_unsigned_int, lod_tensor)
190193

191194
return graph
192195

@@ -219,8 +222,8 @@ def _gather_output_scales_from_attr(self, graph):
219222
use_unsigned_int = False
220223
for output_name in op.op().outputs():
221224
for out_var_name in op.op().output(output_name):
222-
self._var_quant_scales[out_var_name] = (
223-
use_unsigned_int, scale_lod_tensor)
225+
self._add_scale_for_vars(
226+
[out_var_name], use_unsigned_int, scale_lod_tensor)
224227

225228
return graph
226229

@@ -239,24 +242,21 @@ def _update_scales(graph):
239242
output_name = op.output("Out")[0]
240243
tensor_names = [input_name, output_name]
241244

242-
# Scale is not quantized, so if it doesn't have any scales
243-
# to propagate, its tensors won't be added to the waiting list.
244-
if all(name not in self._var_quant_scales for name in tensor_names) \
245-
and op.name() != 'scale':
245+
if all(name not in self._var_quant_scales
246+
for name in tensor_names):
246247
waiting_for_scale.update(tensor_names)
247248
continue
248-
249-
if input_name in self._var_quant_scales:
249+
elif input_name in self._var_quant_scales:
250250
self._var_quant_scales[
251251
output_name] = self._var_quant_scales[input_name]
252252
elif output_name in self._var_quant_scales:
253-
if op.name() == 'scale':
254-
_update_scale_op_in_scale(op, input_name,
255-
output_name)
256-
else:
257-
self._var_quant_scales[
258-
input_name] = self._var_quant_scales[
259-
output_name]
253+
self._var_quant_scales[
254+
input_name] = self._var_quant_scales[output_name]
255+
elif op.name() in self._scale_ops:
256+
input_name = op.input("X")[0]
257+
output_name = op.output("Out")[0]
258+
if output_name in self._var_quant_scales:
259+
_update_scale_op_in_scale(op, input_name, output_name)
260260
return waiting_for_scale
261261

262262
waiting_for_scale = _update_scales(graph)

0 commit comments

Comments
 (0)