Skip to content

Commit 8188f2d

Browse files
committed
Add API to choose output_names when exporting using onnx
1 parent 6e6350f commit 8188f2d

4 files changed

Lines changed: 61 additions & 9 deletions

File tree

model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from typing import Callable
15+
from typing import Callable, Optional, List
1616
from io import BytesIO
1717

1818
import torch.nn
@@ -64,11 +64,14 @@ def __init__(self,
6464
self._use_onnx_custom_quantizer_ops = use_onnx_custom_quantizer_ops
6565
self._onnx_opset_version = onnx_opset_version
6666

67-
def export(self, output_names=None) -> None:
67+
def export(self, output_names: Optional[List[str]] = None) -> None:
6868
"""
6969
Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
7070
(namely, weights that are in fake-quant format) and fake-quant layers for the activations.
7171
72+
Args:
73+
output_names (Optional[List[str]]): Optional list of output node names for export compatibility.
74+
7275
Returns:
7376
Fake-quant PyTorch model.
7477
"""
@@ -130,6 +133,8 @@ def export(self, output_names=None) -> None:
130133
output_names = ['output']
131134
dynamic_axes.update({'output': {0: 'batch_size'}})
132135
else:
136+
assert isinstance(output_names, list), \
137+
f"`output_names` must be a list, but got {type(output_names).__name__}"
133138
if isinstance(model_output, (list, tuple)):
134139
num_of_outputs = len(model_output)
135140
else:

model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self,
4949
save_model_path,
5050
repr_dataset)
5151

