Skip to content

Commit 8649104

Browse files
irenabirenab
authored andcommitted
convert Cut into dataclass, add sorted names signature computed once
1 parent dd561e9 commit 8649104

File tree

4 files changed

+27
-29
lines changed

4 files changed

+27
-29
lines changed

model_compression_toolkit/core/common/graph/memory_graph/cut.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,41 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
from dataclasses import dataclass, field
16+
1517
from typing import List, Set
1618

1719
from model_compression_toolkit.core.common import BaseNode
1820
from model_compression_toolkit.core.common.graph.memory_graph.memory_element import MemoryElements
1921

2022

23+
@dataclass(frozen=True)
2124
class Cut:
2225
"""
2326
A Cut object that contains a set of ordered nodes and their memory elements.
27+
28+
Args:
29+
op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last).
30+
op_record: A (unordered) set of the nodes in the cut.
31+
mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes.
2432
"""
33+
op_order: List[BaseNode]
34+
op_record: Set[BaseNode]
35+
mem_elements: MemoryElements
2536

26-
def __init__(self, op_order: List[BaseNode], op_record: Set[BaseNode], mem_elements: MemoryElements):
27-
"""
28-
Args:
29-
op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last).
30-
op_record: A (unordered) set of the nodes in the cut.
31-
mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes.
32-
"""
37+
records_names: Set[str] = field(init=False)
38+
_sorted_elements_signature: str = field(init=False, default=None)
39+
40+
def __post_init__(self):
41+
# set frozen attributes
42+
object.__setattr__(self, 'records_names', {op.name for op in self.op_record})
3343

34-
self.op_order = op_order
35-
self.op_record = op_record
36-
self.mem_elements = mem_elements
44+
@property
45+
def sorted_elements_signature(self):
46+
if self._sorted_elements_signature is None:
47+
object.__setattr__(self, '_sorted_elements_signature',
48+
'_'.join(sorted([e.node_name for e in self.mem_elements.elements])))
49+
return self._sorted_elements_signature
3750

3851
def memory_size(self) -> float:
3952
"""
@@ -42,15 +55,6 @@ def memory_size(self) -> float:
4255

4356
return self.mem_elements.total_size
4457

45-
def get_record_names(self) -> Set[str]:
46-
"""
47-
Builds a set of the cut nodes' names.
48-
49-
Returns: a set with the nodes' names.
50-
"""
51-
52-
return {op.name for op in self.op_record}
53-
5458
def __eq__(self, other) -> bool:
5559
"""
5660
Overrides the class equality method.
@@ -71,7 +75,3 @@ def __hash__(self):
7175

7276
def __repr__(self):
7377
return f"<Cut: Nodes={[e.node_name for e in self.mem_elements.elements]}, size={self.memory_size()}>" # pragma: no cover
74-
75-
def get_sorted_node_names(self):
76-
""" Return sorted node names of memory elements. """
77-
return sorted([e.node_name for e in self.mem_elements.elements])

model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _get_cut_to_expand(self, open_list: Set[Cut], costs: Dict[Cut, float], route
233233
ordered_cuts_list = sorted(open_list,
234234
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate)),
235235
max_cut_len - len(routes[c]),
236-
''.join(c.get_sorted_node_names())))
236+
c.sorted_elements_signature))
237237

238238
assert len(ordered_cuts_list) > 0
239239
return ordered_cuts_list[0]
@@ -250,8 +250,7 @@ def clean_memory_for_next_step(self, cut: Cut) -> Cut:
250250
251251
"""
252252

253-
cut_records_names = cut.get_record_names()
254-
filtered_memory_elements = set(filter(lambda elm: not all(child.name in cut_records_names for child in
253+
filtered_memory_elements = set(filter(lambda elm: not all(child.name in cut.records_names for child in
255254
self.memory_graph.activation_tensor_children(elm)),
256255
cut.mem_elements.elements))
257256

model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def compute_resource_utilization_data(in_model: Any,
6868
fw_impl,
6969
fqc,
7070
bit_width_config=core_config.bit_width_config,
71-
mixed_precision_enable=mixed_precision_enable,
71+
mixed_precision_enable=True,
7272
running_gptq=False)
7373

7474
ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info)

tests/keras_tests/function_tests/test_graph_max_cut.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,5 @@ def test_graph_max_cut_deterministic_order(self):
151151
assert len(set(max_cut_sizes)) == 1
152152
# nodes within each cut can be in different order, and cuts can be in different order inside cuts list,
153153
# but overall the cuts should be identical between different runs
154-
sorted_cuts_solutions = [sorted(cut.get_sorted_node_names() for cut in cuts) for cuts in cuts_solutions]
154+
sorted_cuts_solutions = [sorted(cut.sorted_elements_signature for cut in cuts) for cuts in cuts_solutions]
155155
assert all(cuts == sorted_cuts_solutions[0] for cuts in sorted_cuts_solutions[1:])
156-

0 commit comments

Comments
 (0)