Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch.nn

from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper

from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
from model_compression_toolkit.verify_packages import FOUND_ONNX
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
Expand Down Expand Up @@ -118,14 +120,17 @@ def export(self, output_names=None) -> None:
dynamic_axes = {'input': {0: 'batch_size'}}
dynamic_axes.update({name: {0: 'batch_size'} for name in output_names})

input_names = [n.name for n in self.model.node_sort if n.type == DummyPlaceHolder]
dynamic_axes.update({name: {0: 'batch_size'} for name in input_names})

if hasattr(self.model, 'metadata'):
onnx_bytes = BytesIO()
torch.onnx.export(self.model,
tuple(model_input) if isinstance(model_input, list) else model_input,
onnx_bytes,
opset_version=self._onnx_opset_version,
verbose=False,
input_names=['input'],
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes)
onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
Expand All @@ -137,7 +142,7 @@ def export(self, output_names=None) -> None:
self.save_model_path,
opset_version=self._onnx_opset_version,
verbose=False,
input_names=['input'],
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes)

Expand Down