Skip to content

Commit 7883312

Browse files
authored
onnx 模型导出与算法更新 (#231)
* 添加样例文件 * onnx模型导出与算法更新 * 为dump_torch_to_onnx函数添加了默认参数 * EXPORT_OVERLAPPED_CONFIG 现在是过时参数,你将使用TQC上的QuantizationVisiblity属性来进行导出控制。该属性有三个可选项:强制导出、TQC激活时导出、不导出。 * 修改了 exporter 逻辑以适配新的QuantizationVisiblity属性 * 修改了onnx qdq的导出逻辑,现在将尽可能消除对称量化中的激活函数。 * 修改了 graphwise analyser 的逻辑,现在允许分析多输出算子的误差 * 修改了 layerwise equalization 的逻辑,现在允许 include act,支持conv1d, conv2d conv3d, convtranpose1d, convtranspose2d, convtranspose3d, gemm, matmul * 修复了 passive parameter pass 中的 pad 量化错误 * 修复了 quant alignment pass 中 pooling 算子的对齐错误 * 修复了 核心量化函数在启动 cuda kernel 的情况下无法处理 cpu tensor 的问题 * 修改 openvino 量化策略,负数部分现在可以取到-128(曾经是-127) * 给 dsp quantizer 添加了一个新的量化类型 * 添加测试样例 * 修复ci错误
1 parent ec0429b commit 7883312

File tree

20 files changed

+696
-490
lines changed

20 files changed

+696
-490
lines changed

ppq/IR/base/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ def copy(self, copy_value: bool = False):
234234
'however its value is not an instance of torch.Tensor, '
235235
'ppq will automaticall convert it to torch.Tensor now.')
236236
self.value = convert_any_to_torch_tensor(self.value)
237-
return Variable(name=self.name, value=self.value.clone(), is_parameter=self.is_parameter)
237+
if isinstance(self.value, torch.Tensor):
238+
value = self.value.clone()
239+
return Variable(name=self.name, value=value, is_parameter=self.is_parameter)
238240

239241

240242
class Operation(OperationBase, Serializable):

ppq/api/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def dump_torch_to_onnx(
157157
model: torch.nn.Module,
158158
onnx_export_file: str,
159159
input_shape: List[int],
160-
input_dtype: torch.dtype,
160+
input_dtype: torch.dtype = torch.float,
161161
inputs: List[Any] = None,
162162
device: str = 'cuda'):
163163
"""

ppq/core/common.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# PPQ System configuration
33
# You can modify following codes for your own purpose.
44

5-
65
# Observer 中,最小 scale 限制,所有小于该值的 scale 将被该值覆盖
76
OBSERVER_MIN_SCALE = 1e-8
87
# Observer 中,最小 scale 的手动覆盖属性
@@ -64,9 +63,6 @@
6463
DEFAULT_OPSET_VERSION = 11
6564
STRICT_OPSET_CHECKING = False
6665

67-
# 导出 qdq 节点时是否需要导出状态已经是 overlap 的节点
68-
EXPORT_OVERLAPPED_CONFIG = False
69-
7066
# LSTM 算子的权重缓存属性
7167
LSTM_FLATTEN_WEIGHT_ATTRIB = 'LSTM_FLATTEN_WEIGHT_ATTRIB'
7268
# GRU 算子的权重缓存属性
@@ -90,4 +86,7 @@
9086
CHECKPOINT_TOLERANCE = 1
9187

9288
# 要做 Bias Correction 的算子种类
93-
BIAS_CORRECTION_INTERST_TYPE = {'Conv', 'Gemm', 'ConvTranspose'}
89+
BIAS_CORRECTION_INTERST_TYPE = {'Conv', 'Gemm', 'ConvTranspose'}
90+
91+
# 导出 qdq 节点时是否需要导出状态已经是 overlap 的节点
92+
EXPORT_OVERLAPPED_CONFIG = False

ppq/core/quant.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,20 @@
44
"""
55

66
import time # for hash generation
7-
from abc import abstractmethod
87
from enum import Enum
98
from typing import Any, Iterable, List
109

1110
import torch
1211

12+
from .common import EXPORT_OVERLAPPED_CONFIG
1313
from .storage import Serializable
1414

1515

