17
17
import importlib
18
18
import os
19
19
import sys
20
+ import copy
20
21
21
22
from pathlib import Path
22
23
38
39
}
39
40
40
41
42
+ class prepend_to_path :
43
+ def __init__ (self , paths ):
44
+ self ._preprended_paths = paths
45
+ self ._original_path = None
46
+
47
+ def __enter__ (self ):
48
+ self ._original_path = copy .deepcopy (sys .path )
49
+ sys .path = self ._preprended_paths + sys .path
50
+
51
+ def __exit__ (self , type , value , traceback ):
52
+ if self ._original_path is not None :
53
+ sys .path = self ._original_path
54
+
55
+
41
56
def is_sequence (element ):
42
57
return isinstance (element , (list , tuple ))
43
58
@@ -46,15 +61,18 @@ def shapes_arg(values):
46
61
"""Checks that the argument represents a tensor shape or a sequence of tensor shapes"""
47
62
shapes = ast .literal_eval (values )
48
63
if not is_sequence (shapes ):
49
- raise argparse .ArgumentTypeError ('{!r}: must be a sequence' .format (shapes ))
64
+ raise argparse .ArgumentTypeError (
65
+ '{!r}: must be a sequence' .format (shapes ))
50
66
if not all (is_sequence (shape ) for shape in shapes ):
51
67
shapes = (shapes , )
52
68
for shape in shapes :
53
69
if not is_sequence (shape ):
54
- raise argparse .ArgumentTypeError ('{!r}: must be a sequence' .format (shape ))
70
+ raise argparse .ArgumentTypeError (
71
+ '{!r}: must be a sequence' .format (shape ))
55
72
for value in shape :
56
73
if not isinstance (value , int ) or value < 0 :
57
- raise argparse .ArgumentTypeError ('Argument {!r} must be a positive integer' .format (value ))
74
+ raise argparse .ArgumentTypeError (
75
+ 'Argument {!r} must be a positive integer' .format (value ))
58
76
return shapes
59
77
60
78
@@ -72,7 +90,8 @@ def model_parameter(parameter):
72
90
def parse_args ():
73
91
"""Parse input arguments"""
74
92
75
- parser = argparse .ArgumentParser (description = 'Conversion of pretrained models from PyTorch to ONNX' )
93
+ parser = argparse .ArgumentParser (
94
+ description = 'Conversion of pretrained models from PyTorch to ONNX' )
76
95
77
96
parser .add_argument ('--model-name' , type = str , required = True ,
78
97
help = 'Model to convert. May be class name or name of constructor function' )
@@ -96,40 +115,40 @@ def parse_args():
96
115
help = 'Data type for inputs' )
97
116
parser .add_argument ('--conversion-param' , type = model_parameter , default = [], action = 'append' ,
98
117
help = 'Additional parameter for export' )
99
- parser .add_argument ('--opset_version' , type = int , default = 11 , help = 'The ONNX opset version' )
118
+ parser .add_argument ('--opset_version' , type = int ,
119
+ default = 11 , help = 'The ONNX opset version' )
100
120
return parser .parse_args ()
101
121
102
122
103
123
def load_model (model_name , weights , model_paths , module_name , model_params ):
104
124
"""Import model and load pretrained weights"""
105
125
106
- if model_paths :
107
- sys .path .extend (model_paths )
108
-
109
- try :
110
- module = importlib .import_module (module_name )
111
- creator = getattr (module , model_name )
112
- model = creator (** model_params )
113
- except ImportError as err :
114
- if model_paths :
115
- print ('Module {} in {} doesn\' t exist. Check import path and name' .format (
116
- model_name , os .pathsep .join (model_paths )))
117
- else :
118
- print ('Module {} doesn\' t exist. Check if it is installed' .format (model_name ))
119
- sys .exit (err )
120
- except AttributeError as err :
121
- print ('ERROR: Module {} contains no class or function with name {}!'
122
- .format (module_name , model_name ))
123
- sys .exit (err )
124
-
125
- try :
126
- if weights :
127
- model .load_state_dict (torch .load (weights , map_location = 'cpu' ))
128
- except RuntimeError as err :
129
- print ('ERROR: Weights from {} cannot be loaded for model {}! Check matching between model and weights' .format (
130
- weights , model_name ))
131
- sys .exit (err )
132
- return model
126
+ with prepend_to_path (model_paths ):
127
+ try :
128
+ module = importlib .import_module (module_name )
129
+ creator = getattr (module , model_name )
130
+ model = creator (** model_params )
131
+ except ImportError as err :
132
+ if model_paths :
133
+ print ('Module {} in {} doesn\' t exist. Check import path and name' .format (
134
+ model_name , os .pathsep .join (model_paths )))
135
+ else :
136
+ print (
137
+ 'Module {} doesn\' t exist. Check if it is installed' .format (model_name ))
138
+ sys .exit (err )
139
+ except AttributeError as err :
140
+ print ('ERROR: Module {} contains no class or function with name {}!'
141
+ .format (module_name , model_name ))
142
+ sys .exit (err )
143
+
144
+ try :
145
+ if weights :
146
+ model .load_state_dict (torch .load (weights , map_location = 'cpu' ))
147
+ except RuntimeError as err :
148
+ print ('ERROR: Weights from {} cannot be loaded for model {}! Check matching between model and weights' .format (
149
+ weights , model_name ))
150
+ sys .exit (err )
151
+ return model
133
152
134
153
135
154
@torch .no_grad ()
0 commit comments