52-
def export(self) -> None:
52+
def export(self, output_names=None) -> None:
5353
"""
5454
Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
5555
(namely, weights that are in fake-quant format) and fake-quant layers for the activations.

model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from typing import Callable
15+
from typing import Callable, Optional, List
1616
from packaging import version
1717

1818
from model_compression_toolkit.verify_packages import FOUND_TORCH
@@ -49,7 +49,8 @@ def pytorch_export_model(model: torch.nn.Module,
4949
is_layer_exportable_fn: Callable = is_pytorch_layer_exportable,
5050
serialization_format: PytorchExportSerializationFormat = PytorchExportSerializationFormat.ONNX,
5151
quantization_format: QuantizationFormat = QuantizationFormat.MCTQ,
52-
onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION) -> None:
52+
onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION,
53+
output_names: Optional[List[str]] = None) -> None:
5354
"""
5455
Export a PyTorch quantized model to a torchscript or onnx model.
5556
The model will be saved to the path in save_model_path.
@@ -67,6 +68,8 @@ def pytorch_export_model(model: torch.nn.Module,
6768
PytorchExportSerializationFormat.ONNX).
6869
quantization_format: Format of how quantizers are exported (fakely-quant, int8, MCTQ quantizers).
6970
onnx_opset_version: ONNX opset version to use for exported ONNX model.
71+
output_names (Optional[List[str]]): Optional list of output node names for export compatibility.
72+
This argument is relevant only when using FakelyQuantONNXPyTorchExporter.
7073
7174
"""
7275
# Ensure 'metadata' is available directly on the model, if present in submodules
@@ -109,7 +112,7 @@ def pytorch_export_model(model: torch.nn.Module,
109112
f'Unsupported serialization {serialization_format} was used to export Pytorch model.'
110113
f' Please see API for supported formats.') # pragma: no cover
111114

112-
exporter.export()
115+
exporter.export(output_names=output_names)
113116

114117
else:
115118
def pytorch_export_model(*args, **kwargs):

tests_pytest/pytorch_tests/e2e_tests/test_exporter.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,14 +199,28 @@ def _run_mct_qat(self, float_model, rep_dataset, abits, a_qmethod):
199199
quantized_model = pytorch_quantization_aware_training_finalize_experimental(qat_ready_model)
200200
return quantized_model
201201

202-
def _run_exporter(self, quantized_model, rep_dataset, quantization_format):
202+
def _run_exporter(self, quantized_model, rep_dataset, quantization_format, output_names=None):
203203
pytorch_export_model(quantized_model,
204204
save_model_path=self.onnx_file,
205205
repr_dataset=rep_dataset,
206206
serialization_format=PytorchExportSerializationFormat.ONNX,
207-
quantization_format=quantization_format)
207+
quantization_format=quantization_format,
208+
output_names=output_names)
208209

209-
return onnx_reader(self.onnx_file, quantized_model.linear_activation_holder_quantizer.activation_holder_quantizer)
210+
return onnx_reader(self.onnx_file,
211+
quantized_model.linear_activation_holder_quantizer.activation_holder_quantizer)
212+
213+
def _assert_outputs_names(self, output_names):
214+
model = onnx.load(self.onnx_file)
215+
exported_output_names = [output.name for output in model.graph.output]
216+
217+
if output_names is None:
218+
if len(exported_output_names) == 1:
219+
output_names = ['output']
220+
else:
221+
output_names = [f"output_{i}" for i in range(len(exported_output_names))]
222+
assert all(name in exported_output_names for name in output_names)
223+
assert len(output_names) == len(exported_output_names)
210224

211225
def _assert_outputs_match(self, quantized_model, rep_dataset, quantization_format, tol=1e-8):
212226
pass
@@ -304,6 +318,17 @@ def test_mct_ptq_and_exporter_mctq(self, w_qmethod, abits, a_qmethod, tol):
304318
self._assert_quant_params_match(quantized_model, onnx_model_dict, a_qmethod, w_qmethod)
305319
self._assert_outputs_match(quantized_model, self.representative_dataset(1), QuantizationFormat.MCTQ, tol=tol)
306320

321+
@pytest.mark.parametrize('w_qmethod', [mctq.QuantizationMethod.POWER_OF_TWO])
322+
@pytest.mark.parametrize('a_qmethod', [mctq.QuantizationMethod.SYMMETRIC])
323+
@pytest.mark.parametrize('abits', [8, 16])
324+
@pytest.mark.parametrize('output_names', [None, ['x']])
325+
def test_mct_ptq_exporter_mctq_output_names(self, w_qmethod, abits, a_qmethod, output_names):
326+
# set_seed(13)
327+
quantized_model = self._run_mct(self.get_model(), self.representative_dataset(1), abits, a_qmethod, w_qmethod)
328+
onnx_model_dict = self._run_exporter(quantized_model, self.representative_dataset(1), QuantizationFormat.MCTQ,
329+
output_names=output_names)
330+
self._assert_outputs_names(output_names=output_names)
331+
307332
@pytest.mark.parametrize('abits, tol', ([8, 1e-4], [16, 1e-2]))
308333
def test_mct_ptq_and_exporter_fq(self, abits, tol):
309334
quantized_model = self._run_mct(self.get_model(), self.representative_dataset(1), abits, mctq.QuantizationMethod.POWER_OF_TWO)
@@ -363,6 +388,25 @@ def forward(self, x):
363388
self._run_exporter(quantized_model, self.representative_dataset(1), QuantizationFormat.MCTQ)
364389
self._assert_outputs_match(quantized_model, self.representative_dataset(1), QuantizationFormat.MCTQ)
365390

391+
@pytest.mark.parametrize('abits', [8, 16])
392+
@pytest.mark.parametrize('output_names', [None, ['x', 'y']])
393+
def test_multi_output_names_mct_and_exporter_mctq(self, abits, output_names):
394+
class MultiOutputModel(torch.nn.Module):
395+
def __init__(self, in_channels, out_channels):
396+
super().__init__()
397+
self.linear = torch.nn.Linear(in_channels, out_channels)
398+
self.linear_y = torch.nn.Linear(in_channels, out_channels)
399+
400+
def forward(self, x):
401+
return self.linear(x), self.linear_y(x)
402+
403+
quantized_model = self._run_mct(MultiOutputModel(self.in_channels, self.out_channels),
404+
self.representative_dataset(1),
405+
abits, mctq.QuantizationMethod.POWER_OF_TWO)
406+
self._run_exporter(quantized_model, self.representative_dataset(1), QuantizationFormat.MCTQ,
407+
output_names=output_names)
408+
self._assert_outputs_names(output_names=output_names)
409+
366410
@pytest.mark.parametrize('abits', [8, 16])
367411
def test_multi_input_output_mct_and_exporter_mctq(self, abits):
368412
class MultiInputOutputModel(torch.nn.Module):

0 commit comments

Comments
 (0)