|
23 | 23 |
|
24 | 24 | from .. import __version__ |
25 | 25 | from ..quantizers import DiffusersAutoQuantizer |
26 | | -from ..utils import deprecate, is_accelerate_available, logging |
| 26 | +from ..utils import deprecate, is_accelerate_available, logging, is_torch_version |
27 | 27 | from ..utils.torch_utils import empty_device_cache |
28 | 28 | from .single_file_utils import ( |
29 | 29 | SingleFileComponentError, |
|
64 | 64 |
|
65 | 65 | from ..models.modeling_utils import load_model_dict_into_meta |
66 | 66 |
|
| 67 | +if is_torch_version(">=", "1.9.0") and is_accelerate_available(): |
| 68 | + _LOW_CPU_MEM_USAGE_DEFAULT = True |
| 69 | +else: |
| 70 | + _LOW_CPU_MEM_USAGE_DEFAULT = False |
67 | 71 |
|
68 | 72 | SINGLE_FILE_LOADABLE_CLASSES = { |
69 | 73 | "StableCascadeUNet": { |
@@ -285,6 +289,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = |
285 | 289 | config_revision = kwargs.pop("config_revision", None) |
286 | 290 | torch_dtype = kwargs.pop("torch_dtype", None) |
287 | 291 | quantization_config = kwargs.pop("quantization_config", None) |
| 292 | + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) |
288 | 293 | device = kwargs.pop("device", None) |
289 | 294 | disable_mmap = kwargs.pop("disable_mmap", False) |
290 | 295 |
|
@@ -389,7 +394,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = |
389 | 394 | model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} |
390 | 395 | diffusers_model_config.update(model_kwargs) |
391 | 396 |
|
392 | | - ctx = init_empty_weights if is_accelerate_available() else nullcontext |
| 397 | + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext |
393 | 398 | with ctx(): |
394 | 399 | model = cls.from_config(diffusers_model_config) |
395 | 400 |
|
@@ -427,7 +432,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = |
427 | 432 | ) |
428 | 433 |
|
429 | 434 | device_map = None |
430 | | - if is_accelerate_available(): |
| 435 | + if low_cpu_mem_usage: |
431 | 436 | param_device = torch.device(device) if device else torch.device("cpu") |
432 | 437 | empty_state_dict = model.state_dict() |
433 | 438 | unexpected_keys = [ |
|
0 commit comments