Skip to content

Commit 2053127

Browse files
author
yarden-sony
committed
minor fix
1 parent e343846 commit 2053127

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,16 @@ def export(self, output_names=None) -> None:
100100
model_output = self.model(*model_input) if isinstance(model_input, (list, tuple)) else self.model(
101101
model_input)
102102

103+
input_names = [n.name for n in self.model.node_sort if n.type == DummyPlaceHolder]
104+
dynamic_axes = {name: {0: 'batch_size'} for name in input_names}
103105
if output_names is None:
104106
# Determine number of outputs and prepare output_names and dynamic_axes
105107
if isinstance(model_output, (list, tuple)):
106108
output_names = [f"output_{i}" for i in range(len(model_output))]
107-
dynamic_axes = {'input': {0: 'batch_size'}}
108109
dynamic_axes.update({name: {0: 'batch_size'} for name in output_names})
109110
else:
110111
output_names = ['output']
111-
dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
112+
dynamic_axes.update({'output': {0: 'batch_size'}})
112113
else:
113114
if isinstance(model_output, (list, tuple)):
114115
num_of_outputs = len(model_output)
@@ -117,12 +118,8 @@ def export(self, output_names=None) -> None:
117118
assert len(output_names) == num_of_outputs, (f"Mismatch between number of requested output names "
118119
f"({output_names}) and model output count "
119120
f"({num_of_outputs}):\n")
120-
dynamic_axes = {'input': {0: 'batch_size'}}
121121
dynamic_axes.update({name: {0: 'batch_size'} for name in output_names})
122-
123-
input_names = [n.name for n in self.model.node_sort if n.type == DummyPlaceHolder]
124-
dynamic_axes.update({name: {0: 'batch_size'} for name in input_names})
125-
122+
dynamic_axes.update({"input": {0: 'batch_size'}})
126123
if hasattr(self.model, 'metadata'):
127124
onnx_bytes = BytesIO()
128125
torch.onnx.export(self.model,

0 commit comments

Comments
 (0)