Skip to content

Commit cd3f0ee

Browse files
authored
Fix keyword argument passing in pystow.utils.download() (#139)
Closes #138
1 parent 0f18a15 commit cd3f0ee

File tree

2 files changed

+35
-19
lines changed

2 files changed

+35
-19
lines changed

src/pystow/impl.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,20 +279,12 @@ def get_rhea_version() -> str:
279279
path, version = self.join(
280280
*subkeys, name=name, version=version, ensure_exists=True, return_version=True
281281
)
282-
_download_kwargs: dict[str, Any] = {}
283-
if version:
284-
_download_kwargs["desc"] = f"Downloading {path.name} v{version}"
285-
download_kwargs = dict(download_kwargs or {})
286-
if version:
287-
if "tqdm_kwargs" not in download_kwargs:
288-
download_kwargs["tqdm_kwargs"] = {}
289-
if "desc" not in download_kwargs["tqdm_kwargs"]:
290-
download_kwargs["tqdm_kwargs"]["desc"] = f"Downloading {path.name} v{version}"
291282
utils.download(
292283
url=url,
293284
path=path,
294285
force=force,
295-
**_download_kwargs,
286+
_version=version,
287+
**(download_kwargs or {}),
296288
)
297289
return path
298290

src/pystow/utils/download.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from collections.abc import Mapping
99
from functools import partial
1010
from pathlib import Path
11-
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
11+
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypedDict
1212
from urllib.request import urlretrieve
1313

1414
import requests
1515
from tqdm import tqdm
16+
from typing_extensions import NotRequired, Unpack
1617

1718
from .hashing import raise_on_digest_mismatch
19+
from ..constants import TimeoutHint
1820

1921
if TYPE_CHECKING:
2022
import botocore.client
@@ -58,7 +60,19 @@ def update_to(
5860
self.update(blocks * block_size - self.n) # will also set self.n = b * bsize
5961

6062

61-
def download(
63+
class RequestKwargs(TypedDict):
64+
"""Keyword arguments for :func:`requests.get`."""
65+
66+
auth: NotRequired[tuple[str, str]]
67+
timeout: NotRequired[TimeoutHint]
68+
allow_redirects: NotRequired[bool]
69+
proxies: NotRequired[dict[str, str]]
70+
verify: NotRequired[bool]
71+
stream: NotRequired[bool]
72+
cert: NotRequired[str | tuple[str, str]]
73+
74+
75+
def download( # noqa:C901
6276
url: str,
6377
path: str | Path,
6478
force: bool = True,
@@ -69,7 +83,8 @@ def download(
6983
hexdigests_strict: bool = False,
7084
progress_bar: bool = True,
7185
tqdm_kwargs: Mapping[str, Any] | None = None,
72-
**kwargs: Any,
86+
_version: str | None = None,
87+
**kwargs: Unpack[RequestKwargs],
7388
) -> None:
7489
"""Download a file from a given URL.
7590
@@ -83,14 +98,15 @@ def download(
8398
pairs.
8499
:param hexdigests_remote: The expected hexdigests as (algorithm_name, url to file
85100
with expected hexdigest) pairs.
86-
:param hexdigests_strict: Set this to false to stop automatically checking for the
101+
:param hexdigests_strict: Set this to ``False`` to stop automatically checking for the
87102
`algorithm(filename)=hash` format
88103
:param progress_bar: Set to true to show a progress bar while downloading
89104
:param tqdm_kwargs: Override the default arguments passed to :class:`tadm.tqdm` when
90105
progress_bar is True.
91-
:param kwargs: The keyword arguments to pass to :func:`urllib.request.urlretrieve`
92-
or to `requests.get` depending on the backend chosen. If using 'requests'
93-
backend, `stream` is set to True by default.
106+
:param kwargs: If using :func:`urllib.request.urlretrieve`, there are no keyword
107+
arguments available. If using ``requests`` as a backend, passes these
108+
to :func:`requests.get`. If using ``requests`` as a backend, ``stream`` is
109+
set to True by default.
94110
95111
:raises Exception: Thrown if an error besides a keyboard interrupt is thrown during
96112
download
@@ -113,13 +129,17 @@ def download(
113129
logger.debug("did not re-download %s from %s", path, url)
114130
return
115131

132+
desc = f"Downloading {path.name}"
133+
if _version:
134+
desc += f" (v{_version})"
135+
116136
_tqdm_kwargs = {
117137
"unit": "B",
118138
"unit_scale": True,
119139
"unit_divisor": 1024,
120140
"miniters": 1,
121141
"disable": not progress_bar,
122-
"desc": f"Downloading {path.name}",
142+
"desc": desc,
123143
"leave": False,
124144
}
125145
if tqdm_kwargs:
@@ -128,9 +148,13 @@ def download(
128148
try:
129149
if backend == "urllib":
130150
logger.info("downloading with urllib from %s to %s", url, path)
151+
if kwargs:
152+
logger.warning(
153+
"no kwargs should be supplied when using urllib, skipping: %s", kwargs
154+
)
131155
with TqdmReportHook(**_tqdm_kwargs) as t:
132156
try:
133-
urlretrieve(url, path, reporthook=t.update_to, **kwargs) # noqa:S310
157+
urlretrieve(url, path, reporthook=t.update_to) # noqa:S310
134158
except urllib.error.URLError as e:
135159
raise DownloadError(backend, url, path, e) from e
136160
elif backend == "requests":

0 commit comments

Comments
 (0)