Skip to content

Commit e343846

Browse files
author
yarden-sony
committed
dynamic batch axis in input nodes
1 parent 7a6da34 commit e343846

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import torch.nn
1919

2020
from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
21+
22+
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
2123
from model_compression_toolkit.verify_packages import FOUND_ONNX
2224
from model_compression_toolkit.logger import Logger
2325
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
@@ -118,14 +120,17 @@ def export(self, output_names=None) -> None:
118120
dynamic_axes = {'input': {0: 'batch_size'}}
119121
dynamic_axes.update({name: {0: 'batch_size'} for name in output_names})
120122

123+
input_names = [n.name for n in self.model.node_sort if n.type == DummyPlaceHolder]
124+
dynamic_axes.update({name: {0: 'batch_size'} for name in input_names})
125+
121126
if hasattr(self.model, 'metadata'):
122127
onnx_bytes = BytesIO()
123128
torch.onnx.export(self.model,
124129
tuple(model_input) if isinstance(model_input, list) else model_input,
125130
onnx_bytes,
126131
opset_version=self._onnx_opset_version,
127132
verbose=False,
128-
input_names=['input'],
133+
input_names=input_names,
129134
output_names=output_names,
130135
dynamic_axes=dynamic_axes)
131136
onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
@@ -137,7 +142,7 @@ def export(self, output_names=None) -> None:
137142
self.save_model_path,
138143
opset_version=self._onnx_opset_version,
139144
verbose=False,
140-
input_names=['input'],
145+
input_names=input_names,
141146
output_names=output_names,
142147
dynamic_axes=dynamic_axes)
143148

0 commit comments

Comments
 (0)