Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch.nn

from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper

from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
from model_compression_toolkit.verify_packages import FOUND_ONNX
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
Expand Down Expand Up @@ -98,15 +100,17 @@ def export(self, output_names=None) -> None:
model_output = self.model(*model_input) if isinstance(model_input, (list, tuple)) else self.model(
model_input)

input_nodes = [n for n in self.model.node_sort if n.type == DummyPlaceHolder]
input_names = [f"input_{i}" for i in range(len(input_nodes))] if len(input_nodes) > 1 else ["input"]
dynamic_axes = {name: {0: 'batch_size'} for name in input_names}
if output_names is None:
# Determine number of outputs and prepare output_names and dynamic_axes
if isinstance(model_output, (list, tuple)):
output_names = [f"output_{i}" for i in range(len(model_output))]
dynamic_axes = {'input': {0: 'batch_size'}}
dynamic_axes.update({name: {0: 'batch_size'} for name in output_names})
else:
output_names = ['output']
dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
dynamic_axes.update({'output': {0: 'batch_size'}})
else:
if isinstance(model_output, (list, tuple)):
num_of_outputs = len(model_output)
Expand All @@ -115,17 +119,15 @@ def export(self, output_names=None) -> None:
assert len(output_names) == num_of_outputs, (f"Mismatch between number of requested output names "
f"({output_names}) and model output count "
f"({num_of_outputs}):\n")
dynamic_axes = {'input': {0: 'batch_size'}}
dynamic_axes.update({name: {0: 'batch_size'} for name in output_names})

if hasattr(self.model, 'metadata'):
onnx_bytes = BytesIO()
torch.onnx.export(self.model,
tuple(model_input) if isinstance(model_input, list) else model_input,
onnx_bytes,
opset_version=self._onnx_opset_version,
verbose=False,
input_names=['input'],
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes)
onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
Expand All @@ -137,7 +139,7 @@ def export(self, output_names=None) -> None:
self.save_model_path,
opset_version=self._onnx_opset_version,
verbose=False,
input_names=['input'],
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from model_compression_toolkit.logger import Logger


DEFAULT_ONNX_OPSET_VERSION = 15
DEFAULT_ONNX_OPSET_VERSION = 20


if FOUND_TORCH:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
import torch
import torch.nn as nn

from model_compression_toolkit.core import QuantizationConfig
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModel
from model_compression_toolkit.core.pytorch.utils import set_model
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
FakelyQuantONNXPyTorchExporter
from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import DEFAULT_ONNX_OPSET_VERSION
from model_compression_toolkit.exporter.model_wrapper import is_pytorch_layer_exportable
from tests_pytest.pytorch_tests.torch_test_util.torch_test_mixin import BaseTorchIntegrationTest


class SingleOutputModel(nn.Module):
Expand All @@ -42,7 +45,16 @@ def forward(self, x):
return self.linear(x), x, x + 2


class TestONNXExporter:
class MultipleInputsModel(nn.Module):
def __init__(self):
super(MultipleInputsModel, self).__init__()
self.linear = nn.Linear(8, 5)

def forward(self, input1, input2):
return self.linear(input1) + self.linear(input2)


class TestONNXExporter(BaseTorchIntegrationTest):
test_input_1 = None
test_expected_1 = ['output']

Expand All @@ -59,27 +71,38 @@ class TestONNXExporter:
test_expected_5 = ("Mismatch between number of requested output names (['out', 'out_11', 'out_22', 'out_33']) and "
"model output count (3):\n")

def representative_data_gen(self, shape=(3, 8, 8), num_inputs=1, batch_size=2, num_iter=1):
for _ in range(num_iter):
yield [torch.randn(batch_size, *shape)] * num_inputs
def representative_data_gen(self, num_inputs=1):
batch_size, num_iter, shape = 2, 1, (3, 8, 8)

def data_gen():
for _ in range(num_iter):
yield [torch.randn(batch_size, *shape)] * num_inputs

def get_exporter(self, model, save_model_path):
return FakelyQuantONNXPyTorchExporter(model,
is_pytorch_layer_exportable,
save_model_path,
self.representative_data_gen,
onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION)
return data_gen

def export_model(self, model, save_model_path, output_names, expected_output_names):
exporter = self.get_exporter(model, save_model_path)
def get_pytorch_model(self, model, data_generator, minimal_tpc):
qc = QuantizationConfig()
graph = self.run_graph_preparation(model=model, datagen=data_generator, tpc=minimal_tpc,
quant_config=qc)
pytorch_model = FloatPyTorchModel(graph=graph)
return pytorch_model

def export_model(self, model, save_model_path, data_generator, output_names=None):
exporter = FakelyQuantONNXPyTorchExporter(model,
is_pytorch_layer_exportable,
save_model_path,
data_generator,
onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION)

exporter.export(output_names)

assert save_model_path.exists(), "ONNX file was not created"
assert save_model_path.stat().st_size > 0, "ONNX file is empty"

# Load the ONNX model and check outputs
onnx_model = onnx.load(str(save_model_path))
return onnx_model

def validate_outputs(self, onnx_model, expected_output_names):
outputs = onnx_model.graph.output

# Check number of outputs
Expand All @@ -98,23 +121,35 @@ def export_model(self, model, save_model_path, output_names, expected_output_nam
(MultipleOutputModel(), test_input_3, test_expected_3),
(MultipleOutputModel(), test_input_4, test_expected_4),
])
def test_output_model_name(self, tmp_path, model, output_names, expected_output_names):
def test_output_model_name(self, tmp_path, model, output_names, expected_output_names, minimal_tpc):
save_model_path = tmp_path / "model.onnx"
set_model(model)

self.export_model(model, save_model_path, output_names=output_names,
expected_output_names=expected_output_names)
data_generator = self.representative_data_gen(num_inputs=1)
pytorch_model = self.get_pytorch_model(model, data_generator, minimal_tpc)
onnx_model = self.export_model(pytorch_model, save_model_path, data_generator, output_names=output_names)
self.validate_outputs(onnx_model, expected_output_names)

@pytest.mark.parametrize(
("model", "output_names", "expected_output_names"), [
(MultipleOutputModel(), test_input_5, test_expected_5),
])
def test_wrong_number_output_model_name(self, tmp_path, model, output_names, expected_output_names):
def test_wrong_number_output_model_name(self, tmp_path, model, output_names, expected_output_names, minimal_tpc):
save_model_path = tmp_path / "model.onnx"
set_model(model)

data_generator = self.representative_data_gen(num_inputs=1)
pytorch_model = self.get_pytorch_model(model, data_generator, minimal_tpc)
try:
self.export_model(model, save_model_path, output_names=output_names,
expected_output_names=expected_output_names)
onnx_model = self.export_model(pytorch_model, save_model_path, data_generator, output_names=output_names)
self.validate_outputs(onnx_model, expected_output_names)
except Exception as e:
assert expected_output_names == str(e)

def test_multiple_inputs(self, minimal_tpc, tmp_path):
"""
Test that model with multiple inputs is exported to onnx file properly and that the exported onnx model
has all expected inputs.
"""
save_model_path = tmp_path / "model.onnx"
model = MultipleInputsModel()
data_generator = self.representative_data_gen(num_inputs=2)
pytorch_model = self.get_pytorch_model(model, data_generator, minimal_tpc)
onnx_model = self.export_model(pytorch_model, save_model_path, data_generator)
assert [_input.name for _input in onnx_model.graph.input] == ["input1", "input2"]
Loading