diff --git a/.github/workflows/run_tests_python310_pytorch22.yml b/.github/workflows/run_tests_python310_pytorch26.yml similarity index 85% rename from .github/workflows/run_tests_python310_pytorch22.yml rename to .github/workflows/run_tests_python310_pytorch26.yml index 7d31ab6bd..955b4e121 100644 --- a/.github/workflows/run_tests_python310_pytorch22.yml +++ b/.github/workflows/run_tests_python310_pytorch26.yml @@ -1,4 +1,4 @@ -name: Python 3.10, Pytorch 2.2 +name: Python 3.10, Pytorch 2.6 on: workflow_dispatch: # Allow manual triggers schedule: @@ -16,4 +16,4 @@ jobs: uses: ./.github/workflows/run_pytorch_tests.yml with: python-version: "3.10" - torch-version: "2.2.*" \ No newline at end of file + torch-version: "2.6.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python311_pytorch22.yml b/.github/workflows/run_tests_python311_pytorch26.yml similarity index 85% rename from .github/workflows/run_tests_python311_pytorch22.yml rename to .github/workflows/run_tests_python311_pytorch26.yml index 0fbb6e7f5..814c24550 100644 --- a/.github/workflows/run_tests_python311_pytorch22.yml +++ b/.github/workflows/run_tests_python311_pytorch26.yml @@ -1,4 +1,4 @@ -name: Python 3.11, Pytorch 2.2 +name: Python 3.11, Pytorch 2.6 on: workflow_dispatch: # Allow manual triggers schedule: @@ -16,4 +16,4 @@ jobs: uses: ./.github/workflows/run_pytorch_tests.yml with: python-version: "3.11" - torch-version: "2.2.*" \ No newline at end of file + torch-version: "2.6.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python312_pytorch22.yml b/.github/workflows/run_tests_python312_pytorch26.yml similarity index 85% rename from .github/workflows/run_tests_python312_pytorch22.yml rename to .github/workflows/run_tests_python312_pytorch26.yml index e40fb32b5..db98faf00 100644 --- a/.github/workflows/run_tests_python312_pytorch22.yml +++ b/.github/workflows/run_tests_python312_pytorch26.yml @@ -1,4 +1,4 @@ -name: Python 3.12, Pytorch 2.2 +name: Python 3.12, Pytorch 2.6 on: workflow_dispatch: # Allow manual triggers schedule: @@ -16,4 +16,4 @@ jobs: uses: ./.github/workflows/run_pytorch_tests.yml with: python-version: "3.12" - torch-version: "2.2.*" \ No newline at end of file + torch-version: "2.6.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python39_pytorch22.yml b/.github/workflows/run_tests_python39_pytorch26.yml similarity index 86% rename from .github/workflows/run_tests_python39_pytorch22.yml rename to .github/workflows/run_tests_python39_pytorch26.yml index a79c4ea69..7dc0732c6 100644 --- a/.github/workflows/run_tests_python39_pytorch22.yml +++ b/.github/workflows/run_tests_python39_pytorch26.yml @@ -1,4 +1,4 @@ -name: Python 3.9, Pytorch 2.2 +name: Python 3.9, Pytorch 2.6 on: workflow_dispatch: # Allow manual triggers schedule: @@ -16,4 +16,4 @@ jobs: uses: ./.github/workflows/run_pytorch_tests.yml with: python-version: "3.9" - torch-version: "2.2.*" \ No newline at end of file + torch-version: "2.6.*" \ No newline at end of file diff --git a/README.md b/README.md index b8710a95c..a6b6ec577 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ ______________________________________________________________________ License

- + @@ -30,7 +30,7 @@ ________________________________________________________________________________ ##

