Skip to content

Commit 7228287

Browse files
IrisRainbowNekoIrisRainbowNekogithub-actions[bot]
authored
Add low_cpu_mem_usage option to from_single_file to align with from_pretrained (#12114)
* align meta device of from_single_file with from_pretrained * update docstr * Apply style fixes --------- Co-authored-by: IrisRainbowNeko <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 3552279 commit 7228287

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from .. import __version__
2525
from ..quantizers import DiffusersAutoQuantizer
26-
from ..utils import deprecate, is_accelerate_available, logging
26+
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
2727
from ..utils.torch_utils import empty_device_cache
2828
from .single_file_utils import (
2929
SingleFileComponentError,
@@ -64,6 +64,10 @@
6464

6565
from ..models.modeling_utils import load_model_dict_into_meta
6666

67+
if is_torch_version(">=", "1.9.0") and is_accelerate_available():
68+
_LOW_CPU_MEM_USAGE_DEFAULT = True
69+
else:
70+
_LOW_CPU_MEM_USAGE_DEFAULT = False
6771

6872
SINGLE_FILE_LOADABLE_CLASSES = {
6973
"StableCascadeUNet": {
@@ -236,6 +240,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
236240
revision (`str`, *optional*, defaults to `"main"`):
237241
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
238242
allowed by Git.
243+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 and
244+
is_accelerate_available() else `False`): Speed up model loading only loading the pretrained weights and
245+
not initializing the weights. This also tries to not use more than 1x model size in CPU memory
246+
(including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using
247+
an older version of PyTorch, setting this argument to `True` will raise an error.
239248
disable_mmap ('bool', *optional*, defaults to 'False'):
240249
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
241250
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -285,6 +294,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
285294
config_revision = kwargs.pop("config_revision", None)
286295
torch_dtype = kwargs.pop("torch_dtype", None)
287296
quantization_config = kwargs.pop("quantization_config", None)
297+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
288298
device = kwargs.pop("device", None)
289299
disable_mmap = kwargs.pop("disable_mmap", False)
290300

@@ -389,7 +399,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
389399
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
390400
diffusers_model_config.update(model_kwargs)
391401

392-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
402+
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
393403
with ctx():
394404
model = cls.from_config(diffusers_model_config)
395405

@@ -427,7 +437,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
427437
)
428438

429439
device_map = None
430-
if is_accelerate_available():
440+
if low_cpu_mem_usage:
431441
param_device = torch.device(device) if device else torch.device("cpu")
432442
empty_state_dict = model.state_dict()
433443
unexpected_keys = [

0 commit comments

Comments
 (0)