1818import torch .nn
1919
2020from mct_quantizers import PytorchActivationQuantizationHolder , PytorchQuantizationWrapper
21+
22+ from model_compression_toolkit .core .pytorch .reader .node_holders import DummyPlaceHolder
2123from model_compression_toolkit .verify_packages import FOUND_ONNX
2224from model_compression_toolkit .logger import Logger
2325from model_compression_toolkit .core .pytorch .utils import to_torch_tensor
@@ -118,14 +120,17 @@ def export(self, output_names=None) -> None:
118120 dynamic_axes = {'input' : {0 : 'batch_size' }}
119121 dynamic_axes .update ({name : {0 : 'batch_size' } for name in output_names })
120122
123+ input_names = [n .name for n in self .model .node_sort if n .type == DummyPlaceHolder ]
124+ dynamic_axes .update ({name : {0 : 'batch_size' } for name in input_names })
125+
121126 if hasattr (self .model , 'metadata' ):
122127 onnx_bytes = BytesIO ()
123128 torch .onnx .export (self .model ,
124129 tuple (model_input ) if isinstance (model_input , list ) else model_input ,
125130 onnx_bytes ,
126131 opset_version = self ._onnx_opset_version ,
127132 verbose = False ,
128- input_names = [ 'input' ] ,
133+ input_names = input_names ,
129134 output_names = output_names ,
130135 dynamic_axes = dynamic_axes )
131136 onnx_model = onnx .load_from_string (onnx_bytes .getvalue ())
@@ -137,7 +142,7 @@ def export(self, output_names=None) -> None:
137142 self .save_model_path ,
138143 opset_version = self ._onnx_opset_version ,
139144 verbose = False ,
140- input_names = [ 'input' ] ,
145+ input_names = input_names ,
141146 output_names = output_names ,
142147 dynamic_axes = dynamic_axes )
143148
0 commit comments