From dd561e9cd5dc1558eef5c8e4a7eb3f7cdb4b0365 Mon Sep 17 00:00:00 2001 From: irenab Date: Tue, 11 Mar 2025 17:10:03 +0200 Subject: [PATCH 1/2] max cut - deterministic cuts --- .../core/common/graph/memory_graph/cut.py | 4 ++++ .../common/graph/memory_graph/max_cut_astar.py | 3 ++- .../function_tests/test_graph_max_cut.py | 15 +++++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/model_compression_toolkit/core/common/graph/memory_graph/cut.py b/model_compression_toolkit/core/common/graph/memory_graph/cut.py index bd21f502b..ee307e687 100644 --- a/model_compression_toolkit/core/common/graph/memory_graph/cut.py +++ b/model_compression_toolkit/core/common/graph/memory_graph/cut.py @@ -71,3 +71,7 @@ def __hash__(self): def __repr__(self): return f"" # pragma: no cover + + def get_sorted_node_names(self): + """ Return sorted node names of memory elements. """ + return sorted([e.node_name for e in self.mem_elements.elements]) diff --git a/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py b/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py index 2c121fd4c..25ae58bbb 100644 --- a/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +++ b/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py @@ -232,7 +232,8 @@ def _get_cut_to_expand(self, open_list: Set[Cut], costs: Dict[Cut, float], route max_cut_len = max([len(routes[c]) for c in open_list]) ordered_cuts_list = sorted(open_list, key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate)), - max_cut_len - len(routes[c]))) + max_cut_len - len(routes[c]), + ''.join(c.get_sorted_node_names()))) assert len(ordered_cuts_list) > 0 return ordered_cuts_list[0] diff --git a/tests/keras_tests/function_tests/test_graph_max_cut.py b/tests/keras_tests/function_tests/test_graph_max_cut.py index 780b2e3b4..43a0626a0 100644 --- a/tests/keras_tests/function_tests/test_graph_max_cut.py +++ b/tests/keras_tests/function_tests/test_graph_max_cut.py @@ -139,3 +139,18 @@ def test_graph_max_cut_plain_graph_real_model(self): self.assertIsNotNone(cuts) self.assertTrue(len(cuts) > 0) self.assertTrue(max_cut_size >= memory_graph.memory_lbound_single_op) + + def test_graph_max_cut_deterministic_order(self): + input_shape = (8, 8, 3) + model = complex_model(input_shape) + graph = model_reader(model) + + solutions = [compute_graph_max_cut(MemoryGraph(graph)) for _ in range(10)] + + schedules, max_cut_sizes, cuts_solutions = zip(*solutions) + assert len(set(max_cut_sizes)) == 1 + # nodes within each cut can be in different order, and cuts can be in different order inside cuts list, + # but overall the cuts should be identical between different runs + sorted_cuts_solutions = [sorted(cut.get_sorted_node_names() for cut in cuts) for cuts in cuts_solutions] + assert all(cuts == sorted_cuts_solutions[0] for cuts in sorted_cuts_solutions[1:]) + From 2d7d8249f7babdac969b5297bea429bb1be77d69 Mon Sep 17 00:00:00 2001 From: irenab Date: Tue, 11 Mar 2025 18:11:14 +0200 Subject: [PATCH 2/2] convert Cut into dataclass, add sorted names signature computed once --- .../core/common/graph/memory_graph/cut.py | 32 +++++++++++-------- .../graph/memory_graph/max_cut_astar.py | 2 +- .../function_tests/test_graph_max_cut.py | 3 +- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/model_compression_toolkit/core/common/graph/memory_graph/cut.py b/model_compression_toolkit/core/common/graph/memory_graph/cut.py index ee307e687..8d4b5bae8 100644 --- a/model_compression_toolkit/core/common/graph/memory_graph/cut.py +++ b/model_compression_toolkit/core/common/graph/memory_graph/cut.py @@ -12,28 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import dataclass, field + from typing import List, Set from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.core.common.graph.memory_graph.memory_element import MemoryElements +@dataclass(frozen=True) class Cut: """ A Cut object that contains a set of ordered nodes and their memory elements. + + Args: + op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last). + op_record: A (unordered) set of the nodes in the cut. + mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes. """ + op_order: List[BaseNode] + op_record: Set[BaseNode] + mem_elements: MemoryElements - def __init__(self, op_order: List[BaseNode], op_record: Set[BaseNode], mem_elements: MemoryElements): - """ - Args: - op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last). - op_record: A (unordered) set of the nodes in the cut. - mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes. - """ + _sorted_elements_signature: str = field(init=False, default=None) - self.op_order = op_order - self.op_record = op_record - self.mem_elements = mem_elements + @property + def sorted_elements_signature(self): + if self._sorted_elements_signature is None: + object.__setattr__(self, '_sorted_elements_signature', + '_'.join(sorted([e.node_name for e in self.mem_elements.elements]))) + return self._sorted_elements_signature def memory_size(self) -> float: """ @@ -71,7 +79,3 @@ def __hash__(self): def __repr__(self): return f"" # pragma: no cover - - def get_sorted_node_names(self): - """ Return sorted node names of memory elements. """ - return sorted([e.node_name for e in self.mem_elements.elements]) diff --git a/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py b/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py index 25ae58bbb..a346c0043 100644 --- a/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +++ b/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py @@ -233,7 +233,7 @@ def _get_cut_to_expand(self, open_list: Set[Cut], costs: Dict[Cut, float], route ordered_cuts_list = sorted(open_list, key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate)), max_cut_len - len(routes[c]), - ''.join(c.get_sorted_node_names()))) + c.sorted_elements_signature)) assert len(ordered_cuts_list) > 0 return ordered_cuts_list[0] diff --git a/tests/keras_tests/function_tests/test_graph_max_cut.py b/tests/keras_tests/function_tests/test_graph_max_cut.py index 43a0626a0..e8a49a8a5 100644 --- a/tests/keras_tests/function_tests/test_graph_max_cut.py +++ b/tests/keras_tests/function_tests/test_graph_max_cut.py @@ -151,6 +151,5 @@ def test_graph_max_cut_deterministic_order(self): assert len(set(max_cut_sizes)) == 1 # nodes within each cut can be in different order, and cuts can be in different order inside cuts list, # but overall the cuts should be identical between different runs - sorted_cuts_solutions = [sorted(cut.get_sorted_node_names() for cut in cuts) for cuts in cuts_solutions] + sorted_cuts_solutions = [sorted(cut.sorted_elements_signature for cut in cuts) for cuts in cuts_solutions] assert all(cuts == sorted_cuts_solutions[0] for cuts in sorted_cuts_solutions[1:]) -