Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def __init__(self,
self.wrapper = wrapper
self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn
self.reuse_groups = {}
self._add_modules()
self._reused_nodes = []

self._add_all_modules()

# todo: Move to parent class BaseModelBuilder
@property
Expand Down Expand Up @@ -286,16 +288,30 @@ def wrap(self, node):
node_op = self.wrapper(node, node_builder(node))
return node_op

def _add_modules(self):
def _add_all_modules(self):
"""
Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel
"""
for node in self.node_sort:
if node.reuse:
# If the node is reused, retrieve the original module
if node.reuse_group not in self.reuse_groups:
Logger.critical(f"Reuse group {node.reuse_group} not found for node {node.name}")
self._add_modules(reused_nodes_only=False)
self._add_modules(reused_nodes_only=True)

def _add_modules(self, reused_nodes_only=False):
"""
Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel
:param reuse_nodes_only: whether to go over the reuse nodes list or not.
In case reuse_nodes_only is False - will go over all nodes, and add reused nodes to self.reuse_nodes
In case reuse_nodes_only is True - will go over self._reused_nodes only.

"""
nodes = self._reused_nodes if reused_nodes_only else self.node_sort
for node in nodes:
if node.reuse: # If the node is reused, retrieve the original module
if node.reuse_group not in self.reuse_groups: # original module wasn't created yet
if reused_nodes_only:
Logger.critical(f"Reuse group {node.reuse_group} not found for node {node.name}")
else: # add node to reused list, and go over the list after all other nodes were created
self._reused_nodes.append(node)
continue
node_op = self.reuse_groups[node.reuse_group]
else:
# If it's not reused, create a new module
Expand Down