|
10 | 10 | import requests |
11 | 11 | import torch |
12 | 12 | from PIL import Image |
| 13 | +from requests.adapters import HTTPAdapter |
| 14 | +from urllib3.util.retry import Retry |
13 | 15 |
|
14 | 16 | from swift.utils import get_env_args |
15 | 17 |
|
@@ -105,12 +107,19 @@ def load_file(path: Union[str, bytes, _T]) -> Union[BytesIO, _T]: |
105 | 107 | if isinstance(path, str): |
106 | 108 | path = path.strip() |
107 | 109 | if path.startswith('http'): |
108 | | - request_kwargs = {} |
109 | | - timeout = float(os.getenv('TIMEOUT', '300')) |
110 | | - if timeout > 0: |
111 | | - request_kwargs['timeout'] = timeout |
112 | | - content = requests.get(path, **request_kwargs).content |
113 | | - res = BytesIO(content) |
| 110 | + retries = Retry(total=3, backoff_factor=1, allowed_methods=['GET']) |
| 111 | + with requests.Session() as session: |
| 112 | + session.mount('http://', HTTPAdapter(max_retries=retries)) |
| 113 | + session.mount('https://', HTTPAdapter(max_retries=retries)) |
| 114 | + |
| 115 | + timeout = float(os.getenv('SWIFT_TIMEOUT', '20')) |
| 116 | + request_kwargs = {'timeout': timeout} if timeout > 0 else {} |
| 117 | + |
| 118 | + response = session.get(path, **request_kwargs) |
| 119 | + response.raise_for_status() |
| 120 | + content = response.content |
| 121 | + res = BytesIO(content) |
| 122 | + |
114 | 123 | elif os.path.exists(path) or (not path.startswith('data:') and len(path) <= 200): |
115 | 124 | ROOT_IMAGE_DIR = get_env_args('ROOT_IMAGE_DIR', str, None) |
116 | 125 | if ROOT_IMAGE_DIR is not None and not os.path.exists(path): |
|
0 commit comments