|
23 | 23 |
|
24 | 24 | from .. import __version__
|
25 | 25 | from ..quantizers import DiffusersAutoQuantizer
|
26 |
| -from ..utils import deprecate, is_accelerate_available, logging |
| 26 | +from ..utils import deprecate, is_accelerate_available, is_torch_version, logging |
27 | 27 | from ..utils.torch_utils import empty_device_cache
|
28 | 28 | from .single_file_utils import (
|
29 | 29 | SingleFileComponentError,
|
|
64 | 64 |
|
65 | 65 | from ..models.modeling_utils import load_model_dict_into_meta
|
66 | 66 |
|
| 67 | +if is_torch_version(">=", "1.9.0") and is_accelerate_available(): |
| 68 | + _LOW_CPU_MEM_USAGE_DEFAULT = True |
| 69 | +else: |
| 70 | + _LOW_CPU_MEM_USAGE_DEFAULT = False |
67 | 71 |
|
68 | 72 | SINGLE_FILE_LOADABLE_CLASSES = {
|
69 | 73 | "StableCascadeUNet": {
|
@@ -236,6 +240,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
|
236 | 240 | revision (`str`, *optional*, defaults to `"main"`):
|
237 | 241 | The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
238 | 242 | allowed by Git.
|
| 243 | + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 and |
| 244 | + is_accelerate_available() else `False`): Speed up model loading only loading the pretrained weights and |
| 245 | + not initializing the weights. This also tries to not use more than 1x model size in CPU memory |
| 246 | + (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using |
| 247 | + an older version of PyTorch, setting this argument to `True` will raise an error. |
239 | 248 | disable_mmap ('bool', *optional*, defaults to 'False'):
|
240 | 249 | Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
241 | 250 | is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
@@ -285,6 +294,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
|
285 | 294 | config_revision = kwargs.pop("config_revision", None)
|
286 | 295 | torch_dtype = kwargs.pop("torch_dtype", None)
|
287 | 296 | quantization_config = kwargs.pop("quantization_config", None)
|
| 297 | + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) |
288 | 298 | device = kwargs.pop("device", None)
|
289 | 299 | disable_mmap = kwargs.pop("disable_mmap", False)
|
290 | 300 |
|
@@ -389,7 +399,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
|
389 | 399 | model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
390 | 400 | diffusers_model_config.update(model_kwargs)
|
391 | 401 |
|
392 |
| - ctx = init_empty_weights if is_accelerate_available() else nullcontext |
| 402 | + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext |
393 | 403 | with ctx():
|
394 | 404 | model = cls.from_config(diffusers_model_config)
|
395 | 405 |
|
@@ -427,7 +437,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
|
427 | 437 | )
|
428 | 438 |
|
429 | 439 | device_map = None
|
430 |
| - if is_accelerate_available(): |
| 440 | + if low_cpu_mem_usage: |
431 | 441 | param_device = torch.device(device) if device else torch.device("cpu")
|
432 | 442 | empty_state_dict = model.state_dict()
|
433 | 443 | unexpected_keys = [
|
|
0 commit comments