Getting Started
### Quick Installation -Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.1 or Tensorflow>=2.14. +Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.3 or Tensorflow>=2.14. ``` pip install model-compression-toolkit ``` @@ -130,12 +130,12 @@ Currently, MCT is being tested on various Python, Pytorch and TensorFlow version
Supported Versions Table -| | PyTorch 2.2 | PyTorch 2.3 | PyTorch 2.4 | PyTorch 2.5 | +| | PyTorch 2.3 | PyTorch 2.4 | PyTorch 2.5 | PyTorch 2.6 | |-------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch25.yml) | -| Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch25.yml) | -| Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml) | -| Python 3.12 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml) | +| Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch25.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch26.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch26.yml) | +| Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch25.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch26.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch26.yml) | +| Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch26.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch26.yml) | +| Python 3.12 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch26.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch26.yml) | | | TensorFlow 2.14 | TensorFlow 2.15 | |-------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| diff --git a/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py b/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py index e6318e83a..5706a40ff 100644 --- a/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +++ b/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py @@ -18,6 +18,8 @@ import torch.nn from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper + +from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder from model_compression_toolkit.verify_packages import FOUND_ONNX from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.pytorch.utils import to_torch_tensor @@ -98,15 +100,17 @@ def export(self, output_names=None) -> None: model_output = self.model(*model_input) if isinstance(model_input, (list, tuple)) else self.model( model_input) + input_nodes = [n for n in self.model.node_sort if n.type == DummyPlaceHolder] + input_names = [f"input_{i}" for i in range(len(input_nodes))] if len(input_nodes) > 1 else ["input"] + dynamic_axes = {name: {0: 'batch_size'} for name in input_names} if output_names is None: # Determine number of outputs and prepare output_names and dynamic_axes if isinstance(model_output, (list, tuple)): output_names = [f"output_{i}" for i in range(len(model_output))] - dynamic_axes = {'input': {0: 'batch_size'}} dynamic_axes.update({name: {0: 'batch_size'} for name in output_names}) else: output_names = ['output'] - dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} + dynamic_axes.update({'output': {0: 'batch_size'}}) else: if isinstance(model_output, (list, tuple)): num_of_outputs = len(model_output) @@ -115,9 +119,7 @@ def export(self, output_names=None) -> None: assert len(output_names) == num_of_outputs, (f"Mismatch between number of requested output names " f"({output_names}) and model output count " f"({num_of_outputs}):\n") - dynamic_axes = {'input': {0: 'batch_size'}} dynamic_axes.update({name: {0: 'batch_size'} for name in output_names}) - if hasattr(self.model, 'metadata'): onnx_bytes = BytesIO() torch.onnx.export(self.model, @@ -125,7 +127,7 @@ def export(self, output_names=None) -> None: onnx_bytes, opset_version=self._onnx_opset_version, verbose=False, - input_names=['input'], + input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes) onnx_model = onnx.load_from_string(onnx_bytes.getvalue()) @@ -137,7 +139,7 @@ def export(self, output_names=None) -> None: self.save_model_path, opset_version=self._onnx_opset_version, verbose=False, - input_names=['input'], + input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes) diff --git a/model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py b/model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py index c74665d0f..e27634a98 100644 --- a/model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +++ b/model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== from typing import Callable +from packaging import version from model_compression_toolkit.verify_packages import FOUND_TORCH from model_compression_toolkit.exporter.model_exporter.fw_agonstic.quantization_format import QuantizationFormat @@ -30,6 +31,9 @@ from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import FakelyQuantTorchScriptPyTorchExporter from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable + if version.parse(torch.__version__) >= version.parse("2.4"): + DEFAULT_ONNX_OPSET_VERSION = 20 + supported_serialization_quantization_export_dict = { PytorchExportSerializationFormat.TORCHSCRIPT: [QuantizationFormat.FAKELY_QUANT], PytorchExportSerializationFormat.ONNX: [QuantizationFormat.FAKELY_QUANT, QuantizationFormat.MCTQ] diff --git a/tests/pytorch_tests/exporter_tests/test_exporting_qat_models.py b/tests/pytorch_tests/exporter_tests/test_exporting_qat_models.py index 6b1b8c53c..d57f88515 100644 --- a/tests/pytorch_tests/exporter_tests/test_exporting_qat_models.py +++ b/tests/pytorch_tests/exporter_tests/test_exporting_qat_models.py @@ -51,7 +51,7 @@ def get_tmp_filepath(self): return tempfile.mkstemp('.pt')[1] def load_exported_model(self, filepath): - return torch.load(filepath) + return torch.load(filepath, weights_only=False) def infer(self, model, images): return model(images) @@ -74,7 +74,7 @@ def export_qat_model(self): # Assert qat_ready can be saved and loaded _qat_ready_model_tmp_filepath = tempfile.mkstemp('.pt')[1] torch.save(self.qat_ready, _qat_ready_model_tmp_filepath) - self.qat_ready = torch.load(_qat_ready_model_tmp_filepath) + self.qat_ready = torch.load(_qat_ready_model_tmp_filepath, weights_only=False) self.final_model = mct.qat.pytorch_quantization_aware_training_finalize_experimental(self.qat_ready) @@ -82,7 +82,7 @@ def export_qat_model(self): self.final_model(images[0]) _final_model_tmp_filepath = tempfile.mkstemp('.pt')[1] torch.save(self.final_model, _final_model_tmp_filepath) - self.final_model = torch.load(_final_model_tmp_filepath) + self.final_model = torch.load(_final_model_tmp_filepath, weights_only=False) self.final_model(images[0]) self.filepath = self.get_tmp_filepath() diff --git a/tests/pytorch_tests/function_tests/test_fully_quantized_exporter.py b/tests/pytorch_tests/function_tests/test_fully_quantized_exporter.py index 3c9b2581c..e35145905 100644 --- a/tests/pytorch_tests/function_tests/test_fully_quantized_exporter.py +++ b/tests/pytorch_tests/function_tests/test_fully_quantized_exporter.py @@ -94,7 +94,7 @@ def test_save_and_load_model(self): print(f"Float Pytorch .pth Model: {float_model_file}") model = copy.deepcopy(self.fully_quantized_mbv2) - model.load_state_dict(torch.load(model_file)) + model.load_state_dict(torch.load(model_file, weights_only=False)) model.eval() model(next(self.representative_data_gen())) diff --git a/tests_pytest/pytorch_tests/integration_tests/exporter/__init__.py b/tests_pytest/pytorch_tests/integration_tests/exporter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_pytest/pytorch_tests/unit_tests/exporter/fakely_quant_onnx_pytorch_exporter_output_name.py b/tests_pytest/pytorch_tests/integration_tests/exporter/test_onnx_input_output.py similarity index 54% rename from tests_pytest/pytorch_tests/unit_tests/exporter/fakely_quant_onnx_pytorch_exporter_output_name.py rename to tests_pytest/pytorch_tests/integration_tests/exporter/test_onnx_input_output.py index 1f894dc7a..5222298f6 100644 --- a/tests_pytest/pytorch_tests/unit_tests/exporter/fakely_quant_onnx_pytorch_exporter_output_name.py +++ b/tests_pytest/pytorch_tests/integration_tests/exporter/test_onnx_input_output.py @@ -17,11 +17,14 @@ import torch import torch.nn as nn +from model_compression_toolkit.core import QuantizationConfig +from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModel from model_compression_toolkit.core.pytorch.utils import set_model from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \ FakelyQuantONNXPyTorchExporter from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import DEFAULT_ONNX_OPSET_VERSION from model_compression_toolkit.exporter.model_wrapper import is_pytorch_layer_exportable +from tests_pytest.pytorch_tests.torch_test_util.torch_test_mixin import BaseTorchIntegrationTest class SingleOutputModel(nn.Module): @@ -42,7 +45,16 @@ def forward(self, x): return self.linear(x), x, x + 2 -class TestONNXExporter: +class MultipleInputsModel(nn.Module): + def __init__(self): + super(MultipleInputsModel, self).__init__() + self.linear = nn.Linear(8, 5) + + def forward(self, input1, input2): + return self.linear(input1) + self.linear(input2) + + +class TestONNXExporter(BaseTorchIntegrationTest): test_input_1 = None test_expected_1 = ['output'] @@ -59,27 +71,38 @@ class TestONNXExporter: test_expected_5 = ("Mismatch between number of requested output names (['out', 'out_11', 'out_22', 'out_33']) and " "model output count (3):\n") - def representative_data_gen(self, shape=(3, 8, 8), num_inputs=1, batch_size=2, num_iter=1): - for _ in range(num_iter): - yield [torch.randn(batch_size, *shape)] * num_inputs + def representative_data_gen(self, num_inputs=1): + batch_size, num_iter, shape = 2, 1, (3, 8, 8) + + def data_gen(): + for _ in range(num_iter): + yield [torch.randn(batch_size, *shape)] * num_inputs - def get_exporter(self, model, save_model_path): - return FakelyQuantONNXPyTorchExporter(model, - is_pytorch_layer_exportable, - save_model_path, - self.representative_data_gen, - onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION) + return data_gen - def export_model(self, model, save_model_path, output_names, expected_output_names): - exporter = self.get_exporter(model, save_model_path) + def get_pytorch_model(self, model, data_generator, minimal_tpc): + qc = QuantizationConfig() + graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc, + quant_config=qc) + pytorch_model = FloatPyTorchModel(graph=graph) + return pytorch_model + + def export_model(self, model, save_model_path, data_generator, output_names=None): + exporter = FakelyQuantONNXPyTorchExporter(model, + is_pytorch_layer_exportable, + save_model_path, + data_generator, + onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION) exporter.export(output_names) assert save_model_path.exists(), "ONNX file was not created" assert save_model_path.stat().st_size > 0, "ONNX file is empty" - # Load the ONNX model and check outputs onnx_model = onnx.load(str(save_model_path)) + return onnx_model + + def validate_outputs(self, onnx_model, expected_output_names): outputs = onnx_model.graph.output # Check number of outputs @@ -98,23 +121,35 @@ def export_model(self, model, save_model_path, output_names, expected_output_nam (MultipleOutputModel(), test_input_3, test_expected_3), (MultipleOutputModel(), test_input_4, test_expected_4), ]) - def test_output_model_name(self, tmp_path, model, output_names, expected_output_names): + def test_output_model_name(self, tmp_path, model, output_names, expected_output_names, minimal_tpc): save_model_path = tmp_path / "model.onnx" - set_model(model) - - self.export_model(model, save_model_path, output_names=output_names, - expected_output_names=expected_output_names) + data_generator = self.representative_data_gen(num_inputs=1) + pytorch_model = self.get_pytorch_model(model, data_generator, minimal_tpc) + onnx_model = self.export_model(pytorch_model, save_model_path, data_generator, output_names=output_names) + self.validate_outputs(onnx_model, expected_output_names) @pytest.mark.parametrize( ("model", "output_names", "expected_output_names"), [ (MultipleOutputModel(), test_input_5, test_expected_5), ]) - def test_wrong_number_output_model_name(self, tmp_path, model, output_names, expected_output_names): + def test_wrong_number_output_model_name(self, tmp_path, model, output_names, expected_output_names, minimal_tpc): save_model_path = tmp_path / "model.onnx" - set_model(model) - + data_generator = self.representative_data_gen(num_inputs=1) + pytorch_model = self.get_pytorch_model(model, data_generator, minimal_tpc) try: - self.export_model(model, save_model_path, output_names=output_names, - expected_output_names=expected_output_names) + onnx_model = self.export_model(pytorch_model, save_model_path, data_generator, output_names=output_names) + self.validate_outputs(onnx_model, expected_output_names) except Exception as e: assert expected_output_names == str(e) + + def test_multiple_inputs(self, minimal_tpc, tmp_path): + """ + Test that model with multiple inputs is exported to onnx file properly and that the exported onnx model + has all expected inputs. + """ + save_model_path = tmp_path / "model.onnx" + model = MultipleInputsModel() + data_generator = self.representative_data_gen(num_inputs=2) + pytorch_model = self.get_pytorch_model(model, data_generator, minimal_tpc) + onnx_model = self.export_model(pytorch_model, save_model_path, data_generator) + assert [_input.name for _input in onnx_model.graph.input] == ["input_0", "input_1"]