Skip to content

Commit 7a6da34

Browse files
authored
Export onnx output names list (#1430)
* Add onnx exporter output names
1 parent 8a6fa11 commit 7a6da34

File tree

2 files changed

+148
-8
lines changed

2 files changed

+148
-8
lines changed

model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
2525
from mct_quantizers import pytorch_quantizers
2626

27-
2827
if FOUND_ONNX:
2928
import onnx
3029
from mct_quantizers.pytorch.metadata import add_onnx_metadata
3130

31+
3232
class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
3333
"""
3434
Exporter for fakely-quant PyTorch models.
@@ -63,7 +63,7 @@ def __init__(self,
6363
self._use_onnx_custom_quantizer_ops = use_onnx_custom_quantizer_ops
6464
self._onnx_opset_version = onnx_opset_version
6565

66-
def export(self) -> None:
66+
def export(self, output_names=None) -> None:
6767
"""
6868
Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
6969
(namely, weights that are in fake-quant format) and fake-quant layers for the activations.
@@ -95,6 +95,28 @@ def export(self) -> None:
9595
Logger.info(f"Exporting fake-quant onnx model: {self.save_model_path}")
9696

9797
model_input = to_torch_tensor(next(self.repr_dataset()))
98+
model_output = self.model(*model_input) if isinstance(model_input, (list, tuple)) else self.model(
99+
model_input)
100+
101+
if output_names is None:
102+
# Determine number of outputs and prepare output_names and dynamic_axes
103+
if isinstance(model_output, (list, tuple)):
104+
output_names = [f"output_{i}" for i in range(len(model_output))]
105+
dynamic_axes = {'input': {0: 'batch_size'}}
106+
dynamic_axes.update({name: {0: 'batch_size'} for name in output_names})
107+
else:
108+
output_names = ['output']
109+
dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
110+
else:
111+
if isinstance(model_output, (list, tuple)):
112+
num_of_outputs = len(model_output)
113+
else:
114+
num_of_outputs = 1
115+
assert len(output_names) == num_of_outputs, (f"Mismatch between number of requested output names "
116+
f"({output_names}) and model output count "
117+
f"({num_of_outputs}):\n")
118+
dynamic_axes = {'input': {0: 'batch_size'}}
119+
dynamic_axes.update({name: {0: 'batch_size'} for name in output_names})
98120

99121
if hasattr(self.model, 'metadata'):
100122
onnx_bytes = BytesIO()
@@ -104,9 +126,8 @@ def export(self) -> None:
104126
opset_version=self._onnx_opset_version,
105127
verbose=False,
106128
input_names=['input'],
107-
output_names=['output'],
108-
dynamic_axes={'input': {0: 'batch_size'},
109-
'output': {0: 'batch_size'}})
129+
output_names=output_names,
130+
dynamic_axes=dynamic_axes)
110131
onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
111132
onnx_model = add_onnx_metadata(onnx_model, self.model.metadata)
112133
onnx.save_model(onnx_model, self.save_model_path)
@@ -117,9 +138,8 @@ def export(self) -> None:
117138
opset_version=self._onnx_opset_version,
118139
verbose=False,
119140
input_names=['input'],
120-
output_names=['output'],
121-
dynamic_axes={'input': {0: 'batch_size'},
122-
'output': {0: 'batch_size'}})
141+
output_names=output_names,
142+
dynamic_axes=dynamic_axes)
123143

