Skip to content

Commit a4bddad

Browse files
authored
Deterministic max cut (#1383)
* max cut - deterministic cuts * convert Cut into dataclass, add sorted names signature computed once
1 parent b816b3c commit a4bddad

File tree

3 files changed

+36
-13
lines changed

3 files changed

+36
-13
lines changed

model_compression_toolkit/core/common/graph/memory_graph/cut.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,36 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
from dataclasses import dataclass, field
16+
1517
from typing import List, Set
1618

1719
from model_compression_toolkit.core.common import BaseNode
1820
from model_compression_toolkit.core.common.graph.memory_graph.memory_element import MemoryElements
1921

2022

23+
@dataclass(frozen=True)
2124
class Cut:
2225
"""
2326
A Cut object that contains a set of ordered nodes and their memory elements.
24-
"""
2527
26-
def __init__(self, op_order: List[BaseNode], op_record: Set[BaseNode], mem_elements: MemoryElements):
27-
"""
28-
Args:
29-
op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last).
30-
op_record: A (unordered) set of the nodes in the cut.
31-
mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes.
32-
"""
33-
34-
self.op_order = op_order
35-
self.op_record = op_record
36-
self.mem_elements = mem_elements
28+
Args:
29+
op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last).
30+
op_record: A (unordered) set of the nodes in the cut.
31+
mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes.
32+
"""
33+
op_order: List[BaseNode]
34+
op_record: Set[BaseNode]
35+
mem_elements: MemoryElements
36+
37+
_sorted_elements_signature: str = field(init=False, default=None)
38+
39+
@property
40+
def sorted_elements_signature(self):
41+
if self._sorted_elements_signature is None:
42+
object.__setattr__(self, '_sorted_elements_signature',
43+
'_'.join(sorted([e.node_name for e in self.mem_elements.elements])))
44+
return self._sorted_elements_signature
3745

3846
def memory_size(self) -> float:
3947
"""

model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def _get_cut_to_expand(self, open_list: Set[Cut], costs: Dict[Cut, float], route
232232
max_cut_len = max([len(routes[c]) for c in open_list])
233233
ordered_cuts_list = sorted(open_list,
234234
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate)),
235-
max_cut_len - len(routes[c])))
235+
max_cut_len - len(routes[c]),
236+
c.sorted_elements_signature))
236237

237238
assert len(ordered_cuts_list) > 0
238239
return ordered_cuts_list[0]

tests/keras_tests/function_tests/test_graph_max_cut.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,17 @@ def test_graph_max_cut_plain_graph_real_model(self):
139139
self.assertIsNotNone(cuts)
140140
self.assertTrue(len(cuts) > 0)
141141
self.assertTrue(max_cut_size >= memory_graph.memory_lbound_single_op)
142+
143+
def test_graph_max_cut_deterministic_order(self):
144+
input_shape = (8, 8, 3)
145+
model = complex_model(input_shape)
146+
graph = model_reader(model)
147+
148+
solutions = [compute_graph_max_cut(MemoryGraph(graph)) for _ in range(10)]
149+
150+
schedules, max_cut_sizes, cuts_solutions = zip(*solutions)
151+
assert len(set(max_cut_sizes)) == 1
152+
# nodes within each cut can be in different order, and cuts can be in different order inside cuts list,
153+
# but overall the cuts should be identical between different runs
154+
sorted_cuts_solutions = [sorted(cut.sorted_elements_signature for cut in cuts) for cuts in cuts_solutions]
155+
assert all(cuts == sorted_cuts_solutions[0] for cuts in sorted_cuts_solutions[1:])

0 commit comments

Comments
 (0)