Skip to content

Commit cbab018

Browse files
ceci3ZhangHandi
andauthored
[Cherry pick] fix quant scale name (#44903)
* fix quant scale name (#44116) * fix acc diff problem caused by pr #44116 (#44311) Co-authored-by: handiz <[email protected]>
1 parent 2676281 commit cbab018

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -962,10 +962,10 @@ def _update_program(self):
962962
else:
963963
scale_dict = self._quantized_threshold
964964
for key, val in scale_dict.items():
965-
utils.set_variable_data(self._scope, self._place, key + ".scale",
965+
utils.set_variable_data(self._scope, self._place, key + "@scale",
966966
np.array([val], dtype=np.float32))
967967
utils.set_variable_data(self._scope, self._place,
968-
key + ".quant_dequant.scale",
968+
key + ".quant_dequant@scale",
969969
np.array([val], dtype=np.float32))
970970

971971
if not self._onnx_format:

python/paddle/fluid/contrib/slim/quantization/quantization_pass.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ def _quantized_scale_name(self, var_name):
906906
"""
907907
Return the scale name of quantized variable for the input `var_name`.
908908
"""
909-
return "%s.scale" % (var_name)
909+
return "%s@scale" % (var_name)
910910

911911
def _is_skip_quant(self, graph, op_node):
912912
"""
@@ -1246,8 +1246,8 @@ def _original_var_name(self, var_name):
12461246
return var_name[:-len('.quantized')]
12471247
if var_name.endswith('.dequantized'):
12481248
return var_name[:-len('.dequantized')]
1249-
if var_name.endswith('.scale'):
1250-
return var_name[:-len('.scale')]
1249+
if var_name.endswith('@scale'):
1250+
return var_name[:-len('@scale')]
12511251
else:
12521252
return var_name
12531253

@@ -1440,11 +1440,18 @@ def apply(self, graph):
14401440
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
14411441
continue
14421442

1443-
scale_node = graph.create_persistable_node(
1444-
name=self._scale_name(in_node.name()),
1445-
var_type=core.VarDesc.VarType.LOD_TENSOR,
1446-
shape=[1],
1447-
var_dtype=in_node.dtype())
1443+
try:
1444+
graph._find_node_by_name(
1445+
graph.all_var_nodes(),
1446+
self._scale_name(in_node.name()))
1447+
continue
1448+
except:
1449+
scale_node = graph.create_persistable_node(
1450+
name=self._scale_name(in_node.name()),
1451+
var_type=core.VarDesc.VarType.LOD_TENSOR,
1452+
shape=[1],
1453+
var_dtype=in_node.dtype())
1454+
14481455
data_type = 'float64' if in_node.dtype() \
14491456
== core.VarDesc.VarType.FP64 else 'float32'
14501457
_init_var_node(scale_node, np.ones([1], dtype=data_type),
@@ -1705,7 +1712,7 @@ def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
17051712
shape=var_node.shape(),
17061713
var_dtype=var_node.dtype())
17071714
scale_in_node = graph.create_persistable_node(
1708-
name="{}.quant_dequant.scale".format(var_node.name()),
1715+
name="{}.quant_dequant@scale".format(var_node.name()),
17091716
var_type=core.VarDesc.VarType.LOD_TENSOR,
17101717
shape=[1],
17111718
var_dtype=var_node.dtype())
@@ -1954,7 +1961,7 @@ def _quantized_scale_name(self, var_name):
19541961
"""
19551962
Return the scale name of quantized variable for the input `var_name`.
19561963
"""
1957-
return "%s.scale" % (var_name)
1964+
return "%s@scale" % (var_name)
19581965

19591966
def _zero_point_name(self, var_name):
19601967
"""

0 commit comments

Comments
 (0)