Skip to content

Commit 4f94171

Browse files
schmrlnghanouticelinaWauplin
committed
Pass through additional arguments from HfApi download utils (#3531)
Co-authored-by: célina <[email protected]> Co-authored-by: Lucain <[email protected]>
1 parent 31bccf5 commit 4f94171

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

src/huggingface_hub/hf_api.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5216,6 +5216,7 @@ def hf_hub_download(
52165216
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
52175217
token: Union[bool, str, None] = None,
52185218
local_files_only: bool = False,
5219+
tqdm_class: Optional[type[base_tqdm]] = None,
52195220
dry_run: Literal[False] = False,
52205221
) -> str: ...
52215222

@@ -5234,6 +5235,7 @@ def hf_hub_download(
52345235
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
52355236
token: Union[bool, str, None] = None,
52365237
local_files_only: bool = False,
5238+
tqdm_class: Optional[type[base_tqdm]] = None,
52375239
dry_run: Literal[True],
52385240
) -> DryRunFileInfo: ...
52395241

@@ -5252,6 +5254,7 @@ def hf_hub_download(
52525254
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
52535255
token: Union[bool, str, None] = None,
52545256
local_files_only: bool = False,
5257+
tqdm_class: Optional[type[base_tqdm]] = None,
52555258
dry_run: bool = False,
52565259
) -> Union[str, DryRunFileInfo]:
52575260
"""Download a given file if it's not already present in the local cache.
@@ -5320,6 +5323,11 @@ def hf_hub_download(
53205323
local_files_only (`bool`, *optional*, defaults to `False`):
53215324
If `True`, avoid downloading the file and return the path to the
53225325
local cached file if it exists.
5326+
tqdm_class (`tqdm`, *optional*):
5327+
If provided, overwrites the default behavior for the progress bar. Passed
5328+
argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior.
5329+
Defaults to the custom HF progress bar that can be disabled by setting
5330+
`HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
53235331
dry_run (`bool`, *optional*, defaults to `False`):
53245332
If `True`, perform a dry run without actually downloading the file. Returns a
53255333
[`DryRunFileInfo`] object containing information about what would be downloaded.
@@ -5369,6 +5377,8 @@ def hf_hub_download(
53695377
token=token,
53705378
headers=self.headers,
53715379
local_files_only=local_files_only,
5380+
tqdm_class=tqdm_class,
5381+
dry_run=dry_run,
53725382
)
53735383

53745384
@validate_hf_hub_args
@@ -5388,7 +5398,8 @@ def snapshot_download(
53885398
ignore_patterns: Optional[Union[list[str], str]] = None,
53895399
max_workers: int = 8,
53905400
tqdm_class: Optional[type[base_tqdm]] = None,
5391-
) -> str:
5401+
dry_run: bool = False,
5402+
) -> Union[str, list[DryRunFileInfo]]:
53925403
"""Download repo files.
53935404
53945405
Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
@@ -5443,9 +5454,14 @@ def snapshot_download(
54435454
Note that the `tqdm_class` is not passed to each individual download.
54445455
Defaults to the custom HF progress bar that can be disabled by setting
54455456
`HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
5457+
dry_run (`bool`, *optional*, defaults to `False`):
5458+
If `True`, perform a dry run without actually downloading the files. Returns a list of
5459+
[`DryRunFileInfo`] objects containing information about what would be downloaded.
54465460
54475461
Returns:
5448-
`str`: folder path of the repo snapshot.
5462+
`str` or list of [`DryRunFileInfo`]:
5463+
- If `dry_run=False`: Folder path of the repo snapshot.
5464+
- If `dry_run=True`: A list of [`DryRunFileInfo`] objects containing download information.
54495465
54505466
Raises:
54515467
[`~utils.RepositoryNotFoundError`]
@@ -5484,6 +5500,8 @@ def snapshot_download(
54845500
ignore_patterns=ignore_patterns,
54855501
max_workers=max_workers,
54865502
tqdm_class=tqdm_class,
5503+
headers=self.headers,
5504+
dry_run=dry_run,
54875505
)
54885506

54895507
def get_safetensors_metadata(

tests/test_hf_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3359,6 +3359,8 @@ def test_hf_hub_download_alias(self, mock: Mock) -> None:
33593359
etag_timeout=10,
33603360
local_files_only=False,
33613361
headers=None,
3362+
tqdm_class=None,
3363+
dry_run=False,
33623364
)
33633365

33643366
@patch("huggingface_hub._snapshot_download.snapshot_download")
@@ -3385,6 +3387,8 @@ def test_snapshot_download_alias(self, mock: Mock) -> None:
33853387
ignore_patterns=None,
33863388
max_workers=8,
33873389
tqdm_class=None,
3390+
headers=None,
3391+
dry_run=False,
33883392
)
33893393

33903394

0 commit comments

Comments
 (0)