Skip to content

Commit 57cd26e

Browse files
committed
add
1 parent c222570 commit 57cd26e

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing_extensions import Self
2323

2424
from .. import __version__
25+
from ..models.model_loading_utils import _caching_allocator_warmup, _determine_device_map, _expand_device_map
2526
from ..quantizers import DiffusersAutoQuantizer
2627
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
2728
from ..utils.torch_utils import empty_device_cache
@@ -297,6 +298,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
297298
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
298299
device = kwargs.pop("device", None)
299300
disable_mmap = kwargs.pop("disable_mmap", False)
301+
device_map = kwargs.pop("device_map", None)
300302

301303
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
302304
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
@@ -403,19 +405,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
403405
with ctx():
404406
model = cls.from_config(diffusers_model_config)
405407

406-
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
407-
408-
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
409-
diffusers_format_checkpoint = checkpoint_mapping_fn(
410-
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
411-
)
412-
else:
413-
diffusers_format_checkpoint = checkpoint
408+
model_state_dict = model.state_dict()
414409

415-
if not diffusers_format_checkpoint:
416-
raise SingleFileComponentError(
417-
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
418-
)
419410
# Check if `_keep_in_fp32_modules` is not None
420411
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
421412
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
@@ -428,6 +419,26 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
428419
else:
429420
keep_in_fp32_modules = []
430421

422+
# Now that the model is loaded, we can determine the `device_map`
423+
device_map = _determine_device_map(model, device_map, None, torch_dtype, keep_in_fp32_modules, hf_quantizer)
424+
if device_map is not None:
425+
expanded_device_map = _expand_device_map(device_map, model_state_dict.keys())
426+
_caching_allocator_warmup(model, expanded_device_map, torch_dtype, hf_quantizer)
427+
428+
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
429+
430+
if _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
431+
diffusers_format_checkpoint = checkpoint_mapping_fn(
432+
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
433+
)
434+
else:
435+
diffusers_format_checkpoint = checkpoint
436+
437+
if not diffusers_format_checkpoint:
438+
raise SingleFileComponentError(
439+
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
440+
)
441+
431442
if hf_quantizer is not None:
432443
hf_quantizer.preprocess_model(
433444
model=model,

0 commit comments

Comments
 (0)