2222
2323import torch
2424import torch .distributed as dist
25- from huggingface_hub import file_exists
2625from packaging import version
2726from torch import nn
2827
@@ -392,23 +391,20 @@ def load_custom_generate(
392391 Returns:
393392 A callable that can be used to generate text.
394393 """
395- # Does `pretrained_model_name_or_path` have a `custom_generate` subdirectory? If not -> OSError
396- is_local_code = os .path .exists (pretrained_model_name_or_path )
397- has_custom_generate_folder = True
398- if is_local_code :
399- if not os .path .exists (os .path .join (pretrained_model_name_or_path , "custom_generate/generate.py" )):
400- has_custom_generate_folder = False
401- else :
402- if not file_exists (pretrained_model_name_or_path , "custom_generate/generate.py" ):
403- has_custom_generate_folder = False
404-
405- if not has_custom_generate_folder :
394+ # Fetches the generate.py file from the model repo. If it doesn't exist, a file in `.no_exist` cache directory
395+ # is created (preventing future hub requests), and an OSError is raised.
396+ try :
397+ module = get_cached_module_file (
398+ pretrained_model_name_or_path , module_file = "custom_generate/generate.py" , ** kwargs
399+ )
400+ except OSError :
406401 raise OSError (
407402 f"`{ pretrained_model_name_or_path } ` does not contain a `custom_generate` subdirectory with a "
408403 "`generate.py` file, can't load the custom generate function."
409404 )
410405
411406 # Handle opt-in `trust_remote_code` and related exceptions
407+ is_local_code = os .path .exists (pretrained_model_name_or_path )
412408 error_message = (
413409 f"The repository `{ pretrained_model_name_or_path } ` contains custom generation code that will override "
414410 "the default `generate` method."
@@ -425,9 +421,6 @@ def load_custom_generate(
425421 check_python_requirements (
426422 pretrained_model_name_or_path , requirements_file = "custom_generate/requirements.txt" , ** kwargs
427423 )
428- module = get_cached_module_file (
429- pretrained_model_name_or_path , module_file = "custom_generate/generate.py" , ** kwargs
430- )
431424 custom_generate_function = get_class_in_module ("generate" , module )
432425 return custom_generate_function
433426
0 commit comments