diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index 2ba1986ef..6f873abc9 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -234,7 +234,9 @@ def __init__(self, self.wrapper = wrapper self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn self.reuse_groups = {} - self._add_modules() + self._reused_nodes = [] + + self._add_all_modules() # todo: Move to parent class BaseModelBuilder @property @@ -286,17 +288,37 @@ def wrap(self, node): node_op = self.wrapper(node, node_builder(node)) return node_op - def _add_modules(self): + def _add_all_modules(self): + """ + Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel. + To assure all required nodes for the reused nodes are already initialized, adds none-reused nodes first, + then adds the reused nodes. + """ + self._add_modules(reused_nodes_only=False) # add none-reused nodes + self._add_modules(reused_nodes_only=True) # add reused nodes + + def _add_modules(self, reused_nodes_only=False): """ Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel + Args: + reused_nodes_only: whether to go over the reuse nodes list or not. + In case reuse_nodes_only is False - will go over all nodes, and add reused nodes to self._reused_nodes + In case reuse_nodes_only is True - will go over self._reused_nodes only. + """ - for node in self.node_sort: - if node.reuse: - # If the node is reused, retrieve the original module + nodes = self._reused_nodes if reused_nodes_only else self.node_sort + for node in nodes: + if node.reuse and reused_nodes_only: if node.reuse_group not in self.reuse_groups: - Logger.critical(f"Reuse group {node.reuse_group} not found for node {node.name}") + raise Exception(f"Reuse group {node.reuse_group} not found for node {node.name}. " + f"Make sure you first call the method with reused_nodes_only=False") + else: + node_op = self.reuse_groups[node.reuse_group] # retrieve the original module + + elif node.reuse: # add node to reused list, and go over the list after all other nodes were created + self._reused_nodes.append(node) + continue - node_op = self.reuse_groups[node.reuse_group] else: # If it's not reused, create a new module node_op = self.wrap(node) diff --git a/tests_pytest/pytorch_tests/integration_tests/core/back2framework/__init__.py b/tests_pytest/pytorch_tests/integration_tests/core/back2framework/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py b/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py new file mode 100644 index 000000000..e82fca631 --- /dev/null +++ b/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py @@ -0,0 +1,101 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import re + +import pytest +import torch +from torch import nn +import torch.testing as ptt + +from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PytorchModel +from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device +from tests_pytest.pytorch_tests.torch_test_util.torch_test_mixin import BaseTorchIntegrationTest + + +def get_data_generator(): + def data_gen(): + yield [torch.rand(1, 3, 5, 5)] + return data_gen + + +def get_model_with_reused_weights(): + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.shared_conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3) + self.another_conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3) + self.fc1 = nn.Linear(8 * 3 * 3, 10) + self.fc2 = nn.Linear(8 * 3 * 3, 10) + + def forward(self, x): + x1 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) + x2 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) + x = x1 + x2 + return x + + return Model() + + +class TestWeightsReuse(BaseTorchIntegrationTest): + """ + Test that reused nodes are always initiated after their group node was initiated. + We test it by validating that building pytorch model back from the graph succeed with no errors + (specifically _add_all_modules method), independently to the graph nodes topological order. + """ + + def test_weights_reuse_toposort(self, minimal_tpc): + """ + Test that reused nodes are successfully initiated after their group node was initiated, and that they use the + same weights as their group node. + Test it with nodes sorted in topological order. + """ + model = get_model_with_reused_weights() + data_generator = get_data_generator() + graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc) + pytorch_model = PytorchModel(graph=graph) + assert len(pytorch_model._reused_nodes) == 1 + + def test_weights_reuse_reversed_toposort(self, minimal_tpc): + """ + Test that reused nodes are successfully initiated after their group node was initiated, and that they use the + same weights as their group node. + Test it with nodes sorted in reversed topological order. + """ + model = get_model_with_reused_weights() + data_generator = get_data_generator() + graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc) + pytorch_model = PytorchModel(graph=graph) + + pytorch_model.node_sort.reverse() + pytorch_model.reuse_groups = {} + pytorch_model._reused_nodes = [] + pytorch_model._add_all_modules() + assert len(pytorch_model._reused_nodes) == 1 + + def test_reused_only_initialization(self, minimal_tpc): + """ + Test that in case reused nodes are initiated before none-reused nodes, exception is raised. + """ + model = get_model_with_reused_weights() + data_generator = get_data_generator() + graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc) + pytorch_model = PytorchModel(graph=graph) + pytorch_model.reuse_groups = {} + reused_node = pytorch_model._reused_nodes[0] + with pytest.raises(Exception, match=re.escape(f"Reuse group {reused_node.reuse_group} not found for node " + f"{reused_node.name}. Make sure you first call the method with " + f"reused_nodes_only=False")): + pytorch_model._add_modules(reused_nodes_only=True)