@@ -1785,6 +1785,7 @@ class InsertQuantizeLinear(object):
1785
1785
equal to 0, it will quantization with per channel, else quantization with per layer.
1786
1786
Default is -1.
1787
1787
channel_wise(bool, optional): Whether quantization with per channel or not. Default is False.
1788
+ moving_rate(float): the rate for 'moving average' method.
1788
1789
is_test(bool, optional): Whether quantization with training or not. Default is True.
1789
1790
"""
1790
1791
@@ -1794,22 +1795,24 @@ def __init__(self,
1794
1795
quant_bits = 8 ,
1795
1796
quant_axis = - 1 ,
1796
1797
channel_wise = False ,
1798
+ moving_rate = 0.9 ,
1797
1799
is_test = True ):
1798
1800
self ._place = place
1799
1801
self ._scope = scope
1800
1802
self .quant_bits = quant_bits
1801
1803
self .quant_axis = quant_axis
1802
1804
self .channel_wise = channel_wise
1803
1805
self ._is_test = is_test
1806
+ self ._moving_rate = moving_rate
1804
1807
1805
- def insert_quant_op (self , graph , var_node ):
1808
+ def insert_quant_op (self , graph , var_node , var_name = None ):
1806
1809
assert var_node .is_var (), '{} is not a var' .format (var_node .name ())
1807
-
1808
- quant_var_node = graph .create_var_node (name = self . _quantized_var_name (
1809
- var_node . name () ),
1810
- var_type = var_node .type (),
1811
- shape = var_node .shape (),
1812
- var_dtype = var_node .dtype ())
1810
+ var_name = var_node . name () if not var_name else var_name
1811
+ quant_var_node = graph .create_var_node (
1812
+ name = self . _quantized_var_name ( var_name ),
1813
+ var_type = var_node .type (),
1814
+ shape = var_node .shape (),
1815
+ var_dtype = var_node .dtype ())
1813
1816
data_type = 'float64' if var_node .dtype (
1814
1817
) == core .VarDesc .VarType .FP64 else 'float32'
1815
1818
if self .channel_wise :
@@ -1821,7 +1824,7 @@ def insert_quant_op(self, graph, var_node):
1821
1824
scale_var_type = var_node .type ()
1822
1825
init_scale_value = np .array ([_SCALE_DEFAULT_VALUE ], dtype = data_type )
1823
1826
scale_var_node = graph .create_persistable_node (
1824
- name = self ._quantized_scale_name (var_node . name () ),
1827
+ name = self ._quantized_scale_name (var_name ),
1825
1828
var_type = scale_var_type ,
1826
1829
shape = [scale_var_shape ],
1827
1830
var_dtype = var_node .dtype ())
@@ -1844,13 +1847,39 @@ def insert_quant_op(self, graph, var_node):
1844
1847
inputs ["ZeroPoint" ] = zero_point_node
1845
1848
1846
1849
attrs = {"quant_axis" : self .quant_axis , "bit_length" : self .quant_bits }
1850
+ attrs ["op_role" ] = core .op_proto_and_checker_maker .OpRole .Forward
1847
1851
outputs = {"Y" : quant_var_node }
1848
1852
if not self ._is_test :
1849
- attrs ["is_test" ] = self ._is_test
1850
- attrs ["op_role" ] = core .op_proto_and_checker_maker .OpRole .Forward
1851
1853
scale_out_node = graph .create_var_node_from_desc (
1852
1854
scale_var_node .var ())
1855
+ state_in_node = graph .create_persistable_node (
1856
+ name = unique_name .generate ('state' ),
1857
+ var_type = core .VarDesc .VarType .LOD_TENSOR ,
1858
+ var_dtype = var_node .dtype (),
1859
+ shape = [1 ])
1860
+ data_type = 'float64' if var_node .dtype (
1861
+ ) == core .VarDesc .VarType .FP64 else 'float32'
1862
+ _init_var_node (state_in_node , np .ones ([1 ], dtype = data_type ),
1863
+ self ._scope , self ._place )
1864
+ accum_in_node = graph .create_persistable_node (
1865
+ name = unique_name .generate ('accum' ),
1866
+ var_type = core .VarDesc .VarType .LOD_TENSOR ,
1867
+ var_dtype = var_node .dtype (),
1868
+ shape = [1 ])
1869
+ _init_var_node (accum_in_node , np .ones ([1 ], dtype = data_type ),
1870
+ self ._scope , self ._place )
1871
+ state_out_node = graph .create_var_node_from_desc (
1872
+ state_in_node .var ())
1873
+ accum_out_node = graph .create_var_node_from_desc (
1874
+ accum_in_node .var ())
1875
+
1853
1876
outputs ["OutScale" ] = scale_out_node
1877
+ inputs ['InState' ] = state_in_node
1878
+ inputs ['InAccum' ] = accum_in_node
1879
+ outputs ['OutState' ] = state_out_node
1880
+ outputs ['OutAccum' ] = accum_out_node
1881
+ attrs ["is_test" ] = self ._is_test
1882
+ attrs ['moving_rate' ] = self ._moving_rate
1854
1883
1855
1884
quant_op_node = graph .create_op_node (op_type = "quantize_linear" ,
1856
1885
attrs = attrs ,
@@ -1863,6 +1892,10 @@ def insert_quant_op(self, graph, var_node):
1863
1892
graph .link_to (zero_point_node , quant_op_node )
1864
1893
graph .link_to (quant_op_node , quant_var_node )
1865
1894
if not self ._is_test :
1895
+ graph .link_to (state_in_node , quant_op_node )
1896
+ graph .link_to (accum_in_node , quant_op_node )
1897
+ graph .link_to (quant_op_node , state_out_node )
1898
+ graph .link_to (quant_op_node , accum_out_node )
1866
1899
graph .link_to (quant_op_node , scale_out_node )
1867
1900
return quant_var_node , scale_var_node
1868
1901
@@ -1891,8 +1924,7 @@ def insert_dequant_op(self, graph, var_node, scale_var_node):
1891
1924
inputs ["ZeroPoint" ] = zero_point_node
1892
1925
1893
1926
attrs = {"quant_axis" : self .quant_axis , "bit_length" : self .quant_bits }
1894
- if not self ._is_test :
1895
- attrs ["op_role" ] = core .op_proto_and_checker_maker .OpRole .Forward
1927
+ attrs ["op_role" ] = core .op_proto_and_checker_maker .OpRole .Forward
1896
1928
1897
1929
quant_op_node = graph .create_op_node (op_type = "dequantize_linear" ,
1898
1930
attrs = attrs ,
@@ -1931,10 +1963,10 @@ def _zero_point_name(self, var_name):
1931
1963
return "%s@zero_point" % (var_name )
1932
1964
1933
1965
1934
- class QuantizationTransformPassV2 (object ):
1966
+ class QuantizationTransformPassV2 (QuantizationTransformPass ):
1935
1967
"""
1936
1968
Quantize the ops that have weights. Add quant and dequant ops for
1937
- the quantized ops's inputs.
1969
+ the quantized ops's inputs. It is used in the new format of quantization.
1938
1970
"""
1939
1971
1940
1972
def __init__ (self ,
@@ -2130,13 +2162,13 @@ def _transform_forward(self, graph, op):
2130
2162
if is_weight and self ._weight_quantize_func is not None :
2131
2163
target_out_node = self ._insert_func (
2132
2164
graph , self ._weight_quantize_func , var_node , op )
2133
- processed_vars .append (name )
2165
+ self . processed_vars .append (name )
2134
2166
continue
2135
2167
elif not is_weight and self ._act_quantize_func is not None :
2136
2168
target_out_node = self ._insert_func (graph ,
2137
2169
self ._act_quantize_func ,
2138
2170
var_node , op )
2139
- processed_vars .append (name )
2171
+ self . processed_vars .append (name )
2140
2172
continue
2141
2173
2142
2174
quant_bits = self ._weight_bits if var_node .name () in self .persistable_vars \
@@ -2155,9 +2187,10 @@ def _transform_forward(self, graph, op):
2155
2187
quant_bits = quant_bits ,
2156
2188
quant_axis = quant_axis ,
2157
2189
channel_wise = channel_wise ,
2190
+ moving_rate = self ._moving_rate ,
2158
2191
is_test = self ._is_test )
2159
2192
quant_var_node , scale_var_node = insert_quant_pass .insert_quant_op (
2160
- graph , var_node )
2193
+ graph , var_node , var_name = name )
2161
2194
dequant_var_node = insert_quant_pass .insert_dequant_op (
2162
2195
graph , quant_var_node , scale_var_node )
2163
2196
@@ -2182,24 +2215,6 @@ def _has_weight(self, op):
2182
2215
has_weight = True
2183
2216
return has_weight
2184
2217
2185
- def _is_skip_quant (self , graph , op_node ):
2186
- """
2187
- Analyse whether the op node skips quantization.
2188
- """
2189
- is_skip = False
2190
- if op_node .op ().has_attr ("skip_quant" ) and \
2191
- op_node .op ().attr ("skip_quant" ):
2192
- is_skip = True
2193
- # if the inputs of mul and matmul are not all persistable, use
2194
- # AddQuantDequantPassV2 to quantize them.
2195
- if op_node .name () in ["mul" , "matmul" , "matmul_v2" ] and \
2196
- _is_input_all_not_persistable (graph , op_node ):
2197
- is_skip = True
2198
- if op_node .op ().has_attr ("quantization_type" ) and \
2199
- op_node .op ().attr ("quantization_type" ) == "qat_without_weight" :
2200
- is_skip = True
2201
- return is_skip
2202
-
2203
2218
def apply (self , graph ):
2204
2219
"""
2205
2220
Quantize the graph for training process. According to weight and
@@ -2250,7 +2265,7 @@ def apply(self, graph):
2250
2265
class AddQuantDequantPassV2 (object ):
2251
2266
"""
2252
2267
Quantize the ops that do not have weights, and add quant_linear and dequant_linear
2253
- op for the quantized ops's inputs.
2268
+ op for the quantized ops's inputs. It is used in the new format of quantization.
2254
2269
"""
2255
2270
2256
2271
# To be compatible with PaddleSlim, not remove _activation_type for now
@@ -2377,6 +2392,7 @@ def apply(self, graph):
2377
2392
quant_bits = self ._quant_bits ,
2378
2393
quant_axis = - 1 ,
2379
2394
channel_wise = False ,
2395
+ moving_rate = self ._moving_rate ,
2380
2396
is_test = self ._is_test )
2381
2397
quant_var_node , scale_var_node = insert_quant_pass .insert_quant_op (
2382
2398
graph , in_node )
0 commit comments