124144
for layer in self.model.children():
125145
# Set disable for reuse for weight quantizers if quantizer was reused
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import onnx
16+
import pytest
17+
import torch
18+
import torch.nn as nn
19+
20+
from model_compression_toolkit.core.pytorch.utils import set_model
21+
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
22+
FakelyQuantONNXPyTorchExporter
23+
from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import DEFAULT_ONNX_OPSET_VERSION
24+
from model_compression_toolkit.exporter.model_wrapper import is_pytorch_layer_exportable
25+
26+
27+
class SingleOutputModel(nn.Module):
28+
def __init__(self):
29+
super(SingleOutputModel, self).__init__()
30+
self.linear = nn.Linear(8, 5)
31+
32+
def forward(self, x):
33+
return self.linear(x)
34+
35+
36+
class MultipleOutputModel(nn.Module):
37+
def __init__(self):
38+
super(MultipleOutputModel, self).__init__()
39+
self.linear = nn.Linear(8, 5)
40+
41+
def forward(self, x):
42+
return self.linear(x), x, x + 2
43+
44+
45+
class TestONNXExporter:
46+
test_input_1 = None
47+
test_expected_1 = ['output']
48+
49+
test_input_2 = ['output_2']
50+
test_expected_2 = ['output_2']
51+
52+
test_input_3 = None
53+
test_expected_3 = ['output_0', 'output_1', 'output_2']
54+
55+
test_input_4 = ['out', 'out_11', 'out_22']
56+
test_expected_4 = ['out', 'out_11', 'out_22']
57+
58+
test_input_5 = ['out', 'out_11', 'out_22', 'out_33']
59+
test_expected_5 = ("Mismatch between number of requested output names (['out', 'out_11', 'out_22', 'out_33']) and "
60+
"model output count (3):\n")
61+
62+
def representative_data_gen(self, shape=(3, 8, 8), num_inputs=1, batch_size=2, num_iter=1):
63+
for _ in range(num_iter):
64+
yield [torch.randn(batch_size, *shape)] * num_inputs
65+
66+
def get_exporter(self, model, save_model_path):
67+
return FakelyQuantONNXPyTorchExporter(model,
68+
is_pytorch_layer_exportable,
69+
save_model_path,
70+
self.representative_data_gen,
71+
onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION)
72+
73+
def export_model(self, model, save_model_path, output_names, expected_output_names):
74+
exporter = self.get_exporter(model, save_model_path)
75+
76+
exporter.export(output_names)
77+
78+
assert save_model_path.exists(), "ONNX file was not created"
79+
assert save_model_path.stat().st_size > 0, "ONNX file is empty"
80+
81+
# Load the ONNX model and check outputs
82+
onnx_model = onnx.load(str(save_model_path))
83+
outputs = onnx_model.graph.output
84+
85+
# Check number of outputs
86+
assert len(outputs) == len(
87+
expected_output_names), f"Expected {len(expected_output_names)} output, but found {len(outputs)}"
88+
89+
found_output_names = [output.name for output in outputs]
90+
assert found_output_names == expected_output_names, (
91+
f"Expected output name '{expected_output_names}' found {found_output_names}"
92+
)
93+
94+
@pytest.mark.parametrize(
95+
("model", "output_names", "expected_output_names"), [
96+
(SingleOutputModel(), test_input_1, test_expected_1),
97+
(SingleOutputModel(), test_input_2, test_expected_2),
98+
(MultipleOutputModel(), test_input_3, test_expected_3),
99+
(MultipleOutputModel(), test_input_4, test_expected_4),
100+
])
101+
def test_output_model_name(self, tmp_path, model, output_names, expected_output_names):
102+
save_model_path = tmp_path / "model.onnx"
103+
set_model(model)
104+
105+
self.export_model(model, save_model_path, output_names=output_names,
106+
expected_output_names=expected_output_names)
107+
108+
@pytest.mark.parametrize(
109+
("model", "output_names", "expected_output_names"), [
110+
(MultipleOutputModel(), test_input_5, test_expected_5),
111+
])
112+
def test_wrong_number_output_model_name(self, tmp_path, model, output_names, expected_output_names):
113+
save_model_path = tmp_path / "model.onnx"
114+
set_model(model)
115+
116+
try:
117+
self.export_model(model, save_model_path, output_names=output_names,
118+
expected_output_names=expected_output_names)
119+
except Exception as e:
120+
assert expected_output_names == str(e)

0 commit comments

Comments
 (0)