Skip to content

Commit 228c1d7

Browse files
authored
Fix the error of save_quantized_model (#30587)
动态图中Conv2D保存成预测模型时,对应的Op可能是conv2d,也可能是depthwise_conv2d,但目前的save_quantized_model接口并未考虑depthwise_conv2d情况,可能会致使out_scale的值保存错误,该PR主要是修复这个问题。
1 parent 09aed38 commit 228c1d7

File tree

1 file changed

+9
-5
lines changed
  • python/paddle/fluid/contrib/slim/quantization/imperative

1 file changed

+9
-5
lines changed

python/paddle/fluid/contrib/slim/quantization/imperative/qat.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
_op_real_in_out_name = {
3838
"conv2d": [["Input", "Filter"], ["Output"]],
39-
"conv2d_transpose": [["Input", "Filter"], ["Output"]],
39+
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
4040
"pool2d": [["X"], ["Out"]],
4141
"elementwise_add": [["X", "Y"], ["Out"]],
4242
"softmax": [["X"], ["Out"]],
@@ -329,9 +329,9 @@ def __init__(self, moving_rate=0.9):
329329
super(ImperativeCalcOutScale, self).__init__()
330330
self._moving_rate = moving_rate
331331
self._out_scale_layer_type_list = (
332-
BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D,
333-
Conv2DTranspose, LeakyReLU, Linear, PReLU, Pool2D, MaxPool1D,
334-
MaxPool2D, ReLU, ReLU6, Sigmoid, Softmax, Tanh, Swish)
332+
BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, LeakyReLU,
333+
Linear, PReLU, Pool2D, MaxPool1D, MaxPool2D, ReLU, ReLU6, Sigmoid,
334+
Softmax, Tanh, Swish)
335335
self._register_hook_handle_list = []
336336
self._out_scale_dict = collections.OrderedDict()
337337

@@ -415,9 +415,10 @@ def save_quantized_model(self, layer, path, input_spec=None, **config):
415415

416416
# Traverse all ops in the program and find out the op matching
417417
# the Layer in the dynamic graph.
418-
layer_var_dict = {}
418+
layer_var_dict = collections.OrderedDict()
419419
ops_list = [key for key, _ in self._out_scale_dict.items()]
420420
op_count = 0
421+
conv_count = 0
421422
for block in inference_program.blocks:
422423
for op in block.ops:
423424
if op.type in _op_real_in_out_name:
@@ -472,6 +473,9 @@ def save_quantized_model(self, layer, path, input_spec=None, **config):
472473
layer_name = layer_name.replace('prelu', 'p_re_lu')
473474
if 'relu' in layer_name:
474475
layer_name = layer_name.replace('relu', 're_lu')
476+
if 'conv2d' in layer_name:
477+
layer_name = 'conv2d_' + str(conv_count)
478+
conv_count = conv_count + 1
475479
if layer_name not in self._out_scale_dict:
476480
continue
477481
var_name_op_list[1]._set_attr('out_threshold',

0 commit comments

Comments
 (0)