Skip to content

Commit 24b3bbd

Browse files
yghstillRachelXu7
andauthored
[cherry-pick] fix QuantizeLinear pass and support reduce_max in quantization (#44872)
* fix QuantizeLinear kernel and pass in QAT (#44784) * Add Reduce Max in Quant (#44825) Co-authored-by: Chang Xu <[email protected]>
1 parent 245005d commit 24b3bbd

File tree

7 files changed

+148
-56
lines changed

7 files changed

+148
-56
lines changed

paddle/fluid/operators/fake_quantize_op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ struct FindMovingAverageAbsMaxFunctor {
139139
void operator()(const DeviceContext &ctx,
140140
const framework::Tensor &in_accum,
141141
const framework::Tensor &in_state,
142-
const framework::Tensor &cur_scale,
142+
const T *cur_scale,
143+
const float rate,
143144
framework::Tensor *out_state,
144145
framework::Tensor *out_accum,
145146
framework::Tensor *out_scale);

paddle/fluid/operators/quantize_linear_op.cc

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ class QuantizeLinearOp : public framework::OperatorWithKernel {
9393
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]});
9494
}
9595
}
96+
if (ctx->HasOutput("OutState")) {
97+
ctx->SetOutputDim("OutState", {1});
98+
}
99+
if (ctx->HasOutput("OutAccum")) {
100+
ctx->SetOutputDim("OutAccum", {1});
101+
}
96102
ctx->ShareLoD("X", /*->*/ "Y");
97103
}
98104

@@ -113,7 +119,25 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
113119
AddOutput("Y",
114120
"(Tensor) Output of quantized low level tensor, "
115121
"but also saved as float data type.");
116-
AddOutput("OutScale", "(Tensor) Current scale").AsDispensable().AsExtra();
122+
AddInput("InAccum", "Last accum.")
123+
.AsDispensable()
124+
.AsExtra(); // only qat use
125+
AddInput("InState", "Last state.")
126+
.AsDispensable()
127+
.AsExtra(); // only qat use
128+
AddOutput("OutState", "(Tensor) state buffer.")
129+
.AsDispensable()
130+
.AsExtra(); // only qat use
131+
AddOutput("OutAccum", "(Tensor) accum buffer.")
132+
.AsDispensable()
133+
.AsExtra(); // only qat use
134+
AddOutput("OutScale", "(Tensor) Current scale")
135+
.AsDispensable()
136+
.AsExtra(); // only qat use
137+
AddAttr<float>("moving_rate",
138+
"(float, default 0.9) moving rate.") // only qat use
139+
.SetDefault(0.9)
140+
.AsExtra();
117141
AddAttr<int>("quant_axis",
118142
"(int, default 0) The axis for quantization. "
119143
"For conv2d, depthwise_conv2d, conv2d_transpose "
@@ -154,8 +178,7 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
154178
"nearest ties to even and 1 is rounding to nearest "
155179
"ties away from zero.but the received is %d",
156180
round_type));
157-
})
158-
.AsExtra();
181+
});
159182
AddAttr<bool>("is_test",
160183
"(bool, default false) Set to true for inference only, false "
161184
"for training. Some layers may run faster when this is true.")

paddle/fluid/operators/quantize_linear_op.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,31 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
5656

