Skip to content

Commit a8cfdd3

Browse files
authored
fix PACT quant_aware abnormal training accuracy (PaddlePaddle#1111)
1 parent 0bb7772 commit a8cfdd3

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

paddleslim/quant/quanter.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,31 @@
9595
}
9696

9797

98+
# TODO: Hard-code, remove it when Paddle 2.3.1
99+
class OutScaleForTrainingPassV2(OutScaleForTrainingPass):
100+
def __init__(self, scope=None, place=None, moving_rate=0.9):
101+
OutScaleForTrainingPass.__init__(
102+
self, scope=scope, place=place, moving_rate=moving_rate)
103+
104+
def _scale_name(self, var_name):
105+
"""
106+
Return the scale name for the var named `var_name`.
107+
"""
108+
return "%s@scale" % (var_name)
109+
110+
111+
# TODO: Hard-code, remove it when Paddle 2.3.1
112+
class OutScaleForInferencePassV2(OutScaleForInferencePass):
113+
def __init__(self, scope=None):
114+
OutScaleForInferencePass.__init__(self, scope=scope)
115+
116+
def _scale_name(self, var_name):
117+
"""
118+
Return the scale name for the var named `var_name`.
119+
"""
120+
return "%s@scale" % (var_name)
121+
122+
98123
def load_dict():
99124
with open(VARS_MAPPING_TABLE, 'r') as file:
100125
data = file.read()
@@ -298,7 +323,7 @@ def quant_aware(program,
298323
quantizable_op_type=quant_dequant_ops)
299324
quant_dequant_pass.apply(main_graph)
300325

301-
out_scale_training_pass = OutScaleForTrainingPass(
326+
out_scale_training_pass = OutScaleForTrainingPassV2(
302327
scope=scope, place=place, moving_rate=config['moving_rate'])
303328
out_scale_training_pass.apply(main_graph)
304329

@@ -509,7 +534,7 @@ def convert(program,
509534
quant_weight_pass = QuantWeightPass(scope, place)
510535
quant_weight_pass.apply(test_graph)
511536
else:
512-
out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
537+
out_scale_infer_pass = OutScaleForInferencePassV2(scope=scope)
513538
out_scale_infer_pass.apply(test_graph)
514539
# Freeze the graph after training by adjusting the quantize
515540
# operators' order for the inference.

0 commit comments

Comments
 (0)