Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 124 additions & 64 deletions src/huggingface_hub/_commit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
validate_hf_hub_args,
)
from .utils import tqdm as hf_tqdm
from .utils._runtime import is_xet_available


if TYPE_CHECKING:
Expand Down Expand Up @@ -353,7 +354,7 @@ def _warn_on_overwriting_operations(operations: List[CommitOperation]) -> None:


@validate_hf_hub_args
def _upload_lfs_files(
def _upload_files(
*,
additions: List[CommitOperationAdd],
repo_type: str,
Expand All @@ -362,6 +363,86 @@ def _upload_lfs_files(
endpoint: Optional[str] = None,
num_threads: int = 5,
revision: Optional[str] = None,
create_pr: Optional[bool] = None,
):
"""
Negotiates per-file transfer (LFS vs Xet) and uploads in batches.
"""
xet_additions: List[CommitOperationAdd] = []
lfs_actions: List[Dict] = []
lfs_oid2addop: Dict[str, CommitOperationAdd] = {}

for chunk in chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES):
chunk_list = [op for op in chunk]

transfers: List[str] = ["basic", "multipart"]
has_buffered_io_data = any(isinstance(op.path_or_fileobj, io.BufferedIOBase) for op in chunk_list)
if is_xet_available():
if not has_buffered_io_data:
transfers.append("xet")
else:
logger.warning(
"Uploading files as a binary IO buffer is not supported by Xet Storage. "
"Falling back to HTTP upload."
)

actions_chunk, errors_chunk, chosen_transfer = post_lfs_batch_info(
upload_infos=[op.upload_info for op in chunk_list],
repo_id=repo_id,
repo_type=repo_type,
revision=revision,
endpoint=endpoint,
headers=headers,
token=None,
transfers=transfers,
)
if errors_chunk:
message = "\n".join(
[
f"Encountered error for file with OID {err.get('oid')}: `{err.get('error', {}).get('message')}"
for err in errors_chunk
]
)
raise ValueError(f"LFS batch API returned errors:\n{message}")

# If server returns a transfer we didn't offer (e.g "xet" while uploading from BytesIO),
# fall back to LFS for this chunk.
if chosen_transfer == "xet" and ("xet" in transfers):
xet_additions.extend(chunk_list)
else:
lfs_actions.extend(actions_chunk)
for op in chunk_list:
lfs_oid2addop[op.upload_info.sha256.hex()] = op

if len(lfs_actions) > 0:
_upload_lfs_files(
actions=lfs_actions,
oid2addop=lfs_oid2addop,
headers=headers,
endpoint=endpoint,
num_threads=num_threads,
)

if len(xet_additions) > 0:
_upload_xet_files(
additions=xet_additions,
repo_type=repo_type,
repo_id=repo_id,
headers=headers,
endpoint=endpoint,
revision=revision,
create_pr=create_pr,
)


@validate_hf_hub_args
def _upload_lfs_files(
*,
actions: List[Dict],
oid2addop: Dict[str, CommitOperationAdd],
headers: Dict[str, str],
endpoint: Optional[str] = None,
num_threads: int = 5,
):
"""
Uploads the content of `additions` to the Hub using the large file storage protocol.
Expand All @@ -370,9 +451,21 @@ def _upload_lfs_files(
- LFS Batch API: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md

Args:
additions (`List` of `CommitOperationAdd`):
additions (`Iterable` of `CommitOperationAdd`):
The files to be uploaded
repo_type (`str`):
oid2addop (`Dict[str, CommitOperationAdd]`):
A dictionary mapping the OID of the file to the corresponding `CommitOperationAdd` object.
headers (`Dict[str, str]`):
Headers to use for the request, including authorization headers and user agent.
endpoint (`str`, *optional*):
The endpoint to use for the request. Defaults to `constants.ENDPOINT`.
num_threads (`int`, *optional*):
The number of concurrent threads to use when uploading. Defaults to 5.

Raises:
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
If an upload failed for any reason
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
Expand All @@ -392,50 +485,17 @@ def _upload_lfs_files(
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
If the LFS batch endpoint returned an HTTP error.
"""
# Step 1: retrieve upload instructions from the LFS batch endpoint.
# Upload instructions are retrieved by chunk of 256 files to avoid reaching
# the payload limit.
batch_actions: List[Dict] = []
for chunk in chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES):
batch_actions_chunk, batch_errors_chunk = post_lfs_batch_info(
upload_infos=[op.upload_info for op in chunk],
repo_id=repo_id,
repo_type=repo_type,
revision=revision,
endpoint=endpoint,
headers=headers,
token=None, # already passed in 'headers'
)