5757
if (quant_axis < 0) {
5858
if (!is_test) {
59-
auto* out_scale = context.Output<framework::Tensor>("OutScale");
60-
T* out_s = out_scale->mutable_data<T>(context.GetPlace());
59+
// training
60+
auto* in_accum = context.Input<framework::Tensor>("InAccum");
61+
auto* in_state = context.Input<framework::Tensor>("InState");
62+
auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
63+
T* cur_scale_data = static_cast<T*>(cur_scale->ptr());
64+
6165
FindAbsMaxFunctor<DeviceContext, T>()(
62-
dev_ctx, in->data<T>(), in->numel(), out_s);
66+
dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
67+
68+
auto* out_state = context.Output<framework::Tensor>("OutState");
69+
auto* out_accum = context.Output<framework::Tensor>("OutAccum");
70+
auto* out_scale = context.Output<framework::Tensor>("OutScale");
71+
out_state->mutable_data<T>(context.GetPlace());
72+
out_accum->mutable_data<T>(context.GetPlace());
73+
out_scale->mutable_data<T>(context.GetPlace());
74+
float moving_rate = context.Attr<float>("moving_rate");
75+
76+
FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
77+
*in_accum,
78+
*in_state,
79+
cur_scale_data,
80+
moving_rate,
81+
out_state,
82+
out_accum,
83+
out_scale);
6384
ClipAndFakeQuantFunctor<DeviceContext, T>()(
6485
dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
6586
} else {

python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,7 @@ def quantize(self):
418418
self._update_program()
419419

420420
# save out_threshold for quantized ops.
421-
if not self._onnx_format:
422-
self._save_output_threshold()
421+
self._save_output_threshold()
423422

424423
if any(op_type in self._quantizable_op_type
425424
for op_type in self._dynamic_quantize_op_type):
@@ -996,16 +995,23 @@ def _save_output_threshold(self):
996995
'''
997996
Save output threshold to the quantized op.
998997
'''
998+
self._calibration_scales = {}
999999

10001000
def save_info(op_node, out_var_name, threshold_map, out_info_name,
10011001
quantized_type):
10021002
assert out_var_name in threshold_map, \
10031003
"The output ({}) of {} node does not have threshold.".format(
10041004
out_var_name, op_node.type)
1005-
op_node._set_attr(out_info_name, threshold_map[var_name])
1006-
op_node._set_attr("with_quant_attr", True)
1007-
if op_node.type in self._quantizable_op_type:
1008-
op._set_attr("quantization_type", quantized_type)
1005+
if self._onnx_format:
1006+
# For easy extension, every var_node set a dict to save parameters of quant.
1007+
self._calibration_scales[var_name] = {}
1008+
self._calibration_scales[var_name]['scale'] = threshold_map[
1009+
var_name]
1010+
else:
1011+
op_node._set_attr(out_info_name, threshold_map[var_name])
1012+
op_node._set_attr("with_quant_attr", True)
1013+
if op_node.type in self._quantizable_op_type:
1014+
op._set_attr("quantization_type", quantized_type)
10091015

10101016
def analysis_and_save_info(op_node, out_var_name):
10111017
argname_index = utils._get_output_name_index(op_node, out_var_name)

python/paddle/fluid/contrib/slim/quantization/quantization_pass.py

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1785,6 +1785,7 @@ class InsertQuantizeLinear(object):
17851785
equal to 0, it will quantization with per channel, else quantization with per layer.
17861786
Default is -1.
17871787
channel_wise(bool, optional): Whether quantization with per channel or not. Default is False.
1788+
moving_rate(float): the rate for 'moving average' method.
17881789
is_test(bool, optional): Whether quantization with training or not. Default is True.
17891790
"""
17901791

@@ -1794,22 +1795,24 @@ def __init__(self,
17941795
quant_bits=8,
17951796
quant_axis=-1,
17961797
channel_wise=False,
1798+
moving_rate=0.9,
17971799
is_test=True):
17981800
self._place = place
17991801
self._scope = scope
18001802
self.quant_bits = quant_bits
18011803
self.quant_axis = quant_axis
18021804
self.channel_wise = channel_wise
18031805
self._is_test = is_test
1806+
self._moving_rate = moving_rate
18041807

1805-
def insert_quant_op(self, graph, var_node):
1808+
def insert_quant_op(self, graph, var_node, var_name=None):
18061809
assert var_node.is_var(), '{} is not a var'.format(var_node.name())
1807-
1808-
quant_var_node = graph.create_var_node(name=self._quantized_var_name(
1809-
var_node.name()),
1810-
var_type=var_node.type(),
1811-
shape=var_node.shape(),
1812-
var_dtype=var_node.dtype())
1810+
var_name = var_node.name() if not var_name else var_name
1811+
quant_var_node = graph.create_var_node(
1812+
name=self._quantized_var_name(var_name),
1813+
var_type=var_node.type(),
1814+
shape=var_node.shape(),
1815+
var_dtype=var_node.dtype())
18131816
data_type = 'float64' if var_node.dtype(
18141817
) == core.VarDesc.VarType.FP64 else 'float32'
18151818
if self.channel_wise:
@@ -1821,7 +1824,7 @@ def insert_quant_op(self, graph, var_node):
18211824
scale_var_type = var_node.type()
18221825
init_scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)
18231826
scale_var_node = graph.create_persistable_node(
1824-
name=self._quantized_scale_name(var_node.name()),
1827+
name=self._quantized_scale_name(var_name),
18251828
var_type=scale_var_type,
18261829
shape=[scale_var_shape],
18271830
var_dtype=var_node.dtype())
@@ -1844,13 +1847,39 @@ def insert_quant_op(self, graph, var_node):
18441847
inputs["ZeroPoint"] = zero_point_node
18451848

18461849
attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
1850+
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
18471851
outputs = {"Y": quant_var_node}
18481852
if not self._is_test:
1849-
attrs["is_test"] = self._is_test
1850-
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
18511853
scale_out_node = graph.create_var_node_from_desc(
18521854
scale_var_node.var())
1855+
state_in_node = graph.create_persistable_node(
1856+
name=unique_name.generate('state'),
1857+
var_type=core.VarDesc.VarType.LOD_TENSOR,
1858+
var_dtype=var_node.dtype(),
1859+
shape=[1])
1860+
data_type = 'float64' if var_node.dtype(
1861+
) == core.VarDesc.VarType.FP64 else 'float32'
1862+
_init_var_node(state_in_node, np.ones([1], dtype=data_type),
1863+
self._scope, self._place)
1864+
accum_in_node = graph.create_persistable_node(
1865+
name=unique_name.generate('accum'),
1866+
var_type=core.VarDesc.VarType.LOD_TENSOR,
1867+
var_dtype=var_node.dtype(),
1868+
shape=[1])
1869+
_init_var_node(accum_in_node, np.ones([1], dtype=data_type),
1870+
self._scope, self._place)
1871+
state_out_node = graph.create_var_node_from_desc(
1872+
state_in_node.var())
1873+
accum_out_node = graph.create_var_node_from_desc(
1874+
accum_in_node.var())
1875+
18531876
outputs["OutScale"] = scale_out_node
1877+
inputs['InState'] = state_in_node
1878+
inputs['InAccum'] = accum_in_node
1879+
outputs['OutState'] = state_out_node
1880+
outputs['OutAccum'] = accum_out_node
1881+
attrs["is_test"] = self._is_test
1882+
attrs['moving_rate'] = self._moving_rate
18541883

18551884
quant_op_node = graph.create_op_node(op_type="quantize_linear",
18561885
attrs=attrs,
@@ -1863,6 +1892,10 @@ def insert_quant_op(self, graph, var_node):
18631892
graph.link_to(zero_point_node, quant_op_node)
18641893
graph.link_to(quant_op_node, quant_var_node)
18651894
if not self._is_test:
1895+
graph.link_to(state_in_node, quant_op_node)
1896+
graph.link_to(accum_in_node, quant_op_node)
1897+
graph.link_to(quant_op_node, state_out_node)
1898+
graph.link_to(quant_op_node, accum_out_node)
18661899
graph.link_to(quant_op_node, scale_out_node)
18671900
return quant_var_node, scale_var_node
18681901

@@ -1891,8 +1924,7 @@ def insert_dequant_op(self, graph, var_node, scale_var_node):
18911924
inputs["ZeroPoint"] = zero_point_node
18921925

18931926
attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
1894-
if not self._is_test:
1895-
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
1927+
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
18961928

18971929
quant_op_node = graph.create_op_node(op_type="dequantize_linear",
18981930
attrs=attrs,
@@ -1931,10 +1963,10 @@ def _zero_point_name(self, var_name):
19311963
return "%s@zero_point" % (var_name)
19321964

19331965

1934-
class QuantizationTransformPassV2(object):
1966+
class QuantizationTransformPassV2(QuantizationTransformPass):
19351967
"""
19361968
Quantize the ops that have weights. Add quant and dequant ops for
1937-
the quantized ops's inputs.
1969+
the quantized ops's inputs. It is used in the new format of quantization.
19381970
"""
19391971

19401972
def __init__(self,
@@ -2130,13 +2162,13 @@ def _transform_forward(self, graph, op):
21302162
if is_weight and self._weight_quantize_func is not None:
21312163
target_out_node = self._insert_func(
21322164
graph, self._weight_quantize_func, var_node, op)
2133-
processed_vars.append(name)
2165+
self.processed_vars.append(name)
21342166
continue
21352167
elif not is_weight and self._act_quantize_func is not None:
21362168
target_out_node = self._insert_func(graph,
21372169
self._act_quantize_func,
21382170
var_node, op)
2139-
processed_vars.append(name)
2171+
self.processed_vars.append(name)
21402172
continue
21412173

21422174
quant_bits = self._weight_bits if var_node.name() in self.persistable_vars \
@@ -2155,9 +2187,10 @@ def _transform_forward(self, graph, op):
21552187
quant_bits=quant_bits,
21562188
quant_axis=quant_axis,
21572189
channel_wise=channel_wise,
2190+
moving_rate=self._moving_rate,
21582191
is_test=self._is_test)
21592192
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
2160-
graph, var_node)
2193+
graph, var_node, var_name=name)
21612194
dequant_var_node = insert_quant_pass.insert_dequant_op(
21622195
graph, quant_var_node, scale_var_node)
21632196

@@ -2182,24 +2215,6 @@ def _has_weight(self, op):
21822215
has_weight = True
21832216
return has_weight
21842217

2185-
def _is_skip_quant(self, graph, op_node):
2186-
"""
2187-
Analyse whether the op node skips quantization.
2188-
"""
2189-
is_skip = False
2190-
if op_node.op().has_attr("skip_quant") and \
2191-
op_node.op().attr("skip_quant"):
2192-
is_skip = True
2193-
# if the inputs of mul and matmul are not all persistable, use
2194-
# AddQuantDequantPassV2 to quantize them.
2195-
if op_node.name() in ["mul", "matmul", "matmul_v2"] and \
2196-
_is_input_all_not_persistable(graph, op_node):
2197-
is_skip = True
2198-
if op_node.op().has_attr("quantization_type") and \
2199-
op_node.op().attr("quantization_type") == "qat_without_weight":
2200-
is_skip = True
2201-
return is_skip
2202-
22032218
def apply(self, graph):
22042219
"""
22052220
Quantize the graph for training process. According to weight and
@@ -2250,7 +2265,7 @@ def apply(self, graph):
22502265
class AddQuantDequantPassV2(object):
22512266
"""
22522267
Quantize the ops that do not have weights, and add quant_linear and dequant_linear
2253-
op for the quantized ops's inputs.
2268+
op for the quantized ops's inputs. It is used in the new format of quantization.
22542269
"""
22552270

22562271
# To be compatible with PaddleSlim, not remove _activation_type for now
@@ -2377,6 +2392,7 @@ def apply(self, graph):
23772392
quant_bits=self._quant_bits,
23782393
quant_axis=-1,
23792394
channel_wise=False,
2395+
moving_rate=self._moving_rate,
23802396
is_test=self._is_test)
23812397
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
23822398
graph, in_node)

