Skip to content

Commit 2d7d824

Browse files
irenabirenab
authored andcommitted
convert Cut into dataclass, add sorted names signature computed once
1 parent dd561e9 commit 2d7d824

File tree

3 files changed

+20
-17
lines changed

3 files changed

+20
-17
lines changed

model_compression_toolkit/core/common/graph/memory_graph/cut.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,36 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
from dataclasses import dataclass, field
16+
1517
from typing import List, Set
1618

1719
from model_compression_toolkit.core.common import BaseNode
1820
from model_compression_toolkit.core.common.graph.memory_graph.memory_element import MemoryElements
1921

2022

23+
@dataclass(frozen=True)
2124
class Cut:
2225
"""
2326
A Cut object that contains a set of ordered nodes and their memory elements.
27+
28+
Args:
29+
op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last).
30+
op_record: A (unordered) set of the nodes in the cut.
31+
mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes.
2432
"""
33+
op_order: List[BaseNode]
34+
op_record: Set[BaseNode]
35+
mem_elements: MemoryElements
2536

26-
def __init__(self, op_order: List[BaseNode], op_record: Set[BaseNode], mem_elements: MemoryElements):
27-
"""
28-
Args:
29-
op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last).
30-
op_record: A (unordered) set of the nodes in the cut.
31-
mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes.
32-
"""
37+
_sorted_elements_signature: str = field(init=False, default=None)
3338

34-
self.op_order = op_order
35-
self.op_record = op_record
36-
self.mem_elements = mem_elements
39+
@property
40+
def sorted_elements_signature(self):
41+
if self._sorted_elements_signature is None:
42+
object.__setattr__(self, '_sorted_elements_signature',
43+
'_'.join(sorted([e.node_name for e in self.mem_elements.elements])))
44+
return self._sorted_elements_signature
3745

3846
def memory_size(self) -> float:
3947
"""
@@ -71,7 +79,3 @@ def __hash__(self):
7179

7280
def __repr__(self):
7381
return f"<Cut: Nodes={[e.node_name for e in self.mem_elements.elements]}, size={self.memory_size()}>" # pragma: no cover
74-
75-
def get_sorted_node_names(self):
76-
""" Return sorted node names of memory elements. """
77-
return sorted([e.node_name for e in self.mem_elements.elements])

model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _get_cut_to_expand(self, open_list: Set[Cut], costs: Dict[Cut, float], route
233233
ordered_cuts_list = sorted(open_list,
234234
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate)),
235235
max_cut_len - len(routes[c]),
236-
''.join(c.get_sorted_node_names())))
236+
c.sorted_elements_signature))
237237

238238
assert len(ordered_cuts_list) > 0
239239
return ordered_cuts_list[0]

tests/keras_tests/function_tests/test_graph_max_cut.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,5 @@ def test_graph_max_cut_deterministic_order(self):
151151
assert len(set(max_cut_sizes)) == 1
152152
# nodes within each cut can be in different order, and cuts can be in different order inside cuts list,
153153
# but overall the cuts should be identical between different runs
154-
sorted_cuts_solutions = [sorted(cut.get_sorted_node_names() for cut in cuts) for cuts in cuts_solutions]
154+
sorted_cuts_solutions = [sorted(cut.sorted_elements_signature for cut in cuts) for cuts in cuts_solutions]
155155
assert all(cuts == sorted_cuts_solutions[0] for cuts in sorted_cuts_solutions[1:])
156-

0 commit comments

Comments
 (0)