diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py index a2c805403..3455abbd5 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py @@ -78,15 +78,20 @@ def __init__(self, self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes(fw_info) self.ru_targets = target_resource_utilization.get_restricted_targets() - self.ru_helper = MixedPrecisionRUHelper(self.mp_graph, fw_info, fw_impl) + # self.ru_helper = MixedPrecisionRUHelper(self.mp_graph, fw_info, fw_impl) + self.ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_info, fw_impl) self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(fw_info) self.max_ru_config: Dict[BaseNode, int] = self.mp_graph.get_max_candidates_config(fw_info) - self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, self.min_ru_config) self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.mp_graph, original_graph=self.original_graph) + self.real_min_ru_config: Dict[BaseNode, int] = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(self.min_ru_config) + self.real_max_ru_config: Dict[BaseNode, int] = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(self.max_ru_config) + + self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, self.real_min_ru_config) + def search(self) -> Dict[BaseNode, int]: """ Run mixed precision search. @@ -251,7 +256,8 @@ def _compute_relative_ru_matrices(self) -> Dict[RUTarget, np.ndarray]: else: cfg = self.min_ru_config.copy() cfg[node] = candidate_idx - candidate_rus = self.ru_helper.compute_utilization(self.ru_targets, cfg) + real_cfg = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(cfg) + candidate_rus = self.ru_helper.compute_utilization(self.ru_targets, real_cfg) for target, ru in candidate_rus.items(): rus_per_candidate[target].append(ru) diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py index 49e4c76cf..fb8ce86cf 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py @@ -146,8 +146,15 @@ def cuts(self) -> Dict[Cut, List[BaseNode]]: raise RuntimeError("Failed to calculate activation memory cuts for graph.") cuts = [cut for cut in cuts if cut.mem_elements.elements] # cache cuts nodes for future use, so do not filter by target - self._cuts = {cut: [self.graph.find_node_by_name(m.node_name)[0] for m in cut.mem_elements.elements] - for cut in cuts} + self._cuts = { + cut: [ + node + for m in cut.mem_elements.elements + for node in (self.graph.fusing_info.get_fused_nodes(m.node_name) or ( + self.graph.find_node_by_name(m.node_name)[0],)) + ] + for cut in cuts + } return self._cuts def compute_resource_utilization(self, @@ -581,7 +588,10 @@ def compute_node_bops(self, def _compute_cuts(self): """ Compute activation cuts of the graph. """ - memory_graph = MemoryGraph(deepcopy(self.graph)) + from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser + gf = GraphFuser() + graph = gf.apply_node_fusion(self.graph) + memory_graph = MemoryGraph(deepcopy(graph)) _, _, cuts = compute_graph_max_cut(memory_graph) return cuts diff --git a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py index 7faf91772..bb9f26bce 100644 --- a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +++ b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py @@ -16,7 +16,7 @@ import numpy as np from pulp import * -from typing import Dict, Tuple, Any +from typing import Dict, Tuple, Any, List from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py index e521041b0..cae80baa9 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import math import typing import abc @@ -53,7 +54,8 @@ def get_base_mp_nbits_candidates(): class MixedPrecisionActivationBaseTest(BaseKerasFeatureNetworkTest): def __init__(self, unit_test, activation_layers_idx, num_calibration_iter=1): super().__init__(unit_test, num_calibration_iter=num_calibration_iter) - + # for the model that is used here, the two last tensors compose the max cut + self.max_cut = 10 * 10 * 32 + 13 * 13 * 32 self.activation_layers_idx = activation_layers_idx def get_core_config(self): @@ -135,14 +137,15 @@ def __init__(self, unit_test): super().__init__(unit_test, activation_layers_idx=[1, 2, 4]) def get_resource_utilization(self): - return ResourceUtilization(weights_memory=17919, activation_memory=5407) + return ResourceUtilization(weights_memory=17919, activation_memory=self.max_cut-1) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): # verify chosen activation bitwidth config # resource utilization is infinity -> should give best model - 8bits holder_layers = get_layers_from_model_by_type(quantized_model, KerasActivationQuantizationHolder) activation_bits = [layer.activation_holder_quantizer.get_config()['num_bits'] for layer in holder_layers] - self.unit_test.assertTrue((activation_bits == [8, 4, 8])) + # Since the max cut is the last two tensors, one of them have to get 4 bits + self.unit_test.assertIn(activation_bits, ([8, 4, 8], [8, 8, 4])) self.verify_quantization(quantized_model, input_x, weights_layers_idx=[2, 3], @@ -157,7 +160,7 @@ def __init__(self, unit_test): def get_resource_utilization(self): # resource utilization is for 4 bits on average - return ResourceUtilization(weights_memory=17920 * 4 / 8, activation_memory=4300) + return ResourceUtilization(weights_memory=17920 * 4 / 8, activation_memory=math.ceil(self.max_cut*4/8)) def get_tpc(self): eight_bits = generate_test_op_qc(**generate_test_attr_configs()) @@ -180,7 +183,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= # then there is no guarantee that the activation bitwidth for each layer would be 4-bit, # this assertion tests the expected result for this specific # test with its current setup (therefore, we don't check the input layer's bitwidth) - self.unit_test.assertTrue((activation_bits == [4, 8])) + self.unit_test.assertTrue((activation_bits == [4, 4])) class MixedPrecisionActivationSearch2BitsAvgTest(MixedPrecisionActivationBaseTest): @@ -189,7 +192,7 @@ def __init__(self, unit_test): def get_resource_utilization(self): # resource utilization is for 2 bits on average - return ResourceUtilization(weights_memory=17920.0 * 2 / 8, activation_memory=1544) + return ResourceUtilization(weights_memory=17920.0 * 2 / 8, activation_memory=math.ceil(self.max_cut * 2 / 8)) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): # verify chosen activation bitwidth config @@ -213,7 +216,8 @@ def __init__(self, unit_test): super().__init__(unit_test, activation_layers_idx=[1, 3]) def get_resource_utilization(self): - return ResourceUtilization(47, 767) + # 638 = round_up((16*16*3+13*13*3)/2) -> so it must choose (4,4) + return ResourceUtilization(47, 638) def create_networks(self): inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) @@ -225,10 +229,9 @@ def create_networks(self): def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): # verify chosen activation bitwidth config - # resource utilization is infinity -> should give best model - 8bits holder_layers = get_layers_from_model_by_type(quantized_model, KerasActivationQuantizationHolder) activation_bits = [layer.activation_holder_quantizer.get_config()['num_bits'] for layer in holder_layers] - self.unit_test.assertTrue((activation_bits == [4, 8])) + self.unit_test.assertTrue((activation_bits == [4, 4])) class MixedPrecisionActivationDepthwise4BitTest(MixedPrecisionActivationBaseTest): @@ -236,7 +239,7 @@ def __init__(self, unit_test): super().__init__(unit_test, activation_layers_idx=[1]) def get_resource_utilization(self): - return ResourceUtilization(48.0 * 4 / 8, 768.0 * 4 / 8) + return ResourceUtilization(48.0 * 4 / 8, math.ceil((16*16*3+13*13*3) * 4 / 8)) def get_tpc(self): eight_bits = generate_test_op_qc(**generate_test_attr_configs()) @@ -464,7 +467,7 @@ def __init__(self, unit_test): def get_resource_utilization(self): # 17920: 8-bit weights, 6176: max cut of input+conv_bn - return ResourceUtilization(np.inf, np.inf, total_memory=(17920 + 6176) * 4 / 8) + return ResourceUtilization(np.inf, np.inf, total_memory=(17920 + self.max_cut) * 4 / 8) def _compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None): # verify chosen activation bitwidth config @@ -485,7 +488,7 @@ def __init__(self, unit_test): def get_resource_utilization(self): weights = 17920 * 4 / 8 - activation = 6176 * 4 / 8 + activation = math.ceil(self.max_cut * 4 / 8) return ResourceUtilization(weights, activation, total_memory=weights + activation) def _compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None): @@ -514,7 +517,7 @@ def __init__(self, unit_test): def get_resource_utilization(self): weights = 17920 * 4 / 8 - activation = 6176 * 4 / 8 # max cut of input + conv_bn + activation = math.ceil(self.max_cut * 4 / 8) return ResourceUtilization(weights, activation, total_memory=(weights + activation) / 2) def _compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None): diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py index 60b728065..52224424f 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py @@ -286,11 +286,13 @@ def compare(self, qat_model, finalize=False, input_x=None, quantization_info=Non class QATWrappersMixedPrecisionCfgTest(MixedPrecisionActivationBaseTest): - def __init__(self, unit_test, ru_weights=17919, ru_activation=5407, expected_mp_cfg=[0, 4, 0, 0]): - self.ru_weights = ru_weights - self.ru_activation = ru_activation - self.expected_mp_cfg = expected_mp_cfg + def __init__(self, unit_test, ru_weights=17919, ru_activation=None, expected_mp_cfg=None): super().__init__(unit_test, activation_layers_idx=[1, 3, 6]) + self.ru_weights = ru_weights + # The default test case is that the max cut (which is the fused conv-relu layer tensors, input and output) + # must be reduced to 4 bits on average. + self.ru_activation = ru_activation or (self.max_cut * 4 / 8) + self.expected_mp_cfg = expected_mp_cfg or [0, 4, 0, 1] # input, conv, conv2, relu def run_test(self, **kwargs): model_float = self.create_networks() diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py index 5e4c2401a..7d67dcc80 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import abc +import math import typing @@ -46,6 +47,7 @@ class MixedPrecisionBaseTest(BaseKerasFeatureNetworkTest): def __init__(self, unit_test, val_batch_size=1, num_calibration_iter=1): super().__init__(unit_test, val_batch_size=val_batch_size, num_calibration_iter=num_calibration_iter) + self.max_cut = 10 * 10 * 32 + 13 * 13 * 32 def get_quantization_config(self): return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE, @@ -361,7 +363,7 @@ class MixedPrecisionSearchTotalMemoryNonConfNodesTest(MixedPrecisionBaseTest): def __init__(self, unit_test): super().__init__(unit_test) # Total ResourceUtilization for weights in 2 bit avg and non-configurable activation in 8 bit - self.target_total_ru = ResourceUtilization(total_memory=17920 * 2 / 8 + 6176) + self.target_total_ru = ResourceUtilization(total_memory=17920 * 2 / 8 + math.ceil(self.max_cut * 8 / 8)) def get_resource_utilization(self): return self.target_total_ru diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 9e5166bbd..bd9d47d07 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -831,8 +831,8 @@ def test_qat(self): QuantizationAwareTrainingQuantizersTest(self).run_test() QuantizationAwareTrainingQuantizerHolderTest(self).run_test() QATWrappersMixedPrecisionCfgTest(self).run_test() - QATWrappersMixedPrecisionCfgTest(self, ru_weights=17920 * 4 / 8, ru_activation=5408 * 4 / 8, - expected_mp_cfg=[0, 5, 1, 1]).run_test() + QATWrappersMixedPrecisionCfgTest(self, ru_weights=17920 * 4 / 8, ru_activation=8608 * 4 / 8, + expected_mp_cfg=[0, 4, 1, 1]).run_test() def test_bn_attributes_quantization(self): BNAttributesQuantization(self, quantize_linear=False).run_test()