diff --git a/model_compression_toolkit/core/common/graph/memory_graph/cut.py b/model_compression_toolkit/core/common/graph/memory_graph/cut.py index bd21f502b..8d4b5bae8 100644 --- a/model_compression_toolkit/core/common/graph/memory_graph/cut.py +++ b/model_compression_toolkit/core/common/graph/memory_graph/cut.py @@ -12,28 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import dataclass, field + from typing import List, Set from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.core.common.graph.memory_graph.memory_element import MemoryElements +@dataclass(frozen=True) class Cut: """ A Cut object that contains a set of ordered nodes and their memory elements. - """ - def __init__(self, op_order: List[BaseNode], op_record: Set[BaseNode], mem_elements: MemoryElements): - """ - Args: - op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last). - op_record: A (unordered) set of the nodes in the cut. - mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes. - """ - - self.op_order = op_order - self.op_record = op_record - self.mem_elements = mem_elements + Args: + op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last). + op_record: A (unordered) set of the nodes in the cut. + mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes. + """ + op_order: List[BaseNode] + op_record: Set[BaseNode] + mem_elements: MemoryElements + + _sorted_elements_signature: str = field(init=False, default=None) + + @property + def sorted_elements_signature(self): + if self._sorted_elements_signature is None: + object.__setattr__(self, '_sorted_elements_signature', + '_'.join(sorted([e.node_name for e in self.mem_elements.elements]))) + return self._sorted_elements_signature def memory_size(self) -> float: """ diff --git a/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py b/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py index 2c121fd4c..a346c0043 100644 --- a/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +++ b/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py @@ -232,7 +232,8 @@ def _get_cut_to_expand(self, open_list: Set[Cut], costs: Dict[Cut, float], route max_cut_len = max([len(routes[c]) for c in open_list]) ordered_cuts_list = sorted(open_list, key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate)), - max_cut_len - len(routes[c]))) + max_cut_len - len(routes[c]), + c.sorted_elements_signature)) assert len(ordered_cuts_list) > 0 return ordered_cuts_list[0] diff --git a/tests/keras_tests/function_tests/test_graph_max_cut.py b/tests/keras_tests/function_tests/test_graph_max_cut.py index 780b2e3b4..e8a49a8a5 100644 --- a/tests/keras_tests/function_tests/test_graph_max_cut.py +++ b/tests/keras_tests/function_tests/test_graph_max_cut.py @@ -139,3 +139,17 @@ def test_graph_max_cut_plain_graph_real_model(self): self.assertIsNotNone(cuts) self.assertTrue(len(cuts) > 0) self.assertTrue(max_cut_size >= memory_graph.memory_lbound_single_op) + + def test_graph_max_cut_deterministic_order(self): + input_shape = (8, 8, 3) + model = complex_model(input_shape) + graph = model_reader(model) + + solutions = [compute_graph_max_cut(MemoryGraph(graph)) for _ in range(10)] + + schedules, max_cut_sizes, cuts_solutions = zip(*solutions) + assert len(set(max_cut_sizes)) == 1 + # nodes within each cut can be in different order, and cuts can be in different order inside cuts list, + # but overall the cuts should be identical between different runs + sorted_cuts_solutions = [sorted(cut.sorted_elements_signature for cut in cuts) for cuts in cuts_solutions] + assert all(cuts == sorted_cuts_solutions[0] for cuts in sorted_cuts_solutions[1:])