Skip to content

Commit 847d70d

Browse files
authored
Migrate CoreMLQuantizer to ET
Differential Revision: D90200393 Pull Request resolved: #16473
1 parent 5b4900c commit 847d70d

File tree

5 files changed

+1628
-12
lines changed

5 files changed

+1628
-12
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) 2024, Apple Inc. All rights reserved.
2+
#
3+
# Use of this source code is governed by a BSD-3-clause license that can be
4+
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5+
6+
from typing import Optional as _Optional
7+
8+
import torch as _torch
9+
10+
from attr import define as _define
11+
12+
from coremltools.optimize.torch.quantization.quantization_config import (
13+
ModuleLinearQuantizerConfig as _ModuleLinearQuantizerConfig,
14+
QuantizationScheme as _QuantizationScheme,
15+
)
16+
17+
from torchao.quantization.pt2e.fake_quantize import FakeQuantize as _FakeQuantize
18+
19+
from torchao.quantization.pt2e.observer import (
20+
MinMaxObserver as _MinMaxObserver,
21+
MovingAverageMinMaxObserver as _MovingAverageMinMaxObserver,
22+
MovingAveragePerChannelMinMaxObserver as _MovingAveragePerChannelMinMaxObserver,
23+
PerChannelMinMaxObserver as _PerChannelMinMaxObserver,
24+
)
25+
from torchao.quantization.pt2e.quantizer import (
26+
QuantizationSpec as _TorchQuantizationSpec,
27+
)
28+
29+
30+
def _get_observer(observer_type, is_per_channel: bool):
31+
_str_to_observer_map = {
32+
"moving_average_min_max": _MovingAverageMinMaxObserver,
33+
"min_max": _MinMaxObserver,
34+
"moving_average_min_max_per_channel": _MovingAveragePerChannelMinMaxObserver,
35+
"min_max_per_channel": _PerChannelMinMaxObserver,
36+
}
37+
observer_name = observer_type.value
38+
if is_per_channel:
39+
observer_name = f"{observer_name}_per_channel"
40+
if observer_name not in _str_to_observer_map:
41+
raise ValueError(f"Unsupported observer type: {observer_name}")
42+
return _str_to_observer_map[observer_name]
43+
44+
45+
@_define
46+
class AnnotationConfig:
47+
"""
48+
Module/Operator level configuration class for :py:class:`CoreMLQuantizer`.
49+
50+
For each module/operator, defines the dtype, quantization scheme and observer type
51+
for input(s), output and weights (if any).
52+
"""
53+
54+
input_activation: _Optional[_TorchQuantizationSpec] = None
55+
output_activation: _Optional[_TorchQuantizationSpec] = None
56+
weight: _Optional[_TorchQuantizationSpec] = None
57+
58+
@staticmethod
59+
def _normalize_dtype(dtype: _torch.dtype) -> _torch.dtype:
60+
"""
61+
PyTorch export quantizer only supports uint8 and int8 data types,
62+
so we map the quantized dtypes to the corresponding supported dtype.
63+
"""
64+
dtype_map = {
65+
_torch.quint8: _torch.uint8,
66+
_torch.qint8: _torch.int8,
67+
}
68+
return dtype_map.get(dtype, dtype)
69+
70+
@classmethod
71+
def from_quantization_config(
72+
cls,
73+
quantization_config: _Optional[_ModuleLinearQuantizerConfig],
74+
) -> _Optional["AnnotationConfig"]:
75+
"""
76+
Creates a :py:class:`AnnotationConfig` from ``ModuleLinearQuantizerConfig``
77+
"""
78+
if (
79+
quantization_config is None
80+
or quantization_config.weight_dtype == _torch.float32
81+
):
82+
return None
83+
84+
# Activation QSpec
85+
if quantization_config.activation_dtype == _torch.float32:
86+
output_activation_qspec = None
87+
else:
88+
activation_qscheme = _QuantizationScheme.get_qscheme(
89+
quantization_config.quantization_scheme,
90+
is_per_channel=False,
91+
)
92+
activation_dtype = cls._normalize_dtype(
93+
quantization_config.activation_dtype
94+
)
95+
output_activation_qspec = _TorchQuantizationSpec(
96+
observer_or_fake_quant_ctr=_FakeQuantize.with_args(
97+
observer=_get_observer(
98+
quantization_config.activation_observer,
99+
is_per_channel=False,
100+
),
101+
dtype=activation_dtype,
102+
qscheme=activation_qscheme,
103+
),
104+
dtype=activation_dtype,
105+
qscheme=activation_qscheme,
106+
)
107+
108+
# Weight QSpec
109+
weight_qscheme = _QuantizationScheme.get_qscheme(
110+
quantization_config.quantization_scheme,
111+
is_per_channel=quantization_config.weight_per_channel,
112+
)
113+
weight_dtype = cls._normalize_dtype(quantization_config.weight_dtype)
114+
weight_qspec = _TorchQuantizationSpec(
115+
observer_or_fake_quant_ctr=_FakeQuantize.with_args(
116+
observer=_get_observer(
117+
quantization_config.weight_observer,
118+
is_per_channel=quantization_config.weight_per_channel,
119+
),
120+
dtype=weight_dtype,
121+
qscheme=weight_qscheme,
122+
),
123+
dtype=weight_dtype,
124+
qscheme=weight_qscheme,
125+
)
126+
return AnnotationConfig(
127+
input_activation=output_activation_qspec,
128+
output_activation=output_activation_qspec,
129+
weight=weight_qspec,
130+
)

0 commit comments

Comments
 (0)