@@ -100,15 +100,16 @@ def export(self, output_names=None) -> None:
100100 model_output = self .model (* model_input ) if isinstance (model_input , (list , tuple )) else self .model (
101101 model_input )
102102
103+ input_names = [n .name for n in self .model .node_sort if n .type == DummyPlaceHolder ]
104+ dynamic_axes = {name : {0 : 'batch_size' } for name in input_names }
103105 if output_names is None :
104106 # Determine number of outputs and prepare output_names and dynamic_axes
105107 if isinstance (model_output , (list , tuple )):
106108 output_names = [f"output_{ i } " for i in range (len (model_output ))]
107- dynamic_axes = {'input' : {0 : 'batch_size' }}
108109 dynamic_axes .update ({name : {0 : 'batch_size' } for name in output_names })
109110 else :
110111 output_names = ['output' ]
111- dynamic_axes = { 'input' : { 0 : 'batch_size' }, ' output' : {0 : 'batch_size' }}
112+ dynamic_axes . update ({ ' output' : {0 : 'batch_size' }})
112113 else :
113114 if isinstance (model_output , (list , tuple )):
114115 num_of_outputs = len (model_output )
@@ -117,12 +118,8 @@ def export(self, output_names=None) -> None:
117118 assert len (output_names ) == num_of_outputs , (f"Mismatch between number of requested output names "
118119 f"({ output_names } ) and model output count "
119120 f"({ num_of_outputs } ):\n " )
120- dynamic_axes = {'input' : {0 : 'batch_size' }}
121121 dynamic_axes .update ({name : {0 : 'batch_size' } for name in output_names })
122-
123- input_names = [n .name for n in self .model .node_sort if n .type == DummyPlaceHolder ]
124- dynamic_axes .update ({name : {0 : 'batch_size' } for name in input_names })
125-
122+ dynamic_axes .update ({"input" : {0 : 'batch_size' }})
126123 if hasattr (self .model , 'metadata' ):
127124 onnx_bytes = BytesIO ()
128125 torch .onnx .export (self .model ,
0 commit comments