Skip to content

Commit bf03f50

Browse files
leramorozovaJulia Kamelina
andauthored
Cvs 62354/adopt download engine for onezoo (#3087)
* add sha256 * more flexible file source processing * fix linter * fix linter * remove whitespaces Co-authored-by: Julia Kamelina <[email protected]> * remove whitespaces Co-authored-by: Julia Kamelina <[email protected]> * remove useless dosc Co-authored-by: Julia Kamelina <[email protected]>
1 parent c2796b9 commit bf03f50

File tree

4 files changed

+121
-86
lines changed

4 files changed

+121
-86
lines changed

tools/model_tools/src/openvino/model_zoo/_configuration.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class Model:
5151
def __init__(
5252
self, name, subdirectory, files, postprocessing, mo_args, framework,
5353
description, license_url, precisions, quantization_output_precisions,
54-
task_type, conversion_to_onnx_args, composite_model_name
54+
task_type, conversion_to_onnx_args, converter_to_onnx, composite_model_name
5555
):
5656
self.name = name
5757
self.subdirectory = subdirectory
@@ -65,12 +65,17 @@ def __init__(
6565
self.quantization_output_precisions = quantization_output_precisions
6666
self.task_type = task_type
6767
self.conversion_to_onnx_args = conversion_to_onnx_args
68-
self.converter_to_onnx = _common.KNOWN_FRAMEWORKS[framework]
68+
self.converter_to_onnx = converter_to_onnx
6969
self.composite_model_name = composite_model_name
7070
self.model_stages = {}
7171

7272
@classmethod
73-
def deserialize(cls, model, name, subdirectory, composite_model_name):
73+
def deserialize(cls, model, name, subdirectory, composite_model_name, known_frameworks=None, known_task_types=None):
74+
if known_frameworks is None:
75+
known_frameworks = _common.KNOWN_FRAMEWORKS
76+
if known_task_types is None:
77+
known_task_types = _common.KNOWN_TASK_TYPES
78+
7479
with validation.deserialization_context('In model "{}"'.format(name)):
7580
if not RE_MODEL_NAME.fullmatch(name):
7681
raise validation.DeserializationError('Invalid name, must consist only of letters, digits or ._-')
@@ -93,10 +98,10 @@ def deserialize(cls, model, name, subdirectory, composite_model_name):
9398
postprocessings.append(postprocessing.Postproc.deserialize(postproc))
9499

95100
framework = validation.validate_string_enum('"framework"', model['framework'],
96-
_common.KNOWN_FRAMEWORKS.keys())
101+
known_frameworks.keys())
97102

98103
conversion_to_onnx_args = model.get('conversion_to_onnx_args', None)
99-
if _common.KNOWN_FRAMEWORKS[framework]:
104+
if known_frameworks[framework]:
100105
if not conversion_to_onnx_args:
101106
raise validation.DeserializationError('"conversion_to_onnx_args" is absent. '
102107
'Framework "{}" is supported only by conversion to ONNX.'
@@ -155,11 +160,11 @@ def deserialize(cls, model, name, subdirectory, composite_model_name):
155160
license_url = validation.validate_string('"license"', model['license'])
156161

157162
task_type = validation.validate_string_enum('"task_type"', model['task_type'],
158-
_common.KNOWN_TASK_TYPES)
163+
known_task_types)
159164

160165
return cls(name, subdirectory, files, postprocessings, mo_args, framework,
161166
description, license_url, precisions, quantization_output_precisions,
162-
task_type, conversion_to_onnx_args, composite_model_name)
167+
task_type, conversion_to_onnx_args, known_frameworks[framework], composite_model_name)
163168