# If at least 1 error, we do not retrieve information for other chunks
if batch_errors_chunk:
message = "\n".join(
[
f"Encountered error for file with OID {err.get('oid')}: `{err.get('error', {}).get('message')}"
for err in batch_errors_chunk
]
)
raise ValueError(f"LFS batch endpoint returned errors:\n{message}")

batch_actions += batch_actions_chunk
oid2addop = {add_op.upload_info.sha256.hex(): add_op for add_op in additions}

# Step 2: ignore files that have already been uploaded
# Filter out files already present upstream
filtered_actions = []
for action in batch_actions:
for action in actions:
if action.get("actions") is None:
logger.debug(
f"Content of file {oid2addop[action['oid']].path_in_repo} is already"
" present upstream - skipping upload."
f"Content of file {oid2addop[action['oid']].path_in_repo} is already present upstream - skipping upload."
)
else:
filtered_actions.append(action)

if len(filtered_actions) == 0:
logger.debug("No LFS files to upload.")
return

# Step 3: upload files concurrently according to these instructions
# Upload according to server-provided actions
def _wrapped_lfs_upload(batch_action) -> None:
try:
operation = oid2addop[batch_action["oid"]]
Expand Down Expand Up @@ -576,30 +636,30 @@ def token_refresher() -> Tuple[str, int]:
progress, progress_callback = None, None

try:
for i, chunk in enumerate(chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES)):
_chunk = [op for op in chunk]

bytes_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, bytes)]
paths_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, (str, Path))]

if len(paths_ops) > 0:
upload_files(
[str(op.path_or_fileobj) for op in paths_ops],
xet_endpoint,
access_token_info,
token_refresher,
progress_callback,
repo_type,
)
if len(bytes_ops) > 0:
upload_bytes(
[op.path_or_fileobj for op in bytes_ops],
xet_endpoint,
access_token_info,
token_refresher,
progress_callback,
repo_type,
)
all_bytes_ops = [op for op in additions if isinstance(op.path_or_fileobj, bytes)]
all_paths_ops = [op for op in additions if isinstance(op.path_or_fileobj, (str, Path))]

if len(all_paths_ops) > 0:
all_paths = [str(op.path_or_fileobj) for op in all_paths_ops]
upload_files(
all_paths,
xet_endpoint,
access_token_info,
token_refresher,
progress_callback,
repo_type,
)

if len(all_bytes_ops) > 0:
all_bytes = [op.path_or_fileobj for op in all_bytes_ops]
upload_bytes(
all_bytes,
xet_endpoint,
access_token_info,
token_refresher,
progress_callback,
repo_type,
)

finally:
if progress is not None:
Expand Down
44 changes: 8 additions & 36 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

import inspect
import io
import json
import re
import struct
Expand Down Expand Up @@ -46,7 +45,7 @@
Union,
overload,
)
from urllib.parse import quote, unquote
from urllib.parse import quote

import requests
from requests.exceptions import HTTPError
Expand All @@ -62,8 +61,7 @@
_fetch_files_to_copy,
_fetch_upload_modes,
_prepare_commit_payload,
_upload_lfs_files,
_upload_xet_files,
_upload_files,
_warn_on_overwriting_operations,
)
from ._inference_endpoints import InferenceEndpoint, InferenceEndpointType
Expand Down Expand Up @@ -132,13 +130,8 @@
validate_hf_hub_args,
)
from .utils import tqdm as hf_tqdm
from .utils._auth import (
_get_token_from_environment,
_get_token_from_file,
_get_token_from_google_colab,
)
from .utils._auth import _get_token_from_environment, _get_token_from_file, _get_token_from_google_colab
from .utils._deprecation import _deprecate_arguments, _deprecate_method
from .utils._runtime import is_xet_available
from .utils._typing import CallableT
from .utils.endpoint_helpers import _is_emission_within_threshold

