Skip to content

Commit 48e0634

Browse files
Tracinzhangqi3
andauthored
[Misc] Update about STPU. (#232)
* [Misc] Update about STPU. --------- Co-authored-by: zhangqi3 <[email protected]>
1 parent e2f6d78 commit 48e0634

File tree

7 files changed

+67
-44
lines changed

7 files changed

+67
-44
lines changed

mqbench/custom_quantizer/model_quantizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,7 @@ def _find_act_quants(self, model: GraphModule) -> List:
233233
((node.op == 'call_function' or node.op == 'call_method') and
234234
node.target in self.function_type_to_quant_input) or node.name in self.additional_node_name:
235235
input_node_list = self._flatten_args(node.args)
236-
# Means this is not Tensor + Tensor.
237-
if not all([isinstance(_node, torch.fx.node.Node) for _node in input_node_list]):
238-
continue
236+
input_node_list = [_node for _node in input_node_list if isinstance(_node, torch.fx.node.Node)]
239237
for _node in input_node_list:
240238
if self._is_implicit_merge(modules, (node, _node)):
241239
logger.info("Implicit merge: {} + {}".format(_node.name, node.name))
@@ -272,4 +270,4 @@ def _convert(self, module, mapping=None, inplace=False, scope=''):
272270
for key, value in reassign.items():
273271
module._modules[key] = value
274272

275-
return module
273+
return module

mqbench/custom_quantizer/total_int_quantizer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def _passed_func_type(self):
3232
@property
3333
def _passed_module_type(self):
3434
return (
35+
torch.nn.Dropout2d,
3536
torch.nn.ReLU,
3637
torch.nn.ReLU6
3738
)
@@ -50,9 +51,11 @@ def _find_act_quants(self, model: GraphModule) -> list:
5051
((node.op == 'call_function' or node.op == 'call_method') and
5152
node.target in self.function_type_to_quant_input):
5253
for next_node in node.users:
53-
if not ((next_node.op == 'call_function' and next_node.target in self._passed_func_type) or
54+
if ((next_node.op == 'call_function' and next_node.target in self._passed_func_type) or
5455
(next_node.op == 'call_module' and isinstance(modules[next_node.target], self._passed_module_type))):
55-
node_need_to_quantize_output.append(node)
56-
else:
5756
node_need_to_quantize_output.append(next_node)
58-
return node_need_to_quantize_output
57+
elif self._is_implicit_merge(modules, (next_node, node)):
58+
continue
59+
else:
60+
node_need_to_quantize_output.append(node)
61+
return node_need_to_quantize_output

mqbench/deploy/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,13 @@ def prepare_initializer(graph):
225225
return named_initializer
226226

227227

228+
def insert_initializer(graph, new_init):
229+
for init in graph.initializer:
230+
if init.name == new_init.name:
231+
graph.initializer.remove(init)
232+
graph.initializer.append(new_init)
233+
234+
228235
def parse_attrs(node_attrs):
229236
attrs = {}
230237
for attr in node_attrs:

mqbench/deploy/deploy_linear.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def deal_with_activation_fakequant(self, node, inp2node):
7575
next_nodes = inp2node[node.output[0]]
7676
for next_node, idx in next_nodes:
7777
next_node.input[idx] = node.input[0]
78-
return
7978

8079
def parse_qparams(self, node, name2data):
8180
tensor_name, scale, zero_point = node.input[:3]
@@ -119,13 +118,13 @@ def post_process_clip_ranges(self, clip_ranges, graph, inp2node):
119118
def find_the_closest_clip_range(node):
120119
if node.input[0] in clip_ranges:
121120
return node.input[0]
122-
elif node.op_type in ['Flatten', 'Resize'] and node.output[0] in inp2node:
121+
elif node.op_type in ['Flatten', 'Resize', 'Reshape'] and node.output[0] in inp2node:
123122
return find_the_closest_clip_range(inp2node[node.output[0]][0][0])
124123
else:
125124
return None
126125

127126
for node in graph.node:
128-
if node.op_type in ['Flatten', 'Resize']:
127+
if node.op_type in ['Flatten', 'Resize', 'Reshape']:
129128
tensor_name = find_the_closest_clip_range(node)
130129
if tensor_name:
131130
clip_ranges[node.input[0]] = clip_ranges[tensor_name]

mqbench/deploy/deploy_stpu.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from collections import OrderedDict
44

55
import onnx
6+
from onnx import numpy_helper
67

78
from mqbench.deploy.common import (get_constant_inputs, prepare_data,
8-
prepare_initializer,
9+
prepare_initializer, insert_initializer,
910
update_inp2node_out2node)
1011
from mqbench.deploy.deploy_linear import (PERTENSOR_FAKEQUANTIZER,
1112
LinearQuantizer_process)
@@ -17,10 +18,8 @@ class STPU_process(LinearQuantizer_process):
1718
def remove_fakequantize_and_collect_params(self, onnx_path, model_name):
1819
model = onnx.load(onnx_path)
1920
graph = model.graph
20-
out2node, inp2node = update_inp2node_out2node(graph)
2121
name2data = prepare_data(graph)
2222
named_initializer = prepare_initializer(graph)
23-
2423
out2node, inp2node = update_inp2node_out2node(graph)
2524

2625
quant_params = OrderedDict()
@@ -57,6 +56,35 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name):
5756
"min": -127 * scale,
5857
"max": 127 * scale
5958
}
59+
# Merge Conv + mul
60+
for conv_node in graph.node:
61+
# Newwork output.
62+
if conv_node.output[0] not in inp2node or len(inp2node[conv_node.output[0]]) < 1:
63+
continue
64+
mul_node = inp2node[conv_node.output[0]][0][0]
65+
if conv_node.op_type == 'Conv' and mul_node.op_type == 'Mul':
66+
mul_scale = numpy_helper.to_array(out2node[mul_node.input[1]].attribute[0].t)
67+
weight_name = named_initializer[conv_node.input[1]].name
68+
bias_name = named_initializer[conv_node.input[2]].name
69+
weight = numpy_helper.to_array(named_initializer[conv_node.input[1]])
70+
bias = numpy_helper.to_array(named_initializer[conv_node.input[2]])
71+
new_weight = numpy_helper.from_array(weight * mul_scale)
72+
new_bias = numpy_helper.from_array(bias * mul_scale)
73+
new_weight.name = weight_name
74+
new_bias.name = bias_name
75+
insert_initializer(graph, new_weight)
76+
insert_initializer(graph, new_bias)
77+
quant_params[conv_node.name + '_weights']['min'] *= mul_scale
78+
quant_params[conv_node.name + '_weights']['max'] *= mul_scale
79+
# Delete mul node.
80+
nodes_to_be_removed.append(mul_node)
81+
conv_node.output[0] = mul_node.output[0]
82+
# Pass concat
83+
for node in graph.node:
84+
if node.op_type == 'Concat' and node.output[0] in quant_params:
85+
for node_input in node.input:
86+
quant_params[node_input] = quant_params[node.output[0]]
87+
logger.info(f'Pass {node.output[0]} range to {node.name} input {node_input}.')
6088
# Update bias scale = input scale * weight scale
6189
for node in graph.node:
6290
if node.op_type in ['Gemm', 'Conv'] and len(node.input) == 3:
@@ -74,6 +102,7 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name):
74102
}
75103
quant_params = self.post_process_clip_ranges(quant_params, graph, inp2node)
76104
self.merge_relu_layer(graph, quant_params, out2node)
105+
# Update emin.
77106
for node in graph.node:
78107
self.update_emin(node, quant_params, named_initializer)
79108
# Delete node and init.
@@ -131,7 +160,7 @@ def find_conv_emin(i_vmax, w_vmax, o_vmax, n, r):
131160

