diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index dcb00715d59e..ecccf3c11311 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -23,7 +23,7 @@ from .. import __version__ from ..quantizers import DiffusersAutoQuantizer -from ..utils import deprecate, is_accelerate_available, logging +from ..utils import deprecate, is_accelerate_available, is_torch_version, logging from ..utils.torch_utils import empty_device_cache from .single_file_utils import ( SingleFileComponentError, @@ -64,6 +64,10 @@ from ..models.modeling_utils import load_model_dict_into_meta +if is_torch_version(">=", "1.9.0") and is_accelerate_available(): + _LOW_CPU_MEM_USAGE_DEFAULT = True +else: + _LOW_CPU_MEM_USAGE_DEFAULT = False SINGLE_FILE_LOADABLE_CLASSES = { "StableCascadeUNet": { @@ -236,6 +240,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 and + is_accelerate_available() else `False`): Speed up model loading only loading the pretrained weights and + not initializing the weights. This also tries to not use more than 1x model size in CPU memory + (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using + an older version of PyTorch, setting this argument to `True` will raise an error. disable_mmap ('bool', *optional*, defaults to 'False'): Whether to disable mmap when loading a Safetensors model. This option can perform better when the model is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. @@ -285,6 +294,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = config_revision = kwargs.pop("config_revision", None) torch_dtype = kwargs.pop("torch_dtype", None) quantization_config = kwargs.pop("quantization_config", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) device = kwargs.pop("device", None) disable_mmap = kwargs.pop("disable_mmap", False) @@ -389,7 +399,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} diffusers_model_config.update(model_kwargs) - ctx = init_empty_weights if is_accelerate_available() else nullcontext + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext with ctx(): model = cls.from_config(diffusers_model_config) @@ -427,7 +437,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = ) device_map = None - if is_accelerate_available(): + if low_cpu_mem_usage: param_device = torch.device(device) if device else torch.device("cpu") empty_state_dict = model.state_dict() unexpected_keys = [