|
23 | 23 |
|
24 | 24 | from .. import __version__ |
25 | 25 | from ..quantizers import DiffusersAutoQuantizer |
26 | | -from ..utils import deprecate, is_accelerate_available, logging |
| 26 | +from ..utils import deprecate, is_accelerate_available, is_torch_version, logging |
27 | 27 | from ..utils.torch_utils import empty_device_cache |
28 | 28 | from .single_file_utils import ( |
29 | 29 | SingleFileComponentError, |
|
64 | 64 |
|
65 | 65 | from ..models.modeling_utils import load_model_dict_into_meta |
66 | 66 |
|
| 67 | +if is_torch_version(">=", "1.9.0") and is_accelerate_available(): |
| 68 | + _LOW_CPU_MEM_USAGE_DEFAULT = True |
| 69 | +else: |
| 70 | + _LOW_CPU_MEM_USAGE_DEFAULT = False |
67 | 71 |
|
68 | 72 | SINGLE_FILE_LOADABLE_CLASSES = { |
69 | 73 | "StableCascadeUNet": { |
@@ -236,6 +240,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = |
236 | 240 | revision (`str`, *optional*, defaults to `"main"`): |
237 | 241 | The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier |
238 | 242 | allowed by Git. |
| 243 | + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 and |
| 244 | + is_accelerate_available() else `False`): Speed up model loading only loading the pretrained weights and |
| 245 | + not initializing the weights. This also tries to not use more than 1x model size in CPU memory |
| 246 | + (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using |
| 247 | + an older version of PyTorch, setting this argument to `True` will raise an error. |
239 | 248 | disable_mmap ('bool', *optional*, defaults to 'False'): |
240 | 249 | Whether to disable mmap when loading a Safetensors model. This option can perform better when the model |
241 | 250 | is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. |
@@ -285,6 +294,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = |
285 | 294 | config_revision = kwargs.pop("config_revision", None) |
286 | 295 | torch_dtype = kwargs.pop("torch_dtype", None) |
287 | 296 | quantization_config = kwargs.pop("quantization_config", None) |
| 297 | + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) |
288 | 298 | device = kwargs.pop("device", None) |
289 | 299 | disable_mmap = kwargs.pop("disable_mmap", False) |
290 | 300 |
|
@@ -389,7 +399,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = |
389 | 399 | model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} |
390 | 400 | diffusers_model_config.update(model_kwargs) |
391 | 401 |
|
392 | | - ctx = init_empty_weights if is_accelerate_available() else nullcontext |
| 402 | + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext |
393 | 403 | with ctx(): |
394 | 404 | model = cls.from_config(diffusers_model_config) |
395 | 405 |
|
@@ -427,7 +437,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = |
427 | 437 | ) |
428 | 438 |
|
429 | 439 | device_map = None |
430 | | - if is_accelerate_available(): |
| 440 | + if low_cpu_mem_usage: |
431 | 441 | param_device = torch.device(device) if device else torch.device("cpu") |
432 | 442 | empty_state_dict = model.state_dict() |
433 | 443 | unexpected_keys = [ |
|
0 commit comments