From 0c9c1ec1798a4b61c0b59f2d910641fb7df3429c Mon Sep 17 00:00:00 2001 From: yarden-sony Date: Tue, 15 Apr 2025 14:46:49 +0300 Subject: [PATCH 1/8] reuse second pass --- .../back2framework/pytorch_model_builder.py | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index 2ba1986ef..617cdf149 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -234,7 +234,9 @@ def __init__(self, self.wrapper = wrapper self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn self.reuse_groups = {} - self._add_modules() + self.reused_nodes = [] + + self._add_all_modules() # todo: Move to parent class BaseModelBuilder @property @@ -286,16 +288,30 @@ def wrap(self, node): node_op = self.wrapper(node, node_builder(node)) return node_op - def _add_modules(self): + def _add_all_modules(self): """ Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel """ - for node in self.node_sort: - if node.reuse: - # If the node is reused, retrieve the original module - if node.reuse_group not in self.reuse_groups: - Logger.critical(f"Reuse group {node.reuse_group} not found for node {node.name}") + self._add_modules(reused_nodes_only=False) + self._add_modules(reused_nodes_only=True) + def _add_modules(self, reused_nodes_only=False): + """ + Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel + :param reuse_nodes_only: whether to go over the reuse nodes list or not. + In case reuse_nodes_only is False - will go over all nodes, and add reused nodes to self.reuse_nodes + In case reuse_nodes_only is True - will go over self.reused_nodes only. + + """ + nodes = self.reused_nodes if reused_nodes_only else self.node_sort + for node in nodes: + if node.reuse: # If the node is reused, retrieve the original module + if node.reuse_group not in self.reuse_groups: # original module wasn't created yet + if reused_nodes_only: + Logger.critical(f"Reuse group {node.reuse_group} not found for node {node.name}") + else: # add node to reused list, and go over the list after all other nodes were created + self.reused_nodes.append(node) + continue node_op = self.reuse_groups[node.reuse_group] else: # If it's not reused, create a new module @@ -304,6 +320,7 @@ def _add_modules(self): # Store the module for future reuse self.reuse_groups[node.reuse_group] = node_op + if isinstance(node, FunctionalNode): # for functional layers setattr(self, node.name, node_op) From c2ed5907514e57d6a92b04e19b2a07683aa8a910 Mon Sep 17 00:00:00 2001 From: yarden-sony Date: Tue, 15 Apr 2025 14:50:29 +0300 Subject: [PATCH 2/8] delete empty line --- .../core/pytorch/back2framework/pytorch_model_builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index 617cdf149..454211b47 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -320,7 +320,6 @@ def _add_modules(self, reused_nodes_only=False): # Store the module for future reuse self.reuse_groups[node.reuse_group] = node_op - if isinstance(node, FunctionalNode): # for functional layers setattr(self, node.name, node_op) From 2f63004a7374227a29c5c99e8f02945c64a5352b Mon Sep 17 00:00:00 2001 From: yarden-sony Date: Tue, 15 Apr 2025 16:06:26 +0300 Subject: [PATCH 3/8] rename field --- .../pytorch/back2framework/pytorch_model_builder.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index 454211b47..fbb27730d 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -234,7 +234,7 @@ def __init__(self, self.wrapper = wrapper self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn self.reuse_groups = {} - self.reused_nodes = [] + self._reused_nodes = [] self._add_all_modules() @@ -242,7 +242,7 @@ def __init__(self, @property def use_activation_holder_during_model_building(self) -> bool: """ - Returns: Whether or not the model builder uses a PytorchActivationQuantizationHolder during + Returns: Whether the model builder uses a PytorchActivationQuantizationHolder during model building (by adding it as a module when converting the graph to a Pytorch model). If so - the model builder expects the activation quantizers not to be wrapped in a PytorchQuantizeWrapper. @@ -300,17 +300,17 @@ def _add_modules(self, reused_nodes_only=False): Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel :param reuse_nodes_only: whether to go over the reuse nodes list or not. In case reuse_nodes_only is False - will go over all nodes, and add reused nodes to self.reuse_nodes - In case reuse_nodes_only is True - will go over self.reused_nodes only. + In case reuse_nodes_only is True - will go over self._reused_nodes only. """ - nodes = self.reused_nodes if reused_nodes_only else self.node_sort + nodes = self._reused_nodes if reused_nodes_only else self.node_sort for node in nodes: if node.reuse: # If the node is reused, retrieve the original module if node.reuse_group not in self.reuse_groups: # original module wasn't created yet if reused_nodes_only: Logger.critical(f"Reuse group {node.reuse_group} not found for node {node.name}") else: # add node to reused list, and go over the list after all other nodes were created - self.reused_nodes.append(node) + self._reused_nodes.append(node) continue node_op = self.reuse_groups[node.reuse_group] else: From e6c1a976792256b9c12dfd20a8da748f97166899 Mon Sep 17 00:00:00 2001 From: yarden-sony Date: Tue, 15 Apr 2025 16:15:24 +0300 Subject: [PATCH 4/8] revert unrelated change --- .../core/pytorch/back2framework/pytorch_model_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index fbb27730d..7de8fa649 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -242,7 +242,7 @@ def __init__(self, @property def use_activation_holder_during_model_building(self) -> bool: """ - Returns: Whether the model builder uses a PytorchActivationQuantizationHolder during + Returns: Whether or not the model builder uses a PytorchActivationQuantizationHolder during model building (by adding it as a module when converting the graph to a Pytorch model). If so - the model builder expects the activation quantizers not to be wrapped in a PytorchQuantizeWrapper. From 67560a2c7d0c83ab4f62842294aba883ed8bd1c0 Mon Sep 17 00:00:00 2001 From: yarden-sony Date: Thu, 17 Apr 2025 13:51:52 +0300 Subject: [PATCH 5/8] add reuse test --- .../back2framework/pytorch_model_builder.py | 34 ++++--- .../core/back2framework/__init__.py | 0 .../core/back2framework/test_weights_reuse.py | 88 +++++++++++++++++++ 3 files changed, 108 insertions(+), 14 deletions(-) create mode 100644 tests_pytest/pytorch_tests/unit_tests/core/back2framework/__init__.py create mode 100644 tests_pytest/pytorch_tests/unit_tests/core/back2framework/test_weights_reuse.py diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index 7de8fa649..6f873abc9 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -290,29 +290,35 @@ def wrap(self, node): def _add_all_modules(self): """ - Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel + Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel. + To assure all required nodes for the reused nodes are already initialized, adds none-reused nodes first, + then adds the reused nodes. """ - self._add_modules(reused_nodes_only=False) - self._add_modules(reused_nodes_only=True) + self._add_modules(reused_nodes_only=False) # add none-reused nodes + self._add_modules(reused_nodes_only=True) # add reused nodes def _add_modules(self, reused_nodes_only=False): """ Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel - :param reuse_nodes_only: whether to go over the reuse nodes list or not. - In case reuse_nodes_only is False - will go over all nodes, and add reused nodes to self.reuse_nodes - In case reuse_nodes_only is True - will go over self._reused_nodes only. + Args: + reused_nodes_only: whether to go over the reuse nodes list or not. + In case reuse_nodes_only is False - will go over all nodes, and add reused nodes to self._reused_nodes + In case reuse_nodes_only is True - will go over self._reused_nodes only. """ nodes = self._reused_nodes if reused_nodes_only else self.node_sort for node in nodes: - if node.reuse: # If the node is reused, retrieve the original module - if node.reuse_group not in self.reuse_groups: # original module wasn't created yet - if reused_nodes_only: - Logger.critical(f"Reuse group {node.reuse_group} not found for node {node.name}") - else: # add node to reused list, and go over the list after all other nodes were created - self._reused_nodes.append(node) - continue - node_op = self.reuse_groups[node.reuse_group] + if node.reuse and reused_nodes_only: + if node.reuse_group not in self.reuse_groups: + raise Exception(f"Reuse group {node.reuse_group} not found for node {node.name}. " + f"Make sure you first call the method with reused_nodes_only=False") + else: + node_op = self.reuse_groups[node.reuse_group] # retrieve the original module + + elif node.reuse: # add node to reused list, and go over the list after all other nodes were created + self._reused_nodes.append(node) + continue + else: # If it's not reused, create a new module node_op = self.wrap(node) diff --git a/tests_pytest/pytorch_tests/unit_tests/core/back2framework/__init__.py b/tests_pytest/pytorch_tests/unit_tests/core/back2framework/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_pytest/pytorch_tests/unit_tests/core/back2framework/test_weights_reuse.py b/tests_pytest/pytorch_tests/unit_tests/core/back2framework/test_weights_reuse.py new file mode 100644 index 000000000..6aa4f7238 --- /dev/null +++ b/tests_pytest/pytorch_tests/unit_tests/core/back2framework/test_weights_reuse.py @@ -0,0 +1,88 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from model_compression_toolkit.core import QuantizationConfig +from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph, graph_preparation_runner +from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PytorchModel +from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO +from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation +from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \ + AttachTpcToPytorch + +import torch +from torch import nn + + +def data_gen(): + yield [torch.rand(1, 3, 5, 5)] + + +class Model(nn.Module): + def __init__(self): + super().__init__() + self.shared_conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3) + self.fc1 = nn.Linear(8 * 3 * 3, 10) + self.fc2 = nn.Linear(8 * 3 * 3, 10) + + def forward(self, x): + batch_size = x.size(0) + + x1 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) + x2 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) + x1 = x1.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) + x2 = x2.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) + x1 = self.fc1(x1) # shape: (batch_size, 10) + x2 = self.fc2(x2) # shape: (batch_size, 10) + x = x1 + x2 + return x + + +def get_model_graph(model, minimal_tpc): + fw_impl = PytorchImplementation() + fw_info = DEFAULT_PYTORCH_INFO + return graph_preparation_runner(model, + data_gen, + QuantizationConfig(), + fw_info=fw_info, + fw_impl=fw_impl, + fqc=AttachTpcToPytorch().attach(minimal_tpc), + mixed_precision_enable=False, + running_gptq=False) + + +def test_weights_reuse_toposort(minimal_tpc): + """ + Test that reused nodes are successfully initiated after their group node was initiated. + Test it with nodes sorted in topological order. + """ + model = Model() + graph = get_model_graph(model, minimal_tpc) + pytorch_model = PytorchModel(graph=graph) + assert len(pytorch_model._reused_nodes) == 1 + + +def test_weights_reuse_reversed_toposort(minimal_tpc): + """ + Test that reused nodes are successfully initiated after their group node was initiated. + Test it with nodes sorted in reversed topological order. + """ + model = Model() + graph = get_model_graph(model, minimal_tpc) + pytorch_model = PytorchModel(graph=graph) + + pytorch_model.node_sort.reverse() + pytorch_model.reuse_groups = {} + pytorch_model._reused_nodes = [] + pytorch_model._add_all_modules() + assert len(pytorch_model._reused_nodes) == 1 From e2a9d8be26a09a79823c12db7998789c79c04010 Mon Sep 17 00:00:00 2001 From: yarden-sony Date: Sun, 20 Apr 2025 14:22:19 +0300 Subject: [PATCH 6/8] move test to integration tests --- .../core/back2framework/__init__.py | 0 .../core/back2framework/test_weights_reuse.py | 97 +++++++++++++++++++ .../core/back2framework/test_weights_reuse.py | 88 ----------------- 3 files changed, 97 insertions(+), 88 deletions(-) rename tests_pytest/pytorch_tests/{unit_tests => integration_tests}/core/back2framework/__init__.py (100%) create mode 100644 tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py delete mode 100644 tests_pytest/pytorch_tests/unit_tests/core/back2framework/test_weights_reuse.py diff --git a/tests_pytest/pytorch_tests/unit_tests/core/back2framework/__init__.py b/tests_pytest/pytorch_tests/integration_tests/core/back2framework/__init__.py similarity index 100% rename from tests_pytest/pytorch_tests/unit_tests/core/back2framework/__init__.py rename to tests_pytest/pytorch_tests/integration_tests/core/back2framework/__init__.py diff --git a/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py b/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py new file mode 100644 index 000000000..f571ed8e2 --- /dev/null +++ b/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py @@ -0,0 +1,97 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import re + +import pytest +import torch +from torch import nn + +from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PytorchModel +from tests_pytest.pytorch_tests.torch_test_util.torch_test_mixin import BaseTorchIntegrationTest + + +class TestWeightsReuse(BaseTorchIntegrationTest): + + @staticmethod + def get_data_generator(): + def data_gen(): + yield [torch.rand(1, 3, 5, 5)] + + return data_gen + + def get_model_with_reused_weights(self): + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.shared_conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3) + self.fc1 = nn.Linear(8 * 3 * 3, 10) + self.fc2 = nn.Linear(8 * 3 * 3, 10) + + def forward(self, x): + batch_size = x.size(0) + + x1 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) + x2 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) + x1 = x1.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) + x2 = x2.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) + x1 = self.fc1(x1) # shape: (batch_size, 10) + x2 = self.fc2(x2) # shape: (batch_size, 10) + x = x1 + x2 + return x + + return Model() + + def test_weights_reuse_toposort(self, minimal_tpc): + """ + Test that reused nodes are successfully initiated after their group node was initiated. + Test it with nodes sorted in topological order. + """ + model = self.get_model_with_reused_weights() + data_generator = self.get_data_generator() + graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc) + pytorch_model = PytorchModel(graph=graph) + assert len(pytorch_model._reused_nodes) == 1 + + def test_weights_reuse_reversed_toposort(self, minimal_tpc): + """ + Test that reused nodes are successfully initiated after their group node was initiated. + Test it with nodes sorted in reversed topological order. + """ + model = self.get_model_with_reused_weights() + data_generator = self.get_data_generator() + graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc) + pytorch_model = PytorchModel(graph=graph) + + pytorch_model.node_sort.reverse() + pytorch_model.reuse_groups = {} + pytorch_model._reused_nodes = [] + pytorch_model._add_all_modules() + assert len(pytorch_model._reused_nodes) == 1 + + def test_reused_only_initialization(self, minimal_tpc): + """ + Test that in case reused nodes are initiated before none-reused nodes, exception is raised. + """ + model = self.get_model_with_reused_weights() + data_generator = self.get_data_generator() + graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc) + pytorch_model = PytorchModel(graph=graph) + pytorch_model.reuse_groups = {} + reused_node = pytorch_model._reused_nodes[0] + with pytest.raises(Exception, match=re.escape(f"Reuse group {reused_node.reuse_group} not found for node " + f"{reused_node.name}. Make sure you first call the method with " + f"reused_nodes_only=False")): + pytorch_model._add_modules(reused_nodes_only=True) diff --git a/tests_pytest/pytorch_tests/unit_tests/core/back2framework/test_weights_reuse.py b/tests_pytest/pytorch_tests/unit_tests/core/back2framework/test_weights_reuse.py deleted file mode 100644 index 6aa4f7238..000000000 --- a/tests_pytest/pytorch_tests/unit_tests/core/back2framework/test_weights_reuse.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from model_compression_toolkit.core import QuantizationConfig -from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph, graph_preparation_runner -from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PytorchModel -from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO -from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation -from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \ - AttachTpcToPytorch - -import torch -from torch import nn - - -def data_gen(): - yield [torch.rand(1, 3, 5, 5)] - - -class Model(nn.Module): - def __init__(self): - super().__init__() - self.shared_conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3) - self.fc1 = nn.Linear(8 * 3 * 3, 10) - self.fc2 = nn.Linear(8 * 3 * 3, 10) - - def forward(self, x): - batch_size = x.size(0) - - x1 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) - x2 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) - x1 = x1.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) - x2 = x2.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) - x1 = self.fc1(x1) # shape: (batch_size, 10) - x2 = self.fc2(x2) # shape: (batch_size, 10) - x = x1 + x2 - return x - - -def get_model_graph(model, minimal_tpc): - fw_impl = PytorchImplementation() - fw_info = DEFAULT_PYTORCH_INFO - return graph_preparation_runner(model, - data_gen, - QuantizationConfig(), - fw_info=fw_info, - fw_impl=fw_impl, - fqc=AttachTpcToPytorch().attach(minimal_tpc), - mixed_precision_enable=False, - running_gptq=False) - - -def test_weights_reuse_toposort(minimal_tpc): - """ - Test that reused nodes are successfully initiated after their group node was initiated. - Test it with nodes sorted in topological order. - """ - model = Model() - graph = get_model_graph(model, minimal_tpc) - pytorch_model = PytorchModel(graph=graph) - assert len(pytorch_model._reused_nodes) == 1 - - -def test_weights_reuse_reversed_toposort(minimal_tpc): - """ - Test that reused nodes are successfully initiated after their group node was initiated. - Test it with nodes sorted in reversed topological order. - """ - model = Model() - graph = get_model_graph(model, minimal_tpc) - pytorch_model = PytorchModel(graph=graph) - - pytorch_model.node_sort.reverse() - pytorch_model.reuse_groups = {} - pytorch_model._reused_nodes = [] - pytorch_model._add_all_modules() - assert len(pytorch_model._reused_nodes) == 1 From 6d53fa503873fc79b141a6ac270ddb5432a63c72 Mon Sep 17 00:00:00 2001 From: yarden-sony Date: Thu, 24 Apr 2025 09:59:24 +0300 Subject: [PATCH 7/8] validate weights reuse --- .../core/back2framework/test_weights_reuse.py | 90 ++++++++++++------- 1 file changed, 56 insertions(+), 34 deletions(-) diff --git a/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py b/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py index f571ed8e2..fa88df941 100644 --- a/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py +++ b/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py @@ -17,61 +17,76 @@ import pytest import torch from torch import nn +import torch.testing as ptt from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PytorchModel +from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device from tests_pytest.pytorch_tests.torch_test_util.torch_test_mixin import BaseTorchIntegrationTest -class TestWeightsReuse(BaseTorchIntegrationTest): +def get_data_generator(): + def data_gen(): + yield [torch.rand(1, 3, 5, 5)] + return data_gen + - @staticmethod - def get_data_generator(): - def data_gen(): - yield [torch.rand(1, 3, 5, 5)] +def get_model_with_reused_weights(): + class Model(nn.Module): - return data_gen + def __init__(self): + super().__init__() + self.shared_conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3) + self.another_conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3) + self.fc1 = nn.Linear(8 * 3 * 3, 10) + self.fc2 = nn.Linear(8 * 3 * 3, 10) + self.fc3 = nn.Linear(8 * 3 * 3, 10) - def get_model_with_reused_weights(self): - class Model(nn.Module): + def forward(self, x): + batch_size = x.size(0) - def __init__(self): - super().__init__() - self.shared_conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3) - self.fc1 = nn.Linear(8 * 3 * 3, 10) - self.fc2 = nn.Linear(8 * 3 * 3, 10) + x1 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) + x2 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) + x3 = self.another_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) + x1 = x1.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) + x2 = x2.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) + x3 = x3.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) + x1 = self.fc1(x1) # shape: (batch_size, 10) + x2 = self.fc2(x2) # shape: (batch_size, 10) + x3 = self.fc2(x3) # shape: (batch_size, 10) + x = x1 + x2 + x3 + return x - def forward(self, x): - batch_size = x.size(0) + return Model() - x1 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) - x2 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) - x1 = x1.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) - x2 = x2.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) - x1 = self.fc1(x1) # shape: (batch_size, 10) - x2 = self.fc2(x2) # shape: (batch_size, 10) - x = x1 + x2 - return x - return Model() +class TestWeightsReuse(BaseTorchIntegrationTest): def test_weights_reuse_toposort(self, minimal_tpc): """ - Test that reused nodes are successfully initiated after their group node was initiated. + Test that reused nodes are successfully initiated after their group node was initiated, and that they use the + same weights as their group node. Test it with nodes sorted in topological order. """ - model = self.get_model_with_reused_weights() - data_generator = self.get_data_generator() + model = get_model_with_reused_weights() + data_generator = get_data_generator() graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc) pytorch_model = PytorchModel(graph=graph) - assert len(pytorch_model._reused_nodes) == 1 + input_tensor = torch.randn(1, 3, 5, 5).to(device=get_working_device()) + x1 = pytorch_model.shared_conv(input_tensor) + x2 = pytorch_model.shared_conv(input_tensor) + x3 = pytorch_model.another_conv(input_tensor) + ptt.assert_close(x1, x2, msg='Test failed: x1 and x2 do not share the same weights!') + with pytest.raises(AssertionError, match=re.escape('Test failed: x1 and x3 should not share the same weights!')): + ptt.assert_close(x1, x3, msg='Test failed: x1 and x3 should not share the same weights!!') def test_weights_reuse_reversed_toposort(self, minimal_tpc): """ - Test that reused nodes are successfully initiated after their group node was initiated. + Test that reused nodes are successfully initiated after their group node was initiated, and that they use the + same weights as their group node. Test it with nodes sorted in reversed topological order. """ - model = self.get_model_with_reused_weights() - data_generator = self.get_data_generator() + model = get_model_with_reused_weights() + data_generator = get_data_generator() graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc) pytorch_model = PytorchModel(graph=graph) @@ -79,14 +94,21 @@ def test_weights_reuse_reversed_toposort(self, minimal_tpc): pytorch_model.reuse_groups = {} pytorch_model._reused_nodes = [] pytorch_model._add_all_modules() - assert len(pytorch_model._reused_nodes) == 1 + + input_tensor = torch.randn(1, 3, 5, 5).to(device=get_working_device()) + x1 = pytorch_model.shared_conv(input_tensor) + x2 = pytorch_model.shared_conv(input_tensor) + x3 = pytorch_model.another_conv(input_tensor) + ptt.assert_close(x1, x2, msg='Test failed: x1 and x2 do not share the same weights!') + with pytest.raises(AssertionError, match=re.escape('Test failed: x1 and x3 should not share the same weights!!')): + ptt.assert_close(x1, x3, msg='Test failed: x1 and x3 should not share the same weights!!') def test_reused_only_initialization(self, minimal_tpc): """ Test that in case reused nodes are initiated before none-reused nodes, exception is raised. """ - model = self.get_model_with_reused_weights() - data_generator = self.get_data_generator() + model = get_model_with_reused_weights() + data_generator = get_data_generator() graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc) pytorch_model = PytorchModel(graph=graph) pytorch_model.reuse_groups = {} From e3bec63b99303d8fd076e6f79e3846595c1f1fe6 Mon Sep 17 00:00:00 2001 From: yarden-sony Date: Thu, 24 Apr 2025 17:51:41 +0300 Subject: [PATCH 8/8] changes according to PR --- .../core/back2framework/test_weights_reuse.py | 34 +++++-------------- 1 file changed, 8 insertions(+), 26 deletions(-) diff --git a/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py b/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py index fa88df941..e82fca631 100644 --- a/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py +++ b/tests_pytest/pytorch_tests/integration_tests/core/back2framework/test_weights_reuse.py @@ -39,27 +39,22 @@ def __init__(self): self.another_conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3) self.fc1 = nn.Linear(8 * 3 * 3, 10) self.fc2 = nn.Linear(8 * 3 * 3, 10) - self.fc3 = nn.Linear(8 * 3 * 3, 10) def forward(self, x): - batch_size = x.size(0) - x1 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) x2 = self.shared_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) - x3 = self.another_conv(x) # shape: (batch_size, 8, height-2, width-2) = (1, 8, 3, 3) - x1 = x1.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) - x2 = x2.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) - x3 = x3.view(batch_size, -1) # shape: (batch_size, 8 * 3 * 3) - x1 = self.fc1(x1) # shape: (batch_size, 10) - x2 = self.fc2(x2) # shape: (batch_size, 10) - x3 = self.fc2(x3) # shape: (batch_size, 10) - x = x1 + x2 + x3 + x = x1 + x2 return x return Model() class TestWeightsReuse(BaseTorchIntegrationTest): + """ + Test that reused nodes are always initiated after their group node was initiated. + We test it by validating that building pytorch model back from the graph succeed with no errors + (specifically _add_all_modules method), independently to the graph nodes topological order. + """ def test_weights_reuse_toposort(self, minimal_tpc): """ @@ -71,13 +66,7 @@ def test_weights_reuse_toposort(self, minimal_tpc): data_generator = get_data_generator() graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc) pytorch_model = PytorchModel(graph=graph) - input_tensor = torch.randn(1, 3, 5, 5).to(device=get_working_device()) - x1 = pytorch_model.shared_conv(input_tensor) - x2 = pytorch_model.shared_conv(input_tensor) - x3 = pytorch_model.another_conv(input_tensor) - ptt.assert_close(x1, x2, msg='Test failed: x1 and x2 do not share the same weights!') - with pytest.raises(AssertionError, match=re.escape('Test failed: x1 and x3 should not share the same weights!')): - ptt.assert_close(x1, x3, msg='Test failed: x1 and x3 should not share the same weights!!') + assert len(pytorch_model._reused_nodes) == 1 def test_weights_reuse_reversed_toposort(self, minimal_tpc): """ @@ -94,14 +83,7 @@ def test_weights_reuse_reversed_toposort(self, minimal_tpc): pytorch_model.reuse_groups = {} pytorch_model._reused_nodes = [] pytorch_model._add_all_modules() - - input_tensor = torch.randn(1, 3, 5, 5).to(device=get_working_device()) - x1 = pytorch_model.shared_conv(input_tensor) - x2 = pytorch_model.shared_conv(input_tensor) - x3 = pytorch_model.another_conv(input_tensor) - ptt.assert_close(x1, x2, msg='Test failed: x1 and x2 do not share the same weights!') - with pytest.raises(AssertionError, match=re.escape('Test failed: x1 and x3 should not share the same weights!!')): - ptt.assert_close(x1, x3, msg='Test failed: x1 and x3 should not share the same weights!!') + assert len(pytorch_model._reused_nodes) == 1 def test_reused_only_initialization(self, minimal_tpc): """