Expand Down Expand Up @@ -4502,6 +4495,10 @@ def preupload_lfs_files(
f"Skipped upload for {len(new_lfs_additions) - len(new_lfs_additions_to_upload)} LFS file(s) "
"(ignored by gitignore file)."
)
# If no LFS files remain to upload, keep previous behavior and log explicitly
if len(new_lfs_additions_to_upload) == 0:
logger.debug("No LFS files to upload.")
return
# Prepare upload parameters
upload_kwargs = {
"additions": new_lfs_additions_to_upload,
Expand All @@ -4514,32 +4511,7 @@ def preupload_lfs_files(
# PR (i.e. `revision`).
"revision": revision if not create_pr else None,
}
# Upload files using Xet protocol if all of the following are true:
# - xet is enabled for the repo,
# - the files are provided as str or paths objects,
# - the library is installed.
# Otherwise, default back to LFS.
xet_enabled = self.repo_info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of this to remove this call 😄

repo_id=repo_id,
repo_type=repo_type,
revision=unquote(revision) if revision is not None else revision,
expand="xetEnabled",
token=token,
).xet_enabled
has_buffered_io_data = any(
isinstance(addition.path_or_fileobj, io.BufferedIOBase) for addition in new_lfs_additions_to_upload
)
if xet_enabled and not has_buffered_io_data and is_xet_available():
logger.debug("Uploading files using Xet Storage..")
_upload_xet_files(**upload_kwargs, create_pr=create_pr) # type: ignore [arg-type]
else:
if xet_enabled and is_xet_available():
if has_buffered_io_data:
logger.warning(
"Uploading files as a binary IO buffer is not supported by Xet Storage. "
"Falling back to HTTP upload."
)
_upload_lfs_files(**upload_kwargs, num_threads=num_threads) # type: ignore [arg-type]
_upload_files(**upload_kwargs, num_threads=num_threads, create_pr=create_pr) # type: ignore [arg-type]
for addition in new_lfs_additions_to_upload:
addition._is_uploaded = True
if free_memory:
Expand Down
15 changes: 11 additions & 4 deletions src/huggingface_hub/lfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def post_lfs_batch_info(
revision: Optional[str] = None,
endpoint: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
) -> Tuple[List[dict], List[dict]]:
transfers: Optional[List[str]] = None,
) -> Tuple[List[dict], List[dict], Optional[str]]:
"""
Requests the LFS batch endpoint to retrieve upload instructions

Expand All @@ -129,9 +130,10 @@ def post_lfs_batch_info(
Additional headers to include in the request

Returns:
`LfsBatchInfo`: 2-tuple:
`LfsBatchInfo`: 3-tuple:
- First element is the list of upload instructions from the server
- Second element is an list of errors, if any
- Second element is a list of errors, if any
- Third element is the chosen transfer adapter if provided by the server (e.g. "basic", "multipart", "xet")

Raises:
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
Expand All @@ -146,7 +148,7 @@ def post_lfs_batch_info(
batch_url = f"{endpoint}/{url_prefix}{repo_id}.git/info/lfs/objects/batch"
payload: Dict = {
"operation": "upload",
"transfers": ["basic", "multipart"],
"transfers": transfers if transfers is not None else ["basic", "multipart"],
"objects": [
{
"oid": upload.sha256.hex(),
Expand All @@ -172,9 +174,14 @@ def post_lfs_batch_info(
if not isinstance(objects, list):
raise ValueError("Malformed response from server")

chosen_transfer = batch_info.get("transfer")
if chosen_transfer is not None and not isinstance(chosen_transfer, str):
chosen_transfer = None

return (
[_validate_batch_actions(obj) for obj in objects if "error" not in obj],
[_validate_batch_error(obj) for obj in objects if "error" in obj],
chosen_transfer,
)


Expand Down
Loading
Loading