Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def __init__(self,
self.wrapper = wrapper
self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn
self.reuse_groups = {}
self._add_modules()
self._reused_nodes = []

self._add_all_modules()

# todo: Move to parent class BaseModelBuilder
@property
Expand Down Expand Up @@ -286,17 +288,37 @@ def wrap(self, node):
node_op = self.wrapper(node, node_builder(node))
return node_op

def _add_modules(self):
def _add_all_modules(self):
"""
Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel.
To assure all required nodes for the reused nodes are already initialized, adds none-reused nodes first,
then adds the reused nodes.
"""
self._add_modules(reused_nodes_only=False) # add none-reused nodes
self._add_modules(reused_nodes_only=True) # add reused nodes

def _add_modules(self, reused_nodes_only=False):
"""
Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel
Args:
reused_nodes_only: whether to go over the reuse nodes list or not.
In case reuse_nodes_only is False - will go over all nodes, and add reused nodes to self._reused_nodes
In case reuse_nodes_only is True - will go over self._reused_nodes only.

"""
for node in self.node_sort:
if node.reuse:
# If the node is reused, retrieve the original module
nodes = self._reused_nodes if reused_nodes_only else self.node_sort
for node in nodes:
if node.reuse and reused_nodes_only:
if node.reuse_group not in self.reuse_groups:
Logger.critical(f"Reuse group {node.reuse_group} not found for node {node.name}")
raise Exception(f"Reuse group {node.reuse_group} not found for node {node.name}. "
f"Make sure you first call the method with reused_nodes_only=False")
else:
node_op = self.reuse_groups[node.reuse_group] # retrieve the original module

elif node.reuse: # add node to reused list, and go over the list after all other nodes were created
self._reused_nodes.append(node)
continue

node_op = self.reuse_groups[node.reuse_group]
else:
# If it's not reused, create a new module
node_op = self.wrap(node)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import re

import pytest
import torch
from torch import nn
import torch.testing as ptt

from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PytorchModel
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
from tests_pytest.pytorch_tests.torch_test_util.torch_test_mixin import BaseTorchIntegrationTest


def get_data_generator():
def data_gen():
yield [torch.rand(1, 3, 5, 5)]
return data_gen


def get_model_with_reused_weights():
class Model(nn.Module):

def __init__(self):
super().__init__()
self.shared_conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3)
self.another_conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3)
self.fc1 = nn.Linear(8 * 3 * 3, 10)
self.fc2 = nn.Linear(8 * 3 * 3, 10)

def forward(self, x):
x1 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3)
x2 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3)
x = x1 + x2
return x

return Model()


class TestWeightsReuse(BaseTorchIntegrationTest):
"""
Test that reused nodes are always initiated after their group node was initiated.
We test it by validating that building pytorch model back from the graph succeed with no errors
(specifically _add_all_modules method), independently to the graph nodes topological order.
"""

def test_weights_reuse_toposort(self, minimal_tpc):
"""
Test that reused nodes are successfully initiated after their group node was initiated, and that they use the
same weights as their group node.
Test it with nodes sorted in topological order.
"""
model = get_model_with_reused_weights()
data_generator = get_data_generator()
graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc)
pytorch_model = PytorchModel(graph=graph)
assert len(pytorch_model._reused_nodes) == 1

def test_weights_reuse_reversed_toposort(self, minimal_tpc):
"""
Test that reused nodes are successfully initiated after their group node was initiated, and that they use the
same weights as their group node.
Test it with nodes sorted in reversed topological order.
"""
model = get_model_with_reused_weights()
data_generator = get_data_generator()
graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc)
pytorch_model = PytorchModel(graph=graph)

pytorch_model.node_sort.reverse()
pytorch_model.reuse_groups = {}
pytorch_model._reused_nodes = []
pytorch_model._add_all_modules()
assert len(pytorch_model._reused_nodes) == 1

def test_reused_only_initialization(self, minimal_tpc):
"""
Test that in case reused nodes are initiated before none-reused nodes, exception is raised.
"""
model = get_model_with_reused_weights()
data_generator = get_data_generator()
graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc)
pytorch_model = PytorchModel(graph=graph)
pytorch_model.reuse_groups = {}
reused_node = pytorch_model._reused_nodes[0]
with pytest.raises(Exception, match=re.escape(f"Reuse group {reused_node.reuse_group} not found for node "
f"{reused_node.name}. Make sure you first call the method with "
f"reused_nodes_only=False")):
pytorch_model._add_modules(reused_nodes_only=True)