|
25 | 25 | import tempfile |
26 | 26 | from threading import Lock |
27 | 27 | from typing import Optional |
28 | | -from urllib.parse import urljoin |
29 | | -import urllib.request |
| 28 | +from urllib.parse import urljoin, urlparse |
| 29 | + |
| 30 | +import requests |
30 | 31 |
|
31 | 32 | __all__ = ["DownloadManager"] |
32 | 33 |
|
@@ -186,17 +187,23 @@ def _retrieve_data(self, url: str, filename: str, dest: str = None, force: bool |
186 | 187 | str |
187 | 188 | The local path where the file was saved. |
188 | 189 | """ |
189 | | - local_path = "" |
190 | 190 | if dest is None: |
191 | | - dest = tempfile.gettempdir() # Use system temp directory if no destination is provided |
192 | | - local_path = Path(dest) / Path(filename).name |
193 | | - if not force and Path(local_path).is_file(): |
194 | | - return local_path |
195 | | - try: |
196 | | - local_path, _ = urllib.request.urlretrieve(url, filename=local_path) |
197 | | - except urllib.error.HTTPError: |
198 | | - raise FileNotFoundError(f"Failed to download {filename} from {url}, file does not exist.") |
199 | | - return local_path |
| 191 | + dest = tempfile.gettempdir() |
| 192 | + local_path = Path(dest) / Path(filename).name |
| 193 | + |
| 194 | + if not force and local_path.is_file(): |
| 195 | + return str(local_path) |
| 196 | + |
| 197 | + parsed_url = urlparse(url) |
| 198 | + if parsed_url.scheme not in ("http", "https"): |
| 199 | + raise ValueError(f"Unsafe URL scheme: {parsed_url.scheme}") |
| 200 | + |
| 201 | + response = requests.get(url, timeout=60) |
| 202 | + response.raise_for_status() |
| 203 | + |
| 204 | + Path(local_path).write_bytes(response.content) |
| 205 | + |
| 206 | + return str(local_path) |
200 | 207 |
|
201 | 208 |
|
202 | 209 | # Create a singleton instance of DownloadManager |
|
0 commit comments