16+
class QuantizationVisiblity(Enum):
17+
FORCE_EXPORT = 1
18+
EXPOET_WHEN_ACTIVE = 2
19+
INTERNAL = 3
20+
1621
class NetworkFramework(Enum):
1722
PPL = 1
1823
ONNX = 2
@@ -365,7 +370,7 @@ def __init__(
365370
offset: Any = None,
366371
observer_algorithm: str = None,
367372
detail: Any = None,
368-
require_export: bool = None,
373+
visiblity: QuantizationVisiblity = QuantizationVisiblity.EXPOET_WHEN_ACTIVE,
369374
state: QuantizationStates = QuantizationStates.INITIAL
370375
):
371376
"""Create a PPQ Tensor Quantization Configuration Instance.
@@ -395,7 +400,13 @@ def __init__(
395400
detail (Any, optional): Only used by PPQ internal logic, detail is used to store some internal data,
396401
you are not supposed to use it.
397402
398-
require_export (bool, optional): If require_export == True, PPQ exporter will export this TQC ignoring state checks.
403+
visiblity (Visiblity): visiblity is the attribute that controls export logic.
404+
405+
Currently, there are 3 Visiblity level in PPQ:
406+
if Visiblity == FORCE_EXPORT, ppq exporter will export this TQC
407+
ignoring state check(even if current TQC has been overrlapped).
408+
if Visiblity == EXPORT_WHEN_ACTIVD, ppq exporter will export this TQC only when it has been actived.
409+
if Visiblity == INTERNAL, This TQC will not be exported.
399410
400411
state (QuantizationStates, optional):
401412
Defaults to QuantizationStates.INITIAL, see QuantizationStates for more detail.
@@ -416,17 +427,25 @@ def __init__(
416427
self.detail = {} if detail is None else detail
417428
self._father_config = self # union-find
418429
self._hash = self.__create_hash()
419-
self._require_export = require_export
430+
self._visiblity = visiblity
420431
super().__init__()
421432

422-
@ abstractmethod
423-
def export(self) -> str:
424-
raise Exception('Implement this first')
433+
def can_export(self) -> bool:
434+
if self.visiblity == QuantizationVisiblity.INTERNAL: return False
435+
type_check = isinstance(self.scale, torch.Tensor) and isinstance(self.offset, torch.Tensor)
436+
valid_states = {QuantizationStates.BAKED, QuantizationStates.PASSIVE_BAKED}
437+
438+
if EXPORT_OVERLAPPED_CONFIG: valid_states.add(QuantizationStates.OVERLAPPED)
439+
state_check = QuantizationStates.is_activated(self.state) or self.state in valid_states
440+
441+
if (state_check or self.visiblity == QuantizationVisiblity.FORCE_EXPORT):
442+
if type_check: return True
443+
return False
425444

426445
def __eq__(self, o: object) -> bool:
427446
if not isinstance(o, TensorQuantizationConfig):
428-
raise TypeError('Can only compare TensorQuantizationConfig object '\
429-
'with another TensorQuantizationConfig object.')
447+
raise TypeError('Can only compare TensorQuantizationConfig object '
448+
'with another TensorQuantizationConfig object.')
430449
return self._hash == o._hash
431450

432451
def __str__(self) -> str:
@@ -509,17 +528,13 @@ def is_revisable(self):
509528
})
510529

511530
@ property
512-
def exportable(self) -> bool:
513-
value_check = isinstance(self.scale, torch.Tensor)
514-
if self._require_export is None:
515-
state_check = QuantizationStates.can_export(self.state)
516-
return (value_check and state_check)
517-
else: return (self._require_export and value_check)
518-
519-
@ exportable.setter
520-
def exportable(self, export_override: bool):
521-
self._require_export = export_override
522-
531+
def visiblity(self) -> bool:
532+
return self._visiblity
533+
534+
@ visiblity.setter
535+
def visiblity(self, visiblity: bool):
536+
self._visiblity = visiblity
537+
523538
@ property
524539
def scale(self) -> torch.Tensor:
525540
if self.dominated_by == self: return self._scale

ppq/parser/caffe_exporter.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,7 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph):
4040
for operation in graph.operations.values():
4141
if not isinstance(operation, QuantableOperation): continue
4242
for config, var in operation.config_with_variable:
43-
if not QuantizationStates.can_export(config.state):
44-
raise PermissionError(
45-
'Can not export quant config cause not all quantization configurations '
46-
'have been correctly initialized(or some of them has been deactivated). '
47-
f'Operation {operation.name} has an invalid quantization config({config.state}) '
48-
f'at variable {var.name}.')
43+
if not config.can_export(): continue
4944

5045
# PATCH 2021.11.25
5146
# REMOVE BIAS FROM CONFIGURATION

ppq/parser/ncnn_exporter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph):
1818
if op.is_computing_op and isinstance(op, QuantableOperation):
1919
fd.write(f'{op.name}_param_0 ')
2020
param_cfg = op.config.input_quantization_config[1]
21+
if not param_cfg.can_export(): continue
22+
2123
assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\
2224
and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \
2325
param_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL)
@@ -32,6 +34,7 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph):
3234
for s in scale:
3335
fd.write('%f '% s)
3436
fd.write('\n')
37+
3538
for op in topo_order:
3639
if op.is_computing_op and isinstance(op, QuantableOperation):
3740
fd.write(f'{op.name} ')

ppq/parser/nxp_exporter.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ def export(self, file_path: str, graph: BaseGraph,
6363
if variable.is_parameter and not export_param: continue
6464
for config in configs:
6565
if config is None: continue # source_op can be None
66-
if config.state in {QuantizationStates.ACTIVATED, QuantizationStates.BAKED,
67-
QuantizationStates.OVERLAPPED, QuantizationStates.PASSIVE_BAKED}:
68-
if config.state == QuantizationStates.OVERLAPPED: config = config.dominated_by
66+
if config.can_export():
67+
6968
tensor_range = config.scale * pow(2, config.num_of_bits - 1)
7069
min_val, max_val = -tensor_range, tensor_range - config.scale
7170
min_tensor = numpy_helper.from_array(

0 commit comments

Comments
 (0)