Skip to content

Commit a80f05d

Browse files
authored
[generate] cache missing custom generate file (#41216)
* cache missing custom generate file * make fixup
1 parent 1f1e93e commit a80f05d

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

src/transformers/generation/utils.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import torch
2424
import torch.distributed as dist
25-
from huggingface_hub import file_exists
2625
from packaging import version
2726
from torch import nn
2827

@@ -392,23 +391,20 @@ def load_custom_generate(
392391
Returns:
393392
A callable that can be used to generate text.
394393
"""
395-
# Does `pretrained_model_name_or_path` have a `custom_generate` subdirectory? If not -> OSError
396-
is_local_code = os.path.exists(pretrained_model_name_or_path)
397-
has_custom_generate_folder = True
398-
if is_local_code:
399-
if not os.path.exists(os.path.join(pretrained_model_name_or_path, "custom_generate/generate.py")):
400-
has_custom_generate_folder = False
401-
else:
402-
if not file_exists(pretrained_model_name_or_path, "custom_generate/generate.py"):
403-
has_custom_generate_folder = False
404-
405-
if not has_custom_generate_folder:
394+
# Fetches the generate.py file from the model repo. If it doesn't exist, a file in `.no_exist` cache directory
395+
# is created (preventing future hub requests), and an OSError is raised.
396+
try:
397+
module = get_cached_module_file(
398+
pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs
399+
)
400+
except OSError:
406401
raise OSError(
407402
f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a "
408403
"`generate.py` file, can't load the custom generate function."
409404
)
410405

411406
# Handle opt-in `trust_remote_code` and related exceptions
407+
is_local_code = os.path.exists(pretrained_model_name_or_path)
412408
error_message = (
413409
f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override "
414410
"the default `generate` method."
@@ -425,9 +421,6 @@ def load_custom_generate(
425421
check_python_requirements(
426422
pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs
427423
)
428-
module = get_cached_module_file(
429-
pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs
430-
)
431424
custom_generate_function = get_class_in_module("generate", module)
432425
return custom_generate_function
433426

0 commit comments

Comments
 (0)