Skip to content

Commit 247ed43

Browse files
irenabirenab
authored andcommitted
fix merge and old test
1 parent e53b4cf commit 247ed43

File tree

2 files changed

+93
-2
lines changed

2 files changed

+93
-2
lines changed

model_compression_toolkit/core/common/graph/base_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from copy import copy, deepcopy
1818
from functools import wraps
19-
from typing import List, Tuple, Any, Callable
19+
from typing import List, Tuple, Any, Callable, Dict
2020

2121
import networkx as nx
2222
import numpy as np

tests/pytorch_tests/model_tests/feature_models/mixed_precision_weights_test.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
import numpy as np
16+
import torch
1617
from torch.nn import Conv2d
1718

1819
from model_compression_toolkit.core import ResourceUtilization
@@ -49,7 +50,7 @@ def get_core_configs(self):
4950
return {"mixed_precision_model": mct.core.CoreConfig(quantization_config=qc, mixed_precision_config=mpc)}
5051

5152
def create_feature_network(self, input_shape):
52-
raise NotImplementedError()
53+
return MixedPrecisionNet(input_shape)
5354

5455
def compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None):
5556
# This is a base test, so it does not check a thing. Only actual tests of mixed precision
@@ -132,3 +133,93 @@ def get_mixed_precision_config(self):
132133

133134
def compare(self, quantized_models, float_model, input_x=None, quantization_info=None):
134135
self.compare_results(quantization_info, quantized_models, float_model, 1)
136+
137+
138+
class MixedPrecisionNet(torch.nn.Module):
139+
def __init__(self, input_shape):
140+
super(MixedPrecisionNet, self).__init__()
141+
_, in_channels, _, _ = input_shape[0]
142+
self.conv1 = torch.nn.Conv2d(in_channels, 3, kernel_size=3)
143+
self.bn1 = torch.nn.BatchNorm2d(3)
144+
self.conv2 = torch.nn.Conv2d(3, 4, kernel_size=5)
145+
self.relu = torch.nn.ReLU()
146+
147+
def forward(self, inp):
148+
x = self.conv1(inp)
149+
x = self.bn1(x)
150+
x = self.conv2(x)
151+
output = self.relu(x)
152+
return output
153+
154+
155+
class MixedPrecisionWeightsConfigurableActivations(MixedPrecisionBaseTest):
156+
def __init__(self, unit_test):
157+
super().__init__(unit_test)
158+
self.expected_config = [1]
159+
160+
def get_core_configs(self):
161+
return {"mixed_precision_model": CoreConfig(quantization_config=QuantizationConfig(
162+
custom_tpc_opset_to_layer={"Weights": CustomOpsetLayers([torch.nn.Conv2d],
163+
{KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
164+
BIAS_ATTR: DefaultDict(default_value=BIAS)}),
165+
"Activations": CustomOpsetLayers([torch.nn.ReLU, torch.add])}
166+
))}
167+
168+
def get_tpc(self):
169+
cfg, mixed_precision_cfg_list, _ = get_op_quantization_configs()
170+
171+
act_eight_bit_cfg = cfg.clone_and_edit(activation_n_bits=8,
172+
attr_weights_configs_mapping={})
173+
act_four_bit_cfg = cfg.clone_and_edit(activation_n_bits=4,
174+
attr_weights_configs_mapping={})
175+
act_two_bit_cfg = cfg.clone_and_edit(activation_n_bits=2,
176+
attr_weights_configs_mapping={})
177+
178+
mixed_precision_cfg_list = \
179+
[c.clone_and_edit(enable_activation_quantization=False) for c in mixed_precision_cfg_list]
180+
cfg = mixed_precision_cfg_list[0]
181+
182+
act_mixed_cfg = QuantizationConfigOptions(quantization_configurations=tuple(
183+
[act_eight_bit_cfg, act_four_bit_cfg, act_two_bit_cfg]),
184+
base_config=act_eight_bit_cfg,
185+
)
186+
187+
weight_mixed_cfg = QuantizationConfigOptions(quantization_configurations=tuple(
188+
mixed_precision_cfg_list),
189+
base_config=cfg,
190+
)
191+
192+
tpc = TargetPlatformCapabilities(
193+
default_qco=QuantizationConfigOptions(quantization_configurations=tuple([cfg]), base_config=cfg),
194+
tpc_minor_version=None,
195+
tpc_patch_version=None,
196+
tpc_platform_type=None,
197+
operator_set=tuple([
198+
OperatorsSet(name="Activations", qc_options=act_mixed_cfg),
199+
OperatorsSet(name="Weights", qc_options=weight_mixed_cfg)]),
200+
name="mp_weights_conf_act_test")
201+
202+
return {'mixed_precision_model': tpc}
203+
204+
def create_feature_network(self, input_shape):
205+
return MixedPrecisionWeightsTestNet(input_shape)
206+
207+
def get_resource_utilization(self):
208+
return ResourceUtilization(80)
209+
210+
def compare(self, quantized_models, float_model, input_x=None, quantization_info=None):
211+
self.unit_test.assertTrue(quantization_info.mixed_precision_cfg == self.expected_config)
212+
213+
214+
class MixedPrecisionWeightsTestNet(torch.nn.Module):
215+
def __init__(self, input_shape):
216+
super(MixedPrecisionWeightsTestNet, self).__init__()
217+
_, in_channels, _, _ = input_shape[0]
218+
self.conv1 = torch.nn.Conv2d(in_channels, 3, kernel_size=(3, 3))
219+
self.relu = torch.nn.ReLU()
220+
221+
def forward(self, inp):
222+
x = self.conv1(inp)
223+
x = torch.add(x, x)
224+
output = self.relu(x)
225+
return output

0 commit comments

Comments
 (0)