12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import contextlib
15
16
import functools
16
17
import requests
17
18
import ssl
19
+ import sys
20
+ import threading
18
21
import time
19
22
import types
20
23
24
+ from pathlib import Path
25
+ from typing import Set
26
+
27
+ from openvino .model_zoo import _common , _concurrency , _reporting
21
28
from openvino .model_zoo .download_engine import cache
22
29
23
30
DOWNLOAD_TIMEOUT = 5 * 60
24
31
32
+
33
+ # There is no evidence that the requests.Session class is thread-safe,
34
+ # so for safety, we use one Session per thread. This class ensures that
35
+ # each thread gets its own Session.
36
+ class ThreadSessionFactory :
37
+ def __init__ (self , exit_stack ):
38
+ self ._lock = threading .Lock ()
39
+ self ._thread_local = threading .local ()
40
+ self ._exit_stack = exit_stack
41
+
42
+ def __call__ (self ):
43
+ try :
44
+ session = self ._thread_local .session
45
+ except AttributeError :
46
+ with self ._lock : # ExitStack might not be thread-safe either
47
+ session = self ._exit_stack .enter_context (requests .Session ())
48
+ self ._thread_local .session = session
49
+ return session
50
+
51
+
25
52
class Downloader :
26
- def __init__ (self , output_dir = None , cache_dir = None , num_attempts = 1 , timeout = DOWNLOAD_TIMEOUT ):
53
+ def __init__ (self , requested_precisions : str = None , output_dir : Path = None ,
54
+ cache_dir : Path = None , num_attempts : int = 1 , timeout : int = DOWNLOAD_TIMEOUT ):
27
55
self .output_dir = output_dir
28
56
self .cache = cache .NullCache () if cache_dir is None else cache .DirCache (cache_dir )
29
57
self .num_attempts = num_attempts
30
58
self .timeout = timeout
59
+ self .requested_precisions = requested_precisions
60
+
61
+ @property
62
+ def requested_precisions (self ) -> Set [str ]:
63
+ return self ._requested_precisions
64
+
65
+ @requested_precisions .setter
66
+ def requested_precisions (self , value : str = None ):
67
+ if value is None :
68
+ _requested_precisions = _common .KNOWN_PRECISIONS
69
+ else :
70
+ _requested_precisions = set (value .split (',' ))
71
+
72
+ unknown_precisions = _requested_precisions - _common .KNOWN_PRECISIONS
73
+ if unknown_precisions :
74
+ sys .exit ('Unknown precisions specified: {}.' .format (', ' .join (sorted (unknown_precisions ))))
75
+
76
+ self ._requested_precisions = _requested_precisions
31
77
32
78
def _process_download (self , reporter , chunk_iterable , size , progress , file ):
33
79
start_time = time .monotonic ()
@@ -72,12 +118,12 @@ def _try_download(self, reporter, file, start_download, size, hasher):
72
118
73
119
try :
74
120
reporter .job_context .check_interrupted ()
75
- chunk_iterable , continue_offset = start_download (offset = progress .size , timeout = self .timeout )
121
+ chunk_iterable , continue_offset = start_download (offset = progress .size , timeout = self .timeout , size = size , checksum = hasher )
76
122
77
123
if continue_offset not in {0 , progress .size }:
78
124
# Somehow we neither restarted nor continued from where we left off.
79
125
# Try to restart.
80
- chunk_iterable , continue_offset = start_download (offset = 0 , timeout = self .timeout )
126
+ chunk_iterable , continue_offset = start_download (offset = 0 , timeout = self .timeout , size = size , checksum = hasher )
81
127
if continue_offset != 0 :
82
128
reporter .log_error ("Remote server refuses to send whole file, aborting" )
83
129
return None
@@ -86,7 +132,7 @@ def _try_download(self, reporter, file, start_download, size, hasher):
86
132
file .seek (0 )
87
133
file .truncate ()
88
134
progress .size = 0
89
- progress .hasher = hasher ()
135
+ progress .hasher = hasher . type ()
90
136
91
137
self ._process_download (reporter , chunk_iterable , size , progress , file )
92
138
@@ -131,6 +177,14 @@ def _try_update_cache(reporter, cache, hash, source):
131
177
except Exception :
132
178
reporter .log_warning ('Failed to update the cache' , exc_info = True )
133
179
180
+ @staticmethod
181
+ def make_reporter (progress_format : str , context = None ):
182
+ if context is None :
183
+ context = _reporting .DirectOutputContext ()
184
+ return _reporting .Reporter (context ,
185
+ enable_human_output = progress_format == 'text' ,
186
+ enable_json_output = progress_format == 'json' )
187
+
134
188
def _try_retrieve (self , reporter , destination , model_file , start_download ):
135
189
destination .parent .mkdir (parents = True , exist_ok = True )
136
190
@@ -142,7 +196,7 @@ def _try_retrieve(self, reporter, destination, model_file, start_download):
142
196
success = False
143
197
144
198
with destination .open ('w+b' ) as f :
145
- actual_hash = self ._try_download (reporter , f , start_download , model_file .size , model_file .checksum . type )
199
+ actual_hash = self ._try_download (reporter , f , start_download , model_file .size , model_file .checksum )
146
200
147
201
if actual_hash and cache .verify_hash (reporter , actual_hash , model_file .checksum .value , destination ):
148
202
self ._try_update_cache (reporter , self .cache , model_file .checksum .value , destination )
@@ -151,7 +205,10 @@ def _try_retrieve(self, reporter, destination, model_file, start_download):
151
205
reporter .print ()
152
206
return success
153
207
154
- def download_model (self , reporter , session_factory , requested_precisions , model , known_precisions ):
208
+ def _download_model (self , reporter , session_factory , model , known_precisions : set = None ):
209
+ if known_precisions is None :
210
+ known_precisions = _common .KNOWN_PRECISIONS
211
+
155
212
session = session_factory ()
156
213
157
214
reporter .print_group_heading ('Downloading {}' , model .name )
@@ -164,7 +221,7 @@ def download_model(self, reporter, session_factory, requested_precisions, model,
164
221
for model_file in model .files :
165
222
if len (model_file .name .parts ) == 2 :
166
223
p = model_file .name .parts [0 ]
167
- if p in known_precisions and p not in requested_precisions :
224
+ if p in known_precisions and p not in self . requested_precisions :
168
225
continue
169
226
170
227
model_file_reporter = reporter .with_event_context (model = model .name , model_file = model_file .name .as_posix ())
@@ -173,7 +230,8 @@ def download_model(self, reporter, session_factory, requested_precisions, model,
173
230
destination = output / model_file .name
174
231
175
232
if not self ._try_retrieve (model_file_reporter , destination , model_file ,
176
- functools .partial (model_file .source .start_download , session , cache .CHUNK_SIZE )):
233
+ functools .partial (model_file .source .start_download , session , cache .CHUNK_SIZE ,
234
+ size = model_file .size , checksum = model_file .checksum )):
177
235
try :
178
236
destination .unlink ()
179
237
except FileNotFoundError :
@@ -198,3 +256,25 @@ def download_model(self, reporter, session_factory, requested_precisions, model,
198
256
reporter .print ()
199
257
200
258
return True
259
+
260
+ def download_model (self , model , reporter , session ):
261
+ if model .model_stages :
262
+ results = []
263
+ for model_stage in model .model_stages :
264
+ results .append (self ._download_model (reporter , session , model_stage ))
265
+ return sum (results ) == len (model .model_stages )
266
+ else :
267
+ return self ._download_model (reporter , session , model )
268
+
269
+ def bulk_download_model (self , models , reporter , jobs : int , progress_format : str ) -> Set [str ]:
270
+ with contextlib .ExitStack () as exit_stack :
271
+ session_factory = ThreadSessionFactory (exit_stack )
272
+ if jobs == 1 :
273
+ results = [self .download_model (model , reporter , session_factory ) for model in models ]
274
+ else :
275
+ results = _concurrency .run_in_parallel (jobs ,
276
+ lambda context , model : self .download_model (model , self .make_reporter (progress_format , context ),
277
+ session_factory ),
278
+ models )
279
+
280
+ return {model .name for model , successful in zip (models , results ) if not successful }
0 commit comments