Skip to content

Commit c12bb90

Browse files
committed
styling
1 parent 17f914f commit c12bb90

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

mltu/torch/callbacks.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -250,20 +250,21 @@ def on_train_end(self, logs=None):
250250

251251
class Model2onnx(Callback):
252252
"""Converts the model from PyTorch to ONNX format after training."""
253+
253254
def __init__(
254-
self,
255-
saved_model_path: str,
256-
input_shape: tuple,
257-
export_params: bool = True,
258-
opset_version: int = 14,
259-
do_constant_folding: bool = True,
260-
input_names: list = ['input'],
261-
output_names: list = ['output'],
262-
dynamic_axes: dict = {'input': {0 : 'batch_size'},
263-
'output': {0 : 'batch_size'}},
264-
verbose: bool = False,
265-
metadata: dict = None,
266-
) -> None:
255+
self,
256+
saved_model_path: str,
257+
input_shape: tuple,
258+
export_params: bool = True,
259+
opset_version: int = 14,
260+
do_constant_folding: bool = True,
261+
input_names: list = ['input'],
262+
output_names: list = ['output'],
263+
dynamic_axes: dict = {'input': {0: 'batch_size'},
264+
'output': {0: 'batch_size'}},
265+
verbose: bool = False,
266+
metadata: dict = None,
267+
) -> None:
267268
""" Converts the model from PyTorch to ONNX format after training.
268269
269270
Args:

0 commit comments

Comments
 (0)