Skip to content

Commit 59be8f2

Browse files
Wauplinlhoestq
andauthored
Fast download in hf file system (#2143)
* first draft * Faster downloads in HfFileSystem * Apply suggestions from code review Co-authored-by: Quentin Lhoest <[email protected]> * speed-up get_file and read * add tests * fix test * fix revision * fix on windows * simplified file.read * Update src/huggingface_hub/hf_file_system.py Co-authored-by: Quentin Lhoest <[email protected]> --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent a42c629 commit 59be8f2

File tree

3 files changed

+239
-91
lines changed

3 files changed

+239
-91
lines changed

src/huggingface_hub/file_download.py

Lines changed: 78 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def http_get(
404404
expected_size: Optional[int] = None,
405405
displayed_filename: Optional[str] = None,
406406
_nb_retries: int = 5,
407+
_tqdm_bar: Optional[tqdm] = None,
407408
) -> None:
408409
"""
409410
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
@@ -483,84 +484,90 @@ def http_get(
483484
)
484485

485486
# Stream file to buffer
486-
with tqdm(
487-
unit="B",
488-
unit_scale=True,
489-
total=total,
490-
initial=resume_size,
491-
desc=displayed_filename,
492-
disable=True if (logger.getEffectiveLevel() == logging.NOTSET) else None,
493-
# ^ set `disable=None` rather than `disable=False` by default to disable progress bar when no TTY attached
494-
# see https://github.com/huggingface/huggingface_hub/pull/2000
495-
) as progress:
496-
if hf_transfer and total is not None and total > 5 * DOWNLOAD_CHUNK_SIZE:
497-
supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters
498-
if not supports_callback:
499-
warnings.warn(
500-
"You are using an outdated version of `hf_transfer`. "
501-
"Consider upgrading to latest version to enable progress bars "
502-
"using `pip install -U hf_transfer`."
503-
)
504-
try:
505-
hf_transfer.download(
506-
url=url,
507-
filename=temp_file.name,
508-
max_files=HF_TRANSFER_CONCURRENCY,
509-
chunk_size=DOWNLOAD_CHUNK_SIZE,
510-
headers=headers,
511-
parallel_failures=3,
512-
max_retries=5,
513-
**({"callback": progress.update} if supports_callback else {}),
514-
)
515-
except Exception as e:
516-
raise RuntimeError(
517-
"An error occurred while downloading using `hf_transfer`. Consider"
518-
" disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling."
519-
) from e
520-
if not supports_callback:
521-
progress.update(total)
522-
if expected_size is not None and expected_size != os.path.getsize(temp_file.name):
523-
raise EnvironmentError(
524-
consistency_error_message.format(
525-
actual_size=os.path.getsize(temp_file.name),
526-
)
527-
)
528-
return
529-
new_resume_size = resume_size
487+
progress = _tqdm_bar
488+
if progress is None:
489+
progress = tqdm(
490+
unit="B",
491+
unit_scale=True,
492+
total=total,
493+
initial=resume_size,
494+
desc=displayed_filename,
495+
disable=True if (logger.getEffectiveLevel() == logging.NOTSET) else None,
496+
# ^ set `disable=None` rather than `disable=False` by default to disable progress bar when no TTY attached
497+
# see https://github.com/huggingface/huggingface_hub/pull/2000
498+
)
499+
500+
if hf_transfer and total is not None and total > 5 * DOWNLOAD_CHUNK_SIZE:
501+
supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters
502+
if not supports_callback:
503+
warnings.warn(
504+
"You are using an outdated version of `hf_transfer`. "
505+
"Consider upgrading to latest version to enable progress bars "
506+
"using `pip install -U hf_transfer`."
507+
)
530508
try:
531-
for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
532-
if chunk: # filter out keep-alive new chunks
533-
progress.update(len(chunk))
534-
temp_file.write(chunk)
535-
new_resume_size += len(chunk)
536-
# Some data has been downloaded from the server so we reset the number of retries.
537-
_nb_retries = 5
538-
except (requests.ConnectionError, requests.ReadTimeout) as e:
539-
# If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely
540-
# a transient error (network outage?). We log a warning message and try to resume the download a few times
541-
# before giving up. Tre retry mechanism is basic but should be enough in most cases.
542-
if _nb_retries <= 0:
543-
logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e))
544-
raise
545-
logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e))
546-
time.sleep(1)
547-
reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects
548-
return http_get(
509+
hf_transfer.download(
549510
url=url,
550-
temp_file=temp_file,
551-
proxies=proxies,
552-
resume_size=new_resume_size,
553-
headers=initial_headers,
554-
expected_size=expected_size,
555-
_nb_retries=_nb_retries - 1,
511+
filename=temp_file.name,
512+
max_files=HF_TRANSFER_CONCURRENCY,
513+
chunk_size=DOWNLOAD_CHUNK_SIZE,
514+
headers=headers,
515+
parallel_failures=3,
516+
max_retries=5,
517+
**({"callback": progress.update} if supports_callback else {}),
556518
)
557-
558-
if expected_size is not None and expected_size != temp_file.tell():
519+
except Exception as e:
520+
raise RuntimeError(
521+
"An error occurred while downloading using `hf_transfer`. Consider"
522+
" disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling."
523+
) from e
524+
if not supports_callback:
525+
progress.update(total)
526+
if expected_size is not None and expected_size != os.path.getsize(temp_file.name):
559527
raise EnvironmentError(
560528
consistency_error_message.format(
561-
actual_size=temp_file.tell(),
529+
actual_size=os.path.getsize(temp_file.name),
562530
)
563531
)
532+
return
533+
new_resume_size = resume_size
534+
try:
535+
for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
536+
if chunk: # filter out keep-alive new chunks
537+
progress.update(len(chunk))
538+
temp_file.write(chunk)
539+
new_resume_size += len(chunk)
540+
# Some data has been downloaded from the server so we reset the number of retries.
541+
_nb_retries = 5
542+
except (requests.ConnectionError, requests.ReadTimeout) as e:
543+
# If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely
544+
# a transient error (network outage?). We log a warning message and try to resume the download a few times
545+
# before giving up. Tre retry mechanism is basic but should be enough in most cases.
546+
if _nb_retries <= 0:
547+
logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e))
548+
raise
549+
logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e))
550+
time.sleep(1)
551+
reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects
552+
return http_get(
553+
url=url,
554+
temp_file=temp_file,
555+
proxies=proxies,
556+
resume_size=new_resume_size,
557+
headers=initial_headers,
558+
expected_size=expected_size,
559+
_nb_retries=_nb_retries - 1,
560+
_tqdm_bar=_tqdm_bar,
561+
)
562+
563+
progress.close()
564+
565+
if expected_size is not None and expected_size != temp_file.tell():
566+
raise EnvironmentError(
567+
consistency_error_message.format(
568+
actual_size=temp_file.tell(),
569+
)
570+
)
564571

565572

566573
@validate_hf_hub_args

src/huggingface_hub/hf_file_system.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,24 @@
66
from dataclasses import dataclass, field
77
from datetime import datetime
88
from itertools import chain
9+
from pathlib import Path
910
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
1011
from urllib.parse import quote, unquote
1112

1213
import fsspec
14+
from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback
15+
from fsspec.utils import isfilelike
1316
from requests import Response
1417

1518
from ._commit_api import CommitOperationCopy, CommitOperationDelete
16-
from .constants import DEFAULT_REVISION, ENDPOINT, REPO_TYPE_MODEL, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES
17-
from .file_download import hf_hub_url
19+
from .constants import (
20+
DEFAULT_REVISION,
21+
ENDPOINT,
22+
REPO_TYPE_MODEL,
23+
REPO_TYPES_MAPPING,
24+
REPO_TYPES_URL_PREFIXES,
25+
)
26+
from .file_download import hf_hub_url, http_get
1827
from .hf_api import HfApi, LastCommitInfo, RepoFile
1928
from .utils import (
2029
EntryNotFoundError,
@@ -591,6 +600,58 @@ def url(self, path: str) -> str:
591600
url = url.replace("/resolve/", "/tree/", 1)
592601
return url
593602

603+
def get_file(self, rpath, lpath, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs) -> None:
604+
"""Copy single remote file to local."""
605+
revision = kwargs.get("revision")
606+
unhandled_kwargs = set(kwargs.keys()) - {"revision"}
607+
if not isinstance(callback, (NoOpCallback, TqdmCallback)) or len(unhandled_kwargs) > 0:
608+
# for now, let's not handle custom callbacks
609+
# and let's not handle custom kwargs
610+
return super().get_file(rpath, lpath, callback=callback, outfile=outfile, **kwargs)
611+
612+
# Taken from https://github.com/fsspec/filesystem_spec/blob/47b445ae4c284a82dd15e0287b1ffc410e8fc470/fsspec/spec.py#L883
613+
if isfilelike(lpath):
614+
outfile = lpath
615+
elif self.isdir(rpath):
616+
os.makedirs(lpath, exist_ok=True)
617+
return None
618+
619+
if isinstance(lpath, (str, Path)): # otherwise, let's assume it's a file-like object
620+
os.makedirs(os.path.dirname(lpath), exist_ok=True)
621+
622+
# Open file if not already open
623+
close_file = False
624+
if outfile is None:
625+
outfile = open(lpath, "wb")
626+
close_file = True
627+
initial_pos = outfile.tell()
628+
629+
# Custom implementation of `get_file` to use `http_get`.
630+
resolve_remote_path = self.resolve_path(rpath, revision=revision)
631+
expected_size = self.info(rpath, revision=revision)["size"]
632+
callback.set_size(expected_size)
633+
try:
634+
http_get(
635+
url=hf_hub_url(
636+
repo_id=resolve_remote_path.repo_id,
637+
revision=resolve_remote_path.revision,
638+
filename=resolve_remote_path.path_in_repo,
639+
repo_type=resolve_remote_path.repo_type,
640+
endpoint=self.endpoint,
641+
),
642+
temp_file=outfile,
643+
displayed_filename=rpath,
644+
expected_size=expected_size,
645+
resume_size=0,
646+
headers=self._api._build_hf_headers(),
647+
_tqdm_bar=callback.tqdm if isinstance(callback, TqdmCallback) else None,
648+
)
649+
outfile.seek(initial_pos)
650+
finally:
651+
# Close file only if we opened it ourselves
652+
if close_file:
653+
outfile.close()
654+
594655
@property
595656
def transaction(self):
596657
"""A context within which files are committed together upon exit
@@ -618,6 +679,7 @@ def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None,
618679
raise FileNotFoundError(
619680
f"{e}.\nMake sure the repository and revision exist before writing data."
620681
) from e
682+
raise
621683
super().__init__(fs, self.resolved_path.unresolve(), **kwargs)
622684
self.fs: HfFileSystem
623685

@@ -667,6 +729,18 @@ def _upload_chunk(self, final: bool = False) -> None:
667729
path=self.resolved_path.unresolve(),
668730
)
669731

732+
def read(self, length=-1):
733+
"""Read remote file.
734+
735+
If `length` is not provided or is -1, the entire file is downloaded and read. On POSIX systems and if
736+
`hf_transfer` is not enabled, the file is loaded in memory directly. Otherwise, the file is downloaded to a
737+
temporary file and read from there.
738+
"""
739+
if self.mode == "rb" and (length is None or length == -1) and self.loc == 0:
740+
with self.fs.open(self.path, "rb", block_size=0) as f: # block_size=0 enables fast streaming
741+
return f.read()
742+
return super().read(length)
743+
670744
def url(self) -> str:
671745
return self.fs.url(self.path)
672746

@@ -695,7 +769,7 @@ def __init__(
695769
raise FileNotFoundError(
696770
f"{e}.\nMake sure the repository and revision exist before writing data."
697771
) from e
698-
# avoid an unecessary .info() call to instantiate .details
772+
# avoid an unnecessary .info() call to instantiate .details
699773
self.details = {"name": self.resolved_path.unresolve(), "size": None}
700774
super().__init__(
701775
fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs

0 commit comments

Comments
 (0)