|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 | from enum import IntEnum, unique |
| 7 | +from functools import partial |
7 | 8 | from typing import Callable, Optional, Sequence, Set |
8 | 9 |
|
9 | 10 | import torch |
@@ -67,28 +68,44 @@ class QuantDtype(IntEnum): |
67 | 68 | # PTQ |
68 | 69 | (QuantDtype.use_16a16w, False): ( |
69 | 70 | get_16a16w_qnn_ptq_config, |
70 | | - get_ptq_per_channel_quant_config(torch.uint16, torch.int16), |
| 71 | + partial( |
| 72 | + get_ptq_per_channel_quant_config, |
| 73 | + act_dtype=torch.uint16, |
| 74 | + weight_dtype=torch.int16, |
| 75 | + ), |
71 | 76 | ), |
72 | 77 | (QuantDtype.use_16a8w, False): ( |
73 | 78 | get_16a8w_qnn_ptq_config, |
74 | | - get_ptq_per_channel_quant_config(torch.uint16, torch.int8), |
| 79 | + partial( |
| 80 | + get_ptq_per_channel_quant_config, |
| 81 | + act_dtype=torch.uint16, |
| 82 | + weight_dtype=torch.int8, |
| 83 | + ), |
75 | 84 | ), |
76 | 85 | (QuantDtype.use_16a4w, False): ( |
77 | 86 | get_16a4w_qnn_ptq_config, |
78 | | - get_ptq_per_channel_quant_config(torch.uint16, "int4"), |
| 87 | + partial( |
| 88 | + get_ptq_per_channel_quant_config, |
| 89 | + act_dtype=torch.uint16, |
| 90 | + weight_dtype="int4", |
| 91 | + ), |
79 | 92 | ), |
80 | 93 | (QuantDtype.use_8a8w, False): ( |
81 | 94 | get_8a8w_qnn_ptq_config, |
82 | | - get_ptq_per_channel_quant_config(), |
| 95 | + partial(get_ptq_per_channel_quant_config), |
83 | 96 | ), |
84 | 97 | # QAT, |
85 | 98 | (QuantDtype.use_16a4w, True): ( |
86 | 99 | get_16a4w_qnn_qat_config, |
87 | | - get_qat_per_channel_quant_config(torch.uint16, "int4"), |
| 100 | + partial( |
| 101 | + get_qat_per_channel_quant_config, |
| 102 | + act_dtype=torch.uint16, |
| 103 | + weight_dtype="int4", |
| 104 | + ), |
88 | 105 | ), |
89 | 106 | (QuantDtype.use_8a8w, True): ( |
90 | 107 | get_8a8w_qnn_qat_config, |
91 | | - get_qat_per_channel_quant_config(), |
| 108 | + partial(get_qat_per_channel_quant_config), |
92 | 109 | ), |
93 | 110 | } |
94 | 111 |
|
@@ -176,11 +193,18 @@ def set_quant_config( |
176 | 193 | f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" |
177 | 194 | ) |
178 | 195 |
|
179 | | - quant_config_fuc, self.per_channel_quant_config = quant_config_dict[ |
| 196 | + quant_config_fuc, per_channel_quant_config_fuc = quant_config_dict[ |
180 | 197 | (quant_dtype, is_qat) |
181 | 198 | ] |
182 | 199 | self.quant_config = ( |
183 | | - quant_config_fuc(act_observer) if act_observer else quant_config_fuc() |
| 200 | + quant_config_fuc(act_observer=act_observer) |
| 201 | + if act_observer |
| 202 | + else quant_config_fuc() |
| 203 | + ) |
| 204 | + self.per_channel_quant_config = ( |
| 205 | + per_channel_quant_config_fuc(act_observer=act_observer) |
| 206 | + if act_observer |
| 207 | + else per_channel_quant_config_fuc() |
184 | 208 | ) |
185 | 209 |
|
186 | 210 | def set_per_channel_conv_quant(self, enable: bool) -> None: |
|
0 commit comments