diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f809b0e14a1a..1a2b5076bbb7 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -34,7 +34,7 @@ from zipfile import is_zipfile import torch -from huggingface_hub import split_torch_state_dict_into_shards +from huggingface_hub import DDUFEntry, get_file_explorer, split_torch_state_dict_into_shards from packaging import version from torch import Tensor, nn from torch.nn import CrossEntropyLoss, Identity @@ -491,33 +491,33 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): def load_state_dict( - checkpoint_file: Union[str, os.PathLike], + checkpoint_file: Union[str, os.PathLike, DDUFEntry], is_quantized: bool = False, map_location: Optional[Union[str, torch.device]] = None, weights_only: bool = True, - dduf_entries=None, ): """ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. """ - if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): - # Check format of the archive - if dduf_entries: - # TODO: Find a way to only open the metadata - with dduf_entries[checkpoint_file].as_mmap() as mm: + checkpoint = get_file_explorer(checkpoint_file) + if not checkpoint.is_file(): + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) + + if checkpoint.file_extension == "safetensors": + if is_safetensors_available(): + with checkpoint.as_mmap() as mm: return safetensors.torch.load(mm) - else: - with safe_open(checkpoint_file, framework="pt") as f: - metadata = f.metadata() - if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: - raise OSError( - f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " - "you save your model with the `save_pretrained` method." - ) - return safe_load_file(checkpoint_file) + raise ValueError( + f"Cannot load safetensors checkpoint at {checkpoint_file} since safetensors is not installed!" + ) + + if isinstance(checkpoint_file, DDUFEntry): + raise ValueError(f"Corrupted DDUF Entry: DDUF format only supports safetensors as saving format for model weights. got {checkpoint_file}") + try: - if dduf_entries: - raise ValueError("DDUF format is not supported yet with torch format. Please use safetensors") if map_location is None: if ( ( @@ -3444,7 +3444,6 @@ def from_pretrained( adapter_name = kwargs.pop("adapter_name", "default") use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) generation_config = kwargs.pop("generation_config", None) - dduf_entries = kwargs.pop("dduf_entries", None) gguf_file = kwargs.pop("gguf_file", None) # Cache path to the GGUF file @@ -3484,14 +3483,9 @@ def from_pretrained( raise ValueError("accelerate is required when loading a GGUF file `pip install accelerate`.") if commit_hash is None: - if not isinstance(config, PretrainedConfig): - if dduf_entries: - # files are in an archive, so I'm assuming the commit hash of the archive is enough. - resolved_config_file = next(iter(dduf_entries.items()))[1].dduf_path - - else: - # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible - resolved_config_file = cached_file( + commit_hash = getattr(config, "_commit_hash", None) + if commit_hash is None: + resolved_file = pretrained_model_name_or_path if os.path.isfile(pretrained_model_name_or_path) else cached_file( pretrained_model_name_or_path, CONFIG_NAME, cache_dir=cache_dir, @@ -3506,33 +3500,28 @@ def from_pretrained( _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, ) - commit_hash = extract_commit_hash(resolved_config_file, commit_hash) - else: - commit_hash = getattr(config, "_commit_hash", None) + commit_hash = extract_commit_hash(resolved_file, commit_hash) + file_explorer = get_file_explorer(pretrained_model_name_or_path) if is_peft_available(): _adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None) if _adapter_model_path is None: - if dduf_entries: - # TODO: use the global var from peft utils - if os.path.join(pretrained_model_name_or_path,"adapter_config.json") in dduf_entries: - _adapter_model_path = os.path.join(pretrained_model_name_or_path, "adapter_config.json") - else: - _adapter_model_path = find_adapter_config_file( - pretrained_model_name_or_path, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - _commit_hash=commit_hash, - **adapter_kwargs, - ) - if _adapter_model_path is not None and os.path.isfile(_adapter_model_path): - with open(_adapter_model_path, "r", encoding="utf-8") as f: + _adapter_model_path = find_adapter_config_file( + file_explorer, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + _commit_hash=commit_hash, + **adapter_kwargs, + ) + if _adapter_model_path is not None: + _adapter_file_explorer = get_file_explorer(_adapter_file_explorer) + if _adapter_file_explorer.is_file(): _adapter_model_path = pretrained_model_name_or_path - pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"] + pretrained_model_name_or_path = json.loads(_adapter_file_explorer.read_text())["base_model_name_or_path"] else: _adapter_model_path = None @@ -3615,7 +3604,6 @@ def from_pretrained( subfolder=subfolder, _from_auto=from_auto_class, _from_pipeline=from_pipeline, - dduf_entries=dduf_entries, **kwargs, ) else: @@ -3682,92 +3670,48 @@ def from_pretrained( "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub." ) if pretrained_model_name_or_path is not None and gguf_file is None: - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - # passing a dduf_entries means that we already knows where the file - is_local = os.path.isdir(pretrained_model_name_or_path) or dduf_entries + file_explorer = get_file_explorer(pretrained_model_name_or_path) + is_local = file_explorer.is_dir() if is_local: - if from_tf and ( - os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")) - or os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") - in dduf_entries - ): + if from_tf and file_explorer.navigate_to(subfolder, TF_WEIGHTS_NAME + ".index").is_file(): # Load from a TF 1.0 checkpoint in priority if from_tf - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + archive_file = file_explorer.navigate_to(subfolder, TF_WEIGHTS_NAME + ".index") elif from_tf and ( - os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)) - or os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) in dduf_entries + file_explorer.navigate_to(subfolder, TF2_WEIGHTS_NAME + ".index").is_file() ): # Load from a TF 2.0 checkpoint in priority if from_tf - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + archive_file=file_explorer.navigate_to(subfolder, TF2_WEIGHTS_NAME + ".index") elif from_flax and ( - os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)) - or os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) in dduf_entries + file_explorer.navigate_to(subfolder, FLAX_WEIGHTS_NAME).is_file() ): # Load from a Flax checkpoint in priority if from_flax - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + archive_file = file_explorer.navigate_to(subfolder, FLAX_WEIGHTS_NAME) elif use_safetensors is not False and ( - os.path.isfile( - os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) - ) - ) - or os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) - in dduf_entries + file_explorer.navigate_to(subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)).is_file() ): # Load from a safetensors checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) - ) + archive_file = file_explorer.navigate_to(subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) elif use_safetensors is not False and ( - os.path.isfile( - os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) - ) - ) - or os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) - ) - in dduf_entries + file_explorer.navigate_to(subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)).is_file() ): # Load from a sharded safetensors checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) - ) + file_explorer.navigate_to(subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)) is_sharded = True elif not use_safetensors and ( - os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) - ) - or os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) - in dduf_entries + file_explorer.navigate_to(subfolder, _add_variant(WEIGHTS_NAME, variant)).is_file() ): # Load from a PyTorch checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) - ) + archive_file = file_explorer.navigate_to(subfolder, _add_variant(WEIGHTS_NAME, variant)) elif not use_safetensors and ( - os.path.isfile( - os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) - ) - ) - or os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) - ) - in dduf_entries + file_explorer.navigate_to(subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)).is_file() ): # Load from a sharded PyTorch checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) - ) + file_explorer.navigate_to(subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) is_sharded = True # At this stage we don't have a weight file so we will raise an error. elif not use_safetensors and ( - os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")) - or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)) - or os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") - in dduf_entries - or os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) in dduf_entries + file_explorer.navigate_to(subfolder, TF_WEIGHTS_NAME + ".index").is_file() + or file_explorer.navigate_to(subfolder, TF2_WEIGHTS_NAME).is_file() ): raise EnvironmentError( f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" @@ -3775,8 +3719,7 @@ def from_pretrained( " `from_tf=True` to load this model from those weights." ) elif not use_safetensors and ( - os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)) - or os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) in dduf_entries + file_explorer.navigate_to(subfolder, FLAX_WEIGHTS_NAME).is_file() ): raise EnvironmentError( f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" @@ -3795,9 +3738,11 @@ def from_pretrained( f" {pretrained_model_name_or_path}." ) elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + # TODO: what would it mean in a DDUF environment? archive_file = pretrained_model_name_or_path is_local = True elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")): + # TODO: what would it mean in a DDUF environment? if not from_tf: raise ValueError( f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set " diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index e246bf3094c9..0e013de287ab 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -28,7 +28,6 @@ from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE from ...utils import ( cached_file, - extract_commit_hash, is_g2p_en_available, is_sentencepiece_available, is_tokenizers_available, @@ -706,10 +705,11 @@ def get_tokenizer_config( if resolved_config_file is None: logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") return {} - commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + # TODO: handle this correctly + # commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + commit_hash = None - with open(resolved_config_file, encoding="utf-8") as reader: - result = json.load(reader) + result = json.loads(resolved_config_file.read_text()) result["_commit_hash"] = commit_hash return result diff --git a/src/transformers/models/clip/tokenization_clip.py b/src/transformers/models/clip/tokenization_clip.py index 9d52fef27cc2..731b1b6ddae4 100644 --- a/src/transformers/models/clip/tokenization_clip.py +++ b/src/transformers/models/clip/tokenization_clip.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for CLIP.""" - import json import os import unicodedata @@ -21,6 +20,7 @@ from typing import List, Optional, Tuple import regex as re +from huggingface_hub import get_file_explorer from ...tokenization_utils import AddedToken, PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from ...utils import logging @@ -291,8 +291,6 @@ def __init__( pad_token="<|endoftext|>", # hack to enable padding **kwargs, ): - dduf_entries = kwargs.get("dduf_entries", None) - bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token @@ -304,20 +302,13 @@ def __init__( logger.info("ftfy or spacy is not installed using custom BasicTokenizer instead of ftfy.") self.nlp = BasicTokenizer(strip_accents=False, do_split_on_punc=False) self.fix_text = None - if dduf_entries: - self.encoder = json.loads(dduf_entries[vocab_file].read_text()) - else: - with open(vocab_file, encoding="utf-8") as vocab_handle: - self.encoder = json.load(vocab_handle) + + self.encoder = json.loads(get_file_explorer(vocab_file).read_text(encoding="utf-8")) self.decoder = {v: k for k, v in self.encoder.items()} self.errors = errors # how to handle errors in decoding self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - if dduf_entries: - bpe_merges = dduf_entries[merges_file].read_text().strip().split("\n")[1 : 49152 - 256 - 2 + 1] - else: - with open(merges_file, encoding="utf-8") as merges_handle: - bpe_merges = merges_handle.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1] + bpe_merges = get_file_explorer(merges_file).read_text().strip().split("\n")[1 : 49152 - 256 - 2 + 1] bpe_merges = [tuple(merge.split()) for merge in bpe_merges] self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) diff --git a/src/transformers/models/t5/tokenization_t5.py b/src/transformers/models/t5/tokenization_t5.py index efd6947abce8..14d7d6ee8c0e 100644 --- a/src/transformers/models/t5/tokenization_t5.py +++ b/src/transformers/models/t5/tokenization_t5.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import sentencepiece as spm +from huggingface_hub import get_file_explorer from ...convert_slow_tokenizer import import_protobuf from ...tokenization_utils import PreTrainedTokenizer @@ -137,7 +138,6 @@ def __init__( add_prefix_space=True, **kwargs, ) -> None: - dduf_entries = kwargs.get("dduf_entries", None) pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token @@ -147,11 +147,9 @@ def __init__( self.vocab_file = vocab_file self._extra_ids = extra_ids self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - if dduf_entries: - with dduf_entries[self.vocab_file].as_mmap() as mm: - self.sp_model.load_from_serialized_proto(mm) - else: - self.sp_model.Load(vocab_file) + + with get_file_explorer(self.vocab_file).as_mmap() as mm: + self.sp_model.load_from_serialized_proto(mm) if additional_special_tokens is not None: extra_tokens = [x for x in additional_special_tokens if " 1 and not gguf_file: raise ValueError( f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " @@ -1949,60 +1950,48 @@ def from_pretrained( if "tokenizer_file" in vocab_files: # Try to get the tokenizer config to see if there are versioned tokenizer files. fast_tokenizer_file = FULL_TOKENIZER_FILE - if dduf_entries: - tokenizer_config = json.loads( - dduf_entries[ - os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE) - ].read_text() - ) + resolved_config_file = cached_file( + pretrained_model_name_or_path, + TOKENIZER_CONFIG_FILE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + subfolder=subfolder, + user_agent=user_agent, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + _commit_hash=commit_hash, + ) + # TODO: resolve that properly + # commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + commit_hash = None + if resolved_config_file is not None: + tokenizer_config = json.loads(resolved_config_file.read_text()) if "fast_tokenizer_files" in tokenizer_config: - fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) - else: - resolved_config_file = cached_file( - pretrained_model_name_or_path, - TOKENIZER_CONFIG_FILE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - revision=revision, - local_files_only=local_files_only, - subfolder=subfolder, - user_agent=user_agent, - _raise_exceptions_for_gated_repo=False, - _raise_exceptions_for_missing_entries=False, - _raise_exceptions_for_connection_errors=False, - _commit_hash=commit_hash, - ) - commit_hash = extract_commit_hash(resolved_config_file, commit_hash) - if resolved_config_file is not None: - with open(resolved_config_file, encoding="utf-8") as reader: - tokenizer_config = json.load(reader) - if "fast_tokenizer_files" in tokenizer_config: - fast_tokenizer_file = get_fast_tokenizer_file( - tokenizer_config["fast_tokenizer_files"] - ) + fast_tokenizer_file = get_fast_tokenizer_file( + tokenizer_config["fast_tokenizer_files"] + ) vocab_files["tokenizer_file"] = fast_tokenizer_file # Get files from url, cache, or disk depending on the case resolved_vocab_files = {} unresolved_files = [] for file_id, file_path in vocab_files.items(): - if dduf_entries: - # We don't necessarily have the file - if os.path.join(pretrained_model_name_or_path, file_path) in dduf_entries: - resolved_vocab_files[file_id] = os.path.join(pretrained_model_name_or_path, file_path) + if file_explorer.is_file(file_path): + resolved_vocab_files[file_id] = file_explorer.navigate_to_file(file_path) elif file_path is None: resolved_vocab_files[file_id] = None elif single_file_id == file_id: - if os.path.isfile(file_path): - resolved_vocab_files[file_id] = file_path - elif is_remote_url(file_path): + if is_remote_url(file_path): resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies) else: resolved_vocab_files[file_id] = cached_file( - pretrained_model_name_or_path, + file_explorer, file_path, cache_dir=cache_dir, force_download=force_download, @@ -2018,7 +2007,10 @@ def from_pretrained( _raise_exceptions_for_connection_errors=False, _commit_hash=commit_hash, ) - commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash) + + # TODO: do that properly + # commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash) + commit_hash = None # if dduf_entries: # # preload files that are going to be opened in the tokenizer so that we don't need to modify each tokenizer with dduf ? # for file in cls.vocab_files_names: @@ -2106,11 +2098,7 @@ def _from_pretrained( # Did we saved some inputs and kwargs to reload ? tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None) if tokenizer_config_file is not None: - if dduf_entries: - init_kwargs = json.loads(dduf_entries[tokenizer_config_file].read_text()) - else: - with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: - init_kwargs = json.load(tokenizer_config_handle) + init_kwargs = json.loads(tokenizer_config_file.read_text()) # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers. config_tokenizer_class = init_kwargs.get("tokenizer_class") init_kwargs.pop("tokenizer_class", None) @@ -2126,12 +2114,7 @@ def _from_pretrained( # If an independent chat template file exists, it takes priority over template entries in the tokenizer config chat_template_file = resolved_vocab_files.pop("chat_template_file", None) if chat_template_file is not None: - if dduf_entries: - # TODO: check that it works - init_kwargs["chat_template"] = dduf_entries[chat_template_file].read_text() - else: - with open(chat_template_file) as chat_template_handle: - init_kwargs["chat_template"] = chat_template_handle.read() # Clobbers any template in the config + init_kwargs["chat_template"] = chat_template_file.read_text() if not _is_local: if "auto_map" in init_kwargs: diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index c6f7b9335c98..29c3d7ed0f9a 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -14,7 +14,6 @@ """ Hub utilities: utilities related to download and cache models """ - import json import os import re @@ -34,12 +33,14 @@ from huggingface_hub import ( _CACHED_NO_EXIST, CommitOperationAdd, + FileExplorer, ModelCard, ModelCardData, constants, create_branch, create_commit, create_repo, + get_file_explorer, get_hf_file_metadata, hf_hub_download, hf_hub_url, @@ -285,7 +286,7 @@ def cached_file( _raise_exceptions_for_connection_errors: bool = True, _commit_hash: Optional[str] = None, **deprecated_kwargs, -) -> Optional[str]: +) -> Optional[FileExplorer]: """ Tries to locate a file in a local folder and repo, downloads and cache it if necessary. @@ -365,11 +366,10 @@ def cached_file( if subfolder is None: subfolder = "" - path_or_repo_id = str(path_or_repo_id) - full_filename = os.path.join(subfolder, filename) - if os.path.isdir(path_or_repo_id): - resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename) - if not os.path.isfile(resolved_file): + file_explorer = get_file_explorer(path_or_repo_id) + if file_explorer.is_dir(): + resolved_file = file_explorer.navigate_to(subfolder, filename) + if not resolved_file.is_file(): if _raise_exceptions_for_missing_entries and filename not in ["config.json", f"{subfolder}/config.json"]: raise EnvironmentError( f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout " @@ -469,7 +469,7 @@ def cached_file( raise EnvironmentError( f"Incorrect path_or_model_id: '{path_or_repo_id}'. Please provide either the path to a local folder or the repo_id of a model on the Hub." ) from e - return resolved_file + return get_file_explorer(resolved_file) # TODO: deprecate `get_file_from_repo` or document it differently? diff --git a/src/transformers/utils/peft_utils.py b/src/transformers/utils/peft_utils.py index 7efa80e92347..c3e33aff5fa3 100644 --- a/src/transformers/utils/peft_utils.py +++ b/src/transformers/utils/peft_utils.py @@ -15,6 +15,7 @@ import os from typing import Dict, Optional, Union +from huggingface_hub import get_file_explorer from packaging import version from .hub import cached_file @@ -77,15 +78,17 @@ def find_adapter_config_file( In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. """ - adapter_cached_filename = None if model_id is None: return None - elif os.path.isdir(model_id): - list_remote_files = os.listdir(model_id) + file_explorer = get_file_explorer(model_id) + if file_explorer.is_dir(): + # 'model_id' is either a local directory or a DDUF file (e.g. a zip-based file) + list_remote_files = file_explorer.listdir() if ADAPTER_CONFIG_NAME in list_remote_files: - adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME) + return file_explorer.navigate_to(ADAPTER_CONFIG_NAME) else: - adapter_cached_filename = cached_file( + # otherwise, assume it's a model ID on the Hub + return get_file_explorer(cached_file( model_id, ADAPTER_CONFIG_NAME, cache_dir=cache_dir, @@ -100,9 +103,7 @@ def find_adapter_config_file( _raise_exceptions_for_gated_repo=False, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, - ) - - return adapter_cached_filename + )) def check_peft_version(min_version: str) -> None: