|
13 | 13 | # limitations under the License. |
14 | 14 | # ============================================================================== |
15 | 15 | import numpy as np |
| 16 | +import torch |
16 | 17 | from torch.nn import Conv2d |
17 | 18 |
|
18 | 19 | from model_compression_toolkit.core import ResourceUtilization |
@@ -49,7 +50,7 @@ def get_core_configs(self): |
49 | 50 | return {"mixed_precision_model": mct.core.CoreConfig(quantization_config=qc, mixed_precision_config=mpc)} |
50 | 51 |
|
51 | 52 | def create_feature_network(self, input_shape): |
52 | | - raise NotImplementedError() |
| 53 | + return MixedPrecisionNet(input_shape) |
53 | 54 |
|
54 | 55 | def compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None): |
55 | 56 | # This is a base test, so it does not check a thing. Only actual tests of mixed precision |
@@ -132,3 +133,93 @@ def get_mixed_precision_config(self): |
132 | 133 |
|
133 | 134 | def compare(self, quantized_models, float_model, input_x=None, quantization_info=None): |
134 | 135 | self.compare_results(quantization_info, quantized_models, float_model, 1) |
| 136 | + |
| 137 | + |
| 138 | +class MixedPrecisionNet(torch.nn.Module): |
| 139 | + def __init__(self, input_shape): |
| 140 | + super(MixedPrecisionNet, self).__init__() |
| 141 | + _, in_channels, _, _ = input_shape[0] |
| 142 | + self.conv1 = torch.nn.Conv2d(in_channels, 3, kernel_size=3) |
| 143 | + self.bn1 = torch.nn.BatchNorm2d(3) |
| 144 | + self.conv2 = torch.nn.Conv2d(3, 4, kernel_size=5) |
| 145 | + self.relu = torch.nn.ReLU() |
| 146 | + |
| 147 | + def forward(self, inp): |
| 148 | + x = self.conv1(inp) |
| 149 | + x = self.bn1(x) |
| 150 | + x = self.conv2(x) |
| 151 | + output = self.relu(x) |
| 152 | + return output |
| 153 | + |
| 154 | + |
| 155 | +class MixedPrecisionWeightsConfigurableActivations(MixedPrecisionBaseTest): |
| 156 | + def __init__(self, unit_test): |
| 157 | + super().__init__(unit_test) |
| 158 | + self.expected_config = [1] |
| 159 | + |
| 160 | + def get_core_configs(self): |
| 161 | + return {"mixed_precision_model": CoreConfig(quantization_config=QuantizationConfig( |
| 162 | + custom_tpc_opset_to_layer={"Weights": CustomOpsetLayers([torch.nn.Conv2d], |
| 163 | + {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL), |
| 164 | + BIAS_ATTR: DefaultDict(default_value=BIAS)}), |
| 165 | + "Activations": CustomOpsetLayers([torch.nn.ReLU, torch.add])} |
| 166 | + ))} |
| 167 | + |
| 168 | + def get_tpc(self): |
| 169 | + cfg, mixed_precision_cfg_list, _ = get_op_quantization_configs() |
| 170 | + |
| 171 | + act_eight_bit_cfg = cfg.clone_and_edit(activation_n_bits=8, |
| 172 | + attr_weights_configs_mapping={}) |
| 173 | + act_four_bit_cfg = cfg.clone_and_edit(activation_n_bits=4, |
| 174 | + attr_weights_configs_mapping={}) |
| 175 | + act_two_bit_cfg = cfg.clone_and_edit(activation_n_bits=2, |
| 176 | + attr_weights_configs_mapping={}) |
| 177 | + |
| 178 | + mixed_precision_cfg_list = \ |
| 179 | + [c.clone_and_edit(enable_activation_quantization=False) for c in mixed_precision_cfg_list] |
| 180 | + cfg = mixed_precision_cfg_list[0] |
| 181 | + |
| 182 | + act_mixed_cfg = QuantizationConfigOptions(quantization_configurations=tuple( |
| 183 | + [act_eight_bit_cfg, act_four_bit_cfg, act_two_bit_cfg]), |
| 184 | + base_config=act_eight_bit_cfg, |
| 185 | + ) |
| 186 | + |
| 187 | + weight_mixed_cfg = QuantizationConfigOptions(quantization_configurations=tuple( |
| 188 | + mixed_precision_cfg_list), |
| 189 | + base_config=cfg, |
| 190 | + ) |
| 191 | + |
| 192 | + tpc = TargetPlatformCapabilities( |
| 193 | + default_qco=QuantizationConfigOptions(quantization_configurations=tuple([cfg]), base_config=cfg), |
| 194 | + tpc_minor_version=None, |
| 195 | + tpc_patch_version=None, |
| 196 | + tpc_platform_type=None, |
| 197 | + operator_set=tuple([ |
| 198 | + OperatorsSet(name="Activations", qc_options=act_mixed_cfg), |
| 199 | + OperatorsSet(name="Weights", qc_options=weight_mixed_cfg)]), |
| 200 | + name="mp_weights_conf_act_test") |
| 201 | + |
| 202 | + return {'mixed_precision_model': tpc} |
| 203 | + |
| 204 | + def create_feature_network(self, input_shape): |
| 205 | + return MixedPrecisionWeightsTestNet(input_shape) |
| 206 | + |
| 207 | + def get_resource_utilization(self): |
| 208 | + return ResourceUtilization(80) |
| 209 | + |
| 210 | + def compare(self, quantized_models, float_model, input_x=None, quantization_info=None): |
| 211 | + self.unit_test.assertTrue(quantization_info.mixed_precision_cfg == self.expected_config) |
| 212 | + |
| 213 | + |
| 214 | +class MixedPrecisionWeightsTestNet(torch.nn.Module): |
| 215 | + def __init__(self, input_shape): |
| 216 | + super(MixedPrecisionWeightsTestNet, self).__init__() |
| 217 | + _, in_channels, _, _ = input_shape[0] |
| 218 | + self.conv1 = torch.nn.Conv2d(in_channels, 3, kernel_size=(3, 3)) |
| 219 | + self.relu = torch.nn.ReLU() |
| 220 | + |
| 221 | + def forward(self, inp): |
| 222 | + x = self.conv1(inp) |
| 223 | + x = torch.add(x, x) |
| 224 | + output = self.relu(x) |
| 225 | + return output |
0 commit comments