Skip to content

Commit 51d2aaf

Browse files
Multiple Inputs Support | Onnx Opset V20 | Add torch2.6 & Remove torch2.2 (#1431)
1 parent 10e70d6 commit 51d2aaf

File tree

11 files changed

+89
-48
lines changed

11 files changed

+89
-48
lines changed

.github/workflows/run_tests_python310_pytorch22.yml renamed to .github/workflows/run_tests_python310_pytorch26.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Python 3.10, Pytorch 2.2
1+
name: Python 3.10, Pytorch 2.6
22
on:
33
workflow_dispatch: # Allow manual triggers
44
schedule:
@@ -16,4 +16,4 @@ jobs:
1616
uses: ./.github/workflows/run_pytorch_tests.yml
1717
with:
1818
python-version: "3.10"
19-
torch-version: "2.2.*"
19+
torch-version: "2.6.*"

.github/workflows/run_tests_python311_pytorch22.yml renamed to .github/workflows/run_tests_python311_pytorch26.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Python 3.11, Pytorch 2.2
1+
name: Python 3.11, Pytorch 2.6
22
on:
33
workflow_dispatch: # Allow manual triggers
44
schedule:
@@ -16,4 +16,4 @@ jobs:
1616
uses: ./.github/workflows/run_pytorch_tests.yml
1717
with:
1818
python-version: "3.11"
19-
torch-version: "2.2.*"
19+
torch-version: "2.6.*"

.github/workflows/run_tests_python312_pytorch22.yml renamed to .github/workflows/run_tests_python312_pytorch26.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Python 3.12, Pytorch 2.2
1+
name: Python 3.12, Pytorch 2.6
22
on:
33
workflow_dispatch: # Allow manual triggers
44
schedule:
@@ -16,4 +16,4 @@ jobs:
1616
uses: ./.github/workflows/run_pytorch_tests.yml
1717
with:
1818
python-version: "3.12"
19-
torch-version: "2.2.*"
19+
torch-version: "2.6.*"

.github/workflows/run_tests_python39_pytorch22.yml renamed to .github/workflows/run_tests_python39_pytorch26.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Python 3.9, Pytorch 2.2
1+
name: Python 3.9, Pytorch 2.6
22
on:
33
workflow_dispatch: # Allow manual triggers
44
schedule:
@@ -16,4 +16,4 @@ jobs:
1616
uses: ./.github/workflows/run_pytorch_tests.yml
1717
with:
1818
python-version: "3.9"
19-
torch-version: "2.2.*"
19+
torch-version: "2.6.*"

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ ______________________________________________________________________
1717
<a href="#license">License</a>
1818
</p>
1919
<p align="center">
20-
<a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/pytorch-2.2%20%7C%202.3%20%7C%202.4%20%7C%202.5-blue" /></a>
20+
<a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/pytorch-2.3%20%7C%202.4%20%7C%202.5%20%7C%202.6-blue" /></a>
2121
<a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/tensorflow-2.14%20%7C%202.15-blue" /></a>
2222
<a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue" /></a>
2323
<a href="https://github.com/sony/model_optimization/releases"><img src="https://img.shields.io/github/v/release/sony/model_optimization" /></a>
@@ -30,7 +30,7 @@ ________________________________________________________________________________
3030

3131
## <div align="center">Getting Started</div>
3232
### Quick Installation
33-
Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.1 or Tensorflow>=2.14.
33+
Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.3 or Tensorflow>=2.14.
3434
```
3535
pip install model-compression-toolkit
3636
```
@@ -130,12 +130,12 @@ Currently, MCT is being tested on various Python, Pytorch and TensorFlow version
130130
<details id="supported-versions">
131131
<summary>Supported Versions Table</summary>
132132

133-
| | PyTorch 2.2 | PyTorch 2.3 | PyTorch 2.4 | PyTorch 2.5 |
133+
| | PyTorch 2.3 | PyTorch 2.4 | PyTorch 2.5 | PyTorch 2.6 |
134134
|-------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
135-
| Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch25.yml) |
136-
| Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch25.yml) |
137-
| Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml) |
138-
| Python 3.12 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml) |
135+
| Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch25.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch26.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_pytorch26.yml) |
136+
| Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch25.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch26.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_pytorch26.yml) |
137+
| Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch26.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch26.yml) |
138+
| Python 3.12 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch26.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch26.yml) |
139139

140140
| | TensorFlow 2.14 | TensorFlow 2.15 |
141141
|-------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|

model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import torch.nn
1919

2020
from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
21+
22+
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
2123
from model_compression_toolkit.verify_packages import FOUND_ONNX
2224
from model_compression_toolkit.logger import Logger
2325
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
@@ -98,15 +100,17 @@ def export(self, output_names=None) -> None:
98100
model_output = self.model(*model_input) if isinstance(model_input, (list, tuple)) else self.model(
99101
model_input)
100102

