|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | # ============================================================================== |
| 15 | +from dataclasses import dataclass, field |
| 16 | + |
15 | 17 | from typing import List, Set |
16 | 18 |
|
17 | 19 | from model_compression_toolkit.core.common import BaseNode |
18 | 20 | from model_compression_toolkit.core.common.graph.memory_graph.memory_element import MemoryElements |
19 | 21 |
|
20 | 22 |
|
| 23 | +@dataclass(frozen=True) |
21 | 24 | class Cut: |
22 | 25 | """ |
23 | 26 | A Cut object that contains a set of ordered nodes and their memory elements. |
| 27 | +
|
| 28 | + Args: |
| 29 | + op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last). |
| 30 | + op_record: A (unordered) set of the nodes in the cut. |
| 31 | + mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes. |
24 | 32 | """ |
| 33 | + op_order: List[BaseNode] |
| 34 | + op_record: Set[BaseNode] |
| 35 | + mem_elements: MemoryElements |
25 | 36 |
|
26 | | - def __init__(self, op_order: List[BaseNode], op_record: Set[BaseNode], mem_elements: MemoryElements): |
27 | | - """ |
28 | | - Args: |
29 | | - op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last). |
30 | | - op_record: A (unordered) set of the nodes in the cut. |
31 | | - mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes. |
32 | | - """ |
| 37 | + _sorted_elements_signature: str = field(init=False, default=None) |
33 | 38 |
|
34 | | - self.op_order = op_order |
35 | | - self.op_record = op_record |
36 | | - self.mem_elements = mem_elements |
| 39 | + @property |
| 40 | + def sorted_elements_signature(self): |
| 41 | + if self._sorted_elements_signature is None: |
| 42 | + object.__setattr__(self, '_sorted_elements_signature', |
| 43 | + '_'.join(sorted([e.node_name for e in self.mem_elements.elements]))) |
| 44 | + return self._sorted_elements_signature |
37 | 45 |
|
38 | 46 | def memory_size(self) -> float: |
39 | 47 | """ |
@@ -71,7 +79,3 @@ def __hash__(self): |
71 | 79 |
|
72 | 80 | def __repr__(self): |
73 | 81 | return f"<Cut: Nodes={[e.node_name for e in self.mem_elements.elements]}, size={self.memory_size()}>" # pragma: no cover |
74 | | - |
75 | | - def get_sorted_node_names(self): |
76 | | - """ Return sorted node names of memory elements. """ |
77 | | - return sorted([e.node_name for e in self.mem_elements.elements]) |
|
0 commit comments