python/paddle/fluid/contrib/slim/quantization/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
"square",
110110
"softplus",
111111
"shuffle_channel",
112+
"reduce_max",
112113
]
113114

114115
_out_scale_op_list = list(
@@ -213,6 +214,7 @@
213214
"square": [["X"], ["Out"]],
214215
"softplus": [["X"], ["Out"]],
215216
"shuffle_channel": [["X"], ["Out"]],
217+
"reduce_max": [["X"], ["Out"]],
216218
}
217219

218220

python/paddle/fluid/tests/unittests/test_fake_quantize_op.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -550,18 +550,41 @@ def set_args(self):
550550
def setUp(self):
551551
self.set_args()
552552
self.op_type = "quantize_linear"
553-
x = np.random.randn(31, 65).astype(self.data_type)
554-
yq, scale = quantize_max_abs(x, self.max_range)
555-
scale = np.array(scale).astype(self.data_type)
556-
zero_point = np.zeros(scale.shape, dtype="int32")
557-
558-
self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
559553
self.attrs = {
560554
'bit_length': self.bit_length,
561555
'quant_axis': self.quant_axis,
556+
'moving_rate': 0.9,
562557
'is_test': self.is_test
563558
}
564-
self.outputs = {'Y': yq, 'OutScale': scale}
559+
560+
x = np.random.randn(31, 65).astype(self.data_type)
561+
scale = np.array([0.001]).astype(self.data_type)
562+
zero_point = np.zeros(scale.shape, dtype="int32")
563+
in_accum = np.ones(1).astype(self.data_type)
564+
in_state = np.ones(1).astype(self.data_type)
565+
out_accum = np.zeros(1).astype(self.data_type)
566+
out_state = np.zeros(1).astype(self.data_type)
567+
out_accum[0] = self.attrs['moving_rate'] * in_accum[0] + np.max(
568+
np.abs(x))
569+
out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0
570+
out_scale = out_accum / out_state
571+
572+
round_out = np.round(x / out_scale * self.max_range)
573+
quant_data = np.clip(round_out, -self.max_range - 1, self.max_range)
574+
575+
self.inputs = {
576+
'X': x,
577+
'Scale': scale,
578+
'ZeroPoint': zero_point,
579+
'InAccum': in_accum,
580+
'InState': in_state,
581+
}
582+
self.outputs = {
583+
'Y': quant_data,
584+
'OutScale': out_scale,
585+
'OutAccum': out_accum,
586+
'OutState': out_state,
587+
}
565588

566589
def test_check_output(self):
567590
self.check_output()

0 commit comments

Comments
 (0)