diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index a04b7bd6aa1b..2a8b283a9cf3 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -621,6 +621,7 @@ def _get_config_dict( commit_hash = kwargs.pop("_commit_hash", None) gguf_file = kwargs.get("gguf_file", None) + dduf_entries = kwargs.pop("dduf_entries", None) if trust_remote_code is True: logger.warning( @@ -634,7 +635,7 @@ def _get_config_dict( pretrained_model_name_or_path = str(pretrained_model_name_or_path) - is_local = os.path.isdir(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) or dduf_entries if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): # Special case when pretrained_model_name_or_path is a local file resolved_config_file = pretrained_model_name_or_path @@ -644,26 +645,31 @@ def _get_config_dict( resolved_config_file = download_url(pretrained_model_name_or_path) else: configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file - try: # Load from local folder or from cache or download from model Hub and cache - resolved_config_file = cached_file( - pretrained_model_name_or_path, - configuration_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder, - _commit_hash=commit_hash, - ) - if resolved_config_file is None: - return None, kwargs - commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + if dduf_entries: + resolved_config_file = os.path.join(pretrained_model_name_or_path, configuration_file) + commit_hash = extract_commit_hash(dduf_entries[resolved_config_file].dduf_path, commit_hash) + else: + resolved_config_file = cached_file( + pretrained_model_name_or_path, + configuration_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + if resolved_config_file is None: + return None, kwargs + else: + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # the original exception. @@ -682,7 +688,10 @@ def _get_config_dict( config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"] else: # Load config dict - config_dict = cls._dict_from_json_file(resolved_config_file) + if dduf_entries: + config_dict = json.loads(dduf_entries[resolved_config_file].read_text()) + else: + config_dict = cls._dict_from_json_file(resolved_config_file) config_dict["_commit_hash"] = commit_hash except (json.JSONDecodeError, UnicodeDecodeError): diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index f37f589d5d53..4592f3c9dca2 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -544,17 +544,22 @@ class SpmConverter(Converter): SpmExtractor = SentencePieceExtractor special_tokens = {} - def __init__(self, *args): + def __init__(self, *args, **kwargs): requires_backends(self, "protobuf") super().__init__(*args) + dduf_entries = kwargs.get("dduf_entries", None) # from .utils import sentencepiece_model_pb2 as model_pb2 model_pb2 = import_protobuf() m = model_pb2.ModelProto() - with open(self.original_tokenizer.vocab_file, "rb") as f: - m.ParseFromString(f.read()) + if dduf_entries: + with dduf_entries[self.original_tokenizer.vocab_file].as_mmap() as mm: + m.ParseFromString(mm) + else: + with open(self.original_tokenizer.vocab_file, "rb") as f: + m.ParseFromString(f.read()) self.proto = m if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback: @@ -1606,7 +1611,7 @@ def converted(self) -> Tokenizer: } -def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer: +def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False, **kwargs) -> Tokenizer: """ Utilities to convert a slow tokenizer instance in a fast tokenizer instance. @@ -1625,7 +1630,7 @@ def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokeni tokenizer_class_name = transformer_tokenizer.__class__.__name__ if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken: converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name] - return converter_class(transformer_tokenizer).converted() + return converter_class(transformer_tokenizer, **kwargs).converted() else: try: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f349847b1fd7..9d66b368c046 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -131,6 +131,7 @@ from accelerate.utils.modeling import get_state_dict_from_offload if is_safetensors_available(): + import safetensors from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file @@ -495,21 +496,29 @@ def load_state_dict( is_quantized: bool = False, map_location: Optional[Union[str, torch.device]] = None, weights_only: bool = True, + dduf_entries=None, ): """ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. """ if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): # Check format of the archive - with safe_open(checkpoint_file, framework="pt") as f: - metadata = f.metadata() - if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: - raise OSError( - f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " - "you save your model with the `save_pretrained` method." - ) - return safe_load_file(checkpoint_file) + if dduf_entries: + # TODO: Find a way to only open the metadata + with dduf_entries[checkpoint_file].as_mmap() as mm: + return safetensors.torch.load(mm) + else: + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) + return safe_load_file(checkpoint_file) try: + if dduf_entries: + raise ValueError("DDUF format is not supported yet with torch format. Please use safetensors") if map_location is None: if ( ( @@ -3410,6 +3419,7 @@ def from_pretrained( adapter_name = kwargs.pop("adapter_name", "default") use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) generation_config = kwargs.pop("generation_config", None) + dduf_entries = kwargs.pop("dduf_entries", None) gguf_file = kwargs.pop("gguf_file", None) # Cache path to the GGUF file @@ -3450,22 +3460,27 @@ def from_pretrained( if commit_hash is None: if not isinstance(config, PretrainedConfig): - # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible - resolved_config_file = cached_file( - pretrained_model_name_or_path, - CONFIG_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - _raise_exceptions_for_gated_repo=False, - _raise_exceptions_for_missing_entries=False, - _raise_exceptions_for_connection_errors=False, - ) + if dduf_entries: + # files are in an archive, so I'm assuming the commit hash of the archive is enough. + resolved_config_file = next(iter(dduf_entries.items()))[1].dduf_path + + else: + # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible + resolved_config_file = cached_file( + pretrained_model_name_or_path, + CONFIG_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) commit_hash = extract_commit_hash(resolved_config_file, commit_hash) else: commit_hash = getattr(config, "_commit_hash", None) @@ -3474,16 +3489,21 @@ def from_pretrained( _adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None) if _adapter_model_path is None: - _adapter_model_path = find_adapter_config_file( - pretrained_model_name_or_path, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - _commit_hash=commit_hash, - **adapter_kwargs, - ) + if dduf_entries: + # TODO: use the global var from peft utils + if os.path.join(pretrained_model_name_or_path, "adapter_config.json") in dduf_entries: + _adapter_model_path = os.path.join(pretrained_model_name_or_path, "adapter_config.json") + else: + _adapter_model_path = find_adapter_config_file( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + _commit_hash=commit_hash, + **adapter_kwargs, + ) if _adapter_model_path is not None and os.path.isfile(_adapter_model_path): with open(_adapter_model_path, "r", encoding="utf-8") as f: _adapter_model_path = pretrained_model_name_or_path @@ -3554,7 +3574,6 @@ def from_pretrained( if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True - # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path @@ -3571,6 +3590,7 @@ def from_pretrained( subfolder=subfolder, _from_auto=from_auto_class, _from_pipeline=from_pipeline, + dduf_entries=dduf_entries, **kwargs, ) else: @@ -3636,36 +3656,69 @@ def from_pretrained( raise ValueError( "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub." ) - if pretrained_model_name_or_path is not None and gguf_file is None: pretrained_model_name_or_path = str(pretrained_model_name_or_path) - is_local = os.path.isdir(pretrained_model_name_or_path) + # passing a dduf_entries means that we already knows where the file + is_local = os.path.isdir(pretrained_model_name_or_path) or dduf_entries if is_local: - if from_tf and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + if from_tf and ( + os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")) + or ( + dduf_entries + and os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + in dduf_entries + ) ): # Load from a TF 1.0 checkpoint in priority if from_tf archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") - elif from_tf and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + elif from_tf and ( + os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)) + or ( + dduf_entries + and os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) in dduf_entries + ) ): # Load from a TF 2.0 checkpoint in priority if from_tf archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) - elif from_flax and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + elif from_flax and ( + os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)) + or ( + dduf_entries + and os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) in dduf_entries + ) ): # Load from a Flax checkpoint in priority if from_flax archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) - elif use_safetensors is not False and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) + elif use_safetensors is not False and ( + os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) + ) + ) + or ( + dduf_entries + and os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) + ) + in dduf_entries + ) ): # Load from a safetensors checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) ) - elif use_safetensors is not False and os.path.isfile( - os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + elif use_safetensors is not False and ( + os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + ) + or ( + dduf_entries + and os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + in dduf_entries ) ): # Load from a sharded safetensors checkpoint @@ -3673,15 +3726,33 @@ def from_pretrained( pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) ) is_sharded = True - elif not use_safetensors and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) + elif not use_safetensors and ( + os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) + ) + or ( + dduf_entries + and os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) + in dduf_entries + ) ): # Load from a PyTorch checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) ) - elif not use_safetensors and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) + elif not use_safetensors and ( + os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + ) + or ( + dduf_entries + and os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + in dduf_entries + ) ): # Load from a sharded PyTorch checkpoint archive_file = os.path.join( @@ -3692,14 +3763,26 @@ def from_pretrained( elif not use_safetensors and ( os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)) + or ( + dduf_entries + and ( + os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + in dduf_entries + or os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) in dduf_entries + ) + ) ): raise EnvironmentError( f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use" " `from_tf=True` to load this model from those weights." ) - elif not use_safetensors and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + elif not use_safetensors and ( + os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)) + or ( + dduf_entries + and os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) in dduf_entries + ) ): raise EnvironmentError( f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" @@ -3936,6 +4019,7 @@ def from_pretrained( revision=revision, subfolder=subfolder, _commit_hash=commit_hash, + dduf_entries=dduf_entries, ) if ( @@ -3943,8 +4027,14 @@ def from_pretrained( and isinstance(resolved_archive_file, str) and resolved_archive_file.endswith(".safetensors") ): - with safe_open(resolved_archive_file, framework="pt") as f: - metadata = f.metadata() + if dduf_entries: + # TODO: Find a way to better deal with that. We shouldn't have to read the entire file + metadata = {"format": "pt"} + # with dduf_entries[resolved_archive_file].as_mmap() as mm: + # metadata = safetensors.torch.load(mm).metadata() + else: + with safe_open(resolved_archive_file, framework="pt") as f: + metadata = f.metadata() if metadata is None: # Assume it's a pytorch checkpoint (introduced for timm checkpoints) @@ -3972,7 +4062,9 @@ def from_pretrained( if from_pt: if not is_sharded and state_dict is None: # Time to load the checkpoint - state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only) + state_dict = load_state_dict( + resolved_archive_file, weights_only=weights_only, dduf_entries=dduf_entries + ) # set dtype to instantiate the model under: # 1. If torch_dtype is not None, we use that dtype @@ -3993,7 +4085,9 @@ def from_pretrained( elif not is_sharded: torch_dtype = get_state_dict_dtype(state_dict) else: - one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only) + one_state_dict = load_state_dict( + resolved_archive_file[0], weights_only=weights_only, dduf_entries=dduf_entries + ) torch_dtype = get_state_dict_dtype(one_state_dict) del one_state_dict # free CPU memory logger.info( @@ -4218,6 +4312,7 @@ def from_pretrained( keep_in_fp32_modules=keep_in_fp32_modules, gguf_path=gguf_path, weights_only=weights_only, + dduf_entries=dduf_entries, ) # make sure token embedding weights are still tied if needed @@ -4400,6 +4495,7 @@ def _load_pretrained_model( keep_in_fp32_modules=None, gguf_path=None, weights_only=True, + dduf_entries=None, ): is_safetensors = False is_quantized = hf_quantizer is not None @@ -4742,7 +4838,11 @@ def _find_mismatched_keys( ): map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) state_dict = load_state_dict( - shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only + shard_file, + is_quantized=is_quantized, + map_location=map_location, + weights_only=weights_only, + dduf_entries=dduf_entries, ) # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not diff --git a/src/transformers/models/clip/tokenization_clip.py b/src/transformers/models/clip/tokenization_clip.py index 41a73db8c1ec..1d7376ca4ce8 100644 --- a/src/transformers/models/clip/tokenization_clip.py +++ b/src/transformers/models/clip/tokenization_clip.py @@ -291,6 +291,8 @@ def __init__( pad_token="<|endoftext|>", # hack to enable padding **kwargs, ): + dduf_entries = kwargs.get("dduf_entries", None) + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token @@ -302,15 +304,21 @@ def __init__( logger.info("ftfy or spacy is not installed using custom BasicTokenizer instead of ftfy.") self.nlp = BasicTokenizer(strip_accents=False, do_split_on_punc=False) self.fix_text = None - - with open(vocab_file, encoding="utf-8") as vocab_handle: - self.encoder = json.load(vocab_handle) + if dduf_entries: + self.encoder = json.loads(dduf_entries[vocab_file].read_text()) + else: + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) self.decoder = {v: k for k, v in self.encoder.items()} self.errors = errors # how to handle errors in decoding self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - with open(merges_file, encoding="utf-8") as merges_handle: - bpe_merges = merges_handle.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1] + if dduf_entries: + bpe_merges = dduf_entries[merges_file].read_text().strip().split("\n")[1 : 49152 - 256 - 2 + 1] + else: + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"} diff --git a/src/transformers/models/t5/tokenization_t5.py b/src/transformers/models/t5/tokenization_t5.py index 1e166a78f10d..efd6947abce8 100644 --- a/src/transformers/models/t5/tokenization_t5.py +++ b/src/transformers/models/t5/tokenization_t5.py @@ -137,6 +137,7 @@ def __init__( add_prefix_space=True, **kwargs, ) -> None: + dduf_entries = kwargs.get("dduf_entries", None) pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token @@ -145,9 +146,12 @@ def __init__( self.vocab_file = vocab_file self._extra_ids = extra_ids - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(vocab_file) + if dduf_entries: + with dduf_entries[self.vocab_file].as_mmap() as mm: + self.sp_model.load_from_serialized_proto(mm) + else: + self.sp_model.Load(vocab_file) if additional_special_tokens is not None: extra_tokens = [x for x in additional_special_tokens if " bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False + # TODO: update this. Putting it to True for now + # return os.path.isfile(self.vocab_file) if self.vocab_file else False + return True @staticmethod def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length): @@ -170,9 +173,15 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) logger.info(f"Copy vocab file to {out_vocab_file}") + # copyfile don't work with binary content e.g when we load file from an archive + elif not os.path.isfile(self.vocab_file): + with self.dduf_entries[self.vocab_file].as_mmap() as mm: + with open(out_vocab_file, "wb") as out_file: + out_file.write(mm) + logger.info(f"Copy vocab file to {out_vocab_file}") return (out_vocab_file,) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index f4e5b9b3aaf3..9133e1717fb9 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1894,6 +1894,7 @@ def from_pretrained( from_auto_class = kwargs.pop("_from_auto", False) commit_hash = kwargs.pop("_commit_hash", None) gguf_file = kwargs.get("gguf_file", None) + dduf_entries = kwargs.get("dduf_entries", None) if use_auth_token is not None: warnings.warn( @@ -1952,36 +1953,51 @@ def from_pretrained( if "tokenizer_file" in vocab_files: # Try to get the tokenizer config to see if there are versioned tokenizer files. fast_tokenizer_file = FULL_TOKENIZER_FILE - resolved_config_file = cached_file( - pretrained_model_name_or_path, - TOKENIZER_CONFIG_FILE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - revision=revision, - local_files_only=local_files_only, - subfolder=subfolder, - user_agent=user_agent, - _raise_exceptions_for_gated_repo=False, - _raise_exceptions_for_missing_entries=False, - _raise_exceptions_for_connection_errors=False, - _commit_hash=commit_hash, - ) - commit_hash = extract_commit_hash(resolved_config_file, commit_hash) - if resolved_config_file is not None: - with open(resolved_config_file, encoding="utf-8") as reader: - tokenizer_config = json.load(reader) - if "fast_tokenizer_files" in tokenizer_config: - fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) + if dduf_entries: + tokenizer_config = json.loads( + dduf_entries[ + os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE) + ].read_text() + ) + if "fast_tokenizer_files" in tokenizer_config: + fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) + else: + resolved_config_file = cached_file( + pretrained_model_name_or_path, + TOKENIZER_CONFIG_FILE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + subfolder=subfolder, + user_agent=user_agent, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + _commit_hash=commit_hash, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + if resolved_config_file is not None: + with open(resolved_config_file, encoding="utf-8") as reader: + tokenizer_config = json.load(reader) + if "fast_tokenizer_files" in tokenizer_config: + fast_tokenizer_file = get_fast_tokenizer_file( + tokenizer_config["fast_tokenizer_files"] + ) vocab_files["tokenizer_file"] = fast_tokenizer_file # Get files from url, cache, or disk depending on the case resolved_vocab_files = {} unresolved_files = [] for file_id, file_path in vocab_files.items(): - if file_path is None: + if dduf_entries: + # We don't necessarily have the file + if os.path.join(pretrained_model_name_or_path, file_path) in dduf_entries: + resolved_vocab_files[file_id] = os.path.join(pretrained_model_name_or_path, file_path) + elif file_path is None: resolved_vocab_files[file_id] = None elif single_file_id == file_id: if os.path.isfile(file_path): @@ -2007,6 +2023,10 @@ def from_pretrained( _commit_hash=commit_hash, ) commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash) + # if dduf_entries: + # # preload files that are going to be opened in the tokenizer so that we don't need to modify each tokenizer with dduf ? + # for file in cls.vocab_files_names: + # # e.g vocab_file, merges ... if len(unresolved_files) > 0: logger.info( @@ -2066,6 +2086,7 @@ def _from_pretrained( # file or if `from_slow` is set to True. from_slow = kwargs.get("from_slow", False) gguf_file = kwargs.get("gguf_file", None) + dduf_entries = kwargs.get("dduf_entries", None) has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None # If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be @@ -2089,8 +2110,11 @@ def _from_pretrained( # Did we saved some inputs and kwargs to reload ? tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None) if tokenizer_config_file is not None: - with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: - init_kwargs = json.load(tokenizer_config_handle) + if dduf_entries: + init_kwargs = json.loads(dduf_entries[tokenizer_config_file].read_text()) + else: + with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: + init_kwargs = json.load(tokenizer_config_handle) # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers. config_tokenizer_class = init_kwargs.get("tokenizer_class") init_kwargs.pop("tokenizer_class", None) @@ -2106,8 +2130,12 @@ def _from_pretrained( # If an independent chat template file exists, it takes priority over template entries in the tokenizer config chat_template_file = resolved_vocab_files.pop("chat_template_file", None) if chat_template_file is not None: - with open(chat_template_file) as chat_template_handle: - init_kwargs["chat_template"] = chat_template_handle.read() # Clobbers any template in the config + if dduf_entries: + # TODO: check that it works + init_kwargs["chat_template"] = dduf_entries[chat_template_file].read_text() + else: + with open(chat_template_file) as chat_template_handle: + init_kwargs["chat_template"] = chat_template_handle.read() # Clobbers any template in the config if not _is_local: if "auto_map" in init_kwargs: @@ -2207,26 +2235,29 @@ def _from_pretrained( else: # begin legacy: read the added_tokens_file and update kwargs with special_tokens_map if modified if special_tokens_map_file is not None: - with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle: - special_tokens_map = json.load(special_tokens_map_handle) - for key, value in special_tokens_map.items(): - if key in kwargs and kwargs[key]: - # This value has already been redefined by the kwargs - # We keep this new value and ignore the one stored in the special_tokens_map_file - continue - if isinstance(value, dict): - value["special"] = True - value = AddedToken(**value) - elif key == "additional_special_tokens" and isinstance(value, list): - additional_special_tokens = init_kwargs.pop("additional_special_tokens", []) or [] - for token in value: - if isinstance(token, dict): - token["special"] = True - token = AddedToken(**token) - if token not in additional_special_tokens: - additional_special_tokens.append(token) - value = additional_special_tokens - init_kwargs[key] = value + if dduf_entries: + special_tokens_map = json.loads(dduf_entries[special_tokens_map_file].read_text()) + else: + with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle: + special_tokens_map = json.load(special_tokens_map_handle) + for key, value in special_tokens_map.items(): + if key in kwargs and kwargs[key]: + # This value has already been redefined by the kwargs + # We keep this new value and ignore the one stored in the special_tokens_map_file + continue + if isinstance(value, dict): + value["special"] = True + value = AddedToken(**value) + elif key == "additional_special_tokens" and isinstance(value, list): + additional_special_tokens = init_kwargs.pop("additional_special_tokens", []) or [] + for token in value: + if isinstance(token, dict): + token["special"] = True + token = AddedToken(**token) + if token not in additional_special_tokens: + additional_special_tokens.append(token) + value = additional_special_tokens + init_kwargs[key] = value # slow -> slow|fast, legacy: convert the `"added_tokens.json"` file to `added_tokens_decoder`. # this is for legacy purpose. We don't add the tokens after init for efficiency. @@ -2239,8 +2270,11 @@ def _from_pretrained( else: special_tokens.append(str(init_kwargs[key])) - with open(added_tokens_file, encoding="utf-8") as added_tokens_handle: - added_tok_encoder = json.load(added_tokens_handle) + if dduf_entries: + added_tok_encoder = json.loads(dduf_entries[added_tokens_file].read_text()) + else: + with open(added_tokens_file, encoding="utf-8") as added_tokens_handle: + added_tok_encoder = json.load(added_tokens_handle) for str_token, index in added_tok_encoder.items(): # if index not in added_tokens_decoder and str_token not in added_tokens_map: special = str_token in special_tokens @@ -2253,9 +2287,12 @@ def _from_pretrained( # if `tokenizer_config.json` is `None` if tokenizer_file is not None: # This is for slow so can be done before - with open(tokenizer_file, encoding="utf-8") as tokenizer_file_handle: - tokenizer_file_handle = json.load(tokenizer_file_handle) - added_tokens = tokenizer_file_handle.pop("added_tokens") + if dduf_entries: + tokenizer_file_handle = json.loads(dduf_entries[tokenizer_file].read_text()) + else: + with open(tokenizer_file, encoding="utf-8") as tokenizer_file_handle: + tokenizer_file_handle = json.load(tokenizer_file_handle) + added_tokens = tokenizer_file_handle.pop("added_tokens") for serialized_tokens in added_tokens: idx = serialized_tokens.pop("id") added_tokens_decoder[idx] = AddedToken(**serialized_tokens) @@ -2482,6 +2519,8 @@ def save_pretrained( tokenizer_config.pop("tokenizer_file", None) if "device_map" in tokenizer_config: tokenizer_config.pop("device_map") + if "dduf_entries" in tokenizer_config: + tokenizer_config.pop("dduf_entries", None) with open(tokenizer_config_file, "w", encoding="utf-8") as f: out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n" diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py index d1353adfd225..c7e690e38fc7 100644 --- a/src/transformers/tokenization_utils_fast.py +++ b/src/transformers/tokenization_utils_fast.py @@ -108,7 +108,6 @@ def __init__(self, *args, **kwargs): "Cannot instantiate this tokenizer from a slow version. If it's based on sentencepiece, make sure you " "have sentencepiece installed." ) - if tokenizer_object is not None: fast_tokenizer = copy.deepcopy(tokenizer_object) elif fast_tokenizer_file is not None and not from_slow: @@ -130,7 +129,7 @@ def __init__(self, *args, **kwargs): elif self.slow_tokenizer_class is not None and slow_tokenizer is not False: # We need to create and convert a slow tokenizer to build the backend slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs) - fast_tokenizer = convert_slow_tokenizer(slow_tokenizer) + fast_tokenizer = convert_slow_tokenizer(slow_tokenizer, **kwargs) elif not slow_tokenizer: # We tried loading a slow_tokenizer with spm and failed, try to load with tiktoken self.vocab_file = kwargs.get("vocab_file", None) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index b7194ec579da..5efbb42d268e 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -1044,6 +1044,7 @@ def get_checkpoint_shard_files( revision=None, subfolder="", _commit_hash=None, + dduf_entries=None, **deprecated_kwargs, ): """ @@ -1068,22 +1069,23 @@ def get_checkpoint_shard_files( raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") token = use_auth_token - if not os.path.isfile(index_filename): - raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") - - with open(index_filename, "r") as f: - index = json.loads(f.read()) + if dduf_entries: + index = json.loads(dduf_entries[index_filename].read_text()) + else: + if not os.path.isfile(index_filename): + raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") + with open(index_filename, "r") as f: + index = json.loads(f.read()) shard_filenames = sorted(set(index["weight_map"].values())) sharded_metadata = index["metadata"] sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) sharded_metadata["weight_map"] = index["weight_map"].copy() - # First, let's deal with local folder. - if os.path.isdir(pretrained_model_name_or_path): + # First, let's deal with local folder and dduf + if os.path.isdir(pretrained_model_name_or_path) or dduf_entries: shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames] return shard_filenames, sharded_metadata - # At this stage pretrained_model_name_or_path is a model identifier on the Hub cached_filenames = [] # Check if the model is already cached or not. We only try the last checkpoint, this should cover most cases of