diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 8a05cce209c5..ccc5b11c2a7d 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -290,7 +290,7 @@ def outputs(self) -> List[OutputParam]: def from_pretrained( cls, pretrained_model_name_or_path: str, - trust_remote_code: Optional[bool] = None, + trust_remote_code: bool = False, **kwargs, ): hub_kwargs_names = [ diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 2d9e16f87e47..d9867fb8758d 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -45,6 +45,7 @@ DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES +DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 74ed240bf015..674eb65773f0 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -20,7 +20,6 @@ import os import re import shutil -import signal import sys import threading from pathlib import Path @@ -34,6 +33,7 @@ from .. import __version__ from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging +from .constants import DIFFUSERS_DISABLE_REMOTE_CODE logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -159,52 +159,25 @@ def check_imports(filename): return get_relative_imports(filename) -def _raise_timeout_error(signum, frame): - raise ValueError( - "Loading this model requires you to execute custom code contained in the model repository on your local " - "machine. Please set the option `trust_remote_code=True` to permit loading of this model." - ) - - def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code): - if trust_remote_code is None: - if has_remote_code and TIME_OUT_REMOTE_CODE > 0: - prev_sig_handler = None - try: - prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) - signal.alarm(TIME_OUT_REMOTE_CODE) - while trust_remote_code is None: - answer = input( - f"The repository for {model_name} contains custom code which must be executed to correctly " - f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" - f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" - f"Do you wish to run the custom code? [y/N] " - ) - if answer.lower() in ["yes", "y", "1"]: - trust_remote_code = True - elif answer.lower() in ["no", "n", "0", ""]: - trust_remote_code = False - signal.alarm(0) - except Exception: - # OS which does not support signal.SIGALRM - raise ValueError( - f"The repository for {model_name} contains custom code which must be executed to correctly " - f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" - f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." - ) - finally: - if prev_sig_handler is not None: - signal.signal(signal.SIGALRM, prev_sig_handler) - signal.alarm(0) - elif has_remote_code: - # For the CI which puts the timeout at 0 - _raise_timeout_error(None, None) + trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE + if DIFFUSERS_DISABLE_REMOTE_CODE: + logger.warning( + "Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`." + ) if has_remote_code and not trust_remote_code: - raise ValueError( - f"Loading {model_name} requires you to execute the configuration file in that" - " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" - " set the option `trust_remote_code=True` to remove this error." + error_msg = f"The repository for {model_name} contains custom code. " + error_msg += ( + "Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable." + if DIFFUSERS_DISABLE_REMOTE_CODE + else "Pass `trust_remote_code=True` to allow loading remote code modules." + ) + raise ValueError(error_msg) + + elif has_remote_code and trust_remote_code: + logger.warning( + f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository" ) return trust_remote_code