Skip to content

Commit 307a49c

Browse files
author
IrisRainbowNeko
committed
align meta device of from_single_file with from_pretrained
1 parent 03c3f69 commit 307a49c

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from .. import __version__
2525
from ..quantizers import DiffusersAutoQuantizer
26-
from ..utils import deprecate, is_accelerate_available, logging
26+
from ..utils import deprecate, is_accelerate_available, logging, is_torch_version
2727
from ..utils.torch_utils import empty_device_cache
2828
from .single_file_utils import (
2929
SingleFileComponentError,
@@ -64,6 +64,10 @@
6464

6565
from ..models.modeling_utils import load_model_dict_into_meta
6666

67+
if is_torch_version(">=", "1.9.0") and is_accelerate_available():
68+
_LOW_CPU_MEM_USAGE_DEFAULT = True
69+
else:
70+
_LOW_CPU_MEM_USAGE_DEFAULT = False
6771

6872
SINGLE_FILE_LOADABLE_CLASSES = {
6973
"StableCascadeUNet": {
@@ -285,6 +289,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
285289
config_revision = kwargs.pop("config_revision", None)
286290
torch_dtype = kwargs.pop("torch_dtype", None)
287291
quantization_config = kwargs.pop("quantization_config", None)
292+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
288293
device = kwargs.pop("device", None)
289294
disable_mmap = kwargs.pop("disable_mmap", False)
290295

@@ -389,7 +394,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
389394
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
390395
diffusers_model_config.update(model_kwargs)
391396

392-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
397+
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
393398
with ctx():
394399
model = cls.from_config(diffusers_model_config)
395400

@@ -427,7 +432,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
427432
)
428433

429434
device_map = None
430-
if is_accelerate_available():
435+
if low_cpu_mem_usage:
431436
param_device = torch.device(device) if device else torch.device("cpu")
432437
empty_state_dict = model.state_dict()
433438
unexpected_keys = [

0 commit comments

Comments
 (0)