132161
if node.op_type in ['Upsample', 'DynamicUpsample']:
133162
emin = find_interp_emin(quant_params[node.output[0]]['max'], 2)
134-
quant_params[node.output[0]]['emin'] = emin
163+
quant_params[node.output[0]]['emin'] = emin
135164
if node.op_type in ['Conv', 'ConvTranspose']:
136165
weight_shape = named_initializer[node.input[1]].dims
137166
n = weight_shape[1] * weight_shape[2] * weight_shape[3]

mqbench/fuser_method_mappings.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Optional, Type
2-
31
import torch
42
import torch.nn as nn
53
from torch.quantization.fx.fusion_patterns import ConvBNReLUFusion, ModuleReLUFusion
@@ -13,7 +11,7 @@
1311
from mqbench.nn.modules import FrozenBatchNorm2d
1412

1513

16-
class ConvFreezebnReLUFusion(ConvBNReLUFusion):
14+
class ConvExtendBnReLUFusion(ConvBNReLUFusion):
1715
def __init__(self, quantizer: QuantizerCls, node: Node):
1816
super(ConvBNReLUFusion, self).__init__(quantizer, node)
1917
self.relu_node = None
@@ -87,39 +85,27 @@ def fuse_deconv_bn_relu(deconv, bn, relu):
8785
def fuse_conv_freezebn(conv, bn):
8886
assert bn.training is False, "Freezebn must be eval."
8987

90-
fused_module_class_map = {
91-
nn.Conv2d: qnni.ConvFreezebn2d,
92-
}
93-
9488
if conv.training:
9589
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
9690
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
9791
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
98-
fused_module_class = fused_module_class_map.get((type(conv)), None)
99-
return fused_module_class(conv, bn)
92+
return qnni.ConvFreezebn2d(conv, bn)
10093
else:
10194
return nn.utils.fuse_conv_bn_eval(conv, bn)
10295

10396

10497
def fuse_conv_freezebn_relu(conv, bn, relu):
105-
assert conv.training == relu.training and bn.training is False, "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
106-
fused_module : Optional[Type[nn.Sequential]] = None
98+
assert conv.training == relu.training and bn.training is False, \
99+
"Conv and relu both must be in the same mode (train or eval) and bn must be eval."
100+
107101
if conv.training:
108-
map_to_fused_module_train = {
109-
nn.Conv2d: qnni.ConvFreezebnReLU2d,
110-
}
111102
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
112103
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
113104
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
114-
fused_module = map_to_fused_module_train.get(type(conv), None)
115-
return fused_module(conv, bn, relu)
105+
return qnni.ConvFreezebnReLU2d(conv, bn, relu)
116106
else:
117-
map_to_fused_module_eval = {
118-
nn.Conv2d: nn.intrinsic.ConvReLU2d,
119-
}
120-
fused_module = map_to_fused_module_eval.get(type(conv), None)
121-
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
122-
return fused_module(fused_conv, relu)
107+
fused_conv = nn.utils.fuse_conv_bn_eval(conv, bn)
108+
return nn.intrinsic.ConvReLU2d(fused_conv, relu)
123109

124110

125111
def fuse_deconv_freezebn(deconv, bn):
@@ -135,7 +121,8 @@ def fuse_deconv_freezebn(deconv, bn):
135121

136122

137123
def fuse_deconv_freezebn_relu(deconv, bn, relu):
138-
assert deconv.training == relu.training and bn.training is False, "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
124+
assert deconv.training == relu.training and bn.training is False, \
125+
"Conv and relu both must be in the same mode (train or eval) and bn must be eval."
139126

140127
if deconv.training:
141128
assert bn.num_features == deconv.out_channels, 'Output channel of ConvTranspose2d must match num_features of BatchNorm2d'
@@ -171,13 +158,13 @@ def fuse_deconv_freezebn_relu(deconv, bn, relu):
171158
(torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.ConvTranspose2d)):
172159
ConvBNReLUFusion,
173160
(torch.nn.ReLU, (FrozenBatchNorm2d, torch.nn.Conv2d)):
174-
ConvFreezebnReLUFusion,
161+
ConvExtendBnReLUFusion,
175162
(FrozenBatchNorm2d, torch.nn.Conv2d):
176-
ConvFreezebnReLUFusion,
163+
ConvExtendBnReLUFusion,
177164
(torch.nn.ReLU, (FrozenBatchNorm2d, torch.nn.ConvTranspose2d)):
178-
ConvFreezebnReLUFusion,
165+
ConvExtendBnReLUFusion,
179166
(FrozenBatchNorm2d, torch.nn.ConvTranspose2d):
180-
ConvFreezebnReLUFusion,
167+
ConvExtendBnReLUFusion,
181168
},
182169
"additional_qat_module_mappings": {
183170
nn.ConvTranspose2d: qnn.qat.ConvTranspose2d,

mqbench/prepare_by_platform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ class BackendType(Enum):
123123
default_weight_observer=MinMaxObserver,
124124
default_act_observer=EMAMinMaxObserver),
125125
BackendType.STPU: dict(qtype="affine",
126-
w_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8),
127-
a_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8),
126+
w_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8, symmetric_range=True),
127+
a_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8, symmetric_range=True),
128128
default_weight_quantize=FixedFakeQuantize,
129129
default_act_quantize=FixedFakeQuantize,
130130
default_weight_observer=MinMaxObserver,

0 commit comments

Comments
 (0)