Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def _get_config_dict(
commit_hash = kwargs.pop("_commit_hash", None)

gguf_file = kwargs.get("gguf_file", None)
dduf_entries = kwargs.pop("dduf_entries", None)

if trust_remote_code is True:
logger.warning(
Expand All @@ -634,7 +635,7 @@ def _get_config_dict(

pretrained_model_name_or_path = str(pretrained_model_name_or_path)

is_local = os.path.isdir(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path) or dduf_entries
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
# Special case when pretrained_model_name_or_path is a local file
resolved_config_file = pretrained_model_name_or_path
Expand All @@ -644,26 +645,31 @@ def _get_config_dict(
resolved_config_file = download_url(pretrained_model_name_or_path)
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file

try:
# Load from local folder or from cache or download from model Hub and cache
resolved_config_file = cached_file(
pretrained_model_name_or_path,
configuration_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
if resolved_config_file is None:
return None, kwargs
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
if dduf_entries:
resolved_config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
commit_hash = extract_commit_hash(dduf_entries[resolved_config_file].dduf_path, commit_hash)
else:
resolved_config_file = cached_file(
pretrained_model_name_or_path,
configuration_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
if resolved_config_file is None:
return None, kwargs
else:
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)

except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
Expand All @@ -682,7 +688,10 @@ def _get_config_dict(
config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
else:
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
if dduf_entries:
config_dict = json.loads(dduf_entries[resolved_config_file].read_text())
else:
config_dict = cls._dict_from_json_file(resolved_config_file)

config_dict["_commit_hash"] = commit_hash
except (json.JSONDecodeError, UnicodeDecodeError):
Expand Down
15 changes: 10 additions & 5 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,17 +544,22 @@ class SpmConverter(Converter):
SpmExtractor = SentencePieceExtractor
special_tokens = {}

def __init__(self, *args):
def __init__(self, *args, **kwargs):
requires_backends(self, "protobuf")

super().__init__(*args)
dduf_entries = kwargs.get("dduf_entries", None)

# from .utils import sentencepiece_model_pb2 as model_pb2
model_pb2 = import_protobuf()

m = model_pb2.ModelProto()
with open(self.original_tokenizer.vocab_file, "rb") as f:
m.ParseFromString(f.read())
if dduf_entries:
with dduf_entries[self.original_tokenizer.vocab_file].as_mmap() as mm:
m.ParseFromString(mm)
else:
with open(self.original_tokenizer.vocab_file, "rb") as f:
m.ParseFromString(f.read())
self.proto = m

if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
Expand Down Expand Up @@ -1606,7 +1611,7 @@ def converted(self) -> Tokenizer:
}


def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer:
def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False, **kwargs) -> Tokenizer:
"""
Utilities to convert a slow tokenizer instance in a fast tokenizer instance.

Expand All @@ -1625,7 +1630,7 @@ def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokeni
tokenizer_class_name = transformer_tokenizer.__class__.__name__
if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(transformer_tokenizer).converted()
return converter_class(transformer_tokenizer, **kwargs).converted()

else:
try:
Expand Down
Loading
Loading