164169
class CompositeModel:
165170
def __init__(self, name, subdirectory, task_type, model_stages, description, framework,
@@ -177,23 +182,29 @@ def __init__(self, name, subdirectory, task_type, model_stages, description, fra
177182
self.composite_model_name = composite_model_name
178183

179184
@classmethod
180-
def deserialize(cls, model, name, subdirectory, stages):
185+
def deserialize(cls, model, name, subdirectory, stages, known_frameworks=None, known_task_types=None):
186+
if known_frameworks is None:
187+
known_frameworks = _common.KNOWN_FRAMEWORKS
188+
if known_task_types is None:
189+
known_task_types = _common.KNOWN_TASK_TYPES
190+
181191
with validation.deserialization_context('In model "{}"'.format(name)):
182192
if not RE_MODEL_NAME.fullmatch(name):
183193
raise validation.DeserializationError('Invalid name, must consist only of letters, digits or ._-')
184194

185-
task_type = validation.validate_string_enum('"task_type"', model['task_type'], _common.KNOWN_TASK_TYPES)
195+
task_type = validation.validate_string_enum('"task_type"', model['task_type'], known_task_types)
186196

187197
description = validation.validate_string('"description"', model['description'])
188198

189199
license_url = validation.validate_string('"license"', model['license'])
190200

191201
framework = validation.validate_string_enum('"framework"', model['framework'],
192-
_common.KNOWN_FRAMEWORKS.keys())
202+
known_frameworks)
193203

194204
model_stages = []
195205
for model_subdirectory, model_part in stages.items():
196-
model_stages.append(Model.deserialize(model_part, model_subdirectory.name, model_subdirectory, name))
206+
model_stages.append(Model.deserialize(model_part, model_subdirectory.name, model_subdirectory, name,
207+
known_frameworks=known_frameworks, known_task_types=known_task_types))
197208

198209
quantization_output_precisions = model_stages[0].quantization_output_precisions
199210
precisions = model_stages[0].precisions

tools/model_tools/src/openvino/model_zoo/download_engine/downloader.py

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,68 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextlib
1516
import functools
1617
import requests
1718
import ssl
19+
import sys
20+
import threading
1821
import time
1922
import types
2023

24+
from pathlib import Path
25+
from typing import Set
26+
27+
from openvino.model_zoo import _common, _concurrency, _reporting
2128
from openvino.model_zoo.download_engine import cache
2229

2330
DOWNLOAD_TIMEOUT = 5 * 60
2431

32+
33+
# There is no evidence that the requests.Session class is thread-safe,
34+
# so for safety, we use one Session per thread. This class ensures that
35+
# each thread gets its own Session.
36+
class ThreadSessionFactory:
37+
def __init__(self, exit_stack):
38+
self._lock = threading.Lock()
39+
self._thread_local = threading.local()
40+
self._exit_stack = exit_stack
41+
42+
def __call__(self):
43+
try:
44+
session = self._thread_local.session
45+
except AttributeError:
46+
with self._lock: # ExitStack might not be thread-safe either
47+
session = self._exit_stack.enter_context(requests.Session())
48+
self._thread_local.session = session
49+
return session
50+
51+
2552
class Downloader:
26-
def __init__(self, output_dir=None, cache_dir=None, num_attempts=1, timeout=DOWNLOAD_TIMEOUT):
53+
def __init__(self, requested_precisions: str = None, output_dir: Path = None,
54+
cache_dir: Path = None, num_attempts: int = 1, timeout: int = DOWNLOAD_TIMEOUT):
2755
self.output_dir = output_dir
2856
self.cache = cache.NullCache() if cache_dir is None else cache.DirCache(cache_dir)
2957
self.num_attempts = num_attempts
3058
self.timeout = timeout
59+
self.requested_precisions = requested_precisions
60+
61+
@property
62+
def requested_precisions(self) -> Set[str]:
63+
return self._requested_precisions
64+
65+
@requested_precisions.setter
66+
def requested_precisions(self, value: str = None):
67+
if value is None:
68+
_requested_precisions = _common.KNOWN_PRECISIONS
69+
else:
70+
_requested_precisions = set(value.split(','))
71+
72+
unknown_precisions = _requested_precisions - _common.KNOWN_PRECISIONS
73+
if unknown_precisions:
74+
sys.exit('Unknown precisions specified: {}.'.format(', '.join(sorted(unknown_precisions))))
75+
76+
self._requested_precisions = _requested_precisions
3177

3278
def _process_download(self, reporter, chunk_iterable, size, progress, file):
3379
start_time = time.monotonic()
@@ -72,12 +118,12 @@ def _try_download(self, reporter, file, start_download, size, hasher):
72118

73119
try:
74120
reporter.job_context.check_interrupted()
75-
chunk_iterable, continue_offset = start_download(offset=progress.size, timeout=self.timeout)
121+
chunk_iterable, continue_offset = start_download(offset=progress.size, timeout=self.timeout, size=size, checksum=hasher)
76122

77123
if continue_offset not in {0, progress.size}:
78124
# Somehow we neither restarted nor continued from where we left off.
79125
# Try to restart.
80-
chunk_iterable, continue_offset = start_download(offset=0, timeout=self.timeout)
126+
chunk_iterable, continue_offset = start_download(offset=0, timeout=self.timeout, size=size, checksum=hasher)
81127
if continue_offset != 0:
82128
reporter.log_error("Remote server refuses to send whole file, aborting")
83129
return None
@@ -86,7 +132,7 @@ def _try_download(self, reporter, file, start_download, size, hasher):
86132
file.seek(0)
87133
file.truncate()
88134
progress.size = 0
89-
progress.hasher = hasher()
135+
progress.hasher = hasher.type()
90136

91137
self._process_download(reporter, chunk_iterable, size, progress, file)
92138

@@ -131,6 +177,14 @@ def _try_update_cache(reporter, cache, hash, source):
131177
except Exception:
132178
reporter.log_warning('Failed to update the cache', exc_info=True)
133179

180+
@staticmethod
181+
def make_reporter(progress_format: str, context=None):
182+
if context is None:
183+
context = _reporting.DirectOutputContext()
184+
return _reporting.Reporter(context,
185+
enable_human_output=progress_format == 'text',
186+
enable_json_output=progress_format == 'json')
187+
134188
def _try_retrieve(self, reporter, destination, model_file, start_download):
135189
destination.parent.mkdir(parents=True, exist_ok=True)
136190

@@ -142,7 +196,7 @@ def _try_retrieve(self, reporter, destination, model_file, start_download):
142196
success = False
143197

144198
with destination.open('w+b') as f:
145-
actual_hash = self._try_download(reporter, f, start_download, model_file.size, model_file.checksum.type)
199+
actual_hash = self._try_download(reporter, f, start_download, model_file.size, model_file.checksum)
146200

147201
if actual_hash and cache.verify_hash(reporter, actual_hash, model_file.checksum.value, destination):
148202
self._try_update_cache(reporter, self.cache, model_file.checksum.value, destination)
@@ -151,7 +205,10 @@ def _try_retrieve(self, reporter, destination, model_file, start_download):
151205
reporter.print()
152206
return success
153207

154-
def download_model(self, reporter, session_factory, requested_precisions, model, known_precisions):
208+
def _download_model(self, reporter, session_factory, model, known_precisions: set = None):
209+
if known_precisions is None:
210+
known_precisions = _common.KNOWN_PRECISIONS
211+
155212
session = session_factory()
156213

157214
reporter.print_group_heading('Downloading {}', model.name)
@@ -164,7 +221,7 @@ def download_model(self, reporter, session_factory, requested_precisions, model,
164221
for model_file in model.files:
165222
if len(model_file.name.parts) == 2:
166223
p = model_file.name.parts[0]
167-
if p in known_precisions and p not in requested_precisions:
224+
if p in known_precisions and p not in self.requested_precisions:
168225
continue
169226

170227
model_file_reporter = reporter.with_event_context(model=model.name, model_file=model_file.name.as_posix())
@@ -173,7 +230,8 @@ def download_model(self, reporter, session_factory, requested_precisions, model,
173230
destination = output / model_file.name
174231

175232
if not self._try_retrieve(model_file_reporter, destination, model_file,
176-
functools.partial(model_file.source.start_download, session, cache.CHUNK_SIZE)):
233+
functools.partial(model_file.source.start_download, session, cache.CHUNK_SIZE,
234+
size=model_file.size, checksum=model_file.checksum)):
177235
try:
178236
destination.unlink()
179237
except FileNotFoundError:
@@ -198,3 +256,25 @@ def download_model(self, reporter, session_factory, requested_precisions, model,
198256
reporter.print()
199257

200258
return True
259+
260+
def download_model(self, model, reporter, session):
261+
if model.model_stages:
262+
results = []
263+
for model_stage in model.model_stages:
264+
results.append(self._download_model(reporter, session, model_stage))
265+
return sum(results) == len(model.model_stages)
266+
else:
267+
return self._download_model(reporter, session, model)
268+
269+
def bulk_download_model(self, models, reporter, jobs: int, progress_format: str) -> Set[str]:
270+
with contextlib.ExitStack() as exit_stack:
271+
session_factory = ThreadSessionFactory(exit_stack)
272+
if jobs == 1:
273+
results = [self.download_model(model, reporter, session_factory) for model in models]
274+
else:
275+
results = _concurrency.run_in_parallel(jobs,
276+
lambda context, model: self.download_model(model, self.make_reporter(progress_format, context),
277+
session_factory),
278+
models)
279+
280+
return {model.name for model, successful in zip(models, results) if not successful}

tools/model_tools/src/openvino/model_zoo/download_engine/file_source.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self, url):
6767
def deserialize(cls, source):
6868
return cls(validation.validate_string('"url"', source['url']))
6969

70-
def start_download(self, session, chunk_size, offset, timeout):
70+
def start_download(self, session, chunk_size, offset, timeout, **kwargs):
7171
response = session.get(self.url, stream=True, timeout=timeout,
7272
headers=self.http_range_headers(offset))
7373
response.raise_for_status()
@@ -84,7 +84,7 @@ def __init__(self, id):
8484
def deserialize(cls, source):
8585
return cls(validation.validate_string('"id"', source['id']))
8686

87-
def start_download(self, session, chunk_size, offset, timeout):
87+
def start_download(self, session, chunk_size, offset, timeout, **kwargs):
8888
range_headers = self.http_range_headers(offset)
8989
URL = 'https://docs.google.com/uc?export=download'
9090
response = session.get(URL, params={'id': self.id}, headers=range_headers,

0 commit comments

Comments
 (0)