diff --git a/model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py b/model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py index e4a3063d6..90cfe34ba 100644 --- a/model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py +++ b/model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py @@ -38,9 +38,6 @@ def __init__(self, model_graph: Graph): Args: model_graph: A graph representation of a model. """ - - self.model_graph = model_graph - nodes = list(model_graph.nodes) memory_tensors = [] node_to_tensor = [] diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py index a2c805403..b1daf1ba9 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py @@ -78,14 +78,18 @@ def __init__(self, self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes(fw_info) self.ru_targets = target_resource_utilization.get_restricted_targets() - self.ru_helper = MixedPrecisionRUHelper(self.mp_graph, fw_info, fw_impl) + self.ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_info, fw_impl) self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(fw_info) self.max_ru_config: Dict[BaseNode, int] = self.mp_graph.get_max_candidates_config(fw_info) - self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, self.min_ru_config) self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.mp_graph, original_graph=self.original_graph) + if self.using_virtual_graph: + real_min_ru_config: Dict[BaseNode, int] = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(self.min_ru_config) + self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, real_min_ru_config) + else: + self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, self.min_ru_config) def search(self) -> Dict[BaseNode, int]: """ @@ -251,7 +255,8 @@ def _compute_relative_ru_matrices(self) -> Dict[RUTarget, np.ndarray]: else: cfg = self.min_ru_config.copy() cfg[node] = candidate_idx - candidate_rus = self.ru_helper.compute_utilization(self.ru_targets, cfg) + real_cfg = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(cfg) + candidate_rus = self.ru_helper.compute_utilization(self.ru_targets, real_cfg) for target, ru in candidate_rus.items(): rus_per_candidate[target].append(ru) diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py index 406e9660e..c367271b5 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py @@ -18,6 +18,8 @@ from enum import Enum, auto from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence +from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser + from model_compression_toolkit.constants import FLOAT_BITWIDTH from model_compression_toolkit.core import FrameworkInfo from model_compression_toolkit.core.common import Graph, BaseNode @@ -145,8 +147,14 @@ def cuts(self) -> Dict[Cut, List[BaseNode]]: raise RuntimeError("Failed to calculate activation memory cuts for graph.") cuts = [cut for cut in cuts if cut.mem_elements.elements] # cache cuts nodes for future use, so do not filter by target - self._cuts = {cut: [self.graph.find_node_by_name(m.node_name)[0] for m in cut.mem_elements.elements] - for cut in cuts} + self._cuts = { + cut: [ + node + for m in cut.mem_elements.elements + for node in (self.graph.fusing_info.get_fused_nodes(m.node_name) or (self.graph.find_node_by_name(m.node_name)[0],)) + ] + for cut in cuts + } return self._cuts def compute_resource_utilization(self, @@ -580,7 +588,9 @@ def compute_node_bops(self, def _compute_cuts(self): """ Compute activation cuts of the graph. """ - memory_graph = MemoryGraph(deepcopy(self.graph)) + # Compute memory graph on fused graph with fused nodes + graph = GraphFuser().apply_node_fusion(self.graph) + memory_graph = MemoryGraph(deepcopy(graph)) _, _, cuts = compute_graph_max_cut(memory_graph) return cuts diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py index e521041b0..cae80baa9 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import math import typing import abc @@ -53,7 +54,8 @@ def get_base_mp_nbits_candidates(): class MixedPrecisionActivationBaseTest(BaseKerasFeatureNetworkTest): def __init__(self, unit_test, activation_layers_idx, num_calibration_iter=1): super().__init__(unit_test, num_calibration_iter=num_calibration_iter) - + # for the model that is used here, the two last tensors compose the max cut + self.max_cut = 10 * 10 * 32 + 13 * 13 * 32 self.activation_layers_idx = activation_layers_idx def get_core_config(self): @@ -135,14 +137,15 @@ def __init__(self, unit_test): super().__init__(unit_test, activation_layers_idx=[1, 2, 4]) def get_resource_utilization(self): - return ResourceUtilization(weights_memory=17919, activation_memory=5407) + return ResourceUtilization(weights_memory=17919, activation_memory=self.max_cut-1) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): # verify chosen activation bitwidth config # resource utilization is infinity -> should give best model - 8bits holder_layers = get_layers_from_model_by_type(quantized_model, KerasActivationQuantizationHolder) activation_bits = [layer.activation_holder_quantizer.get_config()['num_bits'] for layer in holder_layers] - self.unit_test.assertTrue((activation_bits == [8, 4, 8])) + # Since the max cut is the last two tensors, one of them have to get 4 bits + self.unit_test.assertIn(activation_bits, ([8, 4, 8], [8, 8, 4])) self.verify_quantization(quantized_model, input_x, weights_layers_idx=[2, 3], @@ -157,7 +160,7 @@ def __init__(self, unit_test): def get_resource_utilization(self): # resource utilization is for 4 bits on average - return ResourceUtilization(weights_memory=17920 * 4 / 8, activation_memory=4300) + return ResourceUtilization(weights_memory=17920 * 4 / 8, activation_memory=math.ceil(self.max_cut*4/8)) def get_tpc(self): eight_bits = generate_test_op_qc(**generate_test_attr_configs()) @@ -180,7 +183,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= # then there is no guarantee that the activation bitwidth for each layer would be 4-bit, # this assertion tests the expected result for this specific # test with its current setup (therefore, we don't check the input layer's bitwidth) - self.unit_test.assertTrue((activation_bits == [4, 8])) + self.unit_test.assertTrue((activation_bits == [4, 4])) class MixedPrecisionActivationSearch2BitsAvgTest(MixedPrecisionActivationBaseTest): @@ -189,7 +192,7 @@ def __init__(self, unit_test): def get_resource_utilization(self): # resource utilization is for 2 bits on average - return ResourceUtilization(weights_memory=17920.0 * 2 / 8, activation_memory=1544) + return ResourceUtilization(weights_memory=17920.0 * 2 / 8, activation_memory=math.ceil(self.max_cut * 2 / 8)) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): # verify chosen activation bitwidth config @@ -213,7 +216,8 @@ def __init__(self, unit_test): super().__init__(unit_test, activation_layers_idx=[1, 3]) def get_resource_utilization(self): - return ResourceUtilization(47, 767) + # 638 = round_up((16*16*3+13*13*3)/2) -> so it must choose (4,4) + return ResourceUtilization(47, 638) def create_networks(self): inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) @@ -225,10 +229,9 @@ def create_networks(self): def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): # verify chosen activation bitwidth config - # resource utilization is infinity -> should give best model - 8bits holder_layers = get_layers_from_model_by_type(quantized_model, KerasActivationQuantizationHolder) activation_bits = [layer.activation_holder_quantizer.get_config()['num_bits'] for layer in holder_layers] - self.unit_test.assertTrue((activation_bits == [4, 8])) + self.unit_test.assertTrue((activation_bits == [4, 4])) class MixedPrecisionActivationDepthwise4BitTest(MixedPrecisionActivationBaseTest): @@ -236,7 +239,7 @@ def __init__(self, unit_test): super().__init__(unit_test, activation_layers_idx=[1]) def get_resource_utilization(self): - return ResourceUtilization(48.0 * 4 / 8, 768.0 * 4 / 8) + return ResourceUtilization(48.0 * 4 / 8, math.ceil((16*16*3+13*13*3) * 4 / 8)) def get_tpc(self): eight_bits = generate_test_op_qc(**generate_test_attr_configs()) @@ -464,7 +467,7 @@ def __init__(self, unit_test): def get_resource_utilization(self): # 17920: 8-bit weights, 6176: max cut of input+conv_bn - return ResourceUtilization(np.inf, np.inf, total_memory=(17920 + 6176) * 4 / 8) + return ResourceUtilization(np.inf, np.inf, total_memory=(17920 + self.max_cut) * 4 / 8) def _compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None): # verify chosen activation bitwidth config @@ -485,7 +488,7 @@ def __init__(self, unit_test): def get_resource_utilization(self): weights = 17920 * 4 / 8 - activation = 6176 * 4 / 8 + activation = math.ceil(self.max_cut * 4 / 8) return ResourceUtilization(weights, activation, total_memory=weights + activation) def _compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None): @@ -514,7 +517,7 @@ def __init__(self, unit_test): def get_resource_utilization(self): weights = 17920 * 4 / 8 - activation = 6176 * 4 / 8 # max cut of input + conv_bn + activation = math.ceil(self.max_cut * 4 / 8) return ResourceUtilization(weights, activation, total_memory=(weights + activation) / 2) def _compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None): diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py index 60b728065..52224424f 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py @@ -286,11 +286,13 @@ def compare(self, qat_model, finalize=False, input_x=None, quantization_info=Non class QATWrappersMixedPrecisionCfgTest(MixedPrecisionActivationBaseTest): - def __init__(self, unit_test, ru_weights=17919, ru_activation=5407, expected_mp_cfg=[0, 4, 0, 0]): - self.ru_weights = ru_weights - self.ru_activation = ru_activation - self.expected_mp_cfg = expected_mp_cfg + def __init__(self, unit_test, ru_weights=17919, ru_activation=None, expected_mp_cfg=None): super().__init__(unit_test, activation_layers_idx=[1, 3, 6]) + self.ru_weights = ru_weights + # The default test case is that the max cut (which is the fused conv-relu layer tensors, input and output) + # must be reduced to 4 bits on average. + self.ru_activation = ru_activation or (self.max_cut * 4 / 8) + self.expected_mp_cfg = expected_mp_cfg or [0, 4, 0, 1] # input, conv, conv2, relu def run_test(self, **kwargs): model_float = self.create_networks() diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py index 5e4c2401a..7d67dcc80 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import abc +import math import typing @@ -46,6 +47,7 @@ class MixedPrecisionBaseTest(BaseKerasFeatureNetworkTest): def __init__(self, unit_test, val_batch_size=1, num_calibration_iter=1): super().__init__(unit_test, val_batch_size=val_batch_size, num_calibration_iter=num_calibration_iter) + self.max_cut = 10 * 10 * 32 + 13 * 13 * 32 def get_quantization_config(self): return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE, @@ -361,7 +363,7 @@ class MixedPrecisionSearchTotalMemoryNonConfNodesTest(MixedPrecisionBaseTest): def __init__(self, unit_test): super().__init__(unit_test) # Total ResourceUtilization for weights in 2 bit avg and non-configurable activation in 8 bit - self.target_total_ru = ResourceUtilization(total_memory=17920 * 2 / 8 + 6176) + self.target_total_ru = ResourceUtilization(total_memory=17920 * 2 / 8 + math.ceil(self.max_cut * 8 / 8)) def get_resource_utilization(self): return self.target_total_ru diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 4c54343d8..7f09d2734 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -831,8 +831,8 @@ def test_qat(self): QuantizationAwareTrainingQuantizersTest(self).run_test() QuantizationAwareTrainingQuantizerHolderTest(self).run_test() QATWrappersMixedPrecisionCfgTest(self).run_test() - QATWrappersMixedPrecisionCfgTest(self, ru_weights=17920 * 4 / 8, ru_activation=5408 * 4 / 8, - expected_mp_cfg=[0, 5, 1, 1]).run_test() + QATWrappersMixedPrecisionCfgTest(self, ru_weights=17920 * 4 / 8, ru_activation=8608 * 4 / 8, + expected_mp_cfg=[0, 4, 1, 1]).run_test() def test_bn_attributes_quantization(self): BNAttributesQuantization(self, quantize_linear=False).run_test() diff --git a/tests/keras_tests/graph_tests/test_memory_graph.py b/tests/keras_tests/graph_tests/test_memory_graph.py index 4f34b0b2b..9aee4d2a0 100644 --- a/tests/keras_tests/graph_tests/test_memory_graph.py +++ b/tests/keras_tests/graph_tests/test_memory_graph.py @@ -84,7 +84,7 @@ def test_memory_graph_build(self): self.assertTrue(len(memory_graph.a_nodes) == 4) self.assertTrue(len(memory_graph.b_nodes) == 4) - self.assertTrue(graph.get_topo_sorted_nodes()[0] in memory_graph.sources_a) + self.assertTrue(graph.get_topo_sorted_nodes()[0].name in [node.name for node in memory_graph.sources_a]) self.assertTrue(len(memory_graph.sinks_b) == 1) self.assertTrue(memory_graph.memory_lbound_single_op == 264) @@ -99,7 +99,7 @@ def test_memory_graph_node_with_multiple_outputs(self): self.assertTrue(len(memory_graph.a_nodes) == 5) self.assertTrue(len(memory_graph.b_nodes) == 6) - self.assertTrue(graph.get_topo_sorted_nodes()[0] in memory_graph.sources_a) + self.assertTrue(graph.get_topo_sorted_nodes()[0].name in [node.name for node in memory_graph.sources_a]) self.assertTrue(len(memory_graph.sinks_b) == 1) self.assertTrue(memory_graph.memory_lbound_single_op == 576) @@ -117,7 +117,7 @@ def test_memory_graph_with_residual(self): self.assertTrue(len(memory_graph.a_nodes) == 5) self.assertTrue(len(memory_graph.b_nodes) == 5) - self.assertTrue(graph.get_topo_sorted_nodes()[0] in memory_graph.sources_a) + self.assertTrue(graph.get_topo_sorted_nodes()[0].name in [node.name for node in memory_graph.sources_a]) self.assertTrue(len(memory_graph.sinks_b) == 1) self.assertTrue(memory_graph.memory_lbound_single_op == 199) diff --git a/tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py b/tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py index 2ee1559ba..6e26f92ad 100644 --- a/tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py +++ b/tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import random from types import MethodType from unittest.mock import Mock import numpy as np import pytest +from model_compression_toolkit.core.common.graph.base_graph import OutTensor from model_compression_toolkit.constants import FLOAT_BITWIDTH from model_compression_toolkit.core import ResourceUtilization from model_compression_toolkit.core.common import Graph +from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo from model_compression_toolkit.core.common.graph.edge import Edge from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import compute_graph_max_cut from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut @@ -498,6 +501,8 @@ def prepare_compute_cuts(self, graph_mock, fw_impl_mock, fw_info_mock, mocker): graph_mock.find_node_by_name = MethodType(Graph.find_node_by_name, graph_mock) graph_mock.retrieve_preserved_quantization_node = lambda x: mp2 if x.name == 'qp' else x + graph_mock.fusing_info = FusingInfo() + # we should not use total size, setting it to bad number cut_elems1 = MemoryElements(elements={ActivationMemoryTensor(mp_reuse.output_shape, 'mp_reuse', 0)}, total_size=-1) cut_elems2 = MemoryElements(elements={ActivationMemoryTensor(mp_reuse.output_shape, 'mp_reuse', 0), @@ -515,6 +520,112 @@ def prepare_compute_cuts(self, graph_mock, fw_impl_mock, fw_info_mock, mocker): ru_calc = ResourceUtilizationCalculator(graph_mock, fw_impl_mock, fw_info_mock) return ru_calc, cuts, nodes + @pytest.mark.parametrize('seed', list(range(42, 52))) + @pytest.mark.parametrize("disable_quantization", [True, False]) + def test_compute_cuts_random_fusion_valid_utilization(self, seed, disable_quantization, fw_impl_mock, fw_info_mock, mocker): + random.seed(seed) + + num_nodes = random.randint(5, 8) + node_names = [f"n{i}" for i in range(num_nodes)] + nodes = [] + edges = [] + classes = [] + + # Build nodes with matching input/output shapes + input_shape = (None, random.randint(5, 10), random.randint(5, 10)) + for i, name in enumerate(node_names): + output_shape = (None, random.randint(5, 10), random.randint(5, 10)) if i < num_nodes - 1 else input_shape + layer_class = f"class_{i}" + node = build_node(name, layer_class=layer_class, qcs=[build_qc()], + input_shape=input_shape, output_shape=output_shape) + nodes.append(node) + classes.append(layer_class) + input_shape = output_shape + + for i in range(num_nodes - 1): + edges.append(Edge(nodes[i], nodes[i + 1], 0, 0)) + + # Generate random fused groups + fused_patterns = [] + fused_data = {} + i = 1 + while i < num_nodes - 1: + if random.random() < 0.5: + fuse_len = random.choice([2, 3]) + if i + fuse_len <= num_nodes: + fused = tuple(nodes[j] for j in range(i, i + fuse_len)) + fused_name = f"FusedNode_{'_'.join(n.name for n in fused)}" + fused_patterns.append([n.layer_class for n in fused]) + fused_data[fused_name] = fused + i += fuse_len + else: + break + else: + i += 1 + + fusing_info = FusingInfo(fusing_patterns=fused_patterns, fusing_data=fused_data) + graph = Graph("g", input_nodes=[nodes[0]], nodes=nodes, + output_nodes=[OutTensor(node=nodes[-1], node_out_index=0)], edge_list=edges) + graph.fusing_info = fusing_info + + if disable_quantization: + graph.disable_fused_nodes_activation_quantization() + + graph.find_node_by_name = MethodType(Graph.find_node_by_name, graph) + + ru_calc = ResourceUtilizationCalculator(graph, fw_impl_mock, fw_info_mock) + + # Patch max cut computation + mocker.patch( + 'model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.' + 'resource_utilization_calculator.compute_graph_max_cut', + wraps=compute_graph_max_cut + ) + + cuts = ru_calc.cuts + + # --- Assert cut structure --- + assert all(isinstance(c, Cut) for c in cuts) + for cut_nodes in cuts.values(): + assert all(isinstance(n.name, str) for n in cut_nodes) + + # --- Utilization --- + total, per_cut, per_cut_per_node = ru_calc.compute_activation_utilization_by_cut( + target_criterion=TIC.AnyQuantized, bitwidth_mode=BM.QDefaultSP + ) + + # Structure checks + assert isinstance(per_cut, dict) + assert isinstance(per_cut_per_node, dict) + assert all(isinstance(k, Cut) for k in per_cut) + assert all(isinstance(k, Cut) for k in per_cut_per_node) + assert all(isinstance(v, Utilization) for v in per_cut.values()) + assert all(isinstance(vv, Utilization) for v in per_cut_per_node.values() for vv in v.values()) + + # Value checks: per_cut == sum(per_cut_per_node) + for cut, node_utils in per_cut_per_node.items(): + summed = sum((u for u in node_utils.values()), Utilization(0, 0)) + assert per_cut[cut] == summed + + # Total check + assert total == max(u.bytes for u in per_cut.values()) + + # Check the utilization bytes per node + for cut, node_utils in per_cut_per_node.items(): + for node_name, util in node_utils.items(): + node = next((n for n in nodes if n.name == node_name), None) + assert node is not None, f"Node {node_name} not found in graph" + + expected_volume = 1 + for dim in node.output_shape: + if dim is not None: + expected_volume *= dim + + assert util.bytes == expected_volume, ( + f"Utilization mismatch for node '{node_name}': " + f"got {util.bytes}, expected {expected_volume} from shape {node.output_shape}" + ) + def test_get_cut_target_nodes(self, prepare_compute_cuts): ru_calc, (cut1, cut2, cut3, cut4), (mp_reuse, mp, noq, sp, mp2, qp) = prepare_compute_cuts assert len(TIC) == 4