Skip to content

Commit 022b718

Browse files
authored
Merge pull request #482 from nbcsm/retry
support retry for url get
2 parents 411c9f6 + 01ab5b2 commit 022b718

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def initialize_options(self):
3535

3636
def finalize_options(self):
3737
pass
38-
38+
3939
def run(self):
4040
with open(os.path.join(SRC_DIR, 'version.py'), 'w') as f:
4141
f.write(dedent('''
@@ -74,11 +74,11 @@ def run(self):
7474
version=VersionInfo.version,
7575
description='Tensorflow to ONNX converter',
7676
setup_requires=['pytest-runner'],
77-
tests_require=['graphviz', 'requests', 'parameterized', 'pytest', 'pytest-cov', 'pyyaml'],
77+
tests_require=['graphviz', 'parameterized', 'pytest', 'pytest-cov', 'pyyaml'],
7878
cmdclass=cmdclass,
7979
packages=find_packages(),
8080
8181
author_email='[email protected]',
8282
url='https://github.com/onnx/tensorflow-onnx',
83-
install_requires=['numpy>=1.14.1', 'onnx>=1.4.1', 'six']
83+
install_requires=['numpy>=1.14.1', 'onnx>=1.4.1', 'requests', 'six']
8484
)

tests/run_pretrained_models.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import PIL.Image
1919
import numpy as np
20-
import requests
2120
import six
2221
import tensorflow as tf
2322
# contrib ops are registered only when the module is imported, the following import statement is needed,
@@ -119,11 +118,7 @@ def download_file(self):
119118
os.makedirs(dir_name, exist_ok=True)
120119
fpath = os.path.join(dir_name, fname)
121120
if not os.path.exists(fpath):
122-
response = requests.get(url)
123-
if response.status_code not in [200]:
124-
response.raise_for_status()
125-
with open(fpath, "wb") as f:
126-
f.write(response.content)
121+
utils.get_url(url, fpath)
127122
model_path = os.path.join(dir_name, self.local)
128123
if not os.path.exists(model_path):
129124
if ftype == 'tgz':

tf2onnx/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313
import re
1414
import shutil
1515
import tempfile
16+
17+
import requests
18+
from requests.adapters import HTTPAdapter
19+
from urllib3.util.retry import Retry
1620
import six
1721
import numpy as np
1822
import tensorflow as tf
1923
from tensorflow.core.framework import types_pb2, tensor_pb2
2024
from google.protobuf import text_format
2125
import onnx
2226
from onnx import helper, onnx_pb, defs, numpy_helper
27+
2328
from . import constants
2429

2530
#
@@ -468,3 +473,23 @@ def set_debug_mode(enabled):
468473

469474
def get_max_value(np_dtype):
470475
return np.iinfo(np_dtype).max
476+
477+
478+
def get_url(url, path, max_retries=5):
479+
""" Download url and save to path. """
480+
retries = Retry(total=max_retries, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504])
481+
adapter = HTTPAdapter(max_retries=retries)
482+
session = requests.Session()
483+
session.mount("http://", adapter)
484+
session.mount("https://", adapter)
485+
486+
response = session.get(url, allow_redirects=True)
487+
if response.status_code not in [200]:
488+
response.raise_for_status()
489+
490+
dir_name = os.path.dirname(path)
491+
if dir_name:
492+
os.makedirs(dir_name, exist_ok=True)
493+
494+
with open(path, "wb") as f:
495+
f.write(response.content)

0 commit comments

Comments
 (0)