1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414# ==============================================================================
15+ from dataclasses import dataclass , field
16+
1517from typing import List , Set
1618
1719from model_compression_toolkit .core .common import BaseNode
1820from model_compression_toolkit .core .common .graph .memory_graph .memory_element import MemoryElements
1921
2022
23+ @dataclass (frozen = True )
2124class Cut :
2225 """
2326 A Cut object that contains a set of ordered nodes and their memory elements.
27+
28+ Args:
29+ op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last).
30+ op_record: A (unordered) set of the nodes in the cut.
31+ mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes.
2432 """
33+ op_order : List [BaseNode ]
34+ op_record : Set [BaseNode ]
35+ mem_elements : MemoryElements
2536
26- def __init__ (self , op_order : List [BaseNode ], op_record : Set [BaseNode ], mem_elements : MemoryElements ):
27- """
28- Args:
29- op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last).
30- op_record: A (unordered) set of the nodes in the cut.
31- mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes.
32- """
37+ records_names : Set [str ] = field (init = False )
38+ _sorted_elements_signature : str = field (init = False , default = None )
39+
40+ def __post_init__ (self ):
41+ # set frozen attributes
42+ object .__setattr__ (self , 'records_names' , {op .name for op in self .op_record })
3343
34- self .op_order = op_order
35- self .op_record = op_record
36- self .mem_elements = mem_elements
44+ @property
45+ def sorted_elements_signature (self ):
46+ if self ._sorted_elements_signature is None :
47+ object .__setattr__ (self , '_sorted_elements_signature' ,
48+ '_' .join (sorted ([e .node_name for e in self .mem_elements .elements ])))
49+ return self ._sorted_elements_signature
3750
3851 def memory_size (self ) -> float :
3952 """
@@ -42,15 +55,6 @@ def memory_size(self) -> float:
4255
4356 return self .mem_elements .total_size
4457
45- def get_record_names (self ) -> Set [str ]:
46- """
47- Builds a set of the cut nodes' names.
48-
49- Returns: a set with the nodes' names.
50- """
51-
52- return {op .name for op in self .op_record }
53-
5458 def __eq__ (self , other ) -> bool :
5559 """
5660 Overrides the class equality method.
@@ -71,7 +75,3 @@ def __hash__(self):
7175
7276 def __repr__ (self ):
7377 return f"<Cut: Nodes={ [e .node_name for e in self .mem_elements .elements ]} , size={ self .memory_size ()} >" # pragma: no cover
74-
75- def get_sorted_node_names (self ):
76- """ Return sorted node names of memory elements. """
77- return sorted ([e .node_name for e in self .mem_elements .elements ])
0 commit comments