Skip to content

Commit ef62bb3

Browse files
committed
[model tools]: fix adding to path order
1 parent f38c439 commit ef62bb3

File tree

1 file changed

+51
-32
lines changed

1 file changed

+51
-32
lines changed

tools/model_tools/src/openvino/model_zoo/internal_scripts/pytorch_to_onnx.py

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import importlib
1818
import os
1919
import sys
20+
import copy
2021

2122
from pathlib import Path
2223

@@ -38,6 +39,20 @@
3839
}
3940

4041

42+
class prepend_to_path:
43+
def __init__(self, paths):
44+
self._preprended_paths = paths
45+
self._original_path = None
46+
47+
def __enter__(self):
48+
self._original_path = copy.deepcopy(sys.path)
49+
sys.path = self._preprended_paths + sys.path
50+
51+
def __exit__(self, type, value, traceback):
52+
if self._original_path is not None:
53+
sys.path = self._original_path
54+
55+
4156
def is_sequence(element):
4257
return isinstance(element, (list, tuple))
4358

@@ -46,15 +61,18 @@ def shapes_arg(values):
4661
"""Checks that the argument represents a tensor shape or a sequence of tensor shapes"""
4762
shapes = ast.literal_eval(values)
4863
if not is_sequence(shapes):
49-
raise argparse.ArgumentTypeError('{!r}: must be a sequence'.format(shapes))
64+
raise argparse.ArgumentTypeError(
65+
'{!r}: must be a sequence'.format(shapes))
5066
if not all(is_sequence(shape) for shape in shapes):
5167
shapes = (shapes, )
5268
for shape in shapes:
5369
if not is_sequence(shape):
54-
raise argparse.ArgumentTypeError('{!r}: must be a sequence'.format(shape))
70+
raise argparse.ArgumentTypeError(
71+
'{!r}: must be a sequence'.format(shape))
5572
for value in shape:
5673
if not isinstance(value, int) or value < 0:
57-
raise argparse.ArgumentTypeError('Argument {!r} must be a positive integer'.format(value))
74+
raise argparse.ArgumentTypeError(
75+
'Argument {!r} must be a positive integer'.format(value))
5876
return shapes
5977

6078

@@ -72,7 +90,8 @@ def model_parameter(parameter):
7290
def parse_args():
7391
"""Parse input arguments"""
7492

75-
parser = argparse.ArgumentParser(description='Conversion of pretrained models from PyTorch to ONNX')
93+
parser = argparse.ArgumentParser(
94+
description='Conversion of pretrained models from PyTorch to ONNX')
7695

7796
parser.add_argument('--model-name', type=str, required=True,
7897
help='Model to convert. May be class name or name of constructor function')
@@ -96,40 +115,40 @@ def parse_args():
96115
help='Data type for inputs')
97116
parser.add_argument('--conversion-param', type=model_parameter, default=[], action='append',
98117
help='Additional parameter for export')
99-
parser.add_argument('--opset_version', type=int, default=11, help='The ONNX opset version')
118+
parser.add_argument('--opset_version', type=int,
119+
default=11, help='The ONNX opset version')
100120
return parser.parse_args()
101121

102122

103123
def load_model(model_name, weights, model_paths, module_name, model_params):
104124
"""Import model and load pretrained weights"""
105125

106-
if model_paths:
107-
sys.path.extend(model_paths)
108-
109-
try:
110-
module = importlib.import_module(module_name)
111-
creator = getattr(module, model_name)
112-
model = creator(**model_params)
113-
except ImportError as err:
114-
if model_paths:
115-
print('Module {} in {} doesn\'t exist. Check import path and name'.format(
116-
model_name, os.pathsep.join(model_paths)))
117-
else:
118-
print('Module {} doesn\'t exist. Check if it is installed'.format(model_name))
119-
sys.exit(err)
120-
except AttributeError as err:
121-
print('ERROR: Module {} contains no class or function with name {}!'
122-
.format(module_name, model_name))
123-
sys.exit(err)
124-
125-
try:
126-
if weights:
127-
model.load_state_dict(torch.load(weights, map_location='cpu'))
128-
except RuntimeError as err:
129-
print('ERROR: Weights from {} cannot be loaded for model {}! Check matching between model and weights'.format(
130-
weights, model_name))
131-
sys.exit(err)
132-
return model
126+
with prepend_to_path(model_paths):
127+
try:
128+
module = importlib.import_module(module_name)
129+
creator = getattr(module, model_name)
130+
model = creator(**model_params)
131+
except ImportError as err:
132+
if model_paths:
133+
print('Module {} in {} doesn\'t exist. Check import path and name'.format(
134+
model_name, os.pathsep.join(model_paths)))
135+
else:
136+
print(
137+
'Module {} doesn\'t exist. Check if it is installed'.format(model_name))
138+
sys.exit(err)
139+
except AttributeError as err:
140+
print('ERROR: Module {} contains no class or function with name {}!'
141+
.format(module_name, model_name))
142+
sys.exit(err)
143+
144+
try:
145+
if weights:
146+
model.load_state_dict(torch.load(weights, map_location='cpu'))
147+
except RuntimeError as err:
148+
print('ERROR: Weights from {} cannot be loaded for model {}! Check matching between model and weights'.format(
149+
weights, model_name))
150+
sys.exit(err)
151+
return model
133152

134153

135154
@torch.no_grad()

0 commit comments

Comments
 (0)