103+
input_nodes = [n for n in self.model.node_sort if n.type == DummyPlaceHolder]
104+
input_names = [f"input_{i}" for i in range(len(input_nodes))] if len(input_nodes) > 1 else ["input"]
105+
dynamic_axes = {name: {0: 'batch_size'} for name in input_names}
101106
if output_names is None:
102107
# Determine number of outputs and prepare output_names and dynamic_axes
103108
if isinstance(model_output, (list, tuple)):
104109
output_names = [f"output_{i}" for i in range(len(model_output))]
105-
dynamic_axes = {'input': {0: 'batch_size'}}
106110
dynamic_axes.update({name: {0: 'batch_size'} for name in output_names})
107111
else:
108112
output_names = ['output']
109-
dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
113+
dynamic_axes.update({'output': {0: 'batch_size'}})
110114
else:
111115
if isinstance(model_output, (list, tuple)):
112116
num_of_outputs = len(model_output)
@@ -115,17 +119,15 @@ def export(self, output_names=None) -> None:
115119
assert len(output_names) == num_of_outputs, (f"Mismatch between number of requested output names "
116120
f"({output_names}) and model output count "
117121
f"({num_of_outputs}):\n")
118-
dynamic_axes = {'input': {0: 'batch_size'}}
119122
dynamic_axes.update({name: {0: 'batch_size'} for name in output_names})
120-
121123
if hasattr(self.model, 'metadata'):
122124
onnx_bytes = BytesIO()
123125
torch.onnx.export(self.model,
124126
tuple(model_input) if isinstance(model_input, list) else model_input,
125127
onnx_bytes,
126128
opset_version=self._onnx_opset_version,
127129
verbose=False,
128-
input_names=['input'],
130+
input_names=input_names,
129131
output_names=output_names,
130132
dynamic_axes=dynamic_axes)
131133
onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
@@ -137,7 +139,7 @@ def export(self, output_names=None) -> None:
137139
self.save_model_path,
138140
opset_version=self._onnx_opset_version,
139141
verbose=False,
140-
input_names=['input'],
142+
input_names=input_names,
141143
output_names=output_names,
142144
dynamic_axes=dynamic_axes)
143145

model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
from typing import Callable
16+
from packaging import version
1617

1718
from model_compression_toolkit.verify_packages import FOUND_TORCH
1819
from model_compression_toolkit.exporter.model_exporter.fw_agonstic.quantization_format import QuantizationFormat
@@ -30,6 +31,9 @@
3031
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import FakelyQuantTorchScriptPyTorchExporter
3132
from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
3233

34+
if version.parse(torch.__version__) >= version.parse("2.4"):
35+
DEFAULT_ONNX_OPSET_VERSION = 20
36+
3337
supported_serialization_quantization_export_dict = {
3438
PytorchExportSerializationFormat.TORCHSCRIPT: [QuantizationFormat.FAKELY_QUANT],
3539
PytorchExportSerializationFormat.ONNX: [QuantizationFormat.FAKELY_QUANT, QuantizationFormat.MCTQ]

tests/pytorch_tests/exporter_tests/test_exporting_qat_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_tmp_filepath(self):
5151
return tempfile.mkstemp('.pt')[1]
5252

5353
def load_exported_model(self, filepath):
54-
return torch.load(filepath)
54+
return torch.load(filepath, weights_only=False)
5555

5656
def infer(self, model, images):
5757
return model(images)
@@ -74,15 +74,15 @@ def export_qat_model(self):
7474
# Assert qat_ready can be saved and loaded
7575
_qat_ready_model_tmp_filepath = tempfile.mkstemp('.pt')[1]
7676
torch.save(self.qat_ready, _qat_ready_model_tmp_filepath)
77-
self.qat_ready = torch.load(_qat_ready_model_tmp_filepath)
77+
self.qat_ready = torch.load(_qat_ready_model_tmp_filepath, weights_only=False)
7878

7979
self.final_model = mct.qat.pytorch_quantization_aware_training_finalize_experimental(self.qat_ready)
8080

8181
# Assert final_model can be used for inference, can be saved and loaded
8282
self.final_model(images[0])
8383
_final_model_tmp_filepath = tempfile.mkstemp('.pt')[1]
8484
torch.save(self.final_model, _final_model_tmp_filepath)
85-
self.final_model = torch.load(_final_model_tmp_filepath)
85+
self.final_model = torch.load(_final_model_tmp_filepath, weights_only=False)
8686
self.final_model(images[0])
8787

8888
self.filepath = self.get_tmp_filepath()

tests/pytorch_tests/function_tests/test_fully_quantized_exporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_save_and_load_model(self):
9494
print(f"Float Pytorch .pth Model: {float_model_file}")
9595

9696
model = copy.deepcopy(self.fully_quantized_mbv2)
97-
model.load_state_dict(torch.load(model_file))
97+
model.load_state_dict(torch.load(model_file, weights_only=False))
9898
model.eval()
9999
model(next(self.representative_data_gen()))
100100

tests_pytest/pytorch_tests/integration_